Skip to content

Commit

Permalink
Convert ONNX models to version 17+ as part of importing them.
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd committed Mar 4, 2024
1 parent 6b8f1c0 commit 507d5a6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
6 changes: 5 additions & 1 deletion iree_tests/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# TODO(scotttodd): compile into a build/temp dir instead of the source dir
*.vmfb
# TODO(scotttodd): compile into a build dir instead of the source dir
# TODO(scotttodd): convert into a build/temp dir instead of the source dir
*.onnx

# Leftovers from an older iteration of compile/run test scripts.
config_*.txt
34 changes: 27 additions & 7 deletions iree_tests/onnx/import_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
ONNX_REPO_GENERATED_TESTS_ROOT = ONNX_REPO_ROOT / "onnx/backend/test/data"
NODE_TESTS_ROOT = ONNX_REPO_GENERATED_TESTS_ROOT / "node"

# Convert test cases to at least this version using The ONNX Version Converter.
ONNX_CONVERTER_OUTPUT_MIN_VERSION = 17

# Write imported files to our own 'generated' folder.
GENERATED_FILES_OUTPUT_ROOT = REPO_ROOT / "iree_tests/onnx/node/generated"

Expand Down Expand Up @@ -71,15 +74,32 @@ def import_onnx_files(test_dir_path, imported_dir_path):
test_data_flagfile_path = imported_dir_path / "test_data_flags.txt"
test_data_flagfile_lines = []

# Import model.onnx to model.mlir.
# TODO(scotttodd): copy the .onnx file into the generated folder? Useful for reproducing
# could also add a symlink or other files with info
# e.g. importer tool / version / flags used, reproducer command
onnx_model_path = test_dir_path / "model.onnx"
# Convert model.onnx up to ONNX_CONVERTER_OUTPUT_MIN_VERSION if needed.
# TODO(scotttodd): stamp some info e.g. importer tool / version / flags used
original_model_path = test_dir_path / "model.onnx"
converted_model_path = imported_dir_path / "model.onnx"

original_model = onnx.load_model(original_model_path)
original_version = original_model.opset_import[0].version
if original_version < ONNX_CONVERTER_OUTPUT_MIN_VERSION:
try:
converted_model = version_converter.convert_version(
original_model, ONNX_CONVERTER_OUTPUT_MIN_VERSION
)
onnx.save(converted_model, converted_model_path)
except:
# Conversion failed. Do our best with the original file.
# print(f"WARNING: ONNX conversion failed for {test_dir_path.name}")
shutil.copy(original_model_path, converted_model_path)
else:
# No conversion needed.
shutil.copy(original_model_path, converted_model_path)

# Import converted model.onnx to model.mlir.
imported_model_path = imported_dir_path / "model.mlir"
exec_args = [
"iree-import-onnx",
str(onnx_model_path),
str(converted_model_path),
"-o",
str(imported_model_path),
]
Expand All @@ -101,7 +121,7 @@ def import_onnx_files(test_dir_path, imported_dir_path):
test_data_dir = test_data_dirs[0]
test_inputs = list(test_data_dir.glob("input_*.pb"))
test_outputs = list(test_data_dir.glob("output_*.pb"))
model = onnx.load(onnx_model_path)
model = onnx.load(converted_model_path)
for i in range(len(test_inputs)):
test_input = test_inputs[i]
t = convert_io_proto(test_input, model.graph.input[i].type)
Expand Down

0 comments on commit 507d5a6

Please sign in to comment.