Skip to content

Commit

Permalink
add log1p for diopi torch & add log1pInp & add cuda test device_confi…
Browse files Browse the repository at this point in the history
…g & update diopi test
  • Loading branch information
DoorKickers committed Aug 12, 2024
1 parent e42a15e commit 183eff2
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 24 deletions.
25 changes: 4 additions & 21 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@
),

'pointwise_op_abs_input': dict(
name=['log', 'log2', 'log10', 'sqrt', 'rsqrt'],
name=['log', 'log2', 'log10', 'log1p', 'sqrt', 'rsqrt'],
interface=['torch'],
is_inplace=True,
dtype=[np.float16, np.float32, np.float64],
Expand All @@ -1137,7 +1137,7 @@
),

'log_integer_input': dict(
name=['log', 'log2', 'log10'],
name=['log', 'log2', 'log10', 'log1p'],
interface=['torch'],
dtype=[np.int16, np.int32, np.int64, np.uint8, np.int8],
tensor_para=dict(
Expand All @@ -1153,25 +1153,8 @@
),
),

'log1p': dict(
name=['log1p'],
interface=['torch'],
dtype=[np.float32, np.float64],
tensor_para=dict(
gen_fn='Genfunc.positive',
args=[
{
"ins": ['input'],
"shape": ((1, ), (1024,), (364800, 4), (2, 128, 3072),
(256, 128, 3, 3),
(2, 31, 512, 6, 40)),
},
],
),
),

'log_zero_input': dict(
name=['log', 'log2', 'log10'],
name=['log', 'log2', 'log10', 'log1p'],
interface=['torch'],
dtype=[np.float16, np.float32, np.float64,
np.int16, np.int32, np.int64,
Expand All @@ -1190,7 +1173,7 @@
),

'log_neg_input': dict(
name=['log', 'log2', 'log10'],
name=['log', 'log2', 'log10', 'log1p'],
interface=['torch'],
dtype=[np.float16, np.float32, np.float64,
np.int16, np.int32, np.int64,
Expand Down
4 changes: 3 additions & 1 deletion diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,10 @@ def log2(input, inplace=False) -> Tensor:
def log10(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiLog10", promote_type(input, Dtype.float32))


def log1p(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiLog1p")
return unary_op(input, inplace, "diopiLog1p", promote_type(input, Dtype.float32))


def erf(input, inplace=False) -> Tensor:
return unary_op(input, inplace, "diopiErf", promote_type(input, Dtype.float32))
Expand Down
23 changes: 23 additions & 0 deletions impl/cuda/functions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,28 @@ extern "C" diopiError_t diopiLog1p(diopiContextHandle_t ctx, diopiTensorHandle_t
DISPATCH_DTYPE(vecLog1p, trInput.dtype(), gridSize, blockSize, stream,
trInput.data(), trOut.data(), trInput.numel());

return diopiSuccess;
}

template<typename T> __global__
void vecLog1pInp(void* a, const int numel) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
T* A = static_cast<T*>(a);
if (id < numel) {
A[id] = logf(1 + A[id]);
}
}

extern "C" diopiError_t diopiLog1pInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) {

auto stream = impl::cuda::getStream(ctx);
auto trInput = impl::cuda::makeTensor(input);

int blockSize = 256;
int gridSize = (trInput.numel() + blockSize - 1) / blockSize;

DISPATCH_DTYPE(vecLog1pInp, trInput.dtype(), gridSize, blockSize, stream,
trInput.data(), trInput.numel());

return diopiSuccess;
}
17 changes: 17 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,23 @@ diopiError_t diopiLog10Inp(diopiContextHandle_t ctx, diopiTensorHandle_t input)
return diopiSuccess;
}

diopiError_t diopiLog1p(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_CUDA_FUNC(log1p_out, atOut, atInput);

return diopiSuccess;
}

diopiError_t diopiLog1pInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
at::log1p_(atInput);

return diopiSuccess;
}

diopiError_t diopiErf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
Expand Down
12 changes: 10 additions & 2 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -1026,11 +1026,19 @@ DIOPI_API diopiError_t diopiLog10(diopiContextHandle_t ctx, diopiTensorHandle_t
/**
* @brief Compute the element-wise natural logarithm of 1 plus the input tensor.
* @param[in] ctx Context environment.
* @param[in] input the input tensor. type = [float32, float64, int16, int32].
* @param[out] out the output tensor. type = [float32, float64].
* @param[in] input the input tensor.
* @param[out] out the output tensor.
*/
DIOPI_API diopiError_t diopiLog1p(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

/**
* @brief The in-place version of diopiLog1p.
* @param[in] ctx Context environment.
* @param[in] input the input tensor.
* @param[out] out the output tensor.
*/
DIOPI_API diopiError_t diopiLog1pInp(diopiContextHandle_t ctx, diopiTensorHandle_t input);

DIOPI_API diopiError_t diopiErfInp(diopiContextHandle_t ctx, diopiTensorHandle_t input);
DIOPI_API diopiError_t diopiErf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

Expand Down

0 comments on commit 183eff2

Please sign in to comment.