From 6f4e9d657d3031b8ea5f65f4b1134f9c30b4a795 Mon Sep 17 00:00:00 2001 From: DoorKickers <1105976166@qq.com> Date: Wed, 14 Aug 2024 14:45:35 +0800 Subject: [PATCH] add acosh on diopi torch impl --- diopi_test/python/configs/diopi_configs.py | 38 +++++++++++++++++++ .../python/conformance/diopi_functions.py | 4 ++ impl/torch/functions/functions.cpp | 17 +++++++++ proto/include/diopi/functions.h | 15 ++++++++ 4 files changed, 74 insertions(+) diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index ee01db6d7..9a50cd4b6 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -1345,6 +1345,44 @@ ), ), + 'acosh': dict( + name=['acosh'], + interface=['torch'], + is_inplace=True, + saved_args=dict(output=0), + dtype=[np.float16, np.float32, np.float64], + tensor_para=dict( + gen_fn='Genfunc.randn', + args=[ + { + "ins": ['input'], + "shape": ((), (1, ), (1024,), (364800, 4), (2, 128, 3072), + (256, 128, 3, 3), + (2, 31, 512, 6, 40), + (0,), (16, 0), (1, 0, 6)), + }, + ], + ), + ), + + 'acosh_not_float': dict( + name=['acosh'], + interface=['torch'], + dtype=[np.int16, np.int32, np.int64, np.uint8, np.int8, np.bool_], + tensor_para=dict( + gen_fn='Genfunc.randn', + args=[ + { + "ins": ['input'], + "shape": ((), (1, ), (1024,), (364800, 4), (2, 128, 3072), + (256, 128, 3, 3), + (2, 31, 512, 6, 40), + (0,), (16, 0), (1, 0, 6)), + }, + ], + ), + ), + 'sign': dict( name=['sign'], interface=['torch'], diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index cca369c8d..77a52de00 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -433,6 +433,10 @@ def asinh(input, inplace=False) -> Tensor: return unary_op(input, inplace, "diopiAsinh", promote_type(input, Dtype.float32)) +def acosh(input, inplace=False) -> Tensor: + return unary_op(input, inplace, "diopiAcosh", promote_type(input, Dtype.float32)) + + def exp(input, inplace=False) -> Tensor: return unary_op(input, inplace, "diopiExp", promote_type(input, Dtype.float32)) diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index d7d199108..dabdf6ed3 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -1043,6 +1043,23 @@ diopiError_t diopiAsinhInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) return diopiSuccess; } +diopiError_t diopiAcosh(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atOut = impl::aten::buildATen(out); + CALL_ATEN_CUDA_FUNC(acosh_out, atOut, atInput); + + return diopiSuccess; +} + +diopiError_t diopiAcoshInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + CALL_ATEN_CUDA_FUNC(acosh_, atInput); + + return diopiSuccess; +} + diopiError_t diopiSigmoid(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index d5d01029c..b3c7d9fe7 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -992,6 +992,21 @@ DIOPI_API diopiError_t diopiAsinhInp(diopiContextHandle_t ctx, diopiTensorHandle */ DIOPI_API diopiError_t diopiAsinh(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); +/** + * @brief The in-place version of diopiAcosh(). + * @param[in] ctx Context environment. + * @param[inout] input the input tensor and will be stroed reuslt tensor. type = [float16, float32, float64]. + */ +DIOPI_API diopiError_t diopiAcoshInp(diopiContextHandle_t ctx, diopiTensorHandle_t input); + +/** + * @brief Returns a new tensor with the arc hyperbolic cosine of the elements of input. + * @param[in] ctx Context environment. + * @param[in] input the input tensor. type = [float16, float32, float64]. + * @param[out] out the output tensor. type = [float16, float32, float64]. + */ +DIOPI_API diopiError_t diopiAcosh(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); + /** * @brief The in-place version of diopiSigmoid(). * @param[in] ctx Context environment.