Skip to content

Commit

Permalink
Refactor runner to add limit
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-baillargeon committed Jun 12, 2024
1 parent 6df9006 commit 13f7195
Showing 1 changed file with 70 additions and 17 deletions.
87 changes: 70 additions & 17 deletions anta/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from anta.tools import Catchtime, cprofile

if TYPE_CHECKING:
from collections.abc import Coroutine
from asyncio import Future
from collections.abc import AsyncGenerator, Coroutine

from anta.catalog import AntaCatalog, AntaTestDefinition
from anta.device import AntaDevice
Expand All @@ -29,6 +30,7 @@
logger = logging.getLogger(__name__)

DEFAULT_NOFILE = 16384
MAXIMUM_TEST_CONCURRENCY = 20000


def adjust_rlimit_nofile() -> tuple[int, int]:
Expand Down Expand Up @@ -75,6 +77,56 @@ def log_cache_statistics(devices: list[AntaDevice]) -> None:
logger.info("Caching is not enabled on %s", device.name)


async def run_tests(tests_generator: AsyncGenerator[Coroutine[Any, Any, TestResult], None], limit: int) -> AsyncGenerator[TestResult, None]:
"""Run tests with a concurrency limit.
This function takes an asynchronous generator of test coroutines and runs them
with a limit on the number of concurrent tests. It yields test results as each
test completes.
Args:
----
tests_generator: An asynchronous generator that yields test coroutines.
limit: The maximum number of concurrent tests to run.
Yields
------
TestResult: The result of each completed test.
"""
aws = aiter(tests_generator)
aws_ended = False
pending: set[Future[TestResult]] = set()

while pending or not aws_ended:
# Add tests to the pending set until the limit is reached or no more tests are available
while len(pending) < limit and not aws_ended:
try:
# Get the next test coroutine from the generator
aw = await anext(aws)
except StopAsyncIteration: # noqa: PERF203
aws_ended = True
logger.debug("All tests have been added to the pending set.")
else:
# Ensure the coroutine is scheduled to run and add it to the pending set
pending.add(asyncio.ensure_future(aw))
logger.debug("Added a test to the pending set: %s", aw)

if len(pending) >= limit:
logger.debug("Concurrency limit reached: %s tests running. Waiting for tests to complete.", limit)

if not pending:
logger.debug("No pending tests and all tests have been processed. Exiting.")
return

# Wait for at least one of the pending tests to complete
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
logger.debug("Completed %s test(s). Pending count: %s", len(done), len(pending))

# Yield results of completed tests
while done:
yield await done.pop()


async def setup_inventory(inventory: AntaInventory, tags: set[str] | None, devices: set[str] | None, *, established_only: bool) -> AntaInventory | None:
"""Set up the inventory for the ANTA run.
Expand Down Expand Up @@ -157,23 +209,24 @@ def prepare_tests(
return device_to_tests


def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]]) -> list[Coroutine[Any, Any, TestResult]]:
"""Get the coroutines for the ANTA run.
async def generate_tests(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]]) -> AsyncGenerator[Coroutine[Any, Any, TestResult], None]:
"""Generate the coroutines for the ANTA run.
It creates an async generator of coroutines which are created by the `test` method of the AntaTest instances. Each coroutine is a test to run.
Args:
----
selected_tests: A mapping of devices to the tests to run. The selected tests are generated by the `prepare_tests` function.
selected_tests: A mapping of devices to the tests to run. The selected tests are created by the `prepare_tests` function.
Returns
-------
The list of coroutines to run.
Yields
------
Coroutine[Any, Any, TestResult]: The coroutine (test) to run.
"""
coros = []
for device, test_definitions in selected_tests.items():
for test in test_definitions:
try:
test_instance = test.test(device=device, inputs=test.inputs)
coros.append(test_instance.test())
coroutine = test_instance.test()
except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught
# An AntaTest instance is potentially user-defined code.
# We need to catch everything and exit gracefully with an error message.
Expand All @@ -184,7 +237,8 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio
],
)
anta_log_exception(e, message, logger)
return coros
else:
yield coroutine


@cprofile()
Expand Down Expand Up @@ -251,20 +305,19 @@ async def main( # noqa: PLR0913
"Please consult the ANTA FAQ."
)

coroutines = get_coroutines(selected_tests)
tests_generator = generate_tests(selected_tests)

if dry_run:
logger.info("Dry-run mode, exiting before running the tests.")
for coro in coroutines:
coro.close()
async for test in tests_generator:
test.close()
return

if AntaTest.progress is not None:
AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coroutines))
AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=catalog.final_tests_count)

with Catchtime(logger=logger, message="Running ANTA tests"):
test_results = await asyncio.gather(*coroutines)
for r in test_results:
manager.add(r)
async for result in run_tests(tests_generator, limit=MAXIMUM_TEST_CONCURRENCY):
manager.add(result)

log_cache_statistics(selected_inventory.devices)

0 comments on commit 13f7195

Please sign in to comment.