Skip to content

Commit

Permalink
update infer/utility.py to support json format model (#14233)
Browse files Browse the repository at this point in the history
* update infer/utility.py to support json format model

* merge from #13524

* fix bug

* fix bug

* Update tools/infer/utility.py

Co-authored-by: jzhang533 <[email protected]>

* fix codestyle

---------

Co-authored-by: jzhang533 <[email protected]>
  • Loading branch information
GreatV and jzhang533 authored Nov 18, 2024
1 parent 500381c commit fbba217
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,21 +232,26 @@ def create_predictor(args, mode, logger):
else:
file_names = ["model", "inference"]
for file_name in file_names:
model_file_path = "{}/{}.pdmodel".format(model_dir, file_name)
params_file_path = "{}/{}.pdiparams".format(model_dir, file_name)
if os.path.exists(model_file_path) and os.path.exists(params_file_path):
params_file_path = f"{model_dir}/{file_name}.pdiparams"
if os.path.exists(params_file_path):
break
if not os.path.exists(model_file_path):
raise ValueError(
"not find model.pdmodel or inference.pdmodel in {}".format(model_dir)
)

if not os.path.exists(params_file_path):
raise ValueError(f"not find {file_name}.pdiparams in {model_dir}")

if not (
os.path.exists(f"{model_dir}/{file_name}.pdmodel")
or os.path.exists(f"{model_dir}/{file_name}.json")
):
raise ValueError(
"not find model.pdiparams or inference.pdiparams in {}".format(
model_dir
)
f"neither {file_name}.json nor {file_name}.pdmodel was found in {model_dir}."
)

if os.path.exists(f"{model_dir}/{file_name}.json"):
model_file_path = f"{model_dir}/{file_name}.json"
else:
model_file_path = f"{model_dir}/{file_name}.pdmodel"

config = inference.Config(model_file_path, params_file_path)

if hasattr(args, "precision"):
Expand Down

0 comments on commit fbba217

Please sign in to comment.