Skip to content

Commit

Permalink
add argsort on diopi torch impl
Browse files Browse the repository at this point in the history
  • Loading branch information
DoorKickers committed Aug 14, 2024
1 parent 74b8872 commit a8c9223
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 0 deletions.
42 changes: 42 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5595,6 +5595,48 @@
),
),

'argsort': dict(
name=['argsort'],
interface=["torch"],
para=dict(
dim=[0, -1, 0, 1, -1, 0, 2, 1],
stable=[True, False, True, False, False, True, True, False],
descending=[True, False, True, False, False, True, True, False],
),
tensor_para=dict(
args=[
{
"ins": ['input'],
"shape": ((), (1,), (1024, 80), (2, 256, 256), (2, 1, 64, 64),
(12, 0), (2, 0, 9), (0, 9, 8, 7)),
"dtype": [np.float64, np.float16, np.float32, np.int32, np.int16,
np.int64, np.uint8, np.int8],
"gen_fn": 'Genfunc.randn',
},
],
),
),

'argsort_same_value': dict(
name=['argsort'],
interface=["torch"],
para=dict(
dim=[-1, 0, -1, 1],
stable=[True, False, True, False],
descending=[True, False, True, False],
),
tensor_para=dict(
args=[
{
"ins": ['input'],
"shape": ((1,), (1024, 80), (2, 256, 256), (2, 1, 64, 64)),
"dtype": [np.float32],
"gen_fn": 'Genfunc.zeros',
},
],
),
),

'adadelta': dict(
name=["adadelta"],
interface=["CustomizedTest"],
Expand Down
8 changes: 8 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3586,6 +3586,14 @@ def argmin(input, dim=None, keepdim=False):

return out

def argsort(input, dim=-1, descending=False, stable=False):
out = Tensor(input.size().data, from_numpy_dtype(glob_vars.int_type))
func = check_function("diopiArgsort")
ret = func(input.context(), out, input, stable, dim, descending)
check_returncode(ret)

return out


def smooth_l1_loss(input, target, reduction="mean", beta=1.0):
assert (
Expand Down
10 changes: 10 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3210,6 +3210,16 @@ diopiError_t diopiArgmin(diopiContextHandle_t ctx, diopiTensorHandle_t out, diop
return diopiSuccess;
}

diopiError_t diopiArgsort(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, bool stable, const int64_t* dim, bool descending) {
impl::aten::setCurStream(ctx);
auto atOut = impl::aten::buildATen(out);
auto atInput = impl::aten::buildATen(input);
atOut = CALL_ATEN_CUDA_FUNC(argsort, atInput, stable, (dim ? *dim : -1), descending);
impl::aten::updateATen2Tensor(ctx, atOut, out);

return diopiSuccess;
}

diopiError_t diopiSmoothL1Loss(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t target,
diopiReduction_t reduction, double beta) {
impl::aten::setCurStream(ctx);
Expand Down
12 changes: 12 additions & 0 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2723,6 +2723,18 @@ DIOPI_API diopiError_t diopiArgmax(diopiContextHandle_t ctx, diopiTensorHandle_t
*/
DIOPI_API diopiError_t diopiArgmin(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const int64_t* dim, bool keepdim);

/**
* @brief Returns the indices that sort a tensor along a given dimension in ascending order by value.
* @param[in] ctx Context environment.
* @param[in] input the input tensor. type=[float32, float64, float16, int16, int32, int64, uint8, int8, bool].
* @param[in] dim the dimension to do the operation over. type=[int32, int64].
* @param[in] descending controls the sorting order (ascending or descending).
* @param[in] stable controls the relative order of equivalent elements.
* @param[out] out the output tensor. type=[int32, int64].
*/
DIOPI_API diopiError_t diopiArgsort(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, bool stable, const int64_t* dim,
bool descending);

/**
* @brief The function is used to implement the Adadelta optimizer. Its functionality is to perform a single parameter update.
* @param[in] ctx Context environment.
Expand Down

0 comments on commit a8c9223

Please sign in to comment.