Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into git-lfs-files
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd committed Apr 2, 2024
2 parents c5979d1 + 30a605d commit 3d73c05
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 885 deletions.
Binary file added e2eshark/onnx/models/QuantizedMLP/model.onnx
Binary file not shown.
40 changes: 40 additions & 0 deletions e2eshark/onnx/models/QuantizedMLP/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy, torch, sys
import onnxruntime

# import from e2eshark/tools to allow running in current dir, for run through
# run.pl, commutils is symbolically linked to allow any rundir to work
sys.path.insert(0, "../../../tools/stubs")
from commonutils import E2ESHARK_CHECK_DEF

# Create an instance of it for this test
E2ESHARK_CHECK = dict(E2ESHARK_CHECK_DEF)


# The generated or checked in onnx file must always be called model.onnx
# the tools/stubs/onnxmodel.py is appended to model.py
# to form runmodel.py in the rundirectory which is then taken
# through flow


# start an onnxrt session
session = onnxruntime.InferenceSession("model.onnx", None)

# Even if model is quantized, the inputs and outputs are
# not, so apply float32
model_input_X = numpy.random.rand(1, 16).astype(numpy.float32)

# gets X in inputs[0] and Y in inputs[1]
inputs = session.get_inputs()
# gets Z in outputs[0]
outputs = session.get_outputs()


model_output = session.run(
[outputs[0].name],
{inputs[0].name: model_input_X},
)[0]
E2ESHARK_CHECK["input"] = [torch.from_numpy(model_input_X)]
E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output]

print("Input:", E2ESHARK_CHECK["input"])
print("Output:", E2ESHARK_CHECK["output"])
77 changes: 77 additions & 0 deletions e2eshark/onnx/operators/GatherND/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2024 Advanced Micro Devices
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# run.py creates runmodel.py by concatenating this file model.py
# and tools/stubs/onnxmodel.py
# Description: testing GatherND
# See https://onnx.ai/onnx/intro/python.html for intro on creating
# onnx model using python onnx API
import numpy, torch, sys
import onnxruntime
from onnx import numpy_helper, TensorProto, save_model
from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info
from onnx.checker import check_model

# import from e2eshark/tools to allow running in current dir, for run through
# run.pl, commutils is symbolically linked to allow any rundir to work
sys.path.insert(0, "../../../tools/stubs")
from commonutils import E2ESHARK_CHECK_DEF

# Create an instance of it for this test
E2ESHARK_CHECK = dict(E2ESHARK_CHECK_DEF)

# Create an input (ValueInfoProto)
D = make_tensor_value_info("D", TensorProto.FLOAT, [2, 2, 3])
I = make_tensor_value_info("I", TensorProto.INT64, [2, 3, 2])

# Create an output
Z = make_tensor_value_info("Z", TensorProto.FLOAT, [2, 3, 3])

# Create a node (NodeProto)
gather_nd_node = make_node(
"GatherND", ["D", "I"], ["Z"], "gather_nd_node" # node name # inputs # outputs
)

# Create the graph (GraphProto)
graph = make_graph(
[gather_nd_node],
"gather_nd_graph",
[D, I],
[Z],
)

# Create the model (ModelProto)
onnx_model = make_model(graph)
onnx_model.opset_import[0].version = 13

# Save the model
# save_model(onnx_model, "model.onnx")
with open("model.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())


session = onnxruntime.InferenceSession("model.onnx", None)
model_input_D = numpy.random.randn(2, 2, 3).astype(numpy.float32)
model_input_I = numpy.random.randint(2, size=(2, 3, 2)).astype(numpy.int64)
# gets D in inputs[0] and I in inputs[1]
inputs = session.get_inputs()
# gets Z in outputs[0]
outputs = session.get_outputs()

model_output = session.run(
[outputs[0].name],
{inputs[0].name: model_input_D, inputs[1].name: model_input_I},
)

# Moving to torch to handle bfloat16 as numpy does not support bfloat16
E2ESHARK_CHECK["input"] = [
torch.from_numpy(model_input_D),
torch.from_numpy(model_input_I),
]
E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output]

print("Input:", E2ESHARK_CHECK["input"])
print("Output:", E2ESHARK_CHECK["output"])
6 changes: 3 additions & 3 deletions iree_tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Tests are run using the [pytest](https://docs.pytest.org/en/stable/) framework.
A [`conftest.py`](conftest.py) file collects test cases from subdirectories,
wrapping each directory matching the format described above to one test case
per test configuration. Test configurations are defined in JSON config files
like [`configs/config_cpu_llvm_sync.json`](./configs/config_cpu_llvm_sync.json).
like [`configs/config_onnx_cpu_llvm_sync.json`](./configs/config_onnx_cpu_llvm_sync.json).

### Common venv setup with deps

Expand Down Expand Up @@ -110,15 +110,15 @@ Run ONNX tests on CPU and print all errors:
```bash
$ pytest iree_tests/onnx -n auto \
--ignore-xfails \
--config-files ./iree_tests/configs/config_cpu_llvm_sync.json
--config-files ./iree_tests/configs/config_onnx_cpu_llvm_sync.json
```

Run ONNX compilation tests only and print all errors:

```bash
$ pytest iree_tests/onnx -n auto \
--ignore-xfails --skip-all-runs \
--config-files ./iree_tests/configs/config_cpu_llvm_sync.json
--config-files ./iree_tests/configs/config_onnx_cpu_llvm_sync.json
```

### Advanced pytest usage tips
Expand Down
Loading

0 comments on commit 3d73c05

Please sign in to comment.