Skip to content

Commit

Permalink
add python test
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Dec 30, 2024
1 parent 6a3cefd commit 164be97
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 8 deletions.
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/clip_tensor_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto ex_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
auto ex_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
DenseTensor ex_min;
MetaTensor meta_min(&ex_min);
CastInferMeta(min, x.dtype(), &meta_min);
DenseTensor ex_max;
MetaTensor meta_max(&ex_max);
CastInferMeta(max, x.dtype(), &meta_max);
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);

const T* x_data = x.data<T>();
const T* min_data = ex_min.data<T>();
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/clip_tensor_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
auto ex_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
auto ex_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
DenseTensor ex_min;
MetaTensor meta_min(&ex_min);
CastInferMeta(min, x.dtype(), &meta_min);
DenseTensor ex_max;
MetaTensor meta_max(&ex_max);
CastInferMeta(max, x.dtype(), &meta_max);
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);

const T* x_data = x.data<T>();
const T* min_data = ex_min.data<T>();
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/clip_tensor_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ void ClipTensorGradKernel(const Context& dev_ctx,
const DenseTensor& max,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto ex_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
auto ex_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
DenseTensor ex_min;
MetaTensor meta_min(&ex_min);
CastInferMeta(min, x.dtype(), &meta_min);
DenseTensor ex_max;
MetaTensor meta_max(&ex_max);
CastInferMeta(max, x.dtype(), &meta_max);
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);

const T* x_data = x.data<T>();
auto numel = x.numel();
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/clip_tensor_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,14 @@ void ClipTensorKernel(const Context& dev_ctx,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
auto ex_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
auto ex_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
DenseTensor ex_min;
MetaTensor meta_min(&ex_min);
CastInferMeta(min, x.dtype(), &meta_min);
DenseTensor ex_max;
MetaTensor meta_max(&ex_max);
CastInferMeta(max, x.dtype(), &meta_max);
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);

std::vector<const DenseTensor*> ins = {&x, &ex_min, &ex_max};
std::vector<DenseTensor*> outs = {out};
Expand Down

0 comments on commit 164be97

Please sign in to comment.