Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: import aioeapi module #676

Merged
merged 16 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 0 additions & 106 deletions anta/aioeapi.py

This file was deleted.

2 changes: 1 addition & 1 deletion anta/cli/exec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from pathlib import Path
from typing import TYPE_CHECKING, Literal

from aioeapi import EapiCommandError
from click.exceptions import UsageError
from httpx import ConnectError, HTTPError

from anta.device import AntaDevice, AsyncEOSDevice
from anta.models import AntaCommand
from asynceapi import EapiCommandError

if TYPE_CHECKING:
from anta.inventory import AntaInventory
Expand Down
43 changes: 23 additions & 20 deletions anta/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from asyncssh import SSHClientConnection, SSHClientConnectionOptions
from httpx import ConnectError, HTTPError, TimeoutException

from anta import __DEBUG__, aioeapi
import asynceapi
from anta import __DEBUG__
from anta.logger import anta_log_exception, exc_to_str
from anta.models import AntaCommand

Expand Down Expand Up @@ -116,7 +117,7 @@ def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
yield "disable_cache", self.cache is None

@abstractmethod
async def _collect(self, command: AntaCommand) -> None:
async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
"""Collect device command output.

This abstract coroutine can be used to implement any command collection method
Expand All @@ -131,11 +132,11 @@ async def _collect(self, command: AntaCommand) -> None:

Args:
----
command: the command to collect

command: The command to collect.
collection_id: An identifier that will used to build the eAPI request ID.
mtache marked this conversation as resolved.
Show resolved Hide resolved
"""

async def collect(self, command: AntaCommand) -> None:
async def collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
"""Collect the output for a specified command.

When caching is activated on both the device and the command,
Expand All @@ -148,8 +149,8 @@ async def collect(self, command: AntaCommand) -> None:

Args:
----
command (AntaCommand): The command to process.

command: The command to collect.
collection_id: An identifier that will used to build the eAPI request ID.
mtache marked this conversation as resolved.
Show resolved Hide resolved
"""
# Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough
# https://github.com/pylint-dev/pylint/issues/7258
Expand All @@ -161,20 +162,20 @@ async def collect(self, command: AntaCommand) -> None:
logger.debug("Cache hit for %s on %s", command.command, self.name)
command.output = cached_output
else:
await self._collect(command=command)
await self._collect(command=command, collection_id=collection_id)
await self.cache.set(command.uid, command.output) # pylint: disable=no-member
else:
await self._collect(command=command)
await self._collect(command=command, collection_id=collection_id)

async def collect_commands(self, commands: list[AntaCommand]) -> None:
async def collect_commands(self, commands: list[AntaCommand], *, collection_id: str | None = None) -> None:
"""Collect multiple commands.

Args:
----
commands: the commands to collect

commands: The commands to collect.
collection_id: An identifier that will used to build the eAPI request ID.
mtache marked this conversation as resolved.
Show resolved Hide resolved
"""
await asyncio.gather(*(self.collect(command=command) for command in commands))
await asyncio.gather(*(self.collect(command=command, collection_id=collection_id) for command in commands))

@abstractmethod
async def refresh(self) -> None:
Expand Down Expand Up @@ -270,7 +271,7 @@ def __init__(
raise ValueError(message)
self.enable = enable
self._enable_password = enable_password
self._session: aioeapi.Device = aioeapi.Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout)
self._session: asynceapi.Device = asynceapi.Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout)
ssh_params: dict[str, Any] = {}
if insecure:
ssh_params["known_hosts"] = None
Expand Down Expand Up @@ -305,7 +306,7 @@ def _keys(self) -> tuple[Any, ...]:
"""
return (self._session.host, self._session.port)

async def _collect(self, command: AntaCommand) -> None: # noqa: C901 function is too complex - because of many required except blocks
async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: # noqa: C901 function is too complex - because of many required except blocks #pylint: disable=line-too-long
"""Collect device command output from EOS using aio-eapi.

Supports outformat `json` and `text` as output structure.
Expand All @@ -314,9 +315,10 @@ async def _collect(self, command: AntaCommand) -> None: # noqa: C901 function

Args:
----
command: the AntaCommand to collect.
command: The command to collect.
collection_id: An identifier that will used to build the eAPI request ID.
mtache marked this conversation as resolved.
Show resolved Hide resolved
"""
commands: list[dict[str, Any]] = []
commands: list[dict[str, str | int]] = []
if self.enable and self._enable_password is not None:
commands.append(
{
Expand All @@ -329,14 +331,15 @@ async def _collect(self, command: AntaCommand) -> None: # noqa: C901 function
commands.append({"cmd": "enable"})
commands += [{"cmd": command.command, "revision": command.revision}] if command.revision else [{"cmd": command.command}]
try:
response: list[dict[str, Any]] = await self._session.cli(
response: list[dict[str, Any] | str] = await self._session.cli(
commands=commands,
ofmt=command.ofmt,
version=command.version,
)
req_id=f"ANTA-{collection_id}-{id(command)}" if collection_id else f"ANTA-{id(command)}",
) # type: ignore[assignment] # multiple commands returns a list
# Do not keep response of 'enable' command
command.output = response[-1]
except aioeapi.EapiCommandError as e:
except asynceapi.EapiCommandError as e:
# This block catches exceptions related to EOS issuing an error.
command.errors = e.errors
if command.requires_privileges:
Expand Down
2 changes: 1 addition & 1 deletion anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ async def collect(self) -> None:
"""Collect outputs of all commands of this test class from the device of this test instance."""
try:
if self.blocked is False:
await self.device.collect_commands(self.instance_commands)
await self.device.collect_commands(self.instance_commands, collection_id=self.name)
except Exception as e: # pylint: disable=broad-exception-caught
# device._collect() is user-defined code.
# We need to catch everything if we want the AntaTest object
Expand Down
9 changes: 9 additions & 0 deletions asynceapi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Initially written by Jeremy Schulman at https://github.com/jeremyschulman/aio-eapi
gmuloc marked this conversation as resolved.
Show resolved Hide resolved

"""Arista EOS eAPI asyncio client."""

from .config_session import SessionConfig
from .device import Device
from .errors import EapiCommandError

__all__ = ["Device", "SessionConfig", "EapiCommandError"]
55 changes: 55 additions & 0 deletions asynceapi/aio_portcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Utility function to check if a port is open."""
# -----------------------------------------------------------------------------
# System Imports
# -----------------------------------------------------------------------------

from __future__ import annotations

import asyncio
import socket
from typing import TYPE_CHECKING

# -----------------------------------------------------------------------------
# Public Imports
# -----------------------------------------------------------------------------

if TYPE_CHECKING:
from httpx import URL

# -----------------------------------------------------------------------------
# Exports
# -----------------------------------------------------------------------------

__all__ = ["port_check_url"]

# -----------------------------------------------------------------------------
#
# CODE BEGINS
#
# -----------------------------------------------------------------------------


async def port_check_url(url: URL, timeout: int = 5) -> bool:
"""
Open the port designated by the URL given the timeout in seconds.

If the port is avaialble then return True; False otherwise.

Parameters
----------
url: The URL that provides the target system
timeout: Time to await for the port to open in seconds
"""
port = url.port or socket.getservbyname(url.scheme)

try:
wr: asyncio.StreamWriter
_, wr = await asyncio.wait_for(asyncio.open_connection(host=url.host, port=port), timeout=timeout)

# MUST close if opened!
wr.close()

except TimeoutError:
return False
else:
return True
Loading
Loading