diff --git a/extensions/eda/plugins/event_source/subscription.py b/extensions/eda/plugins/event_source/subscription.py new file mode 100644 index 000000000..f8a8f176d --- /dev/null +++ b/extensions/eda/plugins/event_source/subscription.py @@ -0,0 +1,148 @@ +""" +subscription.py + +An ansible-rulebook event source plugin template. + +Arguments: + - apic1: APIC IP or hostname + - username: APIC username + - password: APIC password + - subscriptions: query api endpoints for subscriptions + - refresh_timeout: (optional) timeout of subscription + +Examples: + cisco.aci_eda.websocket: + hostname: apic1 + username: ansible + password: ansible + subscriptions: + - /api/mo/uni/tn-demo/ap-demo/epg-demo.json?query-target=children&target-subtree-class=fvCEp&query-target=subtree + refresh_timeout: 60 + +""" + +import asyncio +import json +import os +import sys +import signal +import ssl +from typing import Any, Dict, NoReturn +import requests +import websockets +import urllib3 + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +def login(hostname: str, username: str, password: str) -> str: + """ + login to apic and get session token + + :param hostname: apic hostname or ip + :param username: apic username + :param password: apic password + :return: session token + """ + login_url = f"https://{hostname}/api/aaaLogin.json" + payload = {"aaaUser": {"attributes": {"name": username, "pwd": password}}} + token = "" + + login_response = requests.post(login_url, json=payload, verify=False) + if login_response.ok: + response_json = login_response.json() + token = response_json["imdata"][0]["aaaLogin"]["attributes"]["token"] + return token + + +def subscribe(hostname: str, token: str, rf_timeout: int, sub_urls: list[str]) -> list[str]: + """ + subscribe to a websocket + + :param hostname: apic hostname or ip + :param token: apic session token + :param rf_timeout: refresh timeout of subscription + :param sub_urls: subscriptions url + :return: list of subscription ids + """ + sub_ids = [] + + for sub in sub_urls: + sub_url = f"https://{hostname}{sub}&subscription=yes&refresh-timeout={rf_timeout}" + cookie = {"APIC-cookie": token} + sub_response = requests.get(sub_url, verify=False, cookies=cookie, timeout=60) + if sub_response.ok: + sub_id = sub_response.json()["subscriptionId"] + sub_ids.append(sub_id) + return sub_ids + + +async def refresh(hostname: str, token: str, refresh_timeout: int, sub_ids: list[str]) -> NoReturn: + """ + refresh subscriptions + + :param hostname: apic hostname or ip + :param token: session token + :param refresh_timeout: subscription refresh timeout + :param sub_ids: subscription ids + :return: NoReturn + """ + cookie = {"APIC-cookie": token} + while True: + await asyncio.sleep(refresh_timeout / 2) + for sub_id in sub_ids: + refresh_url = f"https://{hostname}/api/subscriptionRefresh.json?id={sub_id}" + requests.get(refresh_url, verify=False, cookies=cookie, timeout=60) + + +async def main(queue: asyncio.Queue, args: Dict[str, Any]): + hostname = args.get("hostname", "") + username = args.get("username", "") + password = args.get("password", "") + refresh_timeout = int(args.get("refresh_timeout", 300)) + subscriptions = args.get("subscriptions") + + if "" in [hostname, username, password]: + print(f"hostname, username and password can't be empty:{hostname}, {username}, *****") + sys.exit(1) + + if not isinstance(subscriptions, list) or subscriptions == [] or subscriptions is None: + print(f"subscriptions is empty or not a list: {subscriptions}") + sys.exit(1) + + token = login(hostname=hostname, username=username, password=password) + websocket_uri = f"wss://{hostname}/socket{token}" + ctx = ssl.SSLContext() + + async with websockets.connect(websocket_uri, ssl=ctx) as ws: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(ws.close())) + sub_ids = subscribe(hostname, token, refresh_timeout, subscriptions) + + # task to refresh subscription token + asyncio.create_task(refresh(hostname, token, refresh_timeout, sub_ids)) + + async for message in ws: + await queue.put(json.loads(message)) + + +if __name__ == "__main__": + # this function is conly called when executed directly + apic_username = os.environ["apic_username"] + apic_password = os.environ["apic_password"] + apic_url = os.environ["apic_url"] + + class MockQueue(asyncio.Queue): + async def put(self, item): + print(item) + + mock_arguments = { + "hostname": apic_url, + "username": apic_username, + "password": apic_password, + "subscriptions": [ + '/api/node/class/faultInst.json?query-target-filter=and(eq(faultInst.code,"F1386"))', + ], + "refresh_timeout": 30, + } + asyncio.run(main(MockQueue(), mock_arguments)) diff --git a/requirements.txt b/requirements.txt index 073a0a8b6..a706c79d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,6 @@ cryptography pyOpenSSL python_dateutil xmljson +requests +websockets +asyncmock \ No newline at end of file diff --git a/tests/unit/event_source/__init__.py b/tests/unit/event_source/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/event_source/test_subscription.py b/tests/unit/event_source/test_subscription.py new file mode 100644 index 000000000..d83f61c20 --- /dev/null +++ b/tests/unit/event_source/test_subscription.py @@ -0,0 +1,79 @@ +from unittest.mock import patch +from typing import Any, List +from asyncmock import AsyncMock +from extensions.eda.plugins.event_source.subscription import main as subscription_main +import pytest +import json +import asyncio + + +# Refresh mock method +def refresh_patch(hostname: str, token: str, rf_timeout: int, sub_urls: List[str]) -> None: + pass + + +# Login mock method +def login_patch(hostname: str, username: str, password: str) -> str: + return f"{hostname}{username}{password}" + + +# Subscribe mock method +def subscribe_patch(hostname, token, rf_timeout, sub_urls) -> List[str]: + return [f"{hostname}{token}{rf_timeout}{url}" for url in sub_urls] + + +# Mock iterator +class AsyncIterator: + def __init__(self) -> None: + self.count = 0 + + async def __anext__(self) -> str: + if self.count < 2: + self.count += 1 + return json.dumps({"eventid": f"00{self.count}"}) + else: + raise StopAsyncIteration + + +# Mock Async Websocket +class MockWebSocket(AsyncMock): # type: ignore[misc] + def __aiter__(self) -> AsyncIterator: + return AsyncIterator() + + async def close(self) -> None: + pass + + +# Mock AsyncQueue +class MockQueue(asyncio.Queue[Any]): + def __init__(self) -> None: + self.queue: list[Any] = [] + + async def put(self, item: Any) -> None: + self.queue.append(item) + + +def test_websocket_subscription() -> None: + + with patch( + "websockets.connect", + return_value=MockWebSocket(), + ), patch("unit.event_source.tmp_subscription.login", return_value=login_patch), patch( + "unit.event_source.tmp_subscription.subscribe", return_value=subscribe_patch + ), patch("unit.event_source.tmp_subscription.refresh", return_value=refresh_patch): + + my_queue = MockQueue() + asyncio.run( + subscription_main( + my_queue, + { + "hostname": "my-apic.com", + "username": "admin", + "password": "admin", + "subscriptions": ['/api/node/class/faultInst.json?query-target-filter=and(eq(faultInst.code,"F1386"))'], + }, + ) + ) + + assert my_queue.queue[0] == {"eventid": "001"} + assert len(my_queue.queue) == 2