From f57378e33aad8b0ab0e064b5d9ec417f84bf1ff0 Mon Sep 17 00:00:00 2001 From: wiryls <7984500+wiryls@users.noreply.github.com> Date: Wed, 7 Aug 2024 10:19:07 +0800 Subject: [PATCH 1/6] feat: add a new diopiError_t enum to fallback to cpu (#1335) --- proto/include/diopi/diopirt.h | 1 + 1 file changed, 1 insertion(+) diff --git a/proto/include/diopi/diopirt.h b/proto/include/diopi/diopirt.h index a451b0cd2..2cc9a556f 100644 --- a/proto/include/diopi/diopirt.h +++ b/proto/include/diopi/diopirt.h @@ -44,6 +44,7 @@ typedef enum { diopiNoRegisteredGetLastErrorFunction = 11, diopi5DNotSupported = 12, diopiNoImplement = 13, + diopiForceFallbackToCPU = 14, diopiDtypeNotSupported = 1000, } diopiError_t; From fdf95278f3fd7607b8e1c6118c730670f7042044 Mon Sep 17 00:00:00 2001 From: Lantian Zhang <50076473+DoorKickers@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:52:17 +0800 Subject: [PATCH 2/6] align diopi test to torch2.1 (#1338) align diopi test for torch2.1 --- diopi_test/python/configs/diopi_configs.py | 7 ++++--- impl/torch/functions/functions.cpp | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index b852b81be..e57b393da 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -959,6 +959,7 @@ "shape": ((), (16,), (72,), (2, 11856), (2, 741, 80), (4, 4, 16, 20), (0,), (4, 0), (9, 0, 16)), + "gen_fn": dict(fn='Genfunc.uniform', low=0, high=1), }, { "ins": ['weight'], @@ -5236,7 +5237,7 @@ }, { "ins": ['max_exp_avg_sq'], - "shape": [None, None, (4, 8), (12, 4, 8)], + "shape": [(), (16,), (4, 8), (12, 4, 8)], "gen_fn": 'Genfunc.rand', }, ] @@ -6020,13 +6021,13 @@ { "ins": ['input'], "shape": ((8, 0), (0, 128), (256, 8)), - "dtype": [np.float32, np.float16, np.float64], + "dtype": [np.float16, np.float32, np.float64], "gen_fn": 'Genfunc.randn', }, { "ins": ['mat2'], "shape": ((0, 128), (128, 128), (8, 0)), - "dtype": [np.float16, np.float64, np.float32], + "dtype": [np.float16, np.float32, np.float64], "gen_fn": 'Genfunc.randn', }, ], diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 94b6705bd..0313ebe6a 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -2915,7 +2915,11 @@ diopiError_t diopiRmsprop(diopiContextHandle_t ctx, diopiTensorHandle_t param, d at::Tensor atAvg; if (centered) { +#if TORCH_MM_VERSION >= 2010 + atGradAvg.lerp_(atGrad, 1 - alpha); +#else atGradAvg.mul_(alpha).add_(atGrad, 1 - alpha); +#endif atAvg = atSquareAvg.addcmul(atGradAvg, atGradAvg, -1).sqrt_().add_(eps); } else { atAvg = atSquareAvg.sqrt().add_(eps); From e13aea4b3ec4faa906243421186d5adbd7726ef0 Mon Sep 17 00:00:00 2001 From: Fu Jingguo Date: Fri, 9 Aug 2024 17:01:12 +0800 Subject: [PATCH 3/6] [Ascend] fuj/acl-index (#1332) * impl index for aclnn p1 * impl index for ascend p2 * fix bug for index impl ascend p3 * fix a warning bug for index --- impl/ascend/aclnn/adaptor.hpp | 10 +- impl/ascend/ascend_tensor.cpp | 102 +++++++++ impl/ascend/ascend_tensor.hpp | 4 + impl/ascend/functions/index.cpp | 320 +++++++++++++++++++++++++++++ impl/ascend_npu/CMakeLists.txt | 1 + impl/ascend_npu/ascend_config.yaml | 4 +- 6 files changed, 438 insertions(+), 3 deletions(-) create mode 100644 impl/ascend/functions/index.cpp diff --git a/impl/ascend/aclnn/adaptor.hpp b/impl/ascend/aclnn/adaptor.hpp index 117423c78..f0c4ff953 100644 --- a/impl/ascend/aclnn/adaptor.hpp +++ b/impl/ascend/aclnn/adaptor.hpp @@ -149,6 +149,10 @@ struct IsBoolStdArray> : std::true_type {}; inline aclIntArray* createAclIntArrayFromIntVector(const std::vector& vec) { return ::aclCreateIntArray(vec.data(), vec.size()); } +inline aclTensorList* createAclTensorListFromAclTensorVector(const std::vector& tensorsVec) { + return ::aclCreateTensorList(tensorsVec.data(), tensorsVec.size()); +} + inline aclTensorList* createAclTensorListFromAscendTensorVector(const std::vector& tensorsVec) { std::vector tList(tensorsVec.size()); for (size_t i = 0; i < tensorsVec.size(); i++) { @@ -175,7 +179,11 @@ inline aclTensorList* createAclTensorListFromConstDiopiTensorVector(const std::v template >> decltype(auto) convertType(T&& param) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + return std::forward(param); + } else if constexpr (std::is_same_v>) { + return createAclTensorListFromAclTensorVector(std::forward(param)); + } else if constexpr (std::is_same_v) { return createAclTensorFromAscendTensor(std::forward(param)); } else if constexpr (std::is_same_v || std::is_same_v) { return createAclTensorFromDiopiTensor(std::forward(param)); diff --git a/impl/ascend/ascend_tensor.cpp b/impl/ascend/ascend_tensor.cpp index e966bc5f4..f39f87902 100644 --- a/impl/ascend/ascend_tensor.cpp +++ b/impl/ascend/ascend_tensor.cpp @@ -6,9 +6,11 @@ #include "ascend_tensor.hpp" +// #include #include #include #include +#include #include #include "common/debug.hpp" @@ -82,6 +84,106 @@ AscendTensor& AscendTensor::asStrided(const std::vector& shape, const s return *this; } +AscendTensor& AscendTensor::permute(std::vector dims) { + ASCEND_CHECK_ABORT(this->dim() == dims.size(), "permute dims does not match the tensor dims."); + + std::vector newShape(dims.size(), 0); + std::vector newStride(dims.size(), 0); + + for (size_t i = 0; i < dims.size(); i++) { + newShape[i] = this->shape(dims[i]); + newStride[i] = this->stride(dims[i]); + } + + this->shape_ = newShape; + this->stride_ = newStride; + + return *this; +} + +AscendTensor& AscendTensor::expand(std::vector shape) { + ASCEND_CHECK_ABORT(shape.size() >= this->dim(), + "the number of sizes provided[% ld] must be greater or eaqual to the number of dimensions of the tensor[% ld].", + shape.size(), + this->dim()); + + // todo: dim() == 0 + int64_t expandDims = shape.size() - this->shape().size(); + std::vector tShapeExp(expandDims, 0); + auto tShape = this->shape(); + tShapeExp.insert(tShapeExp.end(), tShape.begin(), tShape.end()); + std::vector newShape = shape; + + for (int64_t i = 0; i < newShape.size(); i++) { + if (newShape[i] < 0 && i < expandDims) { + ASCEND_CHECK_ABORT(false, "The expanded size of the tensor (%ld) isn't allowed in a leading, non-existing dimension %ld", newShape[i], i); + } + + if (i >= expandDims) { + if (newShape[i] == -1) { + newShape[i] = tShapeExp[i]; + } else { + ASCEND_CHECK_ABORT(tShapeExp[i] == 1 || newShape[i] == tShapeExp[i], + "The expanded size of the tensor (%ld) must match the existing size (%ld) at non-singleton dimension %ld.", + newShape[i], + tShapeExp[i], + i); + } + } + } + + int64_t numElem = std::accumulate(newShape.begin(), newShape.end(), 1, std::multiplies<>()); + std::vector newStride(expandDims, 0); + auto tStride = this->stride(); + newStride.insert(newStride.end(), tStride.begin(), tStride.end()); + for (int64_t i = expandDims; i < shape.size(); i++) { + if (shape[i] == -1 || shape[i] == tShapeExp[i]) { + continue; + } else { + newStride[i] = 0; + } + } + + this->numel_ = numElem; + this->shape_ = newShape; + this->stride_ = newStride; + + return *this; +} + +AscendTensor& AscendTensor::resize(const std::vector& shape) { + int64_t numElem = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); + std::vector stride(shape.size(), 1); + for (int64_t j = shape.size() - 2; j >= 0; j--) { + stride[j] = stride[j + 1] * shape[j + 1]; + } + + this->numel_ = numElem; + this->shape_ = shape; + this->stride_ = stride; + + return *this; +} +AscendTensor& AscendTensor::select(int64_t dim, int64_t index) { + auto shape = this->shape(); + auto stride = this->stride(); + + ASCEND_CHECK_ABORT(dim >= 0 && dim < shape.size(), "selected dim [%ld] execeed the tensor dims [%ld].", dim, shape.size()); + + if (dim < shape.size() - 1) { + int64_t offset = dim * shape[dim] * stride[dim]; + this->storageOffset_ = offset; + } + this->numel_ /= shape[dim]; + + shape.erase(shape.begin() + dim); + stride.erase(stride.begin() + dim); + this->shape_ = shape; + this->stride_ = stride; + + return *this; +} + AscendTensor& AscendTensor::unsqueeze(int dim) { // Note: `channels_last` tensor uses this will become uncontiguous // which is same with pytorch diff --git a/impl/ascend/ascend_tensor.hpp b/impl/ascend/ascend_tensor.hpp index 5c20faab4..cf295e87b 100644 --- a/impl/ascend/ascend_tensor.hpp +++ b/impl/ascend/ascend_tensor.hpp @@ -245,6 +245,10 @@ class AscendTensor final { AscendTensor& asStrided(const std::vector& shape, const std::vector& stride); AscendTensor& unsqueeze(int dim); AscendTensor& view(const std::vector& shape); + AscendTensor& resize(const std::vector& shape); + AscendTensor& select(int64_t dim, int64_t index); + AscendTensor& permute(std::vector dims); + AscendTensor& expand(std::vector shape); private: // diopi origin tensor diff --git a/impl/ascend/functions/index.cpp b/impl/ascend/functions/index.cpp new file mode 100644 index 000000000..43d9b9d7e --- /dev/null +++ b/impl/ascend/functions/index.cpp @@ -0,0 +1,320 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ + +#include + +#include "../aclnn/acl_scalar.hpp" +#include "../aclnn/adaptor.hpp" + +namespace impl { +namespace ascend { + +static std::vector castIntIndicesToLongIndices(diopiContextHandle_t ctx, std::vector& indices) { + std::vector result; + for (auto& t : indices) { + if (!t.defined()) { + result.emplace_back(nullptr); + continue; + } + if (t.dtype() == diopi_dtype_int32) { + diopiTensorHandle_t indexHandle = nullptr; + auto shape = t.shape(); + diopiSize_t size = vectorToDiopiSize(shape); + diopiRequireTensor(ctx, &indexHandle, &size, nullptr, diopi_dtype_int64, diopi_device); + DIOPI_ASCEND_CALL_ACLNN(aclnnCast, ctx, t, diopi_dtype_int64, indexHandle); + result.emplace_back(indexHandle); + } else { + if (t.device() == diopi_host) { + result.emplace_back(hostToDevice(ctx, t.tensorHandle())); + } else { + result.emplace_back(t); + } + } + } + return result; +} + +static void checkIndexTensorTypes(const std::vector& indices) { + for (const auto& t : indices) { + if (t.defined()) { + diopiDtype_t type = t.dtype(); + ASCEND_CHECK_ABORT(type == diopi_dtype_int64 || type == diopi_dtype_bool || type == diopi_dtype_uint8, + "tensors used as indices must be long, byte or bool tensors"); + } + } +} + +static AscendTensor nonZeroTensor(diopiContextHandle_t ctx, const AscendTensor& self) { + int64_t numELem = self.numel() * self.dim(); + std::vector nShape{self.numel(), self.dim()}; + std::vector nStride(nShape.size(), 1); + for (int64_t i = nShape.size() - 2; i >= 0; i--) { + nStride[i] = nStride[i + 1] * nShape[i + 1]; + } + + diopiTensorHandle_t nzBuff = nullptr; + diopiSize_t nzBuffSize = vectorToDiopiSize(nShape); + diopiRequireTensor(ctx, &nzBuff, &nzBuffSize, nullptr, diopi_dtype_int64, diopi_device); + AscendTensor nzTensor(nzBuff); + + auto aclNZTensor = ::aclCreateTensor( + nShape.data(), nShape.size(), aclDataType::ACL_INT64, nStride.data(), 0, aclFormat::ACL_FORMAT_ND, &numELem, 1, const_cast(nzTensor.data())); + DIOPI_ASCEND_CALL_ACLNN(aclnnNonzero, ctx, self, aclNZTensor); + + int64_t* vDims = nullptr; + uint64_t vDimsNum = 0; + auto ret = aclGetViewShape(aclNZTensor, &vDims, &vDimsNum); + ASCEND_CHECK_ABORT(ret == 0, "NonZero aclGetViewShape failed."); + + std::vector nzShape(vDims, vDims + vDimsNum); + nzTensor = nzTensor.resize(nzShape); + + delete vDims; + vDims = nullptr; + + diopiTensorHandle_t nzTrans = nullptr; + std::vector nzTransShape{nzShape[1], nzShape[0]}; + diopiSize_t nzTransSize = vectorToDiopiSize(nzTransShape); + diopiRequireTensor(ctx, &nzTrans, &nzTransSize, nullptr, diopi_dtype_int64, diopi_device); + std::vector transDims{1, 0}; + diopiSize_t permuteDims = vectorToDiopiSize(transDims); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, nzTensor, permuteDims, nzTrans); + + return AscendTensor(nzTrans); +} + +static std::vector expandIndicesTensors(diopiContextHandle_t ctx, const AscendTensor& self, const std::vector& indices) { + std::vector result; + for (auto& t : indices) { + if (!t.defined()) { + result.push_back(t); + } else { + if (t.dtype() == diopi_dtype_uint8 || t.dtype() == diopi_dtype_bool) { + ASCEND_CHECK(t.dtype() != diopi_dtype_uint8, + "indexing with dtype torch.uint8 is now deprecated," + " please use a dtype torch.bool instead."); + for (uint64_t j = 0; j < static_cast(t.dim()); j++) { + uint64_t srcIdx = result.size() + j; + ASCEND_CHECK_ABORT(t.shape(j) == self.shape(srcIdx), + "The shape of the mask %ld at index %ld does not match the shape of the indexed tensor %ld at index %ld", + t.dim(), + j, + self.dim(), + srcIdx); + } + AscendTensor non = nonZeroTensor(ctx, t); + for (int64_t j = 0; j < t.dim(); j++) { + result.push_back(non.select(0, j)); + } + } else { + result.push_back(t); + } + } + } + return result; +} + +static aclTensor* createEmptyAclTensor() { + std::vector nShape{0}; + std::vector nStride{1}; + int64_t storageSize = 0; + void* storage = nullptr; + + return ::aclCreateTensor(nShape.data(), nShape.size(), aclDataType::ACL_FLOAT16, nStride.data(), 0, aclFormat::ACL_FORMAT_ND, &storageSize, 0, storage); +} + +static std::vector indicesExpandedOutplace(std::vector indices) { + bool first = true; + std::vector sizes; + + for (auto& idx : indices) { + if (!idx.defined()) { + continue; + } else if (first) { + sizes = idx.shape(); + first = false; + } else { + sizes = inferSize(sizes, idx.shape()); + } + } + + std::vector result; + for (auto& idx : indices) { + if (!idx.defined() || (idx.shape() == sizes)) { + result.push_back(idx); + } else { + result.push_back(idx.expand(sizes)); + } + } + return result; +} + +static bool hasContiguousSubspace(std::vector indices) { // true if all the non-null tensors are adjacent + auto isDefined = [](const AscendTensor& tensor) { return tensor.defined(); }; + auto isNull = [](const AscendTensor& tensor) { return !tensor.defined(); }; + auto start = std::find_if(indices.begin(), indices.end(), isDefined); + auto stop = std::find_if(indices.rbegin(), indices.rend(), isDefined); + auto it = std::find_if(start, stop.base(), isNull); + return it == stop.base(); +} + +static std::tuple> transposeToFront(AscendTensor self, std::vector indices) { + std::vector dims; + std::vector transposedIndices; + + dims.reserve(self.dim()); + for (int64_t i = 0; i < self.dim(); i++) { + if (indices[i].defined()) { + dims.push_back(i); + transposedIndices.push_back(indices[i]); + } + } + + for (int64_t i = 0; i < self.dim(); i++) { + if (!indices[i].defined()) { + dims.push_back(i); + transposedIndices.push_back(indices[i]); + } + } + + return std::make_tuple(self.permute(dims), transposedIndices); +} + +static std::vector indexReshape(std::vector endIndices, int64_t dimsBefore, int64_t dimsAfter) { + std::vector indexShape; + for (auto& idx : endIndices) { + if (idx.defined()) { + std::vector shape; + shape.insert(shape.end(), dimsBefore, 1); + shape.insert(shape.end(), idx.shape().begin(), idx.shape().end()); + shape.insert(shape.end(), dimsAfter, 1); + if (indexShape.empty()) { + indexShape = shape; + } else { + indexShape = inferSize(indexShape, shape); + } + } + } + return indexShape; +} + +static std::vector indexOutputSize(const AscendTensor& self, std::vector& indices) { + std::vector midIndices = indicesExpandedOutplace(indices); + while (midIndices.size() < (size_t)self.dim()) { + midIndices.emplace_back(nullptr); + } + + AscendTensor src = self; + std::vector endIndices = midIndices; + if (!hasContiguousSubspace(midIndices)) { + endIndices.clear(); + std::tie(src, endIndices) = transposeToFront(self, midIndices); + } + + int64_t dimsBefore = 0; + int64_t dimsAfter = 0; + int64_t dimsIndexed = 0; + + std::vector replaceShape; + std::vector indexedSizes; + + for (size_t dim = 0; dim < endIndices.size(); dim++) { + if (!endIndices[dim].defined()) { + if (dimsIndexed == 0) { + dimsBefore++; + } else { + dimsAfter++; + } + } else { + dimsIndexed++; + replaceShape = endIndices[dim].shape(); + indexedSizes.push_back(src.shape(dim)); + } + } + + if (std::find(indexedSizes.begin(), indexedSizes.end(), 0) != indexedSizes.end() && + std::find(replaceShape.begin(), replaceShape.end(), 0) == replaceShape.end()) { + ASCEND_CHECK_ABORT(false, "index is out of bounds for dimension with size 0"); + } + + auto selfShape = src.shape(); + int64_t end = dimsBefore + dimsIndexed; + selfShape.erase(selfShape.begin() + dimsBefore, selfShape.begin() + end); + selfShape.insert(selfShape.begin() + dimsBefore, replaceShape.begin(), replaceShape.end()); + + std::vector indexShape = indexReshape(endIndices, dimsBefore, dimsAfter); + std::vector outputSize = indexShape; + if (indexShape != selfShape) { + outputSize = inferSize(indexShape, selfShape); + } + + return outputSize; +} + +diopiError_t diopiIndex(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t* indices, int64_t nums) { + AscendTensor inputAt(input); + std::vector indicesOrigin(nums); + for (int64_t i = 0; i < nums; i++) { + if (indices[i] != nullptr) { + indicesOrigin[i] = AscendTensor(indices[i]); + } + } + + std::vector indicesList = castIntIndicesToLongIndices(ctx, indicesOrigin); + checkIndexTensorTypes(indicesList); + + auto indicesExpanded = expandIndicesTensors(ctx, inputAt, indicesList); + + std::vector allDefinedIndices; + auto emptyTensor = createEmptyAclTensor(); + for (const auto& idx : indicesExpanded) { + if (idx.defined()) { + allDefinedIndices.push_back(aclnn_adaptor::createAclTensorFromAscendTensor(idx)); + } else { + allDefinedIndices.push_back(emptyTensor); + } + } + + std::vector outShape = indexOutputSize(inputAt, indicesExpanded); + + diopiSize_t outSize = vectorToDiopiSize(outShape); + diopiRequireTensor(ctx, out, &outSize, nullptr, inputAt.dtype(), diopi_device); + + DIOPI_ASCEND_CALL_ACLNN(aclnnIndex, ctx, inputAt, allDefinedIndices, *out); + return diopiSuccess; +} + +diopiError_t diopiIndexBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t zerosLikeInput, diopiConstTensorHandle_t* indices, + int64_t nums, diopiConstTensorHandle_t gradOutput) { + AscendTensor gradInputTensor(gradInput); + AscendTensor gradOutputTensor(gradOutput); + if (gradInputTensor.numel() == 0 || gradOutputTensor.numel() == 0) { + return diopiSuccess; + } + + std::vector indicesVec; + indicesVec.reserve(nums); + + for (int i = 0; i < nums; i++) { + if (indices[i] != nullptr) { + indicesVec.emplace_back(indices[i]); + } else { + int64_t array[1] = {0}; + diopiSize_t size = {array, 1}; + diopiTensorHandle_t emptyTensor = nullptr; + diopiRequireTensor(ctx, &emptyTensor, &size, nullptr, gradOutputTensor.dtype(), diopi_device); + indicesVec.emplace_back(emptyTensor); + } + } + + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceCopy, ctx, gradInput, zerosLikeInput); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, gradInput, indicesVec, gradOutput, true, false); + + return diopiSuccess; +} + +} // namespace ascend +} // namespace impl diff --git a/impl/ascend_npu/CMakeLists.txt b/impl/ascend_npu/CMakeLists.txt index ba7701105..84f285f0b 100755 --- a/impl/ascend_npu/CMakeLists.txt +++ b/impl/ascend_npu/CMakeLists.txt @@ -167,6 +167,7 @@ set(OLD_IMPL_SRC ${OLD_IMPL_DIR}/functions/arange.cpp ${OLD_IMPL_DIR}/functions/gather.cpp ${OLD_IMPL_DIR}/functions/layer_norm.cpp + ${OLD_IMPL_DIR}/functions/index.cpp ${OLD_IMPL_DIR}/functions/index_put.cpp ${OLD_IMPL_DIR}/functions/index_select.cpp ${OLD_IMPL_DIR}/functions/repeat.cpp diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index 9dbdec336..7def339c0 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -112,6 +112,8 @@ ascend: - diopiHardtanh - diopiHardtanhBackward - diopiHardtanhInp +- diopiIndex +- diopiIndexBackward - diopiIndexPut - diopiIndexPutInp - diopiIndexSelect @@ -265,8 +267,6 @@ ascend_npu: - diopiApplyPenalty - diopiContextAttentionInference - diopiGetNativeMemoryFormat -- diopiIndex -- diopiIndexBackward - diopiNLLLoss - diopiNLLLossBackward - diopiNLLLossV2 From f8d12275b51bb4564366ab301dd2e954686b9a70 Mon Sep 17 00:00:00 2001 From: Lantian Zhang <50076473+DoorKickers@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:56:04 +0800 Subject: [PATCH 4/6] [torch] Fix some cuda impl related && Add log1p (#1336) * fix some original cuda impl code * fix clang format * add log1p for diopi torch & add log1pInp & add cuda test device_config & update diopi test * add cuda test device_config for last commit's forgetfulness & refactor format * fix cuda impl's test cmakelists: remove {CMAKE_SOURCE_DIR} --------- Co-authored-by: caikun-pjlab --- diopi_test/python/configs/diopi_configs.py | 8 +- .../python/conformance/diopi_functions.py | 4 + impl/cuda/device_configs.py | 78 +++++++++++++++++++ impl/cuda/error.cpp | 1 + impl/cuda/functions.cu | 50 +++++++++++- impl/cuda/test/CMakeLists.txt | 30 ++++--- impl/cuda/test/conform_test.cpp | 12 +++ impl/torch/functions/functions.cpp | 17 ++++ proto/include/diopi/functions.h | 16 ++++ 9 files changed, 194 insertions(+), 22 deletions(-) create mode 100644 impl/cuda/device_configs.py diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index e57b393da..530d5d995 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -1120,7 +1120,7 @@ ), 'pointwise_op_abs_input': dict( - name=['log', 'log2', 'log10', 'sqrt', 'rsqrt'], + name=['log', 'log2', 'log10', 'log1p', 'sqrt', 'rsqrt'], interface=['torch'], is_inplace=True, dtype=[np.float16, np.float32, np.float64], @@ -1138,7 +1138,7 @@ ), 'log_integer_input': dict( - name=['log', 'log2', 'log10'], + name=['log', 'log2', 'log10', 'log1p'], interface=['torch'], dtype=[np.int16, np.int32, np.int64, np.uint8, np.int8], tensor_para=dict( @@ -1155,7 +1155,7 @@ ), 'log_zero_input': dict( - name=['log', 'log2', 'log10'], + name=['log', 'log2', 'log10', 'log1p'], interface=['torch'], dtype=[np.float16, np.float32, np.float64, np.int16, np.int32, np.int64, @@ -1174,7 +1174,7 @@ ), 'log_neg_input': dict( - name=['log', 'log2', 'log10'], + name=['log', 'log2', 'log10', 'log1p'], interface=['torch'], dtype=[np.float16, np.float32, np.float64, np.int16, np.int32, np.int64, diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index fe69f17e8..66bb2ac98 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -429,6 +429,10 @@ def log10(input, inplace=False) -> Tensor: return unary_op(input, inplace, "diopiLog10", promote_type(input, Dtype.float32)) +def log1p(input, inplace=False) -> Tensor: + return unary_op(input, inplace, "diopiLog1p", promote_type(input, Dtype.float32)) + + def erf(input, inplace=False) -> Tensor: return unary_op(input, inplace, "diopiErf", promote_type(input, Dtype.float32)) diff --git a/impl/cuda/device_configs.py b/impl/cuda/device_configs.py new file mode 100644 index 000000000..67b849252 --- /dev/null +++ b/impl/cuda/device_configs.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024, DeepLink. +import numpy as np +from skip import Skip + +device_configs = { + "log_integer_input": dict( + name=["log1p"], + tensor_para=dict( + args=[ + { + "ins": ["input"], + "dtype": [ + Skip(np.int16), + Skip(np.int32), + Skip(np.int64), + Skip(np.int8), + Skip(np.uint8), + Skip(np.float16), + ], + }, + ] + ), + ), + "pointwise_op_abs_input": dict( + name=["log1p"], + tensor_para=dict( + args=[ + { + "ins": ["input"], + "dtype": [ + Skip(np.int16), + Skip(np.int32), + Skip(np.int64), + Skip(np.int8), + Skip(np.uint8), + Skip(np.float16), + ], + }, + ] + ), + ), + "log_zero_input": dict( + name=["log1p"], + tensor_para=dict( + args=[ + { + "ins": ["input"], + "dtype": [ + Skip(np.int16), + Skip(np.int32), + Skip(np.int64), + Skip(np.int8), + Skip(np.uint8), + Skip(np.float16), + ], + }, + ] + ), + ), + "log_neg_input": dict( + name=["log1p"], + tensor_para=dict( + args=[ + { + "ins": ["input"], + "dtype": [ + Skip(np.int16), + Skip(np.int32), + Skip(np.int64), + Skip(np.int8), + Skip(np.uint8), + Skip(np.float16), + ], + }, + ] + ), + ), +} diff --git a/impl/cuda/error.cpp b/impl/cuda/error.cpp index 679211ab3..a856db0a1 100644 --- a/impl/cuda/error.cpp +++ b/impl/cuda/error.cpp @@ -28,4 +28,5 @@ void _set_last_error_string(const char* err) { sprintf(strLastErrorOther, "%s", err); } +const char* diopiGetLastErrorString() { return cuda_get_last_error_string(); } } // extern "C" diff --git a/impl/cuda/functions.cu b/impl/cuda/functions.cu index 69defe670..2e2478f84 100644 --- a/impl/cuda/functions.cu +++ b/impl/cuda/functions.cu @@ -12,6 +12,8 @@ #include "helper.hpp" #include "cuda_helper.hpp" +#include + template __global__ void vecAdd(const void* a, const void* b, void* c, const int numel, const T alpha) { int id = blockIdx.x * blockDim.x + threadIdx.x; @@ -100,7 +102,7 @@ extern "C" diopiError_t diopiAdd(diopiContextHandle_t ctx, diopiTensorHandle_t o int blockSize = 256; double coff = 0.0; - if (trInput.dtype() <= 7) { + if (alpha->stype <= 7) { coff = alpha->ival; } else { coff = alpha->fval; @@ -188,3 +190,49 @@ extern "C" diopiError_t diopiFill(diopiContextHandle_t ctx, diopiTensorHandle_t return diopiSuccess; } + +template __global__ +void vecLog1p(const void* a, void* b, const int numel) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + const T* A = static_cast(a); + T* B = static_cast(b); + if (id < numel) { + B[id] = logf(1 + A[id]); + } +} + +extern "C" diopiError_t diopiLog1p(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + auto stream = impl::cuda::getStream(ctx); + auto trInput = impl::cuda::makeTensor(input); + auto trOut = impl::cuda::makeTensor(out); + + int blockSize = 256; + int gridSize = (trOut.numel() + blockSize - 1) / blockSize; + + DISPATCH_DTYPE(vecLog1p, trInput.dtype(), gridSize, blockSize, stream, + trInput.data(), trOut.data(), trInput.numel()); + + return diopiSuccess; +} + +template __global__ +void vecLog1pInp(void* a, const int numel) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + T* A = static_cast(a); + if (id < numel) { + A[id] = logf(1 + A[id]); + } +} + +extern "C" diopiError_t diopiLog1pInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) { + auto stream = impl::cuda::getStream(ctx); + auto trInput = impl::cuda::makeTensor(input); + + int blockSize = 256; + int gridSize = (trInput.numel() + blockSize - 1) / blockSize; + + DISPATCH_DTYPE(vecLog1pInp, trInput.dtype(), gridSize, blockSize, stream, + trInput.data(), trInput.numel()); + + return diopiSuccess; +} \ No newline at end of file diff --git a/impl/cuda/test/CMakeLists.txt b/impl/cuda/test/CMakeLists.txt index 3cb7ae181..2b28d5af0 100644 --- a/impl/cuda/test/CMakeLists.txt +++ b/impl/cuda/test/CMakeLists.txt @@ -11,6 +11,8 @@ add_subdirectory(${DIOPI_IMPL_DIR}/third_party/pybind11 build) include_directories(SYSTEM "${PROJECT_SOURCE_DIR}/test/include") include_directories(SYSTEM "${PROJECT_SOURCE_DIR}/../third_party/pybind11/include") +set(FUNCTION_SAVE_PATH "${DIOPI_TEST_DIR}/diopi_stub/csrc") +set(TEST_GEN_PATH "${DIOPI_TEST_DIR}/diopi_stub/codegen") set(RUNTIME_SRC litert.cpp conform_test.cpp @@ -28,30 +30,24 @@ cuda_add_library(diopirt SHARED ${RUNTIME_SRC}) target_link_libraries(${DIOPIRT} PRIVATE diopirt) target_link_libraries(diopirt ${DEVICEIMPL}) -set(FUNCTION_SAVE_PATH "${DIOPI_TEST_DIR}/csrc") +file(GLOB TEST_TEMPLATE_CODE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${DIOPI_TEST_DIR}/diopi_stub/codegen/*.py) +add_custom_target(test_gen_dependency DEPENDS ${TEST_TEMPLATE_CODE}) -set(TEST_GEN_PATH "${DIOPI_TEST_DIR}/codegen") +set(GEN_FILES ${FUNCTION_SAVE_PATH}/export_functions.cpp) add_custom_target(test_code_gen COMMAND python3 ${TEST_GEN_PATH}/gen.py - --device=torch --use_adaptor=${USE_ADAPTOR}) -add_custom_target(functions_copy ALL - COMMAND ln -f ${FUNCTION_SAVE_PATH}/export_functions.cpp ${PROJECT_SOURCE_DIR}/test - DEPENDS test_code_gen) - -file(TOUCH export_functions.cpp) -set(FUNCTIONS_SRC - export_functions.cpp -) + --device=torch + BYPRODUCTS ${GEN_FILES} + DEPENDS test_gen_dependency) + +set(FUNCTIONS_SRC ${GEN_FILES}) pybind11_add_module(${DIOPIFUNCTIONS} SHARED ${FUNCTIONS_SRC}) target_link_libraries(${DIOPIFUNCTIONS} PRIVATE diopirt) -add_dependencies(${DIOPIFUNCTIONS} functions_copy) -if(${USE_ADAPTOR} STREQUAL "true") - add_dependencies(${DIOPIFUNCTIONS} adaptor_code_gen) -endif() +add_dependencies(${DIOPIFUNCTIONS} test_code_gen) file(MAKE_DIRECTORY ${DIOPI_TEST_DIR}/python) add_custom_target(python_copy ALL - COMMAND ln -f ${LIBRARY_OUTPUT_PATH}/$ ${DIOPI_TEST_DIR}/python - COMMAND ln -f ${LIBRARY_OUTPUT_PATH}/$ ${DIOPI_TEST_DIR}/python + COMMAND ln -f ${LIBRARY_OUTPUT_PATH}/$ ${DIOPI_TEST_DIR}/python/diopilib + COMMAND ln -f ${LIBRARY_OUTPUT_PATH}/$ ${DIOPI_TEST_DIR}/python/diopilib DEPENDS ${DIOPIFUNCTIONS} ${DIOPIRT}) diff --git a/impl/cuda/test/conform_test.cpp b/impl/cuda/test/conform_test.cpp index 2e48956e8..6ad9882c8 100644 --- a/impl/cuda/test/conform_test.cpp +++ b/impl/cuda/test/conform_test.cpp @@ -9,6 +9,9 @@ #include #include +#include + +#include "litert.hpp" extern "C" { @@ -80,6 +83,15 @@ diopiError_t initLibrary() { return diopiSuccess; } diopiError_t finalizeLibrary() { return diopiSuccess; } +diopiError_t buildGeneratorState(diopiContextHandle_t ctx, diopiTensorHandle_t out) { + std::vector vec{808}; + diopiSize_t size{vec.data(), static_cast(vec.size())}; + diopiTensorHandle_t tensor = nullptr; + diopiRequireTensor(ctx, &tensor, &size, nullptr, diopi_dtype_uint8, diopi_host); + *out = *tensor; + return diopiSuccess; +} + } // extern "C" namespace impl { diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 0313ebe6a..0312be32b 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -1069,6 +1069,23 @@ diopiError_t diopiLog10Inp(diopiContextHandle_t ctx, diopiTensorHandle_t input) return diopiSuccess; } +diopiError_t diopiLog1p(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atOut = impl::aten::buildATen(out); + CALL_ATEN_CUDA_FUNC(log1p_out, atOut, atInput); + + return diopiSuccess; +} + +diopiError_t diopiLog1pInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + CALL_ATEN_CUDA_FUNC(log1p_, atInput); + + return diopiSuccess; +} + diopiError_t diopiErf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 578286245..a12bdbaca 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -1023,6 +1023,22 @@ DIOPI_API diopiError_t diopiLog10Inp(diopiContextHandle_t ctx, diopiTensorHandle */ DIOPI_API diopiError_t diopiLog10(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); +/** + * @brief The in-place version of diopiLog1p. + * @param[in] ctx Context environment. + * @param[in] input the input tensor. + * @param[out] out the output tensor. + */ +DIOPI_API diopiError_t diopiLog1pInp(diopiContextHandle_t ctx, diopiTensorHandle_t input); + +/** + * @brief Compute the element-wise natural logarithm of 1 plus the input tensor. + * @param[in] ctx Context environment. + * @param[in] input the input tensor. + * @param[out] out the output tensor. + */ +DIOPI_API diopiError_t diopiLog1p(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); + DIOPI_API diopiError_t diopiErfInp(diopiContextHandle_t ctx, diopiTensorHandle_t input); DIOPI_API diopiError_t diopiErf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); From 1149415df56d78ac46de6b7ab3782749a13cb5b3 Mon Sep 17 00:00:00 2001 From: Lingjie Date: Tue, 13 Aug 2024 20:56:01 +0800 Subject: [PATCH 5/6] build(droplet): add DEVICEIMPL as imported library (#1341) build: add DEVICEIMPL as imported library in droplet --- impl/droplet/CMakeLists.txt | 5 ++++- impl/droplet/test/CMakeLists.txt | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/impl/droplet/CMakeLists.txt b/impl/droplet/CMakeLists.txt index f6ec94ff6..57f3ca5d7 100644 --- a/impl/droplet/CMakeLists.txt +++ b/impl/droplet/CMakeLists.txt @@ -14,9 +14,12 @@ endif() find_library(DIOPI_DROPLET_LIB NAMES ${DEVICEIMPL} HINTS ${DIOPI_DROPLET_DIR}/lib) message(STATUS "DIOPI-IMPL lib: ${DIOPI_DROPLET_LIB}") if(NOT DIOPI_DROPLET_LIB) - message(FATAL_ERROR "${DEVICEIMPL} library not found !") + message(FATAL_ERROR "${DEVICEIMPL} library not found !") endif() +add_library(${DEVICEIMPL} SHARED IMPORTED GLOBAL) +set_target_properties(${DEVICEIMPL} PROPERTIES IMPORTED_LOCATION ${DIOPI_DROPLET_LIB}) + if (TEST) add_subdirectory(test) endif() diff --git a/impl/droplet/test/CMakeLists.txt b/impl/droplet/test/CMakeLists.txt index ce5748e61..65c53e3e3 100644 --- a/impl/droplet/test/CMakeLists.txt +++ b/impl/droplet/test/CMakeLists.txt @@ -52,7 +52,7 @@ endif() target_link_libraries(${DIOPI_EXPORT_RT} PRIVATE -Wl,--no-as-needed diopiruntime -Wl,--as-needed) target_link_libraries(diopiruntime tangrt_shared) -target_link_libraries(diopiruntime ${DIOPI_DROPLET_LIB}) +target_link_libraries(diopiruntime ${DEVICEIMPL}) set(FUNCTION_SAVE_PATH "${DIOPI_TEST_DIR}/diopi_stub/csrc") From b86d7efe2c71e4dec48ec80ff2326ad42e4bdce8 Mon Sep 17 00:00:00 2001 From: liujingfeng4A069 Date: Wed, 14 Aug 2024 11:14:52 +0800 Subject: [PATCH 6/6] [ascend] replace token_softmax_reducev_inference and token_attention_inference (#1321) --- impl/ascend/common/acloprunner.hpp | 4 + impl/ascend/common/utils.cpp | 43 +++++++ impl/ascend/convert_config.yaml | 9 ++ impl/ascend/functions/syn_batch_norm.cpp | 34 ++++++ .../token_attention_inference.cpp | 107 ++++++++++++++++++ .../token_softmax_reducev_inference.cpp | 103 +++++++++++++++++ impl/ascend_npu/CMakeLists.txt | 3 + impl/ascend_npu/ascend_config.yaml | 7 +- 8 files changed, 308 insertions(+), 2 deletions(-) create mode 100644 impl/ascend/functions/syn_batch_norm.cpp create mode 100644 impl/ascend/functions_ext/token_attention_inference.cpp create mode 100644 impl/ascend/functions_ext/token_softmax_reducev_inference.cpp mode change 100755 => 100644 impl/ascend_npu/CMakeLists.txt mode change 100755 => 100644 impl/ascend_npu/ascend_config.yaml diff --git a/impl/ascend/common/acloprunner.hpp b/impl/ascend/common/acloprunner.hpp index c6a845f02..3fdae6df8 100644 --- a/impl/ascend/common/acloprunner.hpp +++ b/impl/ascend/common/acloprunner.hpp @@ -64,6 +64,10 @@ diopiError_t makeOnesLike(diopiContextHandle_t ctx, diopiTensorHandle_t* out, di diopiTensorHandle_t hostToDevice(diopiContextHandle_t ctx, diopiConstTensorHandle_t src); +AscendTensor hostToDeviceAsync(diopiContextHandle_t ctx, const AscendTensor& hostTensor); + +AscendTensor deviceToHostSync(diopiContextHandle_t ctx, const AscendTensor& deviceTensor); + inline std::vector calcStrides(int ndims, diopiSize_t size, diopiMemoryFormat_t format = diopiMemoryFormat_t::Contiguous) { std::vector strides; strides.resize(ndims); diff --git a/impl/ascend/common/utils.cpp b/impl/ascend/common/utils.cpp index fe9da2c0e..b29465270 100644 --- a/impl/ascend/common/utils.cpp +++ b/impl/ascend/common/utils.cpp @@ -692,6 +692,49 @@ diopiTensorHandle_t hostToDevice(diopiContextHandle_t ctx, diopiConstTensorHandl } } +AscendTensor hostToDeviceAsync(diopiContextHandle_t ctx, const AscendTensor& hostTensor) { + diopiDevice_t device = hostTensor.device(); + + if (device == diopi_host) { + diopiTensorHandle_t dst; + diopiSize_t size{hostTensor.shape().data(), hostTensor.dim()}; + diopiSize_t stride{hostTensor.stride().data(), (int64_t)hostTensor.stride().size()}; + diopiDtype_t dtype = hostTensor.dtype(); + diopiRequireTensor(ctx, &dst, &size, &stride, dtype, diopi_device); + const void* srcPtr = hostTensor.data(); + void* dstPtr; + diopiGetTensorData(dst, &dstPtr); + diopiStreamHandle_t stream; + diopiGetStream(ctx, &stream); + int64_t elemsize = hostTensor.numel() * hostTensor.elemsize(); + CALL_ACLRT(aclrtMemcpyAsync(dstPtr, elemsize, const_cast(srcPtr), elemsize, ACL_MEMCPY_HOST_TO_DEVICE, stream)); + return AscendTensor(dst); + } else { + return hostTensor; + } +} + +AscendTensor deviceToHostSync(diopiContextHandle_t ctx, const AscendTensor& deviceTensor) { + if (deviceTensor.device() == diopi_device) { + diopiTensorHandle_t dst; + diopiSize_t size{deviceTensor.shape().data(), deviceTensor.dim()}; + diopiSize_t stride{deviceTensor.stride().data(), (int64_t)deviceTensor.stride().size()}; + diopiDtype_t dtype = deviceTensor.dtype(); + diopiRequireTensor(ctx, &dst, &size, &stride, dtype, diopi_host); + const void* srcPtr = deviceTensor.data(); + void* dstPtr; + diopiGetTensorData(dst, &dstPtr); + diopiStreamHandle_t stream; + diopiGetStream(ctx, &stream); + int64_t elemsize = deviceTensor.numel() * deviceTensor.elemsize(); + CALL_ACLRT(aclrtMemcpyAsync(dstPtr, elemsize, const_cast(srcPtr), elemsize, ACL_MEMCPY_DEVICE_TO_HOST, stream)); + CALL_ACLRT(aclrtSynchronizeStream(stream)); + return AscendTensor(dst); + } else { + return deviceTensor; + } +} + static diopiError_t choiceDtype(const std::set& opSupportedDtypes, diopiDtype_t* dtype) { if (opSupportedDtypes.find(diopi_dtype_float32) != opSupportedDtypes.end()) { *dtype = diopi_dtype_float32; diff --git a/impl/ascend/convert_config.yaml b/impl/ascend/convert_config.yaml index ac320648a..50b78be98 100755 --- a/impl/ascend/convert_config.yaml +++ b/impl/ascend/convert_config.yaml @@ -479,3 +479,12 @@ - diopiMaxPool2dBackward: tensor_dtype: indices: (int64)->int32 + +- diopiBatchNormStats: + dtype: (float64)->float32 + +- diopiBatchNormGatherStatsWithCounts: + dtype: (float64)->float32 + +- diopiBatchNormBackwardReduce: + dtype: (float64)->float32 diff --git a/impl/ascend/functions/syn_batch_norm.cpp b/impl/ascend/functions/syn_batch_norm.cpp new file mode 100644 index 000000000..60b183fca --- /dev/null +++ b/impl/ascend/functions/syn_batch_norm.cpp @@ -0,0 +1,34 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ + +#include "../aclnn/adaptor.hpp" + +namespace impl { +namespace ascend { + +diopiError_t diopiBatchNormStats(diopiContextHandle_t ctx, diopiTensorHandle_t mean, diopiTensorHandle_t invstd, diopiConstTensorHandle_t input, double eps) { + DIOPI_ASCEND_CALL_ACLNN(aclnnBatchNormStats, ctx, input, eps, mean, invstd); + return diopiSuccess; +} + +diopiError_t diopiBatchNormBackwardReduce(diopiContextHandle_t ctx, diopiTensorHandle_t sumDy, diopiTensorHandle_t sumDyXmu, diopiTensorHandle_t gradWeight, + diopiTensorHandle_t gradBias, diopiConstTensorHandle_t gradOut, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t mean, diopiConstTensorHandle_t invstd, diopiConstTensorHandle_t weight, bool inputG, + bool weightG, bool biasG) { + DIOPI_ASCEND_CALL_ACLNN( + aclnnBatchNormReduceBackward, ctx, gradOut, input, mean, invstd, weight, inputG, weightG, biasG, sumDy, sumDyXmu, gradWeight, gradBias); + return diopiSuccess; +} + +diopiError_t diopiBatchNormGatherStatsWithCounts(diopiContextHandle_t ctx, diopiTensorHandle_t mean, diopiTensorHandle_t invstd, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t meanAll, diopiConstTensorHandle_t invstdAll, diopiTensorHandle_t runningMean, + diopiTensorHandle_t runningVar, float momentum, float eps, diopiConstTensorHandle_t counts) { + DIOPI_ASCEND_CALL_ACLNN(aclnnBatchNormGatherStatsWithCounts, ctx, input, meanAll, invstdAll, runningMean, runningVar, momentum, eps, counts, mean, invstd); + return diopiSuccess; +} + +} // namespace ascend +} // namespace impl diff --git a/impl/ascend/functions_ext/token_attention_inference.cpp b/impl/ascend/functions_ext/token_attention_inference.cpp new file mode 100644 index 000000000..677e0dfe0 --- /dev/null +++ b/impl/ascend/functions_ext/token_attention_inference.cpp @@ -0,0 +1,107 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ + +#include "../aclnn/adaptor.hpp" +#include "../common/acloprunner.hpp" +#include "impl_functions.hpp" + +namespace impl { +namespace ascend { + +diopiError_t diopiTokenAttentionInference(diopiContextHandle_t ctx, diopiTensorHandle_t attentionOut, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, + diopiConstTensorHandle_t bLoc, diopiConstTensorHandle_t bStartLoc, diopiConstTensorHandle_t bSeqLen, + int maxInputLen) { + AscendTensor attentionOutAt(attentionOut), qAt(q), kAt(k), bLocAt(bLoc), bStartLocAt(bStartLoc), bSeqLenAt(bSeqLen); + int batch = bLocAt.shape(0); + int head = qAt.shape(1); + int dim = qAt.shape(2); + qAt = qAt.view({batch, head, 1, dim}); + diopiDtype_t dtype = qAt.dtype(); + diopiDevice_t device = qAt.device(); + + AscendTensor bSeqLenHostAt = deviceToHostSync(ctx, bSeqLenAt); + AscendTensor bStartLocHostAt = deviceToHostSync(ctx, bStartLocAt); + + const int* bSeqLenAtData = reinterpret_cast(bSeqLenHostAt.data()); + const int* bStartLocAtData = reinterpret_cast(bStartLocHostAt.data()); + + for (int i = 0; i < batch; i++) { + int curSeqLen = *(bSeqLenAtData + i); + int curSeqStartLoc = *(bStartLocAtData + i); + AscendTensor kLocAt, indexAt; + makeTensor(ctx, indexAt, {curSeqLen}, diopi_dtype_int32); + diopiScalar_t start = constructDiopiScalarT(diopi_dtype_int32, maxInputLen - curSeqLen); + diopiScalar_t end = constructDiopiScalarT(diopi_dtype_int32, maxInputLen); + diopiScalar_t step = constructDiopiScalarT(diopi_dtype_int32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &start, &end, &step, indexAt); + + AscendTensor bLocAtSlice; + makeTensor(ctx, bLocAtSlice, {1, bLocAt.shape(1)}, bLocAt.dtype()); + + diopiScalar_t sliceIndexScalar = constructDiopiScalarT(diopi_dtype_int32, i); + AscendTensor sliceIndexAt; + makeTensorFromScalar(ctx, sliceIndexAt, &sliceIndexScalar, bLocAt.device()); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, bLocAt, 0, sliceIndexAt, bLocAtSlice); + bLocAtSlice.view({bLocAt.shape(1)}); + makeTensor(ctx, kLocAt, {curSeqLen}, bLocAt.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, bLocAtSlice, 0, indexAt, kLocAt); + + diopiTensorHandle_t keyTmp; + diopiConstTensorHandle_t indexAtHandle = kLocAt.tensorHandle(); + ascend_npu::diopiIndex(ctx, &keyTmp, k, &indexAtHandle, 1); + + AscendTensor keyTmpAt(keyTmp); + + keyTmpAt = keyTmpAt.unsqueeze(0); + AscendTensor keyAt; + makeTensor(ctx, keyAt, {1, head, curSeqLen, dim}, keyTmpAt.dtype()); + std::vector dims{0, 2, 1, 3}; + diopiSize_t permuteDims = vectorToDiopiSize(dims); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, keyTmpAt, permuteDims, keyAt); + + AscendTensor outLocAt; + makeTensor(ctx, outLocAt, {curSeqLen}, diopi_dtype_int32); + diopiScalar_t startScalar = constructDiopiScalarT(diopi_dtype_int32, curSeqStartLoc); + diopiScalar_t endScalar = constructDiopiScalarT(diopi_dtype_int32, curSeqStartLoc + curSeqLen); + diopiScalar_t stepScalar = constructDiopiScalarT(diopi_dtype_int32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &startScalar, &endScalar, &stepScalar, outLocAt); + + AscendTensor scalarTensor; + diopiScalar_t scalarI = constructDiopiScalarT(diopi_dtype_int64, i); + makeTensorFromScalar(ctx, scalarTensor, &scalarI, qAt.device()); + + diopiTensorHandle_t qIndex; + diopiConstTensorHandle_t scalarTensorHandle = scalarTensor.tensorHandle(); + ascend_npu::diopiIndex(ctx, &qIndex, qAt.tensorHandle(), &scalarTensorHandle, 1); + + AscendTensor qIndexAt(qIndex); + + AscendTensor matmulOutAt; + makeTensor(ctx, matmulOutAt, {keyAt.shape(0), keyAt.shape(1), qIndexAt.shape(0), keyAt.shape(2)}, keyAt.dtype()); + qIndexAt.unsqueeze(0); + + AscendTensor keyTmp2At; + makeTensor(ctx, keyTmp2At, {keyAt.shape(0), keyAt.shape(1), keyAt.shape(3), keyAt.shape(2)}, keyAt.dtype()); + dims = {0, 1, 3, 2}; + permuteDims = vectorToDiopiSize(dims); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, keyAt, permuteDims, keyTmp2At); + + DIOPI_ASCEND_CALL_ACLNN( + aclnnMatmul, ctx, qIndexAt.view({qIndexAt.shape(0), qIndexAt.shape(2), qIndexAt.shape(1), qIndexAt.shape(3)}), keyTmp2At, matmulOutAt, 0); + + AscendTensor sqrtDimAt; + diopiScalar_t sqrtDim = constructDiopiScalarT(qAt.dtype(), sqrt(dim)); + makeTensorFromScalar(ctx, sqrtDimAt, &sqrtDim, matmulOutAt.device()); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceDiv, ctx, matmulOutAt, sqrtDimAt); + + std::vector indices{AscendTensor(), outLocAt}; + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, attentionOutAt, indices, matmulOutAt.view({head, curSeqLen}), false, true); + } + return diopiSuccess; +} + +} // namespace ascend +} // namespace impl diff --git a/impl/ascend/functions_ext/token_softmax_reducev_inference.cpp b/impl/ascend/functions_ext/token_softmax_reducev_inference.cpp new file mode 100644 index 000000000..4991716f5 --- /dev/null +++ b/impl/ascend/functions_ext/token_softmax_reducev_inference.cpp @@ -0,0 +1,103 @@ +/** + * @file + * @author DeepLink + * @copyright (c) 2024, DeepLink. + */ + +#include + +#include "../aclnn/adaptor.hpp" +#include "../common/acloprunner.hpp" +#include "impl_functions.hpp" + +namespace impl { +namespace ascend { + +diopiError_t diopiTokenSoftmaxReduceVInference(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t logics, diopiConstTensorHandle_t v, + diopiConstTensorHandle_t bLoc, diopiConstTensorHandle_t bStartLoc, diopiConstTensorHandle_t bSeqLen, + int maxInputLen, int otherKVIndex) { + AscendTensor outAt(out), logicsAt(logics), vAt(v), bLocAt(bLoc), bStartLocAt(bStartLoc), bSeqLenAt(bSeqLen); + int batch = bLocAt.shape(0); + int head = vAt.shape(1); + int dim = vAt.shape(2); + diopiDtype_t dtype = logicsAt.dtype(); + diopiDevice_t device = logicsAt.device(); + + AscendTensor bSeqLenHostAt = deviceToHostSync(ctx, bSeqLenAt); + AscendTensor bStartLocHostAt = deviceToHostSync(ctx, bStartLocAt); + + const int* bSeqLenAtData = reinterpret_cast(bSeqLenHostAt.data()); + const int* bStartLocAtData = reinterpret_cast(bStartLocHostAt.data()); + + for (int i = 0; i < batch; i++) { + int curSeqLen = *(bSeqLenAtData + i); + int curSeqStartLoc = *(bStartLocAtData + i); + AscendTensor indexAt; + makeTensor(ctx, indexAt, {curSeqLen}, diopi_dtype_int32); + diopiScalar_t start = constructDiopiScalarT(diopi_dtype_int32, curSeqStartLoc); + diopiScalar_t end = constructDiopiScalarT(diopi_dtype_int32, curSeqStartLoc + curSeqLen); + diopiScalar_t step = constructDiopiScalarT(diopi_dtype_int32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &start, &end, &step, indexAt); + + diopiTensorHandle_t indexOut; + diopiConstTensorHandle_t indices[2] = {diopiConstTensorHandle_t(), indexAt.tensorHandle()}; + ascend_npu::diopiIndex(ctx, &indexOut, logicsAt.tensorHandle(), indices, 2); + AscendTensor indexOutAt(indexOut); + + AscendTensor softmaxOutAt; + makeTensor(ctx, softmaxOutAt, indexOutAt.shape(), indexOutAt.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnSoftmax, ctx, indexOutAt, indexOutAt.dim() - 1, softmaxOutAt); + + softmaxOutAt = softmaxOutAt.view({head, 1, 1, curSeqLen}); + AscendTensor pAt; + makeTensor(ctx, pAt, {softmaxOutAt.shape(1), softmaxOutAt.shape(0), softmaxOutAt.shape(2), softmaxOutAt.shape(3)}, logicsAt.dtype()); + std::vector dims{1, 0, 2, 3}; + diopiSize_t permuteDims = vectorToDiopiSize(dims); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, softmaxOutAt, permuteDims, pAt); + + makeTensor(ctx, indexAt, {curSeqLen}, diopi_dtype_int32); + diopiScalar_t startVLoc = constructDiopiScalarT(diopi_dtype_int32, maxInputLen - curSeqLen); + diopiScalar_t endVLoc = constructDiopiScalarT(diopi_dtype_int32, maxInputLen); + diopiScalar_t stepvLoc = constructDiopiScalarT(diopi_dtype_int32, 1); + DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &startVLoc, &endVLoc, &stepvLoc, indexAt); + + AscendTensor bLocAtSlice; + makeTensor(ctx, bLocAtSlice, {1, bLocAt.shape(1)}, bLocAt.dtype()); + diopiScalar_t sliceIndexScalar = constructDiopiScalarT(diopi_dtype_int32, i); + AscendTensor sliceIndexAt; + makeTensorFromScalar(ctx, sliceIndexAt, &sliceIndexScalar, bLocAt.device()); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, bLocAt, 0, sliceIndexAt, bLocAtSlice); + bLocAtSlice.view({bLocAt.shape(1)}); + + AscendTensor vLocAt; + makeTensor(ctx, vLocAt, {curSeqLen}, bLocAt.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, bLocAtSlice, 0, indexAt, vLocAt); + + diopiTensorHandle_t vIndexOut; + diopiConstTensorHandle_t indexAtHandle = vLocAt.tensorHandle(); + ascend_npu::diopiIndex(ctx, &vIndexOut, vAt.tensorHandle(), &indexAtHandle, 1); + + AscendTensor vIndexOutAt(vIndexOut); + vIndexOutAt = vIndexOutAt.view({1, curSeqLen, head, dim}); + + AscendTensor vAt; + makeTensor(ctx, vAt, {1, head, curSeqLen, dim}, vIndexOutAt.dtype()); + dims = {0, 2, 1, 3}; + permuteDims = vectorToDiopiSize(dims); + DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, vIndexOutAt, permuteDims, vAt); + + AscendTensor matmulOutAt; + makeTensor(ctx, matmulOutAt, {pAt.shape(0), pAt.shape(1), pAt.shape(2), vAt.shape(3)}, pAt.dtype()); + DIOPI_ASCEND_CALL_ACLNN(aclnnMatmul, ctx, pAt, vAt, matmulOutAt, 0); + + diopiScalar_t scalarI = constructDiopiScalarT(diopi_dtype_int32, i); + AscendTensor tensorI; + makeTensorFromScalar(ctx, tensorI, &scalarI, matmulOutAt.device()); + std::vector indexPutIndices{tensorI}; + DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, outAt, indexPutIndices, matmulOutAt.view({head, dim}), false, true); + } + return diopiSuccess; +} + +} // namespace ascend +} // namespace impl diff --git a/impl/ascend_npu/CMakeLists.txt b/impl/ascend_npu/CMakeLists.txt old mode 100755 new mode 100644 index 84f285f0b..79ce646e2 --- a/impl/ascend_npu/CMakeLists.txt +++ b/impl/ascend_npu/CMakeLists.txt @@ -194,6 +194,7 @@ set(OLD_IMPL_SRC ${OLD_IMPL_DIR}/functions/equal.cpp ${OLD_IMPL_DIR}/functions/masked_select.cpp ${OLD_IMPL_DIR}/functions/unique.cpp + ${OLD_IMPL_DIR}/functions/syn_batch_norm.cpp ${OLD_IMPL_DIR}/functions_mmcv/roi_align_npu.cpp ${OLD_IMPL_DIR}/functions_ext/rms_norm.cpp ${OLD_IMPL_DIR}/functions_ext/rotary_embedding.cpp @@ -205,6 +206,8 @@ set(OLD_IMPL_SRC ${OLD_IMPL_DIR}/functions_ext/prompt_flash_attention.cpp ${OLD_IMPL_DIR}/functions_ext/paged_attention.cpp ${OLD_IMPL_DIR}/functions_ext/matmul_all_reduce.cpp + ${OLD_IMPL_DIR}/functions_ext/token_attention_inference.cpp + ${OLD_IMPL_DIR}/functions_ext/token_softmax_reducev_inference.cpp #${OLD_IMPL_DIR}/test/export_functions.cpp #${OLD_IMPL_DIR}/test/conform_test.cpp ${OLD_IMPL_DIR}/common/utils.cpp diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml old mode 100755 new mode 100644 index 7def339c0..74b1c8733 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -24,6 +24,9 @@ ascend: - diopiBaddbmmInp - diopiBatchNorm - diopiBatchNormBackward +- diopiBatchNormStats +- diopiBatchNormBackwardReduce +- diopiBatchNormGatherStatsWithCounts - diopiBitwiseAnd - diopiBitwiseAndInp - diopiBitwiseAndInpScalar @@ -249,6 +252,8 @@ ascend: - diopiThreshold - diopiThresholdBackward - diopiThresholdInp +- diopiTokenAttentionInference +- diopiTokenSoftmaxReduceVInference - diopiTopk - diopiTranspose - diopiTril @@ -277,5 +282,3 @@ ascend_npu: - diopiScaledMaskedSoftmax - diopiScaledMaskedSoftmaxBackward - diopiTensorDestructionHook -- diopiTokenAttentionInference -- diopiTokenSoftmaxReduceVInference