Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
risemeup1 committed Jan 21, 2025
1 parent 284860a commit 9266cb3
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 34 deletions.
72 changes: 45 additions & 27 deletions paddle2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,36 +56,54 @@ def export(
), f"Model file {model_filename} does not exist."

# translate old ir program to pir
tmp_dir = tempfile.mkdtemp()
dir_and_file, extension = os.path.splitext(model_filename)
filename = os.path.basename(model_filename)
filename_without_extension, _ = os.path.splitext(filename)
save_dir = os.path.join(tmp_dir, filename_without_extension)
if model_filename.endswith(".pdmodel"):
dir_and_file, extension = os.path.splitext(model_filename)
filename = os.path.basename(model_filename)
filename_without_extension, _ = os.path.splitext(filename)
tmp_dir = tempfile.mkdtemp()
paddle.enable_static()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
with paddle.pir_utils.OldIrGuard():
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(dir_and_file, exe)
)
program = paddle.pir.translate_to_pir(inference_program.desc)
for op in program.global_block().ops:
if op.name() == "pd_op.feed":
feed = op.results()
if op.name() == "pd_op.fetch":
fetch = op.operands_source()
save_dir = os.path.join(tmp_dir, filename_without_extension)
paddle.static.save_inference_model(save_dir, feed, fetch, exe, program=program)
model_filename = save_dir + ".json"
params_filename = save_dir + ".pdiparams"
assert os.path.exists(
model_filename
), f"Pir Model file {model_filename} does not exist."
assert os.path.exists(
params_filename
), f"Pir Params file {params_filename} does not exist."
if (os.path.exists(model_filename) and os.path.exists(params_filename)):
# dir_and_file, extension = os.path.splitext(model_filename)
# filename = os.path.basename(model_filename)
# filename_without_extension, _ = os.path.splitext(filename)
paddle.enable_static()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
with paddle.pir_utils.OldIrGuard():
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(dir_and_file, exe)
)
program = paddle.pir.translate_to_pir(inference_program.desc)
for op in program.global_block().ops:
if op.name() == "pd_op.feed":
feed = op.results()
if op.name() == "pd_op.fetch":
fetch = op.operands_source()
# save_dir = os.path.join(tmp_dir, filename_without_extension)
paddle.static.save_inference_model(save_dir, feed, fetch, exe, program=program)
model_filename = save_dir + ".json"
params_filename = save_dir + ".pdiparams"
assert os.path.exists(
model_filename
), f"Pir Model file {model_filename} does not exist."
assert os.path.exists(
params_filename
), f"Pir Params file {params_filename} does not exist."
else:
with paddle.pir_utils.OldIrGuard():
program=paddle.load(model_filename)
pir_program = paddle.pir.translate_to_pir(program.desc)
save_dir = os.path.join(tmp_dir, filename_without_extension)
model_filename=save_dir+ ".json"
with paddle.pir_utils.IrGuard():
paddle.save(pir_program,model_filename)
assert os.path.exists(
model_filename
), f"Pir Model file {model_filename} does not exist."


deploy_backend = deploy_backend.lower()
breakpoint()
if custom_op_info is None:
onnx_model_str = c_p2o.export(
model_filename,
Expand Down
30 changes: 30 additions & 0 deletions paddle2onnx/mapper/tensor/builtin_slice.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle2onnx/mapper/tensor/builtin_slice.h"

namespace paddle2onnx {
REGISTER_PIR_MAPPER(builtin_slice, BuiltinSliceMapper)


void BuiltinSliceMapper::Opset7() {
auto input_info = GetInput(0);
auto output_info = GetOutput(0);
if (HasAttr("index")) {
GetAttr("index", &index);
}
helper_->MakeNode("Identity", {input_info[index].name}, {output_info[0].name});
}

} // namespace paddle2onnx
38 changes: 38 additions & 0 deletions paddle2onnx/mapper/tensor/builtin_slice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {

// PIR builtin.slice operation
class BuiltinSliceMapper : public Mapper {
public:
BuiltinSliceMapper(const PaddlePirParser& p,
OnnxHelper* helper,
int64_t op_id,
bool c)
: Mapper(p, helper, op_id, c) {}

void Opset7() override;

private:
int64_t index;
};

} // namespace paddle2onnx
6 changes: 0 additions & 6 deletions paddle2onnx/parser/pir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,6 @@ std::string PaddlePirParser::GetOpArgName(int64_t op_id,
op_name =
op->attributes().at("op_name").dyn_cast<pir::StrAttribute>().AsString();
}
std::string builtin_prefix = "builtin.";
if (op_name.substr(0, builtin_prefix.size()) == builtin_prefix) {
Assert(false,
"builtin op " + op_name +
" is not supported by GetOpInputOutputName2Idx.");
}
if (_op_arg_name_mappings.count(op_name)) {
name = _op_arg_name_mappings.at(op_name).count(name)
? _op_arg_name_mappings.at(op_name).at(name)
Expand Down
2 changes: 2 additions & 0 deletions tests/onnxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
def _test_with_pir(func):
@wraps(func)
def wrapper(*args, **kwargs):
with paddle.pir_utils.DygraphOldIrGuard():
func(*args, **kwargs)
with paddle.pir_utils.DygraphPirGuard():
func(*args, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions tests/run.bat
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ REM %PY_CMD% -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.c
%PY_CMD% -m pip install https://paddle2onnx.bj.bcebos.com/paddle_windows/paddlepaddle_gpu-0.0.0-cp310-cp310-win_amd64.whl

REM Enable development mode and run tests
set FLAGS_enable_pir_api=0
set ENABLE_DEV=ON
echo ============ failed cases ============ >> result.txt

Expand Down
2 changes: 1 addition & 1 deletion tests/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ do
if [[ ${ignore} =~ ${file##*/} ]]; then
echo "跳过"
else
$PY_CMD -m pytest ${file}
FLAGS_enable_pir_api=0 $PY_CMD -m pytest ${file}
if [ $? -ne 0 ]; then
echo ${file} >> result.txt
bug=`expr ${bug} + 1`
Expand Down

0 comments on commit 9266cb3

Please sign in to comment.