Skip to content

Commit

Permalink
fix setCurStream & add grid_sample & prepare for diopiPool2d
Browse files Browse the repository at this point in the history
  • Loading branch information
DoorKickers committed Aug 23, 2024
1 parent 5eb369e commit 30fadee
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 0 deletions.
25 changes: 25 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8202,6 +8202,31 @@
),
),

'grid_sample': dict(
name=["grid_sample"],
interface=['torch.nn.functional'],
para=dict(
mode=["bilinear", "nearest", "bilinear", "nearest"],
),
tensor_para=dict(
args=[
{
"ins": ['input'],
"shape": ((2, 3, 15, 15), (3, 3, 20, 20, 20), (2, 3, 25, 25), (3, 3, 30, 30, 30)),
"dtype": [np.float16, np.float32, np.float64],
"gen_fn": 'Genfunc.randn',
},
{
"ins": ['grid'],
"shape": ((2, 5, 5, 2), (3, 10, 10, 10, 3), (2, 20, 20, 2), (3, 60, 60, 60, 3)),
"dtype": [np.float16, np.float32, np.float64],
"gen_fn": 'Genfunc.randn',
"gen_num_range": [1, 19],
},
],
),
),

'multinomial': dict(
name=["multinomial"],
interface=['torch'],
Expand Down
10 changes: 10 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5097,6 +5097,16 @@ def meshgrid(tensors, shape=None):
check_returncode(ret)
return out

def grid_sample(input, grid, mode="bilinear"):
if len(input.size().data) == 4:
out = Tensor(size=(input.size().data[0], input.size().data[1], grid.size().data[1], grid.size().data[2],), dtype=input.dtype())
else:
out = Tensor(size=(input.size().data[0], input.size().data[1], grid.size().data[1], grid.size().data[2], grid.size().data[3],), dtype=input.dtype())
func = check_function("diopiGridSample")
ret = func(input.context(), out, input, grid, mode)
check_returncode(ret)
return out


def cast_dtype(input, out) -> Tensor:
call = "diopiCastDtype"
Expand Down
43 changes: 43 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,29 @@ diopiError_t diopiMaxPool2dWithIndices(diopiContextHandle_t ctx, diopiTensorHand
return diopiSuccess;
}

// TODO
diopiError_t diopiPool2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const char* mode, diopiSize_t kernel_size,
diopiSize_t stride, diopiSize_t padding, diopiSize_t dilation, bool ceil_mode, bool exclusive, bool adaptive) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
at::IntArrayRef atKernelSize = impl::aten::buildAtIntArray(kernel_size);
at::IntArrayRef atStride = impl::aten::buildAtIntArray(stride);
at::IntArrayRef atPadding = impl::aten::buildAtIntArray(padding);
at::IntArrayRef atDilation = impl::aten::buildAtIntArray(dilation);
bool atCeilMode = ceil_mode;
at::Tensor atOut = {};
if (strcmp(mode, "max") == 0 && adaptive) {
}
if (strcmp(mode, "max") == 0 && !adaptive) {
}
if (strcmp(mode, "avg") == 0 && adaptive) {
}
if (strcmp(mode, "avg") == 0 && !adaptive) {
}

return diopiSuccess;
}

/**
* @brief
* @param rounding_mode supported in pytorch>=1.8
Expand Down Expand Up @@ -769,6 +792,7 @@ diopiError_t diopiSortBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gra
}

diopiError_t diopiComplex(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t real, diopiConstTensorHandle_t imag) {
impl::aten::setCurStream(ctx);
auto atReal = impl::aten::buildATen(real);
auto atImag = impl::aten::buildATen(imag);
auto atOut = impl::aten::buildATen(out);
Expand All @@ -779,6 +803,7 @@ diopiError_t diopiComplex(diopiContextHandle_t ctx, diopiTensorHandle_t out, dio
}

diopiError_t diopiConj(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
atOut = torch::conj(atInput);
Expand All @@ -788,6 +813,7 @@ diopiError_t diopiConj(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC
}

diopiError_t diopiImag(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
atOut = torch::imag(atInput);
Expand All @@ -797,6 +823,7 @@ diopiError_t diopiImag(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC
}

diopiError_t diopiReal(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
atOut = torch::real(atInput);
Expand Down Expand Up @@ -3018,6 +3045,22 @@ diopiError_t diopiMeshGrid(diopiContextHandle_t ctx, diopiTensorHandle_t* outs,
return diopiSuccess;
}

diopiError_t diopiGridSample(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t grid,
const char* mode) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atGrid = impl::aten::buildATen(grid);
auto atOut = impl::aten::buildATen(out);
int interpolation_mode = 0;
if (strcmp(mode, "bilinear") != 0) {
interpolation_mode = 1;
}
atOut = CALL_ATEN_FUNC(grid_sampler, atInput, atGrid, interpolation_mode, 0, 0);
impl::aten::updateATen2Tensor(ctx, atOut, out);

return diopiSuccess;
}

diopiError_t diopiAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t param, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay,
int64_t step, bool amsgrad) {
Expand Down
17 changes: 17 additions & 0 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,12 @@ DIOPI_API diopiError_t diopiMaxPool2dBackward(diopiContextHandle_t ctx, diopiTen
diopiConstTensorHandle_t input, diopiSize_t kernel_size, diopiSize_t stride, diopiSize_t padding,
diopiSize_t dilation, bool ceil_mode, diopiConstTensorHandle_t indices);

/**
TODO
*/
DIOPI_API diopiError_t diopiPool2d(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const char* mode, diopiSize_t kernel_size,
diopiSize_t stride, diopiSize_t padding, diopiSize_t dilation, bool ceil_mode, bool exclusive, bool adaptive);

/**
* @brief Applies a 2D adaptive average pooling over an input signal composed of several input planes.
* @param[in] ctx Context environment.
Expand Down Expand Up @@ -3503,6 +3509,17 @@ DIOPI_API diopiError_t diopiNormalInp(diopiContextHandle_t ctx, diopiTensorHandl
*/
DIOPI_API diopiError_t diopiMeshGrid(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputsNum);

/**
* @brief Compute grid sample.
* @param[in] ctx Context environment.
* @param[in] input the original tensor to be sampled.
* @param[in] grid the pixel locations of sampling.
* @param[in] mode the sampling mode. [bilinear, nearest].
* @param[out] out the result sampling tensor.
*/
DIOPI_API diopiError_t diopiGridSample(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t grid,
const char* mode);

/**
* @brief Returns a tensor where each row contains num_samples indices sampled from the
* multinomial probability distribution located in the corresponding row of tensor input.
Expand Down

0 comments on commit 30fadee

Please sign in to comment.