-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
340 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) 2023-2024 Arista Networks, Inc. | ||
# Use of this source code is governed by the Apache License 2.0 | ||
# that can be found in the LICENSE file. | ||
"""Benchmark tests for ANTA.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2023-2024 Arista Networks, Inc. | ||
# Use of this source code is governed by the Apache License 2.0 | ||
# that can be found in the LICENSE file. | ||
"""Fixtures for benchmarking ANTA.""" | ||
|
||
import logging | ||
|
||
import pytest | ||
import respx | ||
from _pytest.terminal import TerminalReporter | ||
|
||
from anta.catalog import AntaCatalog | ||
from anta.device import AsyncEOSDevice | ||
from anta.inventory import AntaInventory | ||
|
||
from .utils import AntaMockEnvironment | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
TEST_CASE_COUNT = None | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def catalog() -> AntaCatalog: | ||
"""Fixture that generate an ANTA catalog from unit test data. Also configure respx to mock eAPI responses.""" | ||
global TEST_CASE_COUNT # noqa: PLW0603 pylint: disable=global-statement | ||
eapi_route = respx.post(path="/command-api", headers={"Content-Type": "application/json-rpc"}) | ||
env = AntaMockEnvironment() | ||
TEST_CASE_COUNT = len(env.catalog.tests) | ||
eapi_route.side_effect = env.eapi_response | ||
return env.catalog | ||
|
||
|
||
@pytest.fixture | ||
def inventory(request: pytest.FixtureRequest) -> AntaInventory: | ||
"""Generate an ANTA inventory.""" | ||
inv = AntaInventory() | ||
for i in range(request.param["count"]): | ||
inv.add_device( | ||
AsyncEOSDevice( | ||
host=f"device-{i}.avd.arista.com", | ||
username="admin", | ||
password="admin", # noqa: S106 | ||
name=f"device-{i}", | ||
disable_cache=(not request.param["cache"]), | ||
) | ||
) | ||
return inv | ||
|
||
|
||
def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None: | ||
"""Display the total number of ANTA unit test cases used to benchmark.""" | ||
terminalreporter.write_sep("=", f"{TEST_CASE_COUNT} ANTA test cases") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) 2023-2024 Arista Networks, Inc. | ||
# Use of this source code is governed by the Apache License 2.0 | ||
# that can be found in the LICENSE file. | ||
"""Benchmark tests for ANTA.""" | ||
|
||
import asyncio | ||
import logging | ||
from unittest.mock import AsyncMock, Mock, patch | ||
|
||
import pytest | ||
import respx | ||
from pytest_codspeed import BenchmarkFixture | ||
|
||
from anta.catalog import AntaCatalog | ||
from anta.inventory import AntaInventory | ||
from anta.result_manager import ResultManager | ||
from anta.result_manager.models import AntaTestStatus | ||
from anta.runner import main | ||
|
||
from .utils import collect, collect_commands | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"inventory", | ||
[ | ||
pytest.param({"count": 1, "cache": False}, id="1 device"), | ||
pytest.param({"count": 2, "cache": False}, id="2 devices"), | ||
], | ||
indirect=True, | ||
) | ||
def test_anta_dry_run(benchmark: BenchmarkFixture, catalog: AntaCatalog, inventory: AntaInventory) -> None: | ||
"""Test and benchmark ANTA in Dry-Run Mode.""" | ||
# Disable logging during ANTA execution to avoid having these function time in benchmarks | ||
logging.disable() | ||
|
||
def bench() -> ResultManager: | ||
"""Need to wrap the ANTA Runner to instantiate a new ResultManger for each benchmark run.""" | ||
manager = ResultManager() | ||
asyncio.run(main(manager, inventory, catalog, dry_run=True)) | ||
return manager | ||
|
||
manager = benchmark(lambda: bench()) # pylint: disable=unnecessary-lambda | ||
logging.disable(logging.NOTSET) | ||
if len(manager.results) != 0: | ||
pytest.fail("ANTA Dry-Run mode should not return any result", pytrace=False) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"inventory", | ||
[ | ||
pytest.param({"count": 1, "cache": False}, id="1 device"), | ||
pytest.param({"count": 2, "cache": False}, id="2 devices"), | ||
], | ||
indirect=True, | ||
) | ||
@patch("asyncio.open_connection", AsyncMock(spec=asyncio.open_connection, return_value=(Mock(), Mock()))) # We just want all devices to be reachable | ||
@patch("anta.models.AntaTest.collect", collect) | ||
@patch("anta.device.AntaDevice.collect_commands", collect_commands) | ||
@respx.mock # Mock eAPI responses | ||
def test_anta(benchmark: BenchmarkFixture, catalog: AntaCatalog, inventory: AntaInventory) -> None: | ||
"""Test and benchmark ANTA. Mock eAPI responses.""" | ||
# Disable logging during ANTA execution to avoid having these function time in benchmarks | ||
logging.disable() | ||
|
||
def bench() -> ResultManager: | ||
"""Need to wrap the ANTA Runner to instantiate a new ResultManger for each benchmark run.""" | ||
manager = ResultManager() | ||
asyncio.run(main(manager, inventory, catalog)) | ||
return manager | ||
|
||
manager = benchmark(lambda: bench()) # pylint: disable=unnecessary-lambda | ||
logging.disable(logging.NOTSET) | ||
|
||
if len(catalog.tests) * len(inventory) != len(manager.results): | ||
# This could mean duplicates exist. | ||
# TODO: consider removing this code and refactor unit test data as a dictionary with tuple keys instead of a list | ||
seen = set() | ||
dupes = [] | ||
for test in catalog.tests: | ||
if test in seen: | ||
dupes.append(test) | ||
else: | ||
seen.add(test) | ||
if dupes: | ||
for test in dupes: | ||
msg = f"Found duplicate in test catalog: {test}" | ||
logger.error(msg) | ||
pytest.fail(f"Expected {len(catalog.tests) * len(inventory)} test results but got {len(manager.results)}", pytrace=False) | ||
bench_info = ( | ||
"\n--- ANTA NRFU Benchmark Information ---\n" | ||
f"Test results: {len(manager.results)}\n" | ||
f"Success: {manager.get_total_results({AntaTestStatus.SUCCESS})}\n" | ||
f"Failure: {manager.get_total_results({AntaTestStatus.FAILURE})}\n" | ||
f"Skipped: {manager.get_total_results({AntaTestStatus.SKIPPED})}\n" | ||
f"Error: {manager.get_total_results({AntaTestStatus.ERROR})}\n" | ||
f"Unset: {manager.get_total_results({AntaTestStatus.UNSET})}\n" | ||
"---------------------------------------" | ||
) | ||
logger.info(bench_info) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright (c) 2023-2024 Arista Networks, Inc. | ||
# Use of this source code is governed by the Apache License 2.0 | ||
# that can be found in the LICENSE file. | ||
"""Utils for the ANTA benchmark tests.""" | ||
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
import importlib | ||
import json | ||
import pkgutil | ||
from typing import TYPE_CHECKING, Any | ||
|
||
import httpx | ||
from pydantic import ValidationError | ||
|
||
from anta.catalog import AntaCatalog, AntaTestDefinition | ||
from anta.models import AntaCommand, AntaTest | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Generator | ||
from types import ModuleType | ||
|
||
from anta.device import AntaDevice | ||
|
||
|
||
async def collect(self: AntaTest) -> None: | ||
"""Patched anta.models.AntaTest.collect() method. | ||
When generating the catalog, we inject a unit test case name in the custom_field input to be able to retrieve the eos_data for this specific test. | ||
We use this unit test case name in the eAPI request ID. | ||
""" | ||
if self.inputs.result_overwrite is None or self.inputs.result_overwrite.custom_field is None: | ||
msg = f"The custom_field input is not present for test {self.name}" | ||
raise RuntimeError(msg) | ||
await self.device.collect_commands(self.instance_commands, collection_id=f"{self.name}:{self.inputs.result_overwrite.custom_field}") | ||
|
||
|
||
async def collect_commands(self: AntaDevice, commands: list[AntaCommand], collection_id: str) -> None: | ||
"""Patched anta.device.AntaDevice.collect_commands() method. | ||
For the same reason as above, we inject the command index of the test to the eAPI request ID. | ||
""" | ||
await asyncio.gather(*(self.collect(command=command, collection_id=f"{collection_id}:{idx}") for idx, command in enumerate(commands))) | ||
|
||
|
||
class AntaMockEnvironment: # pylint: disable=too-few-public-methods | ||
"""Generate an ANTA test catalog from the unit tests data. It can be accessed using the `catalog` attribute of this class instance. | ||
Also provide the attribute 'eos_data_catalog` with the output of all the commands used in the test catalog. | ||
Each module in `tests.units.anta_tests` has a `DATA` constant. | ||
The `DATA` structure is a list of dictionaries used to parametrize the test. The list elements have the following keys: | ||
- `name` (str): Test name as displayed by Pytest. | ||
- `test` (AntaTest): An AntaTest subclass imported in the test module - e.g. VerifyUptime. | ||
- `eos_data` (list[dict]): List of data mocking EOS returned data to be passed to the test. | ||
- `inputs` (dict): Dictionary to instantiate the `test` inputs as defined in the class from `test`. | ||
The keys of `eos_data_catalog` is the tuple (DATA['test'], DATA['name']). The values are `eos_data`. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self.catalog, self.eos_data_catalog = self._generate_catalog() | ||
|
||
def _generate_catalog(self) -> tuple[AntaCatalog, dict[tuple[str, str], list[dict[str, Any]]]]: | ||
"""Generate the `catalog` and `eos_data_catalog` attributes.""" | ||
|
||
def import_test_modules() -> Generator[ModuleType, None, None]: | ||
"""Yield all test modules from the given package.""" | ||
package = importlib.import_module("tests.units.anta_tests") | ||
prefix = package.__name__ + "." | ||
for _, module_name, is_pkg in pkgutil.walk_packages(package.__path__, prefix): | ||
if not is_pkg and module_name.split(".")[-1].startswith("test_"): | ||
module = importlib.import_module(module_name) | ||
if hasattr(module, "DATA"): | ||
yield module | ||
|
||
test_definitions = [] | ||
eos_data_catalog = {} | ||
for module in import_test_modules(): | ||
for test_data in module.DATA: | ||
test = test_data["test"] | ||
result_overwrite = AntaTest.Input.ResultOverwrite(custom_field=test_data["name"]) | ||
# Some unit tests purposely have invalid inputs, we just skip them | ||
try: | ||
if test_data["inputs"] is None: | ||
inputs = test.Input(result_overwrite=result_overwrite) | ||
else: | ||
inputs = test.Input(**test_data["inputs"], result_overwrite=result_overwrite) | ||
except ValidationError: | ||
continue | ||
|
||
test_definition = AntaTestDefinition( | ||
test=test, | ||
inputs=inputs, | ||
) | ||
eos_data_catalog[(test.__name__, test_data["name"])] = test_data["eos_data"] | ||
test_definitions.append(test_definition) | ||
|
||
return (AntaCatalog(tests=test_definitions), eos_data_catalog) | ||
|
||
def eapi_response(self, request: httpx.Request) -> httpx.Response: | ||
"""Mock eAPI response. | ||
If the eAPI request ID has the format `ANTA-{test name}:{unit test name}:{command index}-{command ID}`, | ||
the function will return the eos_data from the unit test case. | ||
Otherwise, it will mock 'show version' command or raise an Exception. | ||
""" | ||
words_count = 3 | ||
|
||
def parse_req_id(req_id: str) -> tuple[str, str, int] | None: | ||
"""Parse the patched request ID from the eAPI request.""" | ||
req_id = req_id.removeprefix("ANTA-").rpartition("-")[0] | ||
words = req_id.split(":", words_count) | ||
if len(words) == words_count: | ||
test_name, unit_test_name, command_index = words | ||
return test_name, unit_test_name, int(command_index) | ||
return None | ||
|
||
jsonrpc = json.loads(request.content) | ||
assert jsonrpc["method"] == "runCmds" | ||
commands = jsonrpc["params"]["cmds"] | ||
ofmt = jsonrpc["params"]["format"] | ||
req_id: str = jsonrpc["id"] | ||
result = None | ||
|
||
# Extract the test name, unit test name, and command index from the request ID | ||
if (words := parse_req_id(req_id)) is not None: | ||
test_name, unit_test_name, idx = words | ||
|
||
# This should never happen, but better be safe than sorry | ||
if (test_name, unit_test_name) not in self.eos_data_catalog: | ||
msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: eos_data not found" | ||
raise RuntimeError(msg) | ||
|
||
eos_data = self.eos_data_catalog[(test_name, unit_test_name)] | ||
|
||
# This could happen if the unit test data is not correctly defined | ||
if idx >= len(eos_data): | ||
msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: missing test case in eos_data" | ||
raise RuntimeError(msg) | ||
result = {"output": eos_data[idx]} if ofmt == "text" else eos_data[idx] | ||
elif {"cmd": "show version"} in commands and ofmt == "json": | ||
# Mock 'show version' request performed during inventory refresh. | ||
result = { | ||
"modelName": "pytest", | ||
} | ||
|
||
if result is not None: | ||
return httpx.Response( | ||
status_code=200, | ||
json={ | ||
"jsonrpc": "2.0", | ||
"id": req_id, | ||
"result": [result], | ||
}, | ||
) | ||
msg = f"The following eAPI Request has not been mocked: {jsonrpc}" | ||
raise NotImplementedError(msg) |