-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'DeepLink-org:main' into zq/tiny-fix
- Loading branch information
Showing
18 changed files
with
451 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/** | ||
* @file | ||
* @author DeepLink | ||
* @copyright (c) 2024, DeepLink. | ||
*/ | ||
|
||
#include "../aclnn/adaptor.hpp" | ||
|
||
namespace impl { | ||
namespace ascend { | ||
|
||
diopiError_t diopiBatchNormStats(diopiContextHandle_t ctx, diopiTensorHandle_t mean, diopiTensorHandle_t invstd, diopiConstTensorHandle_t input, double eps) { | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnBatchNormStats, ctx, input, eps, mean, invstd); | ||
return diopiSuccess; | ||
} | ||
|
||
diopiError_t diopiBatchNormBackwardReduce(diopiContextHandle_t ctx, diopiTensorHandle_t sumDy, diopiTensorHandle_t sumDyXmu, diopiTensorHandle_t gradWeight, | ||
diopiTensorHandle_t gradBias, diopiConstTensorHandle_t gradOut, diopiConstTensorHandle_t input, | ||
diopiConstTensorHandle_t mean, diopiConstTensorHandle_t invstd, diopiConstTensorHandle_t weight, bool inputG, | ||
bool weightG, bool biasG) { | ||
DIOPI_ASCEND_CALL_ACLNN( | ||
aclnnBatchNormReduceBackward, ctx, gradOut, input, mean, invstd, weight, inputG, weightG, biasG, sumDy, sumDyXmu, gradWeight, gradBias); | ||
return diopiSuccess; | ||
} | ||
|
||
diopiError_t diopiBatchNormGatherStatsWithCounts(diopiContextHandle_t ctx, diopiTensorHandle_t mean, diopiTensorHandle_t invstd, diopiConstTensorHandle_t input, | ||
diopiConstTensorHandle_t meanAll, diopiConstTensorHandle_t invstdAll, diopiTensorHandle_t runningMean, | ||
diopiTensorHandle_t runningVar, float momentum, float eps, diopiConstTensorHandle_t counts) { | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnBatchNormGatherStatsWithCounts, ctx, input, meanAll, invstdAll, runningMean, runningVar, momentum, eps, counts, mean, invstd); | ||
return diopiSuccess; | ||
} | ||
|
||
} // namespace ascend | ||
} // namespace impl |
107 changes: 107 additions & 0 deletions
107
impl/ascend/functions_ext/token_attention_inference.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/** | ||
* @file | ||
* @author DeepLink | ||
* @copyright (c) 2024, DeepLink. | ||
*/ | ||
|
||
#include "../aclnn/adaptor.hpp" | ||
#include "../common/acloprunner.hpp" | ||
#include "impl_functions.hpp" | ||
|
||
namespace impl { | ||
namespace ascend { | ||
|
||
diopiError_t diopiTokenAttentionInference(diopiContextHandle_t ctx, diopiTensorHandle_t attentionOut, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k, | ||
diopiConstTensorHandle_t bLoc, diopiConstTensorHandle_t bStartLoc, diopiConstTensorHandle_t bSeqLen, | ||
int maxInputLen) { | ||
AscendTensor attentionOutAt(attentionOut), qAt(q), kAt(k), bLocAt(bLoc), bStartLocAt(bStartLoc), bSeqLenAt(bSeqLen); | ||
int batch = bLocAt.shape(0); | ||
int head = qAt.shape(1); | ||
int dim = qAt.shape(2); | ||
qAt = qAt.view({batch, head, 1, dim}); | ||
diopiDtype_t dtype = qAt.dtype(); | ||
diopiDevice_t device = qAt.device(); | ||
|
||
AscendTensor bSeqLenHostAt = deviceToHostSync(ctx, bSeqLenAt); | ||
AscendTensor bStartLocHostAt = deviceToHostSync(ctx, bStartLocAt); | ||
|
||
const int* bSeqLenAtData = reinterpret_cast<const int*>(bSeqLenHostAt.data()); | ||
const int* bStartLocAtData = reinterpret_cast<const int*>(bStartLocHostAt.data()); | ||
|
||
for (int i = 0; i < batch; i++) { | ||
int curSeqLen = *(bSeqLenAtData + i); | ||
int curSeqStartLoc = *(bStartLocAtData + i); | ||
AscendTensor kLocAt, indexAt; | ||
makeTensor(ctx, indexAt, {curSeqLen}, diopi_dtype_int32); | ||
diopiScalar_t start = constructDiopiScalarT(diopi_dtype_int32, maxInputLen - curSeqLen); | ||
diopiScalar_t end = constructDiopiScalarT(diopi_dtype_int32, maxInputLen); | ||
diopiScalar_t step = constructDiopiScalarT(diopi_dtype_int32, 1); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &start, &end, &step, indexAt); | ||
|
||
AscendTensor bLocAtSlice; | ||
makeTensor(ctx, bLocAtSlice, {1, bLocAt.shape(1)}, bLocAt.dtype()); | ||
|
||
diopiScalar_t sliceIndexScalar = constructDiopiScalarT(diopi_dtype_int32, i); | ||
AscendTensor sliceIndexAt; | ||
makeTensorFromScalar(ctx, sliceIndexAt, &sliceIndexScalar, bLocAt.device()); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, bLocAt, 0, sliceIndexAt, bLocAtSlice); | ||
bLocAtSlice.view({bLocAt.shape(1)}); | ||
makeTensor(ctx, kLocAt, {curSeqLen}, bLocAt.dtype()); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, bLocAtSlice, 0, indexAt, kLocAt); | ||
|
||
diopiTensorHandle_t keyTmp; | ||
diopiConstTensorHandle_t indexAtHandle = kLocAt.tensorHandle(); | ||
ascend_npu::diopiIndex(ctx, &keyTmp, k, &indexAtHandle, 1); | ||
|
||
AscendTensor keyTmpAt(keyTmp); | ||
|
||
keyTmpAt = keyTmpAt.unsqueeze(0); | ||
AscendTensor keyAt; | ||
makeTensor(ctx, keyAt, {1, head, curSeqLen, dim}, keyTmpAt.dtype()); | ||
std::vector<int64_t> dims{0, 2, 1, 3}; | ||
diopiSize_t permuteDims = vectorToDiopiSize(dims); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, keyTmpAt, permuteDims, keyAt); | ||
|
||
AscendTensor outLocAt; | ||
makeTensor(ctx, outLocAt, {curSeqLen}, diopi_dtype_int32); | ||
diopiScalar_t startScalar = constructDiopiScalarT(diopi_dtype_int32, curSeqStartLoc); | ||
diopiScalar_t endScalar = constructDiopiScalarT(diopi_dtype_int32, curSeqStartLoc + curSeqLen); | ||
diopiScalar_t stepScalar = constructDiopiScalarT(diopi_dtype_int32, 1); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &startScalar, &endScalar, &stepScalar, outLocAt); | ||
|
||
AscendTensor scalarTensor; | ||
diopiScalar_t scalarI = constructDiopiScalarT(diopi_dtype_int64, i); | ||
makeTensorFromScalar(ctx, scalarTensor, &scalarI, qAt.device()); | ||
|
||
diopiTensorHandle_t qIndex; | ||
diopiConstTensorHandle_t scalarTensorHandle = scalarTensor.tensorHandle(); | ||
ascend_npu::diopiIndex(ctx, &qIndex, qAt.tensorHandle(), &scalarTensorHandle, 1); | ||
|
||
AscendTensor qIndexAt(qIndex); | ||
|
||
AscendTensor matmulOutAt; | ||
makeTensor(ctx, matmulOutAt, {keyAt.shape(0), keyAt.shape(1), qIndexAt.shape(0), keyAt.shape(2)}, keyAt.dtype()); | ||
qIndexAt.unsqueeze(0); | ||
|
||
AscendTensor keyTmp2At; | ||
makeTensor(ctx, keyTmp2At, {keyAt.shape(0), keyAt.shape(1), keyAt.shape(3), keyAt.shape(2)}, keyAt.dtype()); | ||
dims = {0, 1, 3, 2}; | ||
permuteDims = vectorToDiopiSize(dims); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, keyAt, permuteDims, keyTmp2At); | ||
|
||
DIOPI_ASCEND_CALL_ACLNN( | ||
aclnnMatmul, ctx, qIndexAt.view({qIndexAt.shape(0), qIndexAt.shape(2), qIndexAt.shape(1), qIndexAt.shape(3)}), keyTmp2At, matmulOutAt, 0); | ||
|
||
AscendTensor sqrtDimAt; | ||
diopiScalar_t sqrtDim = constructDiopiScalarT(qAt.dtype(), sqrt(dim)); | ||
makeTensorFromScalar(ctx, sqrtDimAt, &sqrtDim, matmulOutAt.device()); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceDiv, ctx, matmulOutAt, sqrtDimAt); | ||
|
||
std::vector<AscendTensor> indices{AscendTensor(), outLocAt}; | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, attentionOutAt, indices, matmulOutAt.view({head, curSeqLen}), false, true); | ||
} | ||
return diopiSuccess; | ||
} | ||
|
||
} // namespace ascend | ||
} // namespace impl |
103 changes: 103 additions & 0 deletions
103
impl/ascend/functions_ext/token_softmax_reducev_inference.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/** | ||
* @file | ||
* @author DeepLink | ||
* @copyright (c) 2024, DeepLink. | ||
*/ | ||
|
||
#include <vector> | ||
|
||
#include "../aclnn/adaptor.hpp" | ||
#include "../common/acloprunner.hpp" | ||
#include "impl_functions.hpp" | ||
|
||
namespace impl { | ||
namespace ascend { | ||
|
||
diopiError_t diopiTokenSoftmaxReduceVInference(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t logics, diopiConstTensorHandle_t v, | ||
diopiConstTensorHandle_t bLoc, diopiConstTensorHandle_t bStartLoc, diopiConstTensorHandle_t bSeqLen, | ||
int maxInputLen, int otherKVIndex) { | ||
AscendTensor outAt(out), logicsAt(logics), vAt(v), bLocAt(bLoc), bStartLocAt(bStartLoc), bSeqLenAt(bSeqLen); | ||
int batch = bLocAt.shape(0); | ||
int head = vAt.shape(1); | ||
int dim = vAt.shape(2); | ||
diopiDtype_t dtype = logicsAt.dtype(); | ||
diopiDevice_t device = logicsAt.device(); | ||
|
||
AscendTensor bSeqLenHostAt = deviceToHostSync(ctx, bSeqLenAt); | ||
AscendTensor bStartLocHostAt = deviceToHostSync(ctx, bStartLocAt); | ||
|
||
const int* bSeqLenAtData = reinterpret_cast<const int*>(bSeqLenHostAt.data()); | ||
const int* bStartLocAtData = reinterpret_cast<const int*>(bStartLocHostAt.data()); | ||
|
||
for (int i = 0; i < batch; i++) { | ||
int curSeqLen = *(bSeqLenAtData + i); | ||
int curSeqStartLoc = *(bStartLocAtData + i); | ||
AscendTensor indexAt; | ||
makeTensor(ctx, indexAt, {curSeqLen}, diopi_dtype_int32); | ||
diopiScalar_t start = constructDiopiScalarT(diopi_dtype_int32, curSeqStartLoc); | ||
diopiScalar_t end = constructDiopiScalarT(diopi_dtype_int32, curSeqStartLoc + curSeqLen); | ||
diopiScalar_t step = constructDiopiScalarT(diopi_dtype_int32, 1); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &start, &end, &step, indexAt); | ||
|
||
diopiTensorHandle_t indexOut; | ||
diopiConstTensorHandle_t indices[2] = {diopiConstTensorHandle_t(), indexAt.tensorHandle()}; | ||
ascend_npu::diopiIndex(ctx, &indexOut, logicsAt.tensorHandle(), indices, 2); | ||
AscendTensor indexOutAt(indexOut); | ||
|
||
AscendTensor softmaxOutAt; | ||
makeTensor(ctx, softmaxOutAt, indexOutAt.shape(), indexOutAt.dtype()); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnSoftmax, ctx, indexOutAt, indexOutAt.dim() - 1, softmaxOutAt); | ||
|
||
softmaxOutAt = softmaxOutAt.view({head, 1, 1, curSeqLen}); | ||
AscendTensor pAt; | ||
makeTensor(ctx, pAt, {softmaxOutAt.shape(1), softmaxOutAt.shape(0), softmaxOutAt.shape(2), softmaxOutAt.shape(3)}, logicsAt.dtype()); | ||
std::vector<int64_t> dims{1, 0, 2, 3}; | ||
diopiSize_t permuteDims = vectorToDiopiSize(dims); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, softmaxOutAt, permuteDims, pAt); | ||
|
||
makeTensor(ctx, indexAt, {curSeqLen}, diopi_dtype_int32); | ||
diopiScalar_t startVLoc = constructDiopiScalarT(diopi_dtype_int32, maxInputLen - curSeqLen); | ||
diopiScalar_t endVLoc = constructDiopiScalarT(diopi_dtype_int32, maxInputLen); | ||
diopiScalar_t stepvLoc = constructDiopiScalarT(diopi_dtype_int32, 1); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnArange, ctx, &startVLoc, &endVLoc, &stepvLoc, indexAt); | ||
|
||
AscendTensor bLocAtSlice; | ||
makeTensor(ctx, bLocAtSlice, {1, bLocAt.shape(1)}, bLocAt.dtype()); | ||
diopiScalar_t sliceIndexScalar = constructDiopiScalarT(diopi_dtype_int32, i); | ||
AscendTensor sliceIndexAt; | ||
makeTensorFromScalar(ctx, sliceIndexAt, &sliceIndexScalar, bLocAt.device()); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, bLocAt, 0, sliceIndexAt, bLocAtSlice); | ||
bLocAtSlice.view({bLocAt.shape(1)}); | ||
|
||
AscendTensor vLocAt; | ||
makeTensor(ctx, vLocAt, {curSeqLen}, bLocAt.dtype()); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnIndexSelect, ctx, bLocAtSlice, 0, indexAt, vLocAt); | ||
|
||
diopiTensorHandle_t vIndexOut; | ||
diopiConstTensorHandle_t indexAtHandle = vLocAt.tensorHandle(); | ||
ascend_npu::diopiIndex(ctx, &vIndexOut, vAt.tensorHandle(), &indexAtHandle, 1); | ||
|
||
AscendTensor vIndexOutAt(vIndexOut); | ||
vIndexOutAt = vIndexOutAt.view({1, curSeqLen, head, dim}); | ||
|
||
AscendTensor vAt; | ||
makeTensor(ctx, vAt, {1, head, curSeqLen, dim}, vIndexOutAt.dtype()); | ||
dims = {0, 2, 1, 3}; | ||
permuteDims = vectorToDiopiSize(dims); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, vIndexOutAt, permuteDims, vAt); | ||
|
||
AscendTensor matmulOutAt; | ||
makeTensor(ctx, matmulOutAt, {pAt.shape(0), pAt.shape(1), pAt.shape(2), vAt.shape(3)}, pAt.dtype()); | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnMatmul, ctx, pAt, vAt, matmulOutAt, 0); | ||
|
||
diopiScalar_t scalarI = constructDiopiScalarT(diopi_dtype_int32, i); | ||
AscendTensor tensorI; | ||
makeTensorFromScalar(ctx, tensorI, &scalarI, matmulOutAt.device()); | ||
std::vector<AscendTensor> indexPutIndices{tensorI}; | ||
DIOPI_ASCEND_CALL_ACLNN(aclnnIndexPutImpl, ctx, outAt, indexPutIndices, matmulOutAt.view({head, dim}), false, true); | ||
} | ||
return diopiSuccess; | ||
} | ||
|
||
} // namespace ascend | ||
} // namespace impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.