Skip to content

Commit

Permalink
add atanh on diopi torch impl
Browse files Browse the repository at this point in the history
  • Loading branch information
DoorKickers committed Aug 14, 2024
1 parent 6f4e9d6 commit 89168cd
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 0 deletions.
38 changes: 38 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,44 @@
),
),

'atanh': dict(
name=['atanh'],
interface=['torch'],
is_inplace=True,
saved_args=dict(output=0),
dtype=[np.float16, np.float32, np.float64],
tensor_para=dict(
gen_fn='Genfunc.randn',
args=[
{
"ins": ['input'],
"shape": ((), (1, ), (1024,), (364800, 4), (2, 128, 3072),
(256, 128, 3, 3),
(2, 31, 512, 6, 40),
(0,), (16, 0), (1, 0, 6)),
},
],
),
),

'atanh_not_float': dict(
name=['atanh'],
interface=['torch'],
dtype=[np.int16, np.int32, np.int64, np.uint8, np.int8, np.bool_],
tensor_para=dict(
gen_fn='Genfunc.randn',
args=[
{
"ins": ['input'],
"shape": ((), (1, ), (1024,), (364800, 4), (2, 128, 3072),
(256, 128, 3, 3),
(2, 31, 512, 6, 40),
(0,), (16, 0), (1, 0, 6)),
},
],
),
),

'sign': dict(
name=['sign'],
interface=['torch'],
Expand Down
3 changes: 3 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ def asinh(input, inplace=False) -> Tensor:
def acosh(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiAcosh", promote_type(input, Dtype.float32))

def atanh(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiAtanh", promote_type(input, Dtype.float32))


def exp(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiExp", promote_type(input, Dtype.float32))
Expand Down
17 changes: 17 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,23 @@ diopiError_t diopiAcoshInp(diopiContextHandle_t ctx, diopiTensorHandle_t input)
return diopiSuccess;
}

diopiError_t diopiAtanh(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(atanh_out, atOut, atInput);

return diopiSuccess;
}

diopiError_t diopiAtanhInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
CALL_ATEN_CUDA_FUNC(atanh_, atInput);

return diopiSuccess;
}

diopiError_t diopiSigmoid(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
Expand Down
15 changes: 15 additions & 0 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,21 @@ DIOPI_API diopiError_t diopiAcoshInp(diopiContextHandle_t ctx, diopiTensorHandle
*/
DIOPI_API diopiError_t diopiAcosh(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

/**
* @brief The in-place version of diopiAtanh().
* @param[in] ctx Context environment.
* @param[inout] input the input tensor and will be stroed reuslt tensor. type = [float16, float32, float64].
*/
DIOPI_API diopiError_t diopiAtanhInp(diopiContextHandle_t ctx, diopiTensorHandle_t input);

/**
* @brief Returns a new tensor with the arc hyperbolic tangent of the elements of input.
* @param[in] ctx Context environment.
* @param[in] input the input tensor. type = [float16, float32, float64].
* @param[out] out the output tensor. type = [float16, float32, float64].
*/
DIOPI_API diopiError_t diopiAtanh(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

/**
* @brief The in-place version of diopiSigmoid().
* @param[in] ctx Context environment.
Expand Down

0 comments on commit 89168cd

Please sign in to comment.