Skip to content

Commit

Permalink
fix bug for index impl ascend p3
Browse files Browse the repository at this point in the history
  • Loading branch information
jingguo-st committed Aug 5, 2024
1 parent 8705e77 commit 2c40285
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
4 changes: 2 additions & 2 deletions impl/ascend/ascend_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ AscendTensor& AscendTensor::expand(std::vector<int64_t> shape) {
}
}

int64_t numElem = std::accumulate(newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
int64_t numElem = std::accumulate(newShape.begin(), newShape.end(), 1, std::multiplies<>());
std::vector<int64_t> newStride(expandDims, 0);
auto tStride = this->stride();
newStride.insert(newStride.end(), tStride.begin(), tStride.end());
Expand All @@ -152,7 +152,7 @@ AscendTensor& AscendTensor::expand(std::vector<int64_t> shape) {
}

AscendTensor& AscendTensor::resize(const std::vector<int64_t>& shape) {
int64_t numElem = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
int64_t numElem = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
std::vector<int64_t> stride(shape.size(), 1);
for (int64_t j = shape.size() - 2; j >= 0; j--) {
stride[j] = stride[j + 1] * shape[j + 1];
Expand Down
9 changes: 4 additions & 5 deletions impl/ascend/functions/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ static std::vector<AscendTensor> castIntIndicesToLongIndices(diopiContextHandle_
std::vector<AscendTensor> result;
for (auto& t : indices) {
if (!t.defined()) {
result.push_back(AscendTensor(nullptr));
result.emplace_back(nullptr);
continue;
}
if (t.dtype() == diopi_dtype_int32) {
Expand All @@ -25,10 +25,10 @@ static std::vector<AscendTensor> castIntIndicesToLongIndices(diopiContextHandle_
diopiSize_t size = vectorToDiopiSize(shape);
diopiRequireTensor(ctx, &indexHandle, &size, nullptr, diopi_dtype_int64, diopi_device);
DIOPI_ASCEND_CALL_ACLNN(aclnnCast, ctx, t, diopi_dtype_int64, indexHandle);
result.push_back(AscendTensor(indexHandle));
result.emplace_back(indexHandle);
} else {
if (t.device() == diopi_host) {
result.push_back(AscendTensor(hostToDevice(ctx, t.tensorHandle())));
result.emplace_back(hostToDevice(ctx, t.tensorHandle()));
} else {
result.emplace_back(t);
}
Expand Down Expand Up @@ -204,7 +204,7 @@ static std::vector<int64_t> indexReshape(std::vector<AscendTensor> endIndices, i
static std::vector<int64_t> indexOutputSize(const AscendTensor& self, std::vector<AscendTensor>& indices) {
std::vector<AscendTensor> midIndices = indicesExpandedOutplace(indices);
while (midIndices.size() < (size_t)self.dim()) {
midIndices.push_back(AscendTensor(nullptr));
midIndices.emplace_back(nullptr);
}

AscendTensor src = self;
Expand Down Expand Up @@ -278,7 +278,6 @@ diopiError_t diopiIndex(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diop
}
}

std::vector<int64_t> outShapeRef{34, 2, 6, 197};
std::vector<int64_t> outShape = indexOutputSize(inputAt, indicesExpanded);

diopiSize_t outSize = vectorToDiopiSize(outShape);
Expand Down

0 comments on commit 2c40285

Please sign in to comment.