diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index dff448be0..91dbaf2c7 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -5595,6 +5595,48 @@ ), ), + 'argsort': dict( + name=['argsort'], + interface=["torch"], + para=dict( + dim=[0, -1, 0, 1, -1, 0, 2, 1], + stable=[True, False, True, False, False, True, True, False], + descending=[True, False, True, False, False, True, True, False], + ), + tensor_para=dict( + args=[ + { + "ins": ['input'], + "shape": ((), (1,), (1024, 80), (2, 256, 256), (2, 1, 64, 64), + (12, 0), (2, 0, 9), (0, 9, 8, 7)), + "dtype": [np.float64, np.float16, np.float32, np.int32, np.int16, + np.int64, np.uint8, np.int8], + "gen_fn": 'Genfunc.randn', + }, + ], + ), + ), + + 'argsort_same_value': dict( + name=['argsort'], + interface=["torch"], + para=dict( + dim=[-1, 0, -1, 1], + stable=[True, False, True, False], + descending=[True, False, True, False], + ), + tensor_para=dict( + args=[ + { + "ins": ['input'], + "shape": ((1,), (1024, 80), (2, 256, 256), (2, 1, 64, 64)), + "dtype": [np.float32], + "gen_fn": 'Genfunc.zeros', + }, + ], + ), + ), + 'adadelta': dict( name=["adadelta"], interface=["CustomizedTest"], diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index aef6c4076..39988e340 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -3586,6 +3586,14 @@ def argmin(input, dim=None, keepdim=False): return out +def argsort(input, dim=-1, descending=False, stable=False): + out = Tensor(input.size().data, from_numpy_dtype(glob_vars.int_type)) + func = check_function("diopiArgsort") + ret = func(input.context(), out, input, stable, dim, descending) + check_returncode(ret) + + return out + def smooth_l1_loss(input, target, reduction="mean", beta=1.0): assert ( diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 7710fd6d4..b19686cfa 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -3210,6 +3210,16 @@ diopiError_t diopiArgmin(diopiContextHandle_t ctx, diopiTensorHandle_t out, diop return diopiSuccess; } +diopiError_t diopiArgsort(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, bool stable, const int64_t* dim, bool descending) { + impl::aten::setCurStream(ctx); + auto atOut = impl::aten::buildATen(out); + auto atInput = impl::aten::buildATen(input); + atOut = CALL_ATEN_CUDA_FUNC(argsort, atInput, stable, (dim ? *dim : -1), descending); + impl::aten::updateATen2Tensor(ctx, atOut, out); + + return diopiSuccess; +} + diopiError_t diopiSmoothL1Loss(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiReduction_t reduction, double beta) { impl::aten::setCurStream(ctx); diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 9a774dfd8..670752bdb 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -2723,6 +2723,18 @@ DIOPI_API diopiError_t diopiArgmax(diopiContextHandle_t ctx, diopiTensorHandle_t */ DIOPI_API diopiError_t diopiArgmin(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const int64_t* dim, bool keepdim); +/** + * @brief Returns the indices that sort a tensor along a given dimension in ascending order by value. + * @param[in] ctx Context environment. + * @param[in] input the input tensor. type=[float32, float64, float16, int16, int32, int64, uint8, int8, bool]. + * @param[in] dim the dimension to do the operation over. type=[int32, int64]. + * @param[in] descending controls the sorting order (ascending or descending). + * @param[in] stable controls the relative order of equivalent elements. + * @param[out] out the output tensor. type=[int32, int64]. + */ +DIOPI_API diopiError_t diopiArgsort(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, bool stable, const int64_t* dim, + bool descending); + /** * @brief The function is used to implement the Adadelta optimizer. Its functionality is to perform a single parameter update. * @param[in] ctx Context environment.