-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Output Comparison for pytest #146
Changes from 1 commit
e751519
1de50bb
442c2a3
0702930
df77c8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ | |
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import List | ||
from numpy import load | ||
from numpy.testing import assert_allclose | ||
import pyjson5 | ||
import os | ||
import pytest | ||
|
@@ -293,6 +295,9 @@ def __init__(self, spec, **kwargs): | |
self.run_args.extend(self.spec.iree_run_module_flags) | ||
self.run_args.append(f"--flagfile={self.spec.data_flagfile_name}") | ||
|
||
self.atol=1e-05 | ||
self.rtol=1e-06 | ||
|
||
def runtest(self): | ||
if self.spec.skip_test: | ||
pytest.skip() | ||
|
@@ -313,7 +318,7 @@ def runtest(self): | |
if not self.spec.expect_run_success: | ||
self.add_marker( | ||
pytest.mark.xfail( | ||
raises=IreeRunException, | ||
raises=(IreeRunException, AssertionError), | ||
strict=True, | ||
reason="Expected run to fail", | ||
) | ||
|
@@ -327,8 +332,23 @@ def test_compile(self): | |
|
||
def test_run(self): | ||
proc = subprocess.run(self.run_args, capture_output=True, cwd=self.test_cwd) | ||
# iree-run-module execution failure | ||
if proc.returncode != 0: | ||
raise IreeRunException(proc, self.test_cwd, self.compile_args) | ||
# TODO: add support for comparison of non numpy supported dtypes. using iree-run-module | ||
# numerical error | ||
self.test_numerical_accuracy() | ||
|
||
def test_numerical_accuracy(self): | ||
num_iree_output_files = len(list(self.test_cwd.glob("iree_output_*.npy"))) | ||
num_output_files = len(list(self.test_cwd.glob("output_*.npy"))) | ||
if num_iree_output_files != num_output_files: | ||
raise AssertionError(f"Number of golden outputs ({num_output_files}) and iree outputs ({num_iree_output_files}) dont match") | ||
|
||
for i in range(num_output_files): | ||
iree_output = load((self.test_cwd / f"iree_output_{i}.npy")) | ||
golden_output = load((self.test_cwd / f"output_{i}.npy")) | ||
assert_allclose(iree_output, golden_output, atol=self.atol, rtol=self.rtol, equal_nan=False) | ||
Comment on lines
+338
to
+351
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to continue using
This style of testing aims to have as thin a test runner as possible, leaning mostly on the native tools themselves. Right now all the test runner does is
By having a thin test runner with a narrow set of responsibilities, other test runner implementations are possible and results are easier to reproduce outside of the test environment.
How about we first see if we can modify https://github.com/openxla/iree/blob/main/runtime/src/iree/tooling/comparison.cc to be more permissive with numpy data type mismatches, or switch the expected outputs from numpy to binary files? |
||
|
||
def repr_failure(self, excinfo): | ||
"""Called when self.runtest() raises an exception.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can put comparison thresholds in the flagfiles themselves, rather than make that a property of the test runner