Skip to content

Commit

Permalink
fix op pd_op.mean
Browse files Browse the repository at this point in the history
  • Loading branch information
0x3878f committed Feb 7, 2025
1 parent 1c8c3fb commit d111ea4
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions paddle2onnx/mapper/tensor/reduce_mean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ void ReduceMeanMapper::Opset18() {
GetAttr(axis_name_, &dim_);
}
} else {
GetAttr("axis", &dim_);
if (HasAttr(axis_name_)) {
GetAttr(axis_name_, &dim_);
} else {
TryGetInputValue(axis_name_, &dim_);
}
if (dim_.size() == 0) {
reduce_all_ = true;
} else {
Expand All @@ -51,6 +55,9 @@ void ReduceMeanMapper::Opset18() {
if (IsAttrVar(axis_name_)) {
auto info = GetAttrVar(axis_name_);
dims = helper_->AutoCast(info[0].name, info[0].dtype, P2ODataType::INT64);
} else if(HasInput(axis_name_)) {
auto info = GetInput(axis_name_);
dims = helper_->AutoCast(info[0].name, info[0].dtype, P2ODataType::INT64);
} else {
if (!reduce_all_) {
dims = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, dim_);
Expand Down Expand Up @@ -92,7 +99,12 @@ void ReduceMeanMapper::Opset11() {
GetAttr(axis_name_, &dim_);
}
} else {
GetAttr("axis", &dim_);
if (HasAttr(axis_name_)) {
GetAttr(axis_name_, &dim_);
} else {
Assert(TryGetInputValue(axis_name_, &dim_),
"Can not get input 'axis(dim)' value.");
}
if (dim_.size() == 0) {
reduce_all_ = true;
} else {
Expand Down

0 comments on commit d111ea4

Please sign in to comment.