Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 8th No.2】为 Paddle 新增 baddbmm API #70757

Open
wants to merge 38 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e2299f9
added reset peak value initialization
Qin-sx Dec 5, 2024
dc30348
added comments
Qin-sx Dec 7, 2024
d081dde
added cpp tests
Qin-sx Dec 7, 2024
35a12c8
added python tests
Qin-sx Dec 7, 2024
cb5036c
added a python test for reset_max_memory_allocated
Qin-sx Dec 8, 2024
5d28856
formatted by pre-commit
Qin-sx Dec 9, 2024
f6b3d84
formatted by pre-commit (clang-format)
Qin-sx Dec 9, 2024
b6ea9be
added reset max memory reserved function
Qin-sx Dec 13, 2024
caad368
deleted memory stats and reset peak memory stats
Qin-sx Dec 15, 2024
299a7c0
Merge branch 'develop' of https://github.com/Qin-sx/Paddle into develop
Qin-sx Dec 24, 2024
323e7b2
formatted baddbmm version
Qin-sx Dec 29, 2024
a0ac3ee
added file for test
Qin-sx Jan 9, 2025
fdf6322
added bloat16 case
Qin-sx Jan 10, 2025
ed02d8d
added InferSymbolicShapeInterface
Qin-sx Jan 11, 2025
2b3ed56
added more tests
Qin-sx Jan 11, 2025
47de1e6
tests overtime (20s), reduced size
Qin-sx Jan 11, 2025
da9467b
test overtime again, reduced size again
Qin-sx Jan 11, 2025
85bb2f5
reduce size for overtime again
Qin-sx Jan 11, 2025
39ae437
just one test for overtime
Qin-sx Jan 11, 2025
c812aa0
divided tests in some files
Qin-sx Jan 12, 2025
284e601
enable static
Qin-sx Jan 12, 2025
c18014a
deleted baddbmm in tensor_compat.h
Qin-sx Jan 12, 2025
ebc6f56
added float in hip
Qin-sx Jan 12, 2025
cd54e10
pre-commit
Qin-sx Jan 12, 2025
8afaf6a
added more tests
Qin-sx Jan 13, 2025
306df76
Merge branch 'PaddlePaddle:develop' into develop
Qin-sx Jan 13, 2025
968dc8b
Merge branch 'PaddlePaddle:develop' into develop
Qin-sx Jan 14, 2025
8aa753c
Merge branch 'PaddlePaddle:develop' into develop
Qin-sx Jan 14, 2025
06ba3bc
modified func order and copyright year
Qin-sx Jan 14, 2025
a287784
updated tests
Qin-sx Jan 15, 2025
b68104b
updated blas hip
Qin-sx Jan 16, 2025
5151e36
added broadcast
Qin-sx Jan 18, 2025
d06509b
added more tests
Qin-sx Jan 19, 2025
5c966a1
added more tests for baddbmm_
Qin-sx Jan 19, 2025
c5598a5
added more tests
Qin-sx Jan 20, 2025
f5f1e38
typo
Qin-sx Jan 22, 2025
d572d4d
Merge branch 'develop' into develop_baddbmm_format
Qin-sx Jan 22, 2025
9c6ac6a
deleted comments and 'print'
Qin-sx Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"addmm",
"allclose",
"any",
"baddbmm",
"bce_loss",
"bmm",
"diag",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,74 @@ bool AucOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool BaddbmmOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const auto &y_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(2));

auto ndim_input = input_shape.shape().size();
auto ndim_x = x_shape.shape().size();
auto ndim_y = y_shape.shape().size();

PADDLE_ENFORCE_EQ(ndim_input == 3 || ndim_input == 2,
true,
common::errors::InvalidArgument(
"The input tensor input's dimension must be 3 or 2. "
"But received input's dimension = [%d].",
ndim_input));
PADDLE_ENFORCE_EQ(ndim_x,
3,
common::errors::InvalidArgument(
"The input tensor x's dimension must be 3. "
"But received x's dimension = [%d].",
ndim_x));
PADDLE_ENFORCE_EQ(ndim_y,
3,
common::errors::InvalidArgument(
"The input tensor y's dimension must be 3. "
"But received y's dimension = [%d].",
ndim_y));

std::vector<symbol::DimExpr> output_shape;
output_shape.push_back(x_shape.shape()[0]); // batch size
output_shape.push_back(x_shape.shape()[1]);
output_shape.push_back(y_shape.shape()[2]);

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(output_shape)});

infer_context->AddEqualCstr(x_shape.shape()[0],
y_shape.shape()[0]); // batch size
infer_context->AddEqualCstr(x_shape.shape()[2], y_shape.shape()[1]);

if (ndim_input == 3) {
infer_context->AddBroadcastableCstr(input_shape.shape()[0],
x_shape.shape()[0]); // batch size
infer_context->AddBroadcastableCstr(input_shape.shape()[1],
x_shape.shape()[1]);
infer_context->AddBroadcastableCstr(input_shape.shape()[2],
y_shape.shape()[2]);
} else if (ndim_input == 2) {
infer_context->AddBroadcastableCstr(input_shape.shape()[0],
x_shape.shape()[0]);
infer_context->AddBroadcastableCstr(input_shape.shape()[1],
y_shape.shape()[2]);
}

return true;
}

bool Baddbmm_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return BaddbmmOpInferSymbolicShape(op, infer_context);
}

bool BatchFcOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Baddbmm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Baddbmm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm)
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,28 @@ Tensor addmm_decomp(const Tensor& input,
full_scalar<T>(beta, input.dtype()) * input;
}

template <typename T>
Tensor baddbmm_decomp(const Tensor& input,
const Tensor& x,
const Tensor& y,
const float beta,
const float alpha) {
int batch_size = x.shape()[0];
std::vector<Tensor> batch_results;

for (int i = 0; i < batch_size; ++i) {
Tensor x_batch = get_slice<T>(x, i);
Tensor y_batch = get_slice<T>(y, i);
Tensor result = matmul<T>(x_batch, y_batch);
batch_results.push_back(result);
}

Tensor x_y_mat = concat<T>(batch_results);

return full_scalar<T>(alpha, x_y_mat.dtype()) * x_y_mat +
full_scalar<T>(beta, input.dtype()) * input;
}

template <typename T>
Tensor eye_decomp(const paddle::Scalar& num_rows,
const paddle::Scalar& num_columns,
Expand Down
82 changes: 82 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,88 @@ void AddmmInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}

void BaddbmmInferMeta(const MetaTensor& input,
const MetaTensor& x,
const MetaTensor& y,
float beta,
float alpha,
MetaTensor* out) {
auto input_dims = input.dims();
auto x_dims = x.dims();
auto y_dims = y.dims();

auto ndim_input = input_dims.size();
auto ndim_x = x_dims.size();
auto ndim_y = y_dims.size();

VLOG(3) << "baddbmm operator input.shape=" << input_dims
<< " x.shape=" << x_dims << " y.shape=" << y_dims << " beta=" << beta
<< " alpha=" << alpha << " ndim_input=" << ndim_input
<< " ndim_x=" << ndim_x << " ndim_y=" << ndim_y;

PADDLE_ENFORCE_NE(
product(input_dims),
0,
errors::PreconditionNotMet("The Input variable 'input' has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));

PADDLE_ENFORCE_NE(
product(x_dims),
0,
errors::PreconditionNotMet("The Input variable 'x' has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));

PADDLE_ENFORCE_NE(
product(y_dims),
0,
errors::PreconditionNotMet("The Input variable 'y' has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
// dim check
PADDLE_ENFORCE_EQ(ndim_input == 3 || ndim_input == 2,
true,
errors::InvalidArgument(
"The input tensor input's dimension must be 3 or 2. "
"But received input's dimension = [%d].",
ndim_input));
PADDLE_ENFORCE_EQ(
ndim_x,
3,
errors::InvalidArgument("The input tensor x's dimension must be 3. "
"But received x's dimension = [%d].",
ndim_x));
PADDLE_ENFORCE_EQ(
ndim_y,
3,
errors::InvalidArgument("The input tensor y's dimension must be 3. "
"But received y's dimension = [%d].",
ndim_y));

PADDLE_ENFORCE_EQ(
x_dims[2],
y_dims[1],
errors::InvalidArgument("The second dimension of x must be equal to the "
"first dimension of y. "
"But received x's second dimension = [%d], y's "
"first dimension = [%d].",
x_dims[2],
y_dims[1]));

std::vector<int64_t> output_dims;
output_dims.push_back(x_dims[0]);
output_dims.push_back(x_dims[1]);
output_dims.push_back(y_dims[2]);

out->set_dims(common::make_ddim(output_dims));
out->share_lod(input);
out->set_dtype(input.dtype());
}

void AffineChannelInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ void AddmmInferMeta(const MetaTensor& input,
float alpha,
MetaTensor* out);

void BaddbmmInferMeta(const MetaTensor& input,
const MetaTensor& x,
const MetaTensor& y,
float beta,
float alpha,
MetaTensor* out);

void AffineChannelInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
33 changes: 33 additions & 0 deletions paddle/phi/kernels/baddbmm_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void BaddbmmGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
float alpha,
float beta,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad);

} // namespace phi
30 changes: 30 additions & 0 deletions paddle/phi/kernels/baddbmm_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void BaddbmmKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& x,
const DenseTensor& y,
float beta,
float alpha,
DenseTensor* out);

} // namespace phi
22 changes: 22 additions & 0 deletions paddle/phi/kernels/cpu/baddbmm_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/baddbmm_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
baddbmm_grad, CPU, ALL_LAYOUT, phi::BaddbmmGradKernel, float, double) {}
22 changes: 22 additions & 0 deletions paddle/phi/kernels/cpu/baddbmm_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/baddbmm_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/baddbmm_kernel_impl.h"

PD_REGISTER_KERNEL(
baddbmm, CPU, ALL_LAYOUT, phi::BaddbmmKernel, float, double) {}
27 changes: 27 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ class Blas {
T beta,
T* C) const;

template <typename T, typename U = T>
void GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
U alpha,
const T* A,
const T* B,
U beta,
T* C) const;

template <typename T>
void GEMM(bool transA,
bool transB,
Expand Down Expand Up @@ -292,6 +304,21 @@ class Blas {
int64_t strideA,
int64_t strideB) const;

template <typename T, typename U = T>
void BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
U alpha,
const T* A,
const T* B,
U beta,
T* C,
int batchCount,
int64_t strideA,
int64_t strideB) const;

template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
Expand Down
Loading
Loading