Skip to content

Commit

Permalink
impl index for ascend p2
Browse files Browse the repository at this point in the history
  • Loading branch information
jingguo-st committed Aug 4, 2024
1 parent 5ce42c0 commit 8705e77
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 63 deletions.
67 changes: 67 additions & 0 deletions impl/ascend/ascend_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,73 @@ AscendTensor& AscendTensor::asStrided(const std::vector<int64_t>& shape, const s
return *this;
}

AscendTensor& AscendTensor::permute(std::vector<int64_t> dims) {
ASCEND_CHECK_ABORT(this->dim() == dims.size(), "permute dims does not match the tensor dims.");

std::vector<int64_t> newShape(dims.size(), 0);
std::vector<int64_t> newStride(dims.size(), 0);

for (size_t i = 0; i < dims.size(); i++) {
newShape[i] = this->shape(dims[i]);
newStride[i] = this->stride(dims[i]);
}

this->shape_ = newShape;
this->stride_ = newStride;

return *this;
}

AscendTensor& AscendTensor::expand(std::vector<int64_t> shape) {
ASCEND_CHECK_ABORT(shape.size() >= this->dim(),
"the number of sizes provided[% ld] must be greater or eaqual to the number of dimensions of the tensor[% ld].",
shape.size(),
this->dim());

// todo: dim() == 0
int64_t expandDims = shape.size() - this->shape().size();
std::vector<int64_t> tShapeExp(expandDims, 0);
auto tShape = this->shape();
tShapeExp.insert(tShapeExp.end(), tShape.begin(), tShape.end());
std::vector<int64_t> newShape = shape;

for (int64_t i = 0; i < newShape.size(); i++) {
if (newShape[i] < 0 && i < expandDims) {
ASCEND_CHECK_ABORT(false, "The expanded size of the tensor (%ld) isn't allowed in a leading, non-existing dimension %ld", newShape[i], i);
}

if (i >= expandDims) {
if (newShape[i] == -1) {
newShape[i] = tShapeExp[i];
} else {
ASCEND_CHECK_ABORT(tShapeExp[i] == 1 || newShape[i] == tShapeExp[i],
"The expanded size of the tensor (%ld) must match the existing size (%ld) at non-singleton dimension %ld.",
newShape[i],
tShapeExp[i],
i);
}
}
}

int64_t numElem = std::accumulate(newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
std::vector<int64_t> newStride(expandDims, 0);
auto tStride = this->stride();
newStride.insert(newStride.end(), tStride.begin(), tStride.end());
for (int64_t i = expandDims; i < shape.size(); i++) {
if (shape[i] == -1 || shape[i] == tShapeExp[i]) {
continue;
} else {
newStride[i] = 0;
}
}

this->numel_ = numElem;
this->shape_ = newShape;
this->stride_ = newStride;

return *this;
}

AscendTensor& AscendTensor::resize(const std::vector<int64_t>& shape) {
int64_t numElem = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
std::vector<int64_t> stride(shape.size(), 1);
Expand Down
2 changes: 2 additions & 0 deletions impl/ascend/ascend_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ class AscendTensor final {
AscendTensor& view(const std::vector<int64_t>& shape);
AscendTensor& resize(const std::vector<int64_t>& shape);
AscendTensor& select(int64_t dim, int64_t index);
AscendTensor& permute(std::vector<int64_t> dims);
AscendTensor& expand(std::vector<int64_t> shape);

private:
// diopi origin tensor
Expand Down
216 changes: 154 additions & 62 deletions impl/ascend/functions/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@
namespace impl {
namespace ascend {

static void printVec(std::vector<int64_t> vec, std::string msg = "") {
if (msg != "") {
std::cout << msg << ": ";
}
std::cout << "[ ";
for (auto i : vec) {
std::cout << i << " ";
}
std::cout << std::endl;
}

static std::vector<AscendTensor> castIntIndicesToLongIndices(diopiContextHandle_t ctx, std::vector<AscendTensor>& indices) {
std::vector<AscendTensor> result;
for (auto& t : indices) {
Expand Down Expand Up @@ -128,13 +117,6 @@ static std::vector<AscendTensor> expandIndicesTensors(diopiContextHandle_t ctx,
return result;
}

static AscendTensor emptyAscendTensor(const AscendTensor& self, std::vector<int64_t> shape) {
diopiTensorHandle_t empty = nullptr;
diopiSize_t size = vectorToDiopiSize(shape);

return AscendTensor(empty);
}

static aclTensor* createEmptyAclTensor() {
std::vector<int64_t> nShape{0};
std::vector<int64_t> nStride{1};
Expand All @@ -144,6 +126,134 @@ static aclTensor* createEmptyAclTensor() {
return ::aclCreateTensor(nShape.data(), nShape.size(), aclDataType::ACL_FLOAT16, nStride.data(), 0, aclFormat::ACL_FORMAT_ND, &storageSize, 0, storage);
}

static std::vector<AscendTensor> indicesExpandedOutplace(std::vector<AscendTensor> indices) {
bool first = true;
std::vector<int64_t> sizes;

for (auto& idx : indices) {
if (!idx.defined()) {
continue;
} else if (first) {
sizes = idx.shape();
first = false;
} else {
sizes = inferSize(sizes, idx.shape());
}
}

std::vector<AscendTensor> result;
for (auto& idx : indices) {
if (!idx.defined() || (idx.shape() == sizes)) {
result.push_back(idx);
} else {
result.push_back(idx.expand(sizes));
}
}
return result;
}

static bool hasContiguousSubspace(std::vector<AscendTensor> indices) { // true if all the non-null tensors are adjacent
auto isDefined = [](const AscendTensor& tensor) { return tensor.defined(); };
auto isNull = [](const AscendTensor& tensor) { return !tensor.defined(); };
auto start = std::find_if(indices.begin(), indices.end(), isDefined);
auto stop = std::find_if(indices.rbegin(), indices.rend(), isDefined);
auto it = std::find_if(start, stop.base(), isNull);
return it == stop.base();
}

static std::tuple<AscendTensor, std::vector<AscendTensor>> transposeToFront(AscendTensor self, std::vector<AscendTensor> indices) {
std::vector<int64_t> dims;
std::vector<AscendTensor> transposedIndices;

dims.reserve(self.dim());
for (int64_t i = 0; i < self.dim(); i++) {
if (indices[i].defined()) {
dims.push_back(i);
transposedIndices.push_back(indices[i]);
}
}

for (int64_t i = 0; i < self.dim(); i++) {
if (!indices[i].defined()) {
dims.push_back(i);
transposedIndices.push_back(indices[i]);
}
}

return std::make_tuple(self.permute(dims), transposedIndices);
}

static std::vector<int64_t> indexReshape(std::vector<AscendTensor> endIndices, int64_t dimsBefore, int64_t dimsAfter) {
std::vector<int64_t> indexShape;
for (auto& idx : endIndices) {
if (idx.defined()) {
std::vector<int64_t> shape;
shape.insert(shape.end(), dimsBefore, 1);
shape.insert(shape.end(), idx.shape().begin(), idx.shape().end());
shape.insert(shape.end(), dimsAfter, 1);
if (indexShape.empty()) {
indexShape = shape;
} else {
indexShape = inferSize(indexShape, shape);
}
}
}
return indexShape;
}

static std::vector<int64_t> indexOutputSize(const AscendTensor& self, std::vector<AscendTensor>& indices) {
std::vector<AscendTensor> midIndices = indicesExpandedOutplace(indices);
while (midIndices.size() < (size_t)self.dim()) {
midIndices.push_back(AscendTensor(nullptr));
}

AscendTensor src = self;
std::vector<AscendTensor> endIndices = midIndices;
if (!hasContiguousSubspace(midIndices)) {
endIndices.clear();
std::tie(src, endIndices) = transposeToFront(self, midIndices);
}

int64_t dimsBefore = 0;
int64_t dimsAfter = 0;
int64_t dimsIndexed = 0;

std::vector<int64_t> replaceShape;
std::vector<int64_t> indexedSizes;

for (size_t dim = 0; dim < endIndices.size(); dim++) {
if (!endIndices[dim].defined()) {
if (dimsIndexed == 0) {
dimsBefore++;
} else {
dimsAfter++;
}
} else {
dimsIndexed++;
replaceShape = endIndices[dim].shape();
indexedSizes.push_back(src.shape(dim));
}
}

if (std::find(indexedSizes.begin(), indexedSizes.end(), 0) != indexedSizes.end() &&
std::find(replaceShape.begin(), replaceShape.end(), 0) == replaceShape.end()) {
ASCEND_CHECK_ABORT(false, "index is out of bounds for dimension with size 0");
}

auto selfShape = src.shape();
int64_t end = dimsBefore + dimsIndexed;
selfShape.erase(selfShape.begin() + dimsBefore, selfShape.begin() + end);
selfShape.insert(selfShape.begin() + dimsBefore, replaceShape.begin(), replaceShape.end());

std::vector<int64_t> indexShape = indexReshape(endIndices, dimsBefore, dimsAfter);
std::vector<int64_t> outputSize = indexShape;
if (indexShape != selfShape) {
outputSize = inferSize(indexShape, selfShape);
}

return outputSize;
}

diopiError_t diopiIndex(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t* indices, int64_t nums) {
AscendTensor inputAt(input);
std::vector<AscendTensor> indicesOrigin(nums);
Expand All @@ -153,27 +263,11 @@ diopiError_t diopiIndex(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diop
}
}

// indices on Device: dipu_device
// nullptr tensor to AscendTensor(nullptr)
std::vector<AscendTensor> indicesList = castIntIndicesToLongIndices(ctx, indicesOrigin);

// check index tensor types
checkIndexTensorTypes(indicesList);

// expand tensors
auto indicesExpanded = expandIndicesTensors(ctx, inputAt, indicesList);

//
// correct until then
// std::vector<AscendTensor> allDefinedIndices;
// for (auto& it : indicesExpanded) {
// if (it.defined()) {
// allDefinedIndices.push_back(it);
// } else {
// allDefinedIndices.push_back(AscendTensor());
// }
// }

std::vector<aclTensor*> allDefinedIndices;
auto emptyTensor = createEmptyAclTensor();
for (const auto& idx : indicesExpanded) {
Expand All @@ -184,44 +278,42 @@ diopiError_t diopiIndex(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diop
}
}

// for (auto& t : indicesExpanded) {
// printContiguousTensor(ctx, t, "");
// }
std::vector<int64_t> outShapeRef{34, 2, 6, 197};
std::vector<int64_t> outShape = indexOutputSize(inputAt, indicesExpanded);

// output
std::vector<int64_t> outShape{34, 2, 6, 197};
diopiSize_t outSize = vectorToDiopiSize(outShape);
diopiRequireTensor(ctx, out, &outSize, nullptr, inputAt.dtype(), diopi_device);

DIOPI_ASCEND_CALL_ACLNN(aclnnIndex, ctx, inputAt, allDefinedIndices, *out);

// BEGIN_CALL_ACL_OP(input);
// torch::List<c10::optional<at::Tensor>> indicesAtList;
// indicesAtList.reserve(nums);
// for (int i = 0; i < nums; ++i) {
// indicesAtList.emplace_back(impl::aten::buildATen(indices[i]));
// }

// auto indicesCast = impl::aten::castIntIndicesToLongIndices(indicesAtList);
// at::Tensor outAt = op_api::index(inputAt, indicesCast);
// impl::aten::buildDiopiTensor(ctx, outAt, out);
// END_CALL_ACL_OP();
return diopiSuccess;
}

diopiError_t diopiIndexBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t zerosLikeInput, diopiConstTensorHandle_t* indices,
int64_t nums, diopiConstTensorHandle_t gradOutput) {
// BEGIN_CALL_ACL_OP(gradInput, zerosLikeInput, gradOutput);
// torch::List<c10::optional<at::Tensor>> indicesAtList;
// indicesAtList.reserve(nums);
// for (int i = 0; i < nums; ++i) {
// indicesAtList.emplace_back(impl::aten::buildATen(indices[i]));
// }

// auto indicesCast = impl::aten::castIntIndicesToLongIndices(indicesAtList);
// op_api::_index_put_impl_(zerosLikeInputAt, indicesCast, gradOutputAt, true, false);
// gradInputAt.copy_(zerosLikeInputAt);
// END_CALL_ACL_OP();
AscendTensor gradInputTensor(gradInput);
AscendTensor gradOutputTensor(gradOutput);
if (gradInputTensor.numel() == 0 || gradOutputTensor.numel() == 0) {
return diopiSuccess;
}

std::vector<diopiConstTensorHandle_t> indicesVec;
indicesVec.reserve(nums);

for (int i = 0; i < nums; i++) {
if (indices[i] != nullptr) {
indicesVec.emplace_back(indices[i]);
} else {
int64_t array[1] = {0};
diopiSize_t size = {array, 1};
diopiTensorHandle_t emptyTensor = nullptr;
diopiRequireTensor(ctx, &emptyTensor, &size, nullptr, gradOutputTensor.dtype(), diopi_device);
indicesVec.emplace_back(emptyTensor);
}
}

DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceCopy, ctx, gradInput, zerosLikeInput);
DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, gradInput, indicesVec, gradOutput, true, false);

return diopiSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion impl/ascend_npu/ascend_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ ascend:
- diopiHardtanhBackward
- diopiHardtanhInp
- diopiIndex
- diopiIndexBackward
- diopiIndexPut
- diopiIndexPutInp
- diopiIndexSelect
Expand Down Expand Up @@ -266,7 +267,6 @@ ascend_npu:
- diopiApplyPenalty
- diopiContextAttentionInference
- diopiGetNativeMemoryFormat
- diopiIndexBackward
- diopiNLLLoss
- diopiNLLLossBackward
- diopiNLLLossV2
Expand Down

0 comments on commit 8705e77

Please sign in to comment.