From 30fadeeae0ba18cceb467c08abcd22c798e8c5e2 Mon Sep 17 00:00:00 2001 From: DoorKickers <1105976166@qq.com> Date: Fri, 23 Aug 2024 12:40:05 +0800 Subject: [PATCH] fix setCurStream & add grid_sample & prepare for diopiPool2d --- diopi_test/python/configs/diopi_configs.py | 25 +++++++++++ .../python/conformance/diopi_functions.py | 10 +++++ impl/torch/functions/functions.cpp | 43 +++++++++++++++++++ proto/include/diopi/functions.h | 17 ++++++++ 4 files changed, 95 insertions(+) diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index 73525e1cc..b8c66bd58 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -8202,6 +8202,31 @@ ), ), + 'grid_sample': dict( + name=["grid_sample"], + interface=['torch.nn.functional'], + para=dict( + mode=["bilinear", "nearest", "bilinear", "nearest"], + ), + tensor_para=dict( + args=[ + { + "ins": ['input'], + "shape": ((2, 3, 15, 15), (3, 3, 20, 20, 20), (2, 3, 25, 25), (3, 3, 30, 30, 30)), + "dtype": [np.float16, np.float32, np.float64], + "gen_fn": 'Genfunc.randn', + }, + { + "ins": ['grid'], + "shape": ((2, 5, 5, 2), (3, 10, 10, 10, 3), (2, 20, 20, 2), (3, 60, 60, 60, 3)), + "dtype": [np.float16, np.float32, np.float64], + "gen_fn": 'Genfunc.randn', + "gen_num_range": [1, 19], + }, + ], + ), + ), + 'multinomial': dict( name=["multinomial"], interface=['torch'], diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index ae7283d90..9249b0fc4 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -5097,6 +5097,16 @@ def meshgrid(tensors, shape=None): check_returncode(ret) return out +def grid_sample(input, grid, mode="bilinear"): + if len(input.size().data) == 4: + out = Tensor(size=(input.size().data[0], input.size().data[1], grid.size().data[1], grid.size().data[2],), dtype=input.dtype()) + else: + out = Tensor(size=(input.size().data[0], input.size().data[1], grid.size().data[1], grid.size().data[2], grid.size().data[3],), dtype=input.dtype()) + func = check_function("diopiGridSample") + ret = func(input.context(), out, input, grid, mode) + check_returncode(ret) + return out + def cast_dtype(input, out) -> Tensor: call = "diopiCastDtype" diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 4e25c362f..59eb11a61 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -133,6 +133,29 @@ diopiError_t diopiMaxPool2dWithIndices(diopiContextHandle_t ctx, diopiTensorHand return diopiSuccess; } +// TODO +diopiError_t diopiPool2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const char* mode, diopiSize_t kernel_size, + diopiSize_t stride, diopiSize_t padding, diopiSize_t dilation, bool ceil_mode, bool exclusive, bool adaptive) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + at::IntArrayRef atKernelSize = impl::aten::buildAtIntArray(kernel_size); + at::IntArrayRef atStride = impl::aten::buildAtIntArray(stride); + at::IntArrayRef atPadding = impl::aten::buildAtIntArray(padding); + at::IntArrayRef atDilation = impl::aten::buildAtIntArray(dilation); + bool atCeilMode = ceil_mode; + at::Tensor atOut = {}; + if (strcmp(mode, "max") == 0 && adaptive) { + } + if (strcmp(mode, "max") == 0 && !adaptive) { + } + if (strcmp(mode, "avg") == 0 && adaptive) { + } + if (strcmp(mode, "avg") == 0 && !adaptive) { + } + + return diopiSuccess; +} + /** * @brief * @param rounding_mode supported in pytorch>=1.8 @@ -769,6 +792,7 @@ diopiError_t diopiSortBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gra } diopiError_t diopiComplex(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t real, diopiConstTensorHandle_t imag) { + impl::aten::setCurStream(ctx); auto atReal = impl::aten::buildATen(real); auto atImag = impl::aten::buildATen(imag); auto atOut = impl::aten::buildATen(out); @@ -779,6 +803,7 @@ diopiError_t diopiComplex(diopiContextHandle_t ctx, diopiTensorHandle_t out, dio } diopiError_t diopiConj(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); auto atOut = impl::aten::buildATen(out); atOut = torch::conj(atInput); @@ -788,6 +813,7 @@ diopiError_t diopiConj(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC } diopiError_t diopiImag(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); auto atOut = impl::aten::buildATen(out); atOut = torch::imag(atInput); @@ -797,6 +823,7 @@ diopiError_t diopiImag(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC } diopiError_t diopiReal(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); auto atOut = impl::aten::buildATen(out); atOut = torch::real(atInput); @@ -3018,6 +3045,22 @@ diopiError_t diopiMeshGrid(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, return diopiSuccess; } +diopiError_t diopiGridSample(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t grid, + const char* mode) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atGrid = impl::aten::buildATen(grid); + auto atOut = impl::aten::buildATen(out); + int interpolation_mode = 0; + if (strcmp(mode, "bilinear") != 0) { + interpolation_mode = 1; + } + atOut = CALL_ATEN_FUNC(grid_sampler, atInput, atGrid, interpolation_mode, 0, 0); + impl::aten::updateATen2Tensor(ctx, atOut, out); + + return diopiSuccess; +} + diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t param, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg, diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay, int64_t step, bool amsgrad) { diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 2e2a43d42..309f9f51f 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -388,6 +388,12 @@ DIOPI_API diopiError_t diopiMaxPool2dBackward(diopiContextHandle_t ctx, diopiTen diopiConstTensorHandle_t input, diopiSize_t kernel_size, diopiSize_t stride, diopiSize_t padding, diopiSize_t dilation, bool ceil_mode, diopiConstTensorHandle_t indices); +/** +TODO + */ +DIOPI_API diopiError_t diopiPool2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const char* mode, diopiSize_t kernel_size, + diopiSize_t stride, diopiSize_t padding, diopiSize_t dilation, bool ceil_mode, bool exclusive, bool adaptive); + /** * @brief Applies a 2D adaptive average pooling over an input signal composed of several input planes. * @param[in] ctx Context environment. @@ -3503,6 +3509,17 @@ DIOPI_API diopiError_t diopiNormalInp(diopiContextHandle_t ctx, diopiTensorHandl */ DIOPI_API diopiError_t diopiMeshGrid(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputsNum); +/** + * @brief Compute grid sample. + * @param[in] ctx Context environment. + * @param[in] input the original tensor to be sampled. + * @param[in] grid the pixel locations of sampling. + * @param[in] mode the sampling mode. [bilinear, nearest]. + * @param[out] out the result sampling tensor. + */ +DIOPI_API diopiError_t diopiGridSample(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t grid, + const char* mode); + /** * @brief Returns a tensor where each row contains num_samples indices sampled from the * multinomial probability distribution located in the corresponding row of tensor input.