Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ExecuTorch] Arm Ethos: Add pass tests #8561

Merged
merged 7 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/arm/test/TARGETS
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load(":targets.bzl", "define_arm_tests")


oncall("executorch")

python_library(
name = "common",
Expand All @@ -8,6 +12,7 @@ python_library(
"//executorch/backends/arm:arm_backend",
"//executorch/exir:lib",
"//executorch/exir/backend:compile_spec_schema",
"fbsource//third-party/pypi/pytest:pytest",
]
)

Expand Down Expand Up @@ -40,7 +45,10 @@ python_library(
"//executorch/backends/arm:tosa_mapping",
"//executorch/backends/arm:tosa_specification",
"//executorch/backends/arm/quantizer:arm_quantizer",
"//executorch/backends/arm:arm_partitioner",
"//executorch/devtools/backend_debug:delegation_info",
"fbsource//third-party/pypi/tabulate:tabulate",
]
)

define_arm_tests()
42 changes: 34 additions & 8 deletions backends/arm/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from typing import Any

import pytest
import torch

try:
import tosa_reference_model
except ImportError:
logging.warning("tosa_reference_model not found, can't run reference model tests")
tosa_reference_model = None

"""
This file contains the pytest hooks, fixtures etc. for the Arm test suite.
Expand All @@ -24,18 +29,29 @@


def pytest_configure(config):

pytest._test_options = {} # type: ignore[attr-defined]

if config.option.arm_run_corstoneFVP:
pytest._test_options["corstone_fvp"] = False # type: ignore[attr-defined]
if (
getattr(config.option, "arm_run_corestoneFVP", False)
and config.option.arm_run_corstoneFVP
):
corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55")
corstone320_exists = shutil.which("FVP_Corstone_SSE-320")
if not (corstone300_exists and corstone320_exists):
raise RuntimeError(
"Tests are run with --arm_run_corstoneFVP but corstone FVP is not installed."
)
# Only enable if we also have the TOSA reference model available.
pytest._test_options["corstone_fvp"] = True # type: ignore[attr-defined]
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]

pytest._test_options["fast_fvp"] = False # type: ignore[attr-defined]
if getattr(config.option, "fast_fvp", False):
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]

# TODO: remove this flag once we have a way to run the reference model tests with Buck
pytest._test_options["tosa_ref_model"] = False # type: ignore[attr-defined]
if tosa_reference_model is not None:
pytest._test_options["tosa_ref_model"] = True # type: ignore[attr-defined]
logging.basicConfig(level=logging.INFO, stream=sys.stdout)


Expand All @@ -44,9 +60,15 @@ def pytest_collection_modifyitems(config, items):


def pytest_addoption(parser):
parser.addoption("--arm_quantize_io", action="store_true", help="Deprecated.")
parser.addoption("--arm_run_corstoneFVP", action="store_true")
parser.addoption("--fast_fvp", action="store_true")
def try_addoption(*args, **kwargs):
try:
parser.addoption(*args, **kwargs)
except Exception:
pass

try_addoption("--arm_quantize_io", action="store_true", help="Deprecated.")
try_addoption("--arm_run_corstoneFVP", action="store_true", help="Deprecated.")
try_addoption("--fast_fvp", action="store_true")


def pytest_sessionstart(session):
Expand Down Expand Up @@ -78,6 +100,8 @@ def set_random_seed():
Rerun with a specific seed found under a random seed test
ARM_TEST_SEED=3478246 pytest --config-file=/dev/null --verbose -s --color=yes backends/arm/test/ops/test_avg_pool.py -k <TESTCASE>
"""
import torch

if os.environ.get("ARM_TEST_SEED", "RANDOM") == "RANDOM":
random.seed() # reset seed, in case any other test has fiddled with it
seed = random.randint(0, 2**32 - 1)
Expand Down Expand Up @@ -161,6 +185,8 @@ def _load_libquantized_ops_aot_lib():
res = subprocess.run(find_lib_cmd, capture_output=True)
if res.returncode == 0:
library_path = res.stdout.decode().strip()
import torch

torch.ops.load_library(library_path)
else:
raise RuntimeError(
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/test/passes/test_rescale_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _test_rescale_pipeline(
):
"""Tests a model with many ops that requires rescales. As more ops are quantized to int32 and
need the InsertRescalesPass, make sure that they play nicely together."""
(
tester = (
ArmTester(
module,
example_inputs=test_data,
Expand All @@ -126,8 +126,9 @@ def _test_rescale_pipeline(
.export()
.to_edge_transform_and_lower()
.to_executorch()
.run_method_and_compare_outputs(test_data)
)
if conftest.is_option_enabled("tosa_ref_model"):
tester.run_method_and_compare_outputs(test_data)


def _test_rescale_pipeline_ethosu(
Expand All @@ -152,6 +153,7 @@ def _test_rescale_pipeline_ethosu(
class TestRescales(unittest.TestCase):

@parameterized.expand(RescaleNetwork.test_parameters)
@pytest.mark.tosa_ref_model
def test_quantized_rescale(self, x, y):
_test_rescale_pipeline(RescaleNetwork(), (x, y))

Expand Down
3 changes: 2 additions & 1 deletion backends/arm/test/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
addopts = --strict-markers
markers =
slow: Tests that take long time
corstone_fvp: Tests that use Corstone300 or Corstone320 FVP
corstone_fvp: Tests that use Corstone300 or Corstone320 FVP # And also uses TOSA reference model
tosa_ref_model: Tests that use TOSA reference model # Temporary!
1 change: 0 additions & 1 deletion backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
try:
import tosa_reference_model
except ImportError:
logger.warning("tosa_reference_model not found, can't run reference model tests")
tosa_reference_model = None
from executorch.backends.arm.arm_backend import get_tosa_spec, is_tosa

Expand Down
33 changes: 33 additions & 0 deletions backends/arm/test/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
load("//caffe2/test/fb:defs.bzl", "define_tests")
load("@bazel_skylib//lib:paths.bzl", "paths")

def define_arm_tests():
# TODO Add more tests
test_files = native.glob(["passes/test_*.py"])

# https://github.com/pytorch/executorch/issues/8606
test_files.remove("passes/test_ioquantization_pass.py")

TESTS = {}

for test_file in test_files:
test_file_name = paths.basename(test_file)
test_name = test_file_name.replace("test_", "").replace(".py", "")
TESTS[test_name] = [test_file]

define_tests(
pytest = True,
tests = TESTS,
pytest_config = "pytest.ini",
resources = ["conftest.py"],
preload_deps = [
"//executorch/kernels/quantized:custom_ops_generated_lib",
],
deps = [
":arm_tester",
":conftest",
"//executorch/exir:lib",
"fbsource//third-party/pypi/pytest:pytest",
"fbsource//third-party/pypi/parameterized:parameterized",
],
)
Loading