Skip to content

Commit

Permalink
Run ONNX import in parallel.
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd committed Mar 4, 2024
1 parent 507d5a6 commit 716faa0
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions iree_tests/onnx/import_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import onnx
from multiprocessing import Pool
from pathlib import Path
from onnx import numpy_helper
from onnx import numpy_helper, version_converter
import shutil
import subprocess
import numpy as np
Expand Down Expand Up @@ -51,6 +53,16 @@ def convert_io_proto(proto_filename, type_proto):
return None


def import_onnx_files_with_cleanup(test_dir_path):
test_name = test_dir_path.name
imported_dir_path = Path(GENERATED_FILES_OUTPUT_ROOT) / test_name
result = import_onnx_files(test_dir_path, imported_dir_path)
if not result:
# Note: could comment this out to keep partially imported directories.
shutil.rmtree(imported_dir_path)
return (test_name, result)


def import_onnx_files(test_dir_path, imported_dir_path):
# This imports one 'test_[name]' subfolder from this:
#
Expand Down Expand Up @@ -146,6 +158,16 @@ def import_onnx_files(test_dir_path, imported_dir_path):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ONNX test case importer.")
parser.add_argument(
"-j",
"--jobs",
type=int,
default=8,
help="Number of parallel processes to use when importing test cases",
)
args = parser.parse_args()

test_dir_paths = find_onnx_tests(NODE_TESTS_ROOT)

# TODO(scotttodd): add flag to not clear output dir?
Expand All @@ -154,27 +176,16 @@ def import_onnx_files(test_dir_path, imported_dir_path):
GENERATED_FILES_OUTPUT_ROOT.mkdir(parents=True)

print(f"Importing tests in '{NODE_TESTS_ROOT}'")

print("******************************************************************")
passed_imports = []
failed_imports = []
# TODO(scotttodd): parallelize this (or move into a test runner like pytest)
for i in range(len(test_dir_paths)):
test_dir_path = test_dir_paths[i]
test_name = test_dir_path.name

current_number = str(i).rjust(4, "0")
progress_str = f"[{current_number}/{len(test_dir_paths)}]"
print(f"{progress_str}: Importing {test_name}")

imported_dir_path = Path(GENERATED_FILES_OUTPUT_ROOT) / test_name
result = import_onnx_files(test_dir_path, imported_dir_path)
if result:
passed_imports.append(test_name)
else:
failed_imports.append(test_name)
# Note: could comment this out to keep partially imported directories.
shutil.rmtree(imported_dir_path)
with Pool(args.jobs) as pool:
results = pool.imap_unordered(import_onnx_files_with_cleanup, test_dir_paths)
for result in results:
if result[1]:
passed_imports.append(result[0])
else:
failed_imports.append(result[0])
print("******************************************************************")

passed_imports.sort()
Expand Down

0 comments on commit 716faa0

Please sign in to comment.