diff --git a/impl/ascend/aclnn/adaptor.hpp b/impl/ascend/aclnn/adaptor.hpp index 117423c78..f0c4ff953 100644 --- a/impl/ascend/aclnn/adaptor.hpp +++ b/impl/ascend/aclnn/adaptor.hpp @@ -149,6 +149,10 @@ struct IsBoolStdArray> : std::true_type {}; inline aclIntArray* createAclIntArrayFromIntVector(const std::vector& vec) { return ::aclCreateIntArray(vec.data(), vec.size()); } +inline aclTensorList* createAclTensorListFromAclTensorVector(const std::vector& tensorsVec) { + return ::aclCreateTensorList(tensorsVec.data(), tensorsVec.size()); +} + inline aclTensorList* createAclTensorListFromAscendTensorVector(const std::vector& tensorsVec) { std::vector tList(tensorsVec.size()); for (size_t i = 0; i < tensorsVec.size(); i++) { @@ -175,7 +179,11 @@ inline aclTensorList* createAclTensorListFromConstDiopiTensorVector(const std::v template >> decltype(auto) convertType(T&& param) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + return std::forward(param); + } else if constexpr (std::is_same_v>) { + return createAclTensorListFromAclTensorVector(std::forward(param)); + } else if constexpr (std::is_same_v) { return createAclTensorFromAscendTensor(std::forward(param)); } else if constexpr (std::is_same_v || std::is_same_v) { return createAclTensorFromDiopiTensor(std::forward(param)); diff --git a/impl/ascend/ascend_tensor.cpp b/impl/ascend/ascend_tensor.cpp index e966bc5f4..f39f87902 100644 --- a/impl/ascend/ascend_tensor.cpp +++ b/impl/ascend/ascend_tensor.cpp @@ -6,9 +6,11 @@ #include "ascend_tensor.hpp" +// #include #include #include #include +#include #include #include "common/debug.hpp" @@ -82,6 +84,106 @@ AscendTensor& AscendTensor::asStrided(const std::vector& shape, const s return *this; } +AscendTensor& AscendTensor::permute(std::vector dims) { + ASCEND_CHECK_ABORT(this->dim() == dims.size(), "permute dims does not match the tensor dims."); + + std::vector newShape(dims.size(), 0); + std::vector newStride(dims.size(), 0); + + for (size_t i = 0; i < dims.size(); i++) { + newShape[i] = this->shape(dims[i]); + newStride[i] = this->stride(dims[i]); + } + + this->shape_ = newShape; + this->stride_ = newStride; + + return *this; +} + +AscendTensor& AscendTensor::expand(std::vector shape) { + ASCEND_CHECK_ABORT(shape.size() >= this->dim(), + "the number of sizes provided[% ld] must be greater or eaqual to the number of dimensions of the tensor[% ld].", + shape.size(), + this->dim()); + + // todo: dim() == 0 + int64_t expandDims = shape.size() - this->shape().size(); + std::vector tShapeExp(expandDims, 0); + auto tShape = this->shape(); + tShapeExp.insert(tShapeExp.end(), tShape.begin(), tShape.end()); + std::vector newShape = shape; + + for (int64_t i = 0; i < newShape.size(); i++) { + if (newShape[i] < 0 && i < expandDims) { + ASCEND_CHECK_ABORT(false, "The expanded size of the tensor (%ld) isn't allowed in a leading, non-existing dimension %ld", newShape[i], i); + } + + if (i >= expandDims) { + if (newShape[i] == -1) { + newShape[i] = tShapeExp[i]; + } else { + ASCEND_CHECK_ABORT(tShapeExp[i] == 1 || newShape[i] == tShapeExp[i], + "The expanded size of the tensor (%ld) must match the existing size (%ld) at non-singleton dimension %ld.", + newShape[i], + tShapeExp[i], + i); + } + } + } + + int64_t numElem = std::accumulate(newShape.begin(), newShape.end(), 1, std::multiplies<>()); + std::vector newStride(expandDims, 0); + auto tStride = this->stride(); + newStride.insert(newStride.end(), tStride.begin(), tStride.end()); + for (int64_t i = expandDims; i < shape.size(); i++) { + if (shape[i] == -1 || shape[i] == tShapeExp[i]) { + continue; + } else { + newStride[i] = 0; + } + } + + this->numel_ = numElem; + this->shape_ = newShape; + this->stride_ = newStride; + + return *this; +} + +AscendTensor& AscendTensor::resize(const std::vector& shape) { + int64_t numElem = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); + std::vector stride(shape.size(), 1); + for (int64_t j = shape.size() - 2; j >= 0; j--) { + stride[j] = stride[j + 1] * shape[j + 1]; + } + + this->numel_ = numElem; + this->shape_ = shape; + this->stride_ = stride; + + return *this; +} +AscendTensor& AscendTensor::select(int64_t dim, int64_t index) { + auto shape = this->shape(); + auto stride = this->stride(); + + ASCEND_CHECK_ABORT(dim >= 0 && dim < shape.size(), "selected dim [%ld] execeed the tensor dims [%ld].", dim, shape.size()); + + if (dim < shape.size() - 1) { + int64_t offset = dim * shape[dim] * stride[dim]; + this->storageOffset_ = offset; + } + this->numel_ /= shape[dim]; + + shape.erase(shape.begin() + dim); + stride.erase(stride.begin() + dim); + this->shape_ = shape; + this->stride_ = stride; + + return *this; +} + AscendTensor& AscendTensor::unsqueeze(int dim) { // Note: `channels_last` tensor uses this will become uncontiguous // which is same with pytorch diff --git a/impl/ascend/ascend_tensor.hpp b/impl/ascend/ascend_tensor.hpp index 5c20faab4..cf295e87b 100644 --- a/impl/ascend/ascend_tensor.hpp +++ b/impl/ascend/ascend_tensor.hpp @@ -245,6 +245,10 @@ class AscendTensor final { AscendTensor& asStrided(const std::vector& shape, const std::vector& stride); AscendTensor& unsqueeze(int dim); AscendTensor& view(const std::vector& shape); + AscendTensor& resize(const std::vector& shape); + AscendTensor& select(int64_t dim, int64_t index); + AscendTensor& permute(std::vector dims); + AscendTensor& expand(std::vector shape); private: // diopi origin tensor diff --git a/impl/ascend/functions/index.cpp b/impl/ascend/functions/index.cpp new file mode 100644 index 000000000..43d9b9d7e --- /dev/null +++ b/impl/ascend/functions/index.cpp @@ -0,0 +1,320 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ + +#include + +#include "../aclnn/acl_scalar.hpp" +#include "../aclnn/adaptor.hpp" + +namespace impl { +namespace ascend { + +static std::vector castIntIndicesToLongIndices(diopiContextHandle_t ctx, std::vector& indices) { + std::vector result; + for (auto& t : indices) { + if (!t.defined()) { + result.emplace_back(nullptr); + continue; + } + if (t.dtype() == diopi_dtype_int32) { + diopiTensorHandle_t indexHandle = nullptr; + auto shape = t.shape(); + diopiSize_t size = vectorToDiopiSize(shape); + diopiRequireTensor(ctx, &indexHandle, &size, nullptr, diopi_dtype_int64, diopi_device); + DIOPI_ASCEND_CALL_ACLNN(aclnnCast, ctx, t, diopi_dtype_int64, indexHandle); + result.emplace_back(indexHandle); + } else { + if (t.device() == diopi_host) { + result.emplace_back(hostToDevice(ctx, t.tensorHandle())); + } else { + result.emplace_back(t); + } + } + } + return result; +} + +static void checkIndexTensorTypes(const std::vector& indices) { + for (const auto& t : indices) { + if (t.defined()) { + diopiDtype_t type = t.dtype(); + ASCEND_CHECK_ABORT(type == diopi_dtype_int64 || type == diopi_dtype_bool || type == diopi_dtype_uint8, + "tensors used as indices must be long, byte or bool tensors"); + } + } +} + +static AscendTensor nonZeroTensor(diopiContextHandle_t ctx, const AscendTensor& self) { + int64_t numELem = self.numel() * self.dim(); + std::vector nShape{self.numel(), self.dim()}; + std::vector nStride(nShape.size(), 1); + for (int64_t i = nShape.size() - 2; i >= 0; i--) { + nStride[i] = nStride[i + 1] * nShape[i + 1]; + } + + diopiTensorHandle_t nzBuff = nullptr; + diopiSize_t nzBuffSize = vectorToDiopiSize(nShape); + diopiRequireTensor(ctx, &nzBuff, &nzBuffSize, nullptr, diopi_dtype_int64, diopi_device); + AscendTensor nzTensor(nzBuff); + + auto aclNZTensor = ::aclCreateTensor( + nShape.data(), nShape.size(), aclDataType::ACL_INT64, nStride.data(), 0, aclFormat::ACL_FORMAT_ND, &numELem, 1, const_cast(nzTensor.data())); + DIOPI_ASCEND_CALL_ACLNN(aclnnNonzero, ctx, self, aclNZTensor); + + int64_t* vDims = nullptr; + uint64_t vDimsNum = 0; + auto ret = aclGetViewShape(aclNZTensor, &vDims, &vDimsNum); + ASCEND_CHECK_ABORT(ret == 0, "NonZero aclGetViewShape failed."); + + std::vector nzShape(vDims, vDims + vDimsNum); + nzTensor = nzTensor.resize(nzShape); + + delete vDims; + vDims = nullptr; + + diopiTensorHandle_t nzTrans = nullptr; + std::vector nzTransShape{nzShape[1], nzShape[0]}; + diopiSize_t nzTransSize = vectorToDiopiSize(nzTransShape); + diopiRequireTensor(ctx, &nzTrans, &nzTransSize, nullptr, diopi_dtype_int64, diopi_device); + std::vector transDims{1, 0}; + diopiSize_t permuteDims = vectorToDiopiSize(transDims); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, nzTensor, permuteDims, nzTrans); + + return AscendTensor(nzTrans); +} + +static std::vector expandIndicesTensors(diopiContextHandle_t ctx, const AscendTensor& self, const std::vector& indices) { + std::vector result; + for (auto& t : indices) { + if (!t.defined()) { + result.push_back(t); + } else { + if (t.dtype() == diopi_dtype_uint8 || t.dtype() == diopi_dtype_bool) { + ASCEND_CHECK(t.dtype() != diopi_dtype_uint8, + "indexing with dtype torch.uint8 is now deprecated," + " please use a dtype torch.bool instead."); + for (uint64_t j = 0; j < static_cast(t.dim()); j++) { + uint64_t srcIdx = result.size() + j; + ASCEND_CHECK_ABORT(t.shape(j) == self.shape(srcIdx), + "The shape of the mask %ld at index %ld does not match the shape of the indexed tensor %ld at index %ld", + t.dim(), + j, + self.dim(), + srcIdx); + } + AscendTensor non = nonZeroTensor(ctx, t); + for (int64_t j = 0; j < t.dim(); j++) { + result.push_back(non.select(0, j)); + } + } else { + result.push_back(t); + } + } + } + return result; +} + +static aclTensor* createEmptyAclTensor() { + std::vector nShape{0}; + std::vector nStride{1}; + int64_t storageSize = 0; + void* storage = nullptr; + + return ::aclCreateTensor(nShape.data(), nShape.size(), aclDataType::ACL_FLOAT16, nStride.data(), 0, aclFormat::ACL_FORMAT_ND, &storageSize, 0, storage); +} + +static std::vector indicesExpandedOutplace(std::vector indices) { + bool first = true; + std::vector sizes; + + for (auto& idx : indices) { + if (!idx.defined()) { + continue; + } else if (first) { + sizes = idx.shape(); + first = false; + } else { + sizes = inferSize(sizes, idx.shape()); + } + } + + std::vector result; + for (auto& idx : indices) { + if (!idx.defined() || (idx.shape() == sizes)) { + result.push_back(idx); + } else { + result.push_back(idx.expand(sizes)); + } + } + return result; +} + +static bool hasContiguousSubspace(std::vector indices) { // true if all the non-null tensors are adjacent + auto isDefined = [](const AscendTensor& tensor) { return tensor.defined(); }; + auto isNull = [](const AscendTensor& tensor) { return !tensor.defined(); }; + auto start = std::find_if(indices.begin(), indices.end(), isDefined); + auto stop = std::find_if(indices.rbegin(), indices.rend(), isDefined); + auto it = std::find_if(start, stop.base(), isNull); + return it == stop.base(); +} + +static std::tuple> transposeToFront(AscendTensor self, std::vector indices) { + std::vector dims; + std::vector transposedIndices; + + dims.reserve(self.dim()); + for (int64_t i = 0; i < self.dim(); i++) { + if (indices[i].defined()) { + dims.push_back(i); + transposedIndices.push_back(indices[i]); + } + } + + for (int64_t i = 0; i < self.dim(); i++) { + if (!indices[i].defined()) { + dims.push_back(i); + transposedIndices.push_back(indices[i]); + } + } + + return std::make_tuple(self.permute(dims), transposedIndices); +} + +static std::vector indexReshape(std::vector endIndices, int64_t dimsBefore, int64_t dimsAfter) { + std::vector indexShape; + for (auto& idx : endIndices) { + if (idx.defined()) { + std::vector shape; + shape.insert(shape.end(), dimsBefore, 1); + shape.insert(shape.end(), idx.shape().begin(), idx.shape().end()); + shape.insert(shape.end(), dimsAfter, 1); + if (indexShape.empty()) { + indexShape = shape; + } else { + indexShape = inferSize(indexShape, shape); + } + } + } + return indexShape; +} + +static std::vector indexOutputSize(const AscendTensor& self, std::vector& indices) { + std::vector midIndices = indicesExpandedOutplace(indices); + while (midIndices.size() < (size_t)self.dim()) { + midIndices.emplace_back(nullptr); + } + + AscendTensor src = self; + std::vector endIndices = midIndices; + if (!hasContiguousSubspace(midIndices)) { + endIndices.clear(); + std::tie(src, endIndices) = transposeToFront(self, midIndices); + } + + int64_t dimsBefore = 0; + int64_t dimsAfter = 0; + int64_t dimsIndexed = 0; + + std::vector replaceShape; + std::vector indexedSizes; + + for (size_t dim = 0; dim < endIndices.size(); dim++) { + if (!endIndices[dim].defined()) { + if (dimsIndexed == 0) { + dimsBefore++; + } else { + dimsAfter++; + } + } else { + dimsIndexed++; + replaceShape = endIndices[dim].shape(); + indexedSizes.push_back(src.shape(dim)); + } + } + + if (std::find(indexedSizes.begin(), indexedSizes.end(), 0) != indexedSizes.end() && + std::find(replaceShape.begin(), replaceShape.end(), 0) == replaceShape.end()) { + ASCEND_CHECK_ABORT(false, "index is out of bounds for dimension with size 0"); + } + + auto selfShape = src.shape(); + int64_t end = dimsBefore + dimsIndexed; + selfShape.erase(selfShape.begin() + dimsBefore, selfShape.begin() + end); + selfShape.insert(selfShape.begin() + dimsBefore, replaceShape.begin(), replaceShape.end()); + + std::vector indexShape = indexReshape(endIndices, dimsBefore, dimsAfter); + std::vector outputSize = indexShape; + if (indexShape != selfShape) { + outputSize = inferSize(indexShape, selfShape); + } + + return outputSize; +} + +diopiError_t diopiIndex(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t* indices, int64_t nums) { + AscendTensor inputAt(input); + std::vector indicesOrigin(nums); + for (int64_t i = 0; i < nums; i++) { + if (indices[i] != nullptr) { + indicesOrigin[i] = AscendTensor(indices[i]); + } + } + + std::vector indicesList = castIntIndicesToLongIndices(ctx, indicesOrigin); + checkIndexTensorTypes(indicesList); + + auto indicesExpanded = expandIndicesTensors(ctx, inputAt, indicesList); + + std::vector allDefinedIndices; + auto emptyTensor = createEmptyAclTensor(); + for (const auto& idx : indicesExpanded) { + if (idx.defined()) { + allDefinedIndices.push_back(aclnn_adaptor::createAclTensorFromAscendTensor(idx)); + } else { + allDefinedIndices.push_back(emptyTensor); + } + } + + std::vector outShape = indexOutputSize(inputAt, indicesExpanded); + + diopiSize_t outSize = vectorToDiopiSize(outShape); + diopiRequireTensor(ctx, out, &outSize, nullptr, inputAt.dtype(), diopi_device); + + DIOPI_ASCEND_CALL_ACLNN(aclnnIndex, ctx, inputAt, allDefinedIndices, *out); + return diopiSuccess; +} + +diopiError_t diopiIndexBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t zerosLikeInput, diopiConstTensorHandle_t* indices, + int64_t nums, diopiConstTensorHandle_t gradOutput) { + AscendTensor gradInputTensor(gradInput); + AscendTensor gradOutputTensor(gradOutput); + if (gradInputTensor.numel() == 0 || gradOutputTensor.numel() == 0) { + return diopiSuccess; + } + + std::vector indicesVec; + indicesVec.reserve(nums); + + for (int i = 0; i < nums; i++) { + if (indices[i] != nullptr) { + indicesVec.emplace_back(indices[i]); + } else { + int64_t array[1] = {0}; + diopiSize_t size = {array, 1}; + diopiTensorHandle_t emptyTensor = nullptr; + diopiRequireTensor(ctx, &emptyTensor, &size, nullptr, gradOutputTensor.dtype(), diopi_device); + indicesVec.emplace_back(emptyTensor); + } + } + + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceCopy, ctx, gradInput, zerosLikeInput); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, gradInput, indicesVec, gradOutput, true, false); + + return diopiSuccess; +} + +} // namespace ascend +} // namespace impl diff --git a/impl/ascend_npu/CMakeLists.txt b/impl/ascend_npu/CMakeLists.txt index ba7701105..84f285f0b 100755 --- a/impl/ascend_npu/CMakeLists.txt +++ b/impl/ascend_npu/CMakeLists.txt @@ -167,6 +167,7 @@ set(OLD_IMPL_SRC ${OLD_IMPL_DIR}/functions/arange.cpp ${OLD_IMPL_DIR}/functions/gather.cpp ${OLD_IMPL_DIR}/functions/layer_norm.cpp + ${OLD_IMPL_DIR}/functions/index.cpp ${OLD_IMPL_DIR}/functions/index_put.cpp ${OLD_IMPL_DIR}/functions/index_select.cpp ${OLD_IMPL_DIR}/functions/repeat.cpp diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index 9dbdec336..7def339c0 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -112,6 +112,8 @@ ascend: - diopiHardtanh - diopiHardtanhBackward - diopiHardtanhInp +- diopiIndex +- diopiIndexBackward - diopiIndexPut - diopiIndexPutInp - diopiIndexSelect @@ -265,8 +267,6 @@ ascend_npu: - diopiApplyPenalty - diopiContextAttentionInference - diopiGetNativeMemoryFormat -- diopiIndex -- diopiIndexBackward - diopiNLLLoss - diopiNLLLossBackward - diopiNLLLossV2