Skip to content

Commit

Permalink
add asinh on diopi torch impl
Browse files Browse the repository at this point in the history
  • Loading branch information
DoorKickers committed Aug 14, 2024
1 parent dca4fcc commit 766a976
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 0 deletions.
38 changes: 38 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,44 @@
),
),

'asinh': dict(
name=['asinh'],
interface=['torch'],
is_inplace=True,
saved_args=dict(output=0),
dtype=[np.float16, np.float32, np.float64],
tensor_para=dict(
gen_fn='Genfunc.randn',
args=[
{
"ins": ['input'],
"shape": ((), (1, ), (1024,), (364800, 4), (2, 128, 3072),
(256, 128, 3, 3),
(2, 31, 512, 6, 40),
(0,), (16, 0), (1, 0, 6)),
},
],
),
),

'asinh_not_float': dict(
name=['asinh'],
interface=['torch'],
dtype=[np.int16, np.int32, np.int64, np.uint8, np.int8, np.bool_],
tensor_para=dict(
gen_fn='Genfunc.randn',
args=[
{
"ins": ['input'],
"shape": ((), (1, ), (1024,), (364800, 4), (2, 128, 3072),
(256, 128, 3, 3),
(2, 31, 512, 6, 40),
(0,), (16, 0), (1, 0, 6)),
},
],
),
),

'sign': dict(
name=['sign'],
interface=['torch'],
Expand Down
4 changes: 4 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ def atan(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiAtan", promote_type(input, Dtype.float32))


def asinh(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiAsinh", promote_type(input, Dtype.float32))


def exp(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiExp", promote_type(input, Dtype.float32))

Expand Down
17 changes: 17 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,23 @@ diopiError_t diopiAtanInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) {
return diopiSuccess;
}

diopiError_t diopiAsinh(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(asinh_out, atOut, atInput);

return diopiSuccess;
}

diopiError_t diopiAsinhInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
CALL_ATEN_CUDA_FUNC(asinh_, atInput);

return diopiSuccess;
}

diopiError_t diopiSigmoid(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
Expand Down
15 changes: 15 additions & 0 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,21 @@ DIOPI_API diopiError_t diopiAtan(diopiContextHandle_t ctx, diopiTensorHandle_t o
*/
DIOPI_API diopiError_t diopiAtanInp(diopiContextHandle_t ctx, diopiTensorHandle_t input);

/**
* @brief The in-place version of diopiAsinh().
* @param[in] ctx Context environment.
* @param[inout] input the input tensor and will be stroed reuslt tensor. type = [float16, float32, float64].
*/
DIOPI_API diopiError_t diopiAsinhInp(diopiContextHandle_t ctx, diopiTensorHandle_t input);

/**
* @brief Returns a new tensor with the arc hyperbolic sine of the elements of input.
* @param[in] ctx Context environment.
* @param[in] input the input tensor. type = [float16, float32, float64].
* @param[out] out the output tensor. type = [float16, float32, float64].
*/
DIOPI_API diopiError_t diopiAsinh(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

/**
* @brief The in-place version of diopiSigmoid().
* @param[in] ctx Context environment.
Expand Down

0 comments on commit 766a976

Please sign in to comment.