Skip to content

Commit

Permalink
[Ascend] fuj/acl-index (#1332)
Browse files Browse the repository at this point in the history
* impl index for aclnn p1

* impl index for ascend p2

* fix bug for index impl ascend p3

* fix a warning bug for index
  • Loading branch information
jingguo-st authored Aug 9, 2024
1 parent fdf9527 commit e13aea4
Show file tree
Hide file tree
Showing 6 changed files with 438 additions and 3 deletions.
10 changes: 9 additions & 1 deletion impl/ascend/aclnn/adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ struct IsBoolStdArray<std::array<bool, N>> : std::true_type {};

inline aclIntArray* createAclIntArrayFromIntVector(const std::vector<int64_t>& vec) { return ::aclCreateIntArray(vec.data(), vec.size()); }

inline aclTensorList* createAclTensorListFromAclTensorVector(const std::vector<aclTensor*>& tensorsVec) {
return ::aclCreateTensorList(tensorsVec.data(), tensorsVec.size());
}

inline aclTensorList* createAclTensorListFromAscendTensorVector(const std::vector<AscendTensor>& tensorsVec) {
std::vector<const aclTensor*> tList(tensorsVec.size());
for (size_t i = 0; i < tensorsVec.size(); i++) {
Expand All @@ -175,7 +179,11 @@ inline aclTensorList* createAclTensorListFromConstDiopiTensorVector(const std::v

template <class T, class U = std::remove_cv_t<std::remove_reference_t<T>>>
decltype(auto) convertType(T&& param) {
if constexpr (std::is_same_v<U, AscendTensor>) {
if constexpr (std::is_same_v<U, aclTensor*>) {
return std::forward<T>(param);
} else if constexpr (std::is_same_v<U, std::vector<aclTensor*>>) {
return createAclTensorListFromAclTensorVector(std::forward<T>(param));
} else if constexpr (std::is_same_v<U, AscendTensor>) {
return createAclTensorFromAscendTensor(std::forward<T>(param));
} else if constexpr (std::is_same_v<U, diopiTensorHandle_t> || std::is_same_v<U, diopiConstTensorHandle_t>) {
return createAclTensorFromDiopiTensor(std::forward<T>(param));
Expand Down
102 changes: 102 additions & 0 deletions impl/ascend/ascend_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

#include "ascend_tensor.hpp"

// #include <algorithm>
#include <array>
#include <cstdint>
#include <mutex>
#include <numeric>
#include <utility>

#include "common/debug.hpp"
Expand Down Expand Up @@ -82,6 +84,106 @@ AscendTensor& AscendTensor::asStrided(const std::vector<int64_t>& shape, const s
return *this;
}

AscendTensor& AscendTensor::permute(std::vector<int64_t> dims) {
ASCEND_CHECK_ABORT(this->dim() == dims.size(), "permute dims does not match the tensor dims.");

std::vector<int64_t> newShape(dims.size(), 0);
std::vector<int64_t> newStride(dims.size(), 0);

for (size_t i = 0; i < dims.size(); i++) {
newShape[i] = this->shape(dims[i]);
newStride[i] = this->stride(dims[i]);
}

this->shape_ = newShape;
this->stride_ = newStride;

return *this;
}

AscendTensor& AscendTensor::expand(std::vector<int64_t> shape) {
ASCEND_CHECK_ABORT(shape.size() >= this->dim(),
"the number of sizes provided[% ld] must be greater or eaqual to the number of dimensions of the tensor[% ld].",
shape.size(),
this->dim());

// todo: dim() == 0
int64_t expandDims = shape.size() - this->shape().size();
std::vector<int64_t> tShapeExp(expandDims, 0);
auto tShape = this->shape();
tShapeExp.insert(tShapeExp.end(), tShape.begin(), tShape.end());
std::vector<int64_t> newShape = shape;

for (int64_t i = 0; i < newShape.size(); i++) {
if (newShape[i] < 0 && i < expandDims) {
ASCEND_CHECK_ABORT(false, "The expanded size of the tensor (%ld) isn't allowed in a leading, non-existing dimension %ld", newShape[i], i);
}

if (i >= expandDims) {
if (newShape[i] == -1) {
newShape[i] = tShapeExp[i];
} else {
ASCEND_CHECK_ABORT(tShapeExp[i] == 1 || newShape[i] == tShapeExp[i],
"The expanded size of the tensor (%ld) must match the existing size (%ld) at non-singleton dimension %ld.",
newShape[i],
tShapeExp[i],
i);
}
}
}

int64_t numElem = std::accumulate(newShape.begin(), newShape.end(), 1, std::multiplies<>());
std::vector<int64_t> newStride(expandDims, 0);
auto tStride = this->stride();
newStride.insert(newStride.end(), tStride.begin(), tStride.end());
for (int64_t i = expandDims; i < shape.size(); i++) {
if (shape[i] == -1 || shape[i] == tShapeExp[i]) {
continue;
} else {
newStride[i] = 0;
}
}

this->numel_ = numElem;
this->shape_ = newShape;
this->stride_ = newStride;

return *this;
}

AscendTensor& AscendTensor::resize(const std::vector<int64_t>& shape) {
int64_t numElem = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
std::vector<int64_t> stride(shape.size(), 1);
for (int64_t j = shape.size() - 2; j >= 0; j--) {
stride[j] = stride[j + 1] * shape[j + 1];
}

this->numel_ = numElem;
this->shape_ = shape;
this->stride_ = stride;

return *this;
}
AscendTensor& AscendTensor::select(int64_t dim, int64_t index) {
auto shape = this->shape();
auto stride = this->stride();

ASCEND_CHECK_ABORT(dim >= 0 && dim < shape.size(), "selected dim [%ld] execeed the tensor dims [%ld].", dim, shape.size());

if (dim < shape.size() - 1) {
int64_t offset = dim * shape[dim] * stride[dim];
this->storageOffset_ = offset;
}
this->numel_ /= shape[dim];

shape.erase(shape.begin() + dim);
stride.erase(stride.begin() + dim);
this->shape_ = shape;
this->stride_ = stride;

return *this;
}

AscendTensor& AscendTensor::unsqueeze(int dim) {
// Note: `channels_last` tensor uses this will become uncontiguous
// which is same with pytorch
Expand Down
4 changes: 4 additions & 0 deletions impl/ascend/ascend_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ class AscendTensor final {
AscendTensor& asStrided(const std::vector<int64_t>& shape, const std::vector<int64_t>& stride);
AscendTensor& unsqueeze(int dim);
AscendTensor& view(const std::vector<int64_t>& shape);
AscendTensor& resize(const std::vector<int64_t>& shape);
AscendTensor& select(int64_t dim, int64_t index);
AscendTensor& permute(std::vector<int64_t> dims);
AscendTensor& expand(std::vector<int64_t> shape);

private:
// diopi origin tensor
Expand Down
Loading

0 comments on commit e13aea4

Please sign in to comment.