Skip to content

Commit

Permalink
ci: add codspeed to benchmark ANTA
Browse files Browse the repository at this point in the history
  • Loading branch information
mtache committed Sep 12, 2024
1 parent abdaea1 commit 2d894fc
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 0 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/code-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,20 @@ jobs:
run: pip install .[doc]
- name: "Build mkdocs documentation offline"
run: mkdocs build
benchmarks:
name: Benchmark ANTA for Python 3.12
runs-on: ubuntu-latest
needs: [test-python]
steps:
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install dependencies
run: pip install .[dev]
- name: Run benchmarks
uses: CodSpeedHQ/action@v3
with:
token: ${{ secrets.CODSPEED_TOKEN }}
run: pytest --codspeed --no-cov --log-cli-level INFO tests/benchmark
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ repos:
- types-pyOpenSSL
- pylint_pydantic
- pytest
- pytest-codspeed
- respx

- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ dev = [
"pytest-asyncio>=0.21.1",
"pytest-cov>=4.1.0",
"pytest-dependency",
"pytest-codspeed>=2.2.0",
"respx",
"pytest-html>=3.2.0",
"pytest-httpx>=0.30.0",
"pytest-metadata>=3.0.0",
Expand Down Expand Up @@ -167,6 +169,7 @@ addopts = "-ra -q -vv --cov --cov-report term:skip-covered --color yes"
log_level = "WARNING"
render_collapsed = true
testpaths = ["tests"]
norecursedirs = ["tests/benchmark"] # Do not run performance testing outside of Codspeed
filterwarnings = [
"error",
# cvprac is raising the next warning
Expand Down
4 changes: 4 additions & 0 deletions tests/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Benchmark tests for ANTA."""
53 changes: 53 additions & 0 deletions tests/benchmark/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Fixtures for benchmarking ANTA."""

import logging

import pytest
import respx
from _pytest.terminal import TerminalReporter

from anta.catalog import AntaCatalog
from anta.device import AsyncEOSDevice
from anta.inventory import AntaInventory

from .utils import AntaMockEnvironment

logger = logging.getLogger(__name__)

TEST_CASE_COUNT = None


@pytest.fixture(scope="session")
def catalog() -> AntaCatalog:
"""Fixture that generate an ANTA catalog from unit test data. Also configure respx to mock eAPI responses."""
global TEST_CASE_COUNT # noqa: PLW0603 pylint: disable=global-statement
eapi_route = respx.post(path="/command-api", headers={"Content-Type": "application/json-rpc"})
env = AntaMockEnvironment()
TEST_CASE_COUNT = len(env.catalog.tests)
eapi_route.side_effect = env.eapi_response
return env.catalog


@pytest.fixture
def inventory(request: pytest.FixtureRequest) -> AntaInventory:
"""Generate an ANTA inventory."""
inv = AntaInventory()
for i in range(request.param["count"]):
inv.add_device(
AsyncEOSDevice(
host=f"device-{i}.avd.arista.com",
username="admin",
password="admin", # noqa: S106
name=f"device-{i}",
disable_cache=(not request.param["cache"]),
)
)
return inv


def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
"""Display the total number of ANTA unit test cases used to benchmark."""
terminalreporter.write_sep("=", f"{TEST_CASE_COUNT} ANTA test cases")
101 changes: 101 additions & 0 deletions tests/benchmark/test_anta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Benchmark tests for ANTA."""

import asyncio
import logging
from unittest.mock import AsyncMock, Mock, patch

import pytest
import respx
from pytest_codspeed import BenchmarkFixture

from anta.catalog import AntaCatalog
from anta.inventory import AntaInventory
from anta.result_manager import ResultManager
from anta.result_manager.models import AntaTestStatus
from anta.runner import main

from .utils import collect, collect_commands

logger = logging.getLogger(__name__)


@pytest.mark.parametrize(
"inventory",
[
pytest.param({"count": 1, "cache": False}, id="1 device"),
pytest.param({"count": 2, "cache": False}, id="2 devices"),
],
indirect=True,
)
def test_anta_dry_run(benchmark: BenchmarkFixture, catalog: AntaCatalog, inventory: AntaInventory) -> None:
"""Test and benchmark ANTA in Dry-Run Mode."""
# Disable logging during ANTA execution to avoid having these function time in benchmarks
logging.disable()

def bench() -> ResultManager:
"""Need to wrap the ANTA Runner to instantiate a new ResultManger for each benchmark run."""
manager = ResultManager()
asyncio.run(main(manager, inventory, catalog, dry_run=True))
return manager

manager = benchmark(lambda: bench()) # pylint: disable=unnecessary-lambda
logging.disable(logging.NOTSET)
if len(manager.results) != 0:
pytest.fail("ANTA Dry-Run mode should not return any result", pytrace=False)


@pytest.mark.parametrize(
"inventory",
[
pytest.param({"count": 1, "cache": False}, id="1 device"),
pytest.param({"count": 2, "cache": False}, id="2 devices"),
],
indirect=True,
)
@patch("asyncio.open_connection", AsyncMock(spec=asyncio.open_connection, return_value=(Mock(), Mock()))) # We just want all devices to be reachable
@patch("anta.models.AntaTest.collect", collect)
@patch("anta.device.AntaDevice.collect_commands", collect_commands)
@respx.mock # Mock eAPI responses
def test_anta(benchmark: BenchmarkFixture, catalog: AntaCatalog, inventory: AntaInventory) -> None:
"""Test and benchmark ANTA. Mock eAPI responses."""
# Disable logging during ANTA execution to avoid having these function time in benchmarks
logging.disable()

def bench() -> ResultManager:
"""Need to wrap the ANTA Runner to instantiate a new ResultManger for each benchmark run."""
manager = ResultManager()
asyncio.run(main(manager, inventory, catalog))
return manager

manager = benchmark(lambda: bench()) # pylint: disable=unnecessary-lambda
logging.disable(logging.NOTSET)

if len(catalog.tests) * len(inventory) != len(manager.results):
# This could mean duplicates exist.
# TODO: consider removing this code and refactor unit test data as a dictionary with tuple keys instead of a list
seen = set()
dupes = []
for test in catalog.tests:
if test in seen:
dupes.append(test)
else:
seen.add(test)
if dupes:
for test in dupes:
msg = f"Found duplicate in test catalog: {test}"
logger.error(msg)
pytest.fail(f"Expected {len(catalog.tests) * len(inventory)} test results but got {len(manager.results)}", pytrace=False)
bench_info = (
"\n--- ANTA NRFU Benchmark Information ---\n"
f"Test results: {len(manager.results)}\n"
f"Success: {manager.get_total_results({AntaTestStatus.SUCCESS})}\n"
f"Failure: {manager.get_total_results({AntaTestStatus.FAILURE})}\n"
f"Skipped: {manager.get_total_results({AntaTestStatus.SKIPPED})}\n"
f"Error: {manager.get_total_results({AntaTestStatus.ERROR})}\n"
f"Unset: {manager.get_total_results({AntaTestStatus.UNSET})}\n"
"---------------------------------------"
)
logger.info(bench_info)
160 changes: 160 additions & 0 deletions tests/benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Utils for the ANTA benchmark tests."""

from __future__ import annotations

import asyncio
import importlib
import json
import pkgutil
from typing import TYPE_CHECKING, Any

import httpx
from pydantic import ValidationError

from anta.catalog import AntaCatalog, AntaTestDefinition
from anta.models import AntaCommand, AntaTest

if TYPE_CHECKING:
from collections.abc import Generator
from types import ModuleType

from anta.device import AntaDevice


async def collect(self: AntaTest) -> None:
"""Patched anta.models.AntaTest.collect() method.
When generating the catalog, we inject a unit test case name in the custom_field input to be able to retrieve the eos_data for this specific test.
We use this unit test case name in the eAPI request ID.
"""
if self.inputs.result_overwrite is None or self.inputs.result_overwrite.custom_field is None:
msg = f"The custom_field input is not present for test {self.name}"
raise RuntimeError(msg)
await self.device.collect_commands(self.instance_commands, collection_id=f"{self.name}:{self.inputs.result_overwrite.custom_field}")


async def collect_commands(self: AntaDevice, commands: list[AntaCommand], collection_id: str) -> None:
"""Patched anta.device.AntaDevice.collect_commands() method.
For the same reason as above, we inject the command index of the test to the eAPI request ID.
"""
await asyncio.gather(*(self.collect(command=command, collection_id=f"{collection_id}:{idx}") for idx, command in enumerate(commands)))


class AntaMockEnvironment: # pylint: disable=too-few-public-methods
"""Generate an ANTA test catalog from the unit tests data. It can be accessed using the `catalog` attribute of this class instance.
Also provide the attribute 'eos_data_catalog` with the output of all the commands used in the test catalog.
Each module in `tests.units.anta_tests` has a `DATA` constant.
The `DATA` structure is a list of dictionaries used to parametrize the test. The list elements have the following keys:
- `name` (str): Test name as displayed by Pytest.
- `test` (AntaTest): An AntaTest subclass imported in the test module - e.g. VerifyUptime.
- `eos_data` (list[dict]): List of data mocking EOS returned data to be passed to the test.
- `inputs` (dict): Dictionary to instantiate the `test` inputs as defined in the class from `test`.
The keys of `eos_data_catalog` is the tuple (DATA['test'], DATA['name']). The values are `eos_data`.
"""

def __init__(self) -> None:
self.catalog, self.eos_data_catalog = self._generate_catalog()

def _generate_catalog(self) -> tuple[AntaCatalog, dict[tuple[str, str], list[dict[str, Any]]]]:
"""Generate the `catalog` and `eos_data_catalog` attributes."""

def import_test_modules() -> Generator[ModuleType, None, None]:
"""Yield all test modules from the given package."""
package = importlib.import_module("tests.units.anta_tests")
prefix = package.__name__ + "."
for _, module_name, is_pkg in pkgutil.walk_packages(package.__path__, prefix):
if not is_pkg and module_name.split(".")[-1].startswith("test_"):
module = importlib.import_module(module_name)
if hasattr(module, "DATA"):
yield module

test_definitions = []
eos_data_catalog = {}
for module in import_test_modules():
for test_data in module.DATA:
test = test_data["test"]
result_overwrite = AntaTest.Input.ResultOverwrite(custom_field=test_data["name"])
# Some unit tests purposely have invalid inputs, we just skip them
try:
if test_data["inputs"] is None:
inputs = test.Input(result_overwrite=result_overwrite)
else:
inputs = test.Input(**test_data["inputs"], result_overwrite=result_overwrite)
except ValidationError:
continue

test_definition = AntaTestDefinition(
test=test,
inputs=inputs,
)
eos_data_catalog[(test.__name__, test_data["name"])] = test_data["eos_data"]
test_definitions.append(test_definition)

return (AntaCatalog(tests=test_definitions), eos_data_catalog)

def eapi_response(self, request: httpx.Request) -> httpx.Response:
"""Mock eAPI response.
If the eAPI request ID has the format `ANTA-{test name}:{unit test name}:{command index}-{command ID}`,
the function will return the eos_data from the unit test case.
Otherwise, it will mock 'show version' command or raise an Exception.
"""
words_count = 3

def parse_req_id(req_id: str) -> tuple[str, str, int] | None:
"""Parse the patched request ID from the eAPI request."""
req_id = req_id.removeprefix("ANTA-").rpartition("-")[0]
words = req_id.split(":", words_count)
if len(words) == words_count:
test_name, unit_test_name, command_index = words
return test_name, unit_test_name, int(command_index)
return None

jsonrpc = json.loads(request.content)
assert jsonrpc["method"] == "runCmds"
commands = jsonrpc["params"]["cmds"]
ofmt = jsonrpc["params"]["format"]
req_id: str = jsonrpc["id"]
result = None

# Extract the test name, unit test name, and command index from the request ID
if (words := parse_req_id(req_id)) is not None:
test_name, unit_test_name, idx = words

# This should never happen, but better be safe than sorry
if (test_name, unit_test_name) not in self.eos_data_catalog:
msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: eos_data not found"
raise RuntimeError(msg)

eos_data = self.eos_data_catalog[(test_name, unit_test_name)]

# This could happen if the unit test data is not correctly defined
if idx >= len(eos_data):
msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: missing test case in eos_data"
raise RuntimeError(msg)
result = {"output": eos_data[idx]} if ofmt == "text" else eos_data[idx]
elif {"cmd": "show version"} in commands and ofmt == "json":
# Mock 'show version' request performed during inventory refresh.
result = {
"modelName": "pytest",
}

if result is not None:
return httpx.Response(
status_code=200,
json={
"jsonrpc": "2.0",
"id": req_id,
"result": [result],
},
)
msg = f"The following eAPI Request has not been mocked: {jsonrpc}"
raise NotImplementedError(msg)

0 comments on commit 2d894fc

Please sign in to comment.