diff --git a/examples/client.py b/examples/client.py index 610b7b7..d08ba24 100644 --- a/examples/client.py +++ b/examples/client.py @@ -98,7 +98,7 @@ async def main(): total_sample_count = 0 total_request_time = 0 - batch_size = 500 + batch_size = 5000 number_of_batches = 5 ioloop = asyncio.get_event_loop() diff --git a/examples/server.py b/examples/server.py index 8e63c28..b6c4400 100644 --- a/examples/server.py +++ b/examples/server.py @@ -84,10 +84,13 @@ async def main(): try: logging.info("Service %s ready", service_example._service_name) await asyncio.Event().wait() - except KeyboardInterrupt: logging.info("Shutting down signal received") - - logging.info("Service %s done", service_example._service_name) + await container.stop() + except BaseException as err: + logging.info("Shutting down. %s: %s", type(err).__name__, err) + await container.stop() + finally: + logging.info("Service %s done", service_example._service_name) if __name__ == "__main__": @@ -99,4 +102,7 @@ async def main(): logging.getLogger("pika").setLevel(logging.CRITICAL) logging.getLogger("etcd3").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) - asyncio.run(main()) + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/examples/server_basic.py b/examples/server_basic.py index f1dc680..ce3a684 100644 --- a/examples/server_basic.py +++ b/examples/server_basic.py @@ -42,12 +42,16 @@ async def main(): transport_settings=transport_settings, ) await mesh_api.connect() - try: logging.info("Service %s ready", service_instance._service_name) await asyncio.Event().wait() - except KeyboardInterrupt: logging.info("Shutting down signal received") + await mesh_api.disconnect() + except BaseException as err: + logging.info("Shutting down. %s: %s", type(err).__name__, err) + await mesh_api.disconnect() + finally: + logging.info("Service %s done", service_instance._service_name) logging.info("Service %s done", service_instance._service_name) @@ -61,4 +65,7 @@ async def main(): logging.getLogger("pika").setLevel(logging.WARNING) logging.getLogger("etcd3").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) - asyncio.run(main()) + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/pyproject.toml b/pyproject.toml index ac9009b..5ddca48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "nuropb" -version = "0.1.6" +version = "0.1.7" description = "NuroPb - A Distributed Event Driven Service Mesh" authors = ["Robert Betts "] readme = "README.md" @@ -47,6 +47,7 @@ env_files = [".env_test"] testpaths = ["tests"] [tool.pytest.ini_options] +asyncio_mode = "strict" log_cli = true log_level = "DEBUG" log_cli_format = " %(levelname).1s %(asctime)s,%(msecs)d %(module)s %(lineno)s %(message)s" diff --git a/src/nuropb/contexts/describe.py b/src/nuropb/contexts/describe.py index cdf97b6..daf7d84 100644 --- a/src/nuropb/contexts/describe.py +++ b/src/nuropb/contexts/describe.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) -AuthoriseFunc = Callable[[str], Dict[str, Any]] +AuthoriseFunc = Callable[[str], Optional[Dict[str, Any]]] def method_visible_on_mesh(method: Callable[..., Any]) -> bool: @@ -87,10 +87,21 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: def describe_service(class_instance: object) -> Dict[str, Any] | None: - """Returns a description of the class methods that will be exposed to the service mesh""" + """Returns a description of the class methods that will be exposed to the service mesh + """ + service_info = { + "service_name": "", + "service_version": "", + "description": "", + "encrypted_methods": [], + "methods": {}, + "warnings": [], + } + if class_instance is None: logger.warning("No service class base has been input") - return None + service_info["warnings"].append("No service class base") + else: service_name = getattr(class_instance, "_service_name", None) service_description = getattr(class_instance, "__doc__", None) @@ -174,18 +185,19 @@ def map_argument(arg_props: Any) -> Tuple[str, Dict[str, Any]]: } methods.append((name, method_spec)) - service_info = { + service_info.update({ "service_name": service_name, "service_version": service_version, "description": service_description, "encrypted_methods": service_has_encrypted_methods, "methods": dict(methods), - } + }) if service_has_encrypted_methods: private_key = service_name = getattr(class_instance, "_private_key", None) if private_key is None: - raise ValueError( + service_info["warnings"].append("Service has encrypted methods but no private key has been set.") + logger.debug( f"Service {service_name} has encrypted methods but no private key has been set" ) diff --git a/src/nuropb/encodings/encryption.py b/src/nuropb/encodings/encryption.py index 8d32985..34da85a 100644 --- a/src/nuropb/encodings/encryption.py +++ b/src/nuropb/encodings/encryption.py @@ -101,12 +101,12 @@ def add_public_key( """ self._service_public_keys[service_name] = public_key - def get_public_key(self, service_name: str) -> rsa.RSAPublicKey: + def get_public_key(self, service_name: str) -> rsa.RSAPublicKey | None: """Get a public key for a service :param service_name: str :return: rsa.RSAPublicKey """ - return self._service_public_keys.get[service_name] + return self._service_public_keys.get(service_name) def has_public_key(self, service_name: str) -> bool: """Check if a service has a public key @@ -142,7 +142,12 @@ def encrypt_payload( if service_name is None: # Mode 4, get public key from the private key - public_key = self._private_key.public_key() + if self._private_key is None: + raise ValueError( + f"Service public key not found for service: {self._service_name}" + ) # pragma: no cover + else: + public_key = self._private_key.public_key() else: # Mode 1, get public key from the destination service's public key public_key = self._service_public_keys[service_name] @@ -175,6 +180,10 @@ def decrypt_payload(self, payload: bytes, correlation_id: str) -> bytes: encrypted_key, encrypted_payload = payload.split(b".", 1) if correlation_id not in self._correlation_id_symmetric_keys: """Mode 3, use public key from the private key to decrypt key""" + if self._private_key is None: + raise ValueError( + f"Service public key not found for service: {self._service_name}" + ) # pragma: no cover key = decrypt_key(encrypted_key, self._private_key) # remember the key for this correlation_id to encrypt the response self._correlation_id_symmetric_keys[correlation_id] = key diff --git a/src/nuropb/interface.py b/src/nuropb/interface.py index d8d0a12..ea263f9 100644 --- a/src/nuropb/interface.py +++ b/src/nuropb/interface.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -NUROPB_VERSION = "0.1.6" +NUROPB_VERSION = "0.1.7" NUROPB_PROTOCOL_VERSION = "0.1.1" NUROPB_PROTOCOL_VERSIONS_SUPPORTED = ("0.1.1",) NUROPB_MESSAGE_TYPES = ( @@ -260,6 +260,26 @@ class NuropbTimeoutError(NuropbException): class NuropbTransportError(NuropbException): """NuropbTransportError: represents an error that inside the plumbing.""" + _close_connection: bool + + def __init__( + self, + description: Optional[str] = None, + payload: Optional[PayloadDict] = None, + exception: Optional[BaseException] = None, + close_connection: bool = False, + ): + super().__init__( + description=description, + payload=payload, + exception=exception, + ) + self._close_connection = close_connection + + @property + def close_connection(self) -> bool: + """close_connection: returns True if the connection should be closed""" + return self._close_connection class NuropbMessageError(NuropbException): diff --git a/src/nuropb/rmq_api.py b/src/nuropb/rmq_api.py index 393e302..e7e350d 100644 --- a/src/nuropb/rmq_api.py +++ b/src/nuropb/rmq_api.py @@ -27,7 +27,19 @@ class RMQAPI(NuropbInterface): - """RMQAPI: A NuropbInterface implementation that uses RabbitMQ as the underlying transport.""" + """RMQAPI: A NuropbInterface implementation that uses RabbitMQ as the underlying transport. + + When an existing transport initialised and connected, and a subsequent transport + instance is connected with the same service_name and instance_id as the first, the broker + will shut down the channel of subsequent instances when they attempt to configure their + response queue. This is because the response queue is opened in exclusive mode. The + exclusive mode is used to ensure that only one consumer (nuropb api connection) is + consuming from the response queue. + + Deliberately specifying a fixed instance_id, is a valid mechanism to ensure that a service + can only run in single instance mode. This is useful for services that are not designed to + be run in a distributed manner or where there is specific service configuration required. + """ _mesh_name: str _connection_name: str @@ -42,6 +54,17 @@ class RMQAPI(NuropbInterface): _service_discovery: Dict[str, Any] _service_public_keys: Dict[str, Any] + @classmethod + def _get_vhost(cls, amqp_url: str | Dict[str, Any]) -> str: + if isinstance(amqp_url, str): + parts = amqp_url.split("/") + vhost = amqp_url.split("/")[-1] + if len(parts) < 4: + raise ValueError("Invalid amqp_url, missing vhost") + else: + vhost = amqp_url["vhost"] + return vhost + def __init__( self, amqp_url: str | Dict[str, Any], @@ -52,14 +75,12 @@ def __init__( events_exchange: Optional[str] = None, transport_settings: Optional[Dict[str, Any]] = None, ): - """RMQAPI: A NuropbInterface implementation that uses RabbitMQ as the underlying transport.""" - if isinstance(amqp_url, str): - parts = amqp_url.split("/") - vhost = amqp_url.split("/")[-1] - if len(parts) < 4: - raise ValueError("Invalid amqp_url, missing vhost") - else: - vhost = amqp_url["vhost"] + """RMQAPI: A NuropbInterface implementation that uses RabbitMQ as the underlying transport. + + Where exchange inputs are none, but they user present in transport_settings, then use the + values from transport_settings + """ + vhost = self._get_vhost(amqp_url) self._mesh_name = vhost @@ -120,6 +141,14 @@ def __init__( "No service instance provided, service will not be able to handle requests" ) # pragma: no cover + """ where exchange inputs are none, but they user present in transport_settings, + then use the values from transport settings + """ + if rpc_exchange is None and transport_settings.get("rpc_exchange", None): + rpc_exchange = transport_settings["rpc_exchange"] + if events_exchange is None and transport_settings.get("events_exchange", None): + events_exchange = transport_settings["events_exchange"] + transport_settings.update( { "service_name": self._service_name, @@ -246,6 +275,41 @@ def receive_transport_message( service_message["nuropb_type"], ) + @classmethod + def _handle_immediate_request_error( + cls, + rpc_response: bool, + payload: RequestPayloadDict | ResponsePayloadDict, + error: Dict[str, Any] | BaseException + ) -> ResponsePayloadDict: + + if rpc_response and isinstance(error, BaseException): + raise error + elif rpc_response: + raise NuropbMessageError( + description=error["description"], + payload=payload, + ) + + if isinstance(error, NuropbException): + error = error.to_dict() + elif isinstance(error, BaseException): + error = { + "error": f"{type(error).__name__}", + "description": f"{type(error).__name__}: {error}", + } + + return { + "tag": "response", + "context": payload["context"], + "correlation_id": payload["correlation_id"], + "trace_id": payload["trace_id"], + "result": None, + "error": error, + "warning": None, + "reply_to": "", + } + async def request( self, service: str, @@ -317,51 +381,19 @@ async def request( encrypted=encrypted, ) except Exception as err: - if rpc_response is False: - return { - "tag": "response", - "context": context, - "correlation_id": correlation_id, - "trace_id": trace_id, - "result": None, - "error": { - "error": f"{type(err).__name__}", - "description": f"Error sending request message: {err}", - }, - } - else: - raise err + return self._handle_immediate_request_error(rpc_response, message, err) try: response = await response_future - if rpc_response is True and response["error"] is not None: - raise NuropbMessageError( - description=response["error"]["description"], - payload=response, - ) - elif rpc_response is True: - return response["result"] - else: - return response - except BaseException as err: - if rpc_response is True: - raise err - else: - if not isinstance(err, NuropbException): - error = { - "error": f"{type(err).__name__}", - "description": f"Error waiting for response: {err}", - } - else: - error = err.to_dict() - return { - "tag": "response", - "context": context, - "correlation_id": correlation_id, - "trace_id": trace_id, - "result": None, - "error": error, - } + except Exception as err: + return self._handle_immediate_request_error(rpc_response, message, err) + + if response["error"] is not None: + return self._handle_immediate_request_error(rpc_response, response, response["error"]) + if rpc_response is True: + return response["result"] + else: + return response def command( self, @@ -501,7 +533,7 @@ async def describe_service( ) return service_info except Exception as err: - logger.error(f"error loading the public key for {service_name}: {err}") + raise ValueError(f"error loading the public key for {service_name}: {err}") async def requires_encryption(self, service_name: str, method_name: str) -> bool: """requires_encryption: Queries the service discovery information for the service_name diff --git a/src/nuropb/rmq_lib.py b/src/nuropb/rmq_lib.py index e76c6bb..f0ce4d9 100644 --- a/src/nuropb/rmq_lib.py +++ b/src/nuropb/rmq_lib.py @@ -11,7 +11,7 @@ from pika.channel import Channel from pika.credentials import PlainCredentials -from nuropb.interface import PayloadDict, NuropbTransportError +from nuropb.interface import PayloadDict, NuropbTransportError, NUROPB_VERSION, NUROPB_PROTOCOL_VERSION logger = logging.getLogger(__name__) @@ -50,50 +50,88 @@ def rmq_api_url_from_amqp_url( return build_rmq_api_url(scheme, host, port, username, password) -def get_connection_parameters(amqp_url: str | Dict[str, Any]) -> pika.ConnectionParameters | pika.URLParameters: - """Return the connection parameters for the transport""" +def get_client_connection_properties(name: Optional[str] = None, instance_id: Optional[str] = None) -> Dict[str, str]: + """Returns the client connection properties for the transport""" + properties = { + "product": "Nuropb", + "version": NUROPB_VERSION, + "protocol": NUROPB_PROTOCOL_VERSION, + "platform": "Python", + } + if name: + properties["name"] = name + if instance_id: + properties["instance_id"] = instance_id + + return properties + + +def get_connection_parameters( + amqp_url: str | Dict[str, Any], + name: Optional[str] = None, + instance_id: Optional[str] = None, + **overrides: Any +) -> pika.ConnectionParameters | pika.URLParameters: + """Return the connection parameters for the transport + :param amqp_url: the AMQP URL or connection parameters to use + :param name: the name of the service or client + :param instance_id: the instance id of the service or client + :param overrides: additional keyword arguments to override the connection parameters + """ if isinstance(amqp_url, dict): # create TLS connection parameters - cafile = amqp_url.get("cafile", None) - if cafile: # pragma: no cover - context = ssl.create_default_context( - cafile=cafile, - ) - else: - context = ssl.create_default_context() - if amqp_url.get("certfile"): - context.load_cert_chain( - certfile=amqp_url.get("certfile"), - keyfile=amqp_url.get("keyfile") + host = amqp_url.get("host", None) + port = amqp_url.get("port", None) + pika_parameters = { + "host": host, + "port": port, + "client_properties": get_client_connection_properties(name=name, instance_id=instance_id), + "heartbeat": 60, + } + vhost = amqp_url.get("vhost", None) + if vhost: + pika_parameters["virtual_host"] = vhost + + if amqp_url.get("cafile", None) or amqp_url.get("certfile"): + cafile = amqp_url.get("cafile", None) + if cafile: # pragma: no cover + context = ssl.create_default_context( + cafile=cafile, + ) + context.verify_mode = ssl.CERT_REQUIRED + else: + context = ssl.create_default_context() + context.verify_mode = ssl.CERT_REQUIRED + + # For client x509 certificate authentication + if amqp_url.get("certfile"): + context.load_cert_chain( + certfile=amqp_url.get("certfile"), + keyfile=amqp_url.get("keyfile") + ) + + # Whether to disable SSL certificate verification + if amqp_url.get("verify", True) is False: + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + else: + context.check_hostname = True + context.verify_mode = ssl.CERT_REQUIRED + + ssl_options = pika.SSLOptions( + context=context, + server_hostname=host ) - - if amqp_url.get("verify", True) is False: - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE - else: - context.check_hostname = True - context.verify_mode = ssl.CERT_REQUIRED + pika_parameters["ssl_options"] = ssl_options if amqp_url.get("username", None): credentials = PlainCredentials(amqp_url["username"], amqp_url["password"]) - else: - credentials = None + pika_parameters["credentials"] = credentials - host = amqp_url.get("host", None) - port = amqp_url.get("port", None) - vhost = amqp_url.get("vhost", "/") - ssl_options = pika.SSLOptions( - context=context, - server_hostname=host - ) - conn_params = pika.ConnectionParameters( - host=host, - port=port, - virtual_host=vhost, - credentials=credentials, - ssl_options=ssl_options - ) + pika_parameters.update(overrides) + + conn_params = pika.ConnectionParameters(**pika_parameters) return conn_params else: diff --git a/src/nuropb/rmq_transport.py b/src/nuropb/rmq_transport.py index 7ee1562..ef5d061 100644 --- a/src/nuropb/rmq_transport.py +++ b/src/nuropb/rmq_transport.py @@ -1,6 +1,6 @@ import logging import functools -from typing import List, Set, Optional, Any, Dict, Awaitable, Literal, TypedDict +from typing import List, Set, Optional, Any, Dict, Awaitable, Literal, TypedDict, cast import asyncio import time @@ -8,7 +8,7 @@ from pika import connection from pika.adapters.asyncio_connection import AsyncioConnection from pika.channel import Channel -from pika.exceptions import ChannelClosedByBroker, ProbableAccessDeniedError +from pika.exceptions import ChannelClosedByBroker, ProbableAccessDeniedError, ChannelClosedByClient import pika.spec from pika.frame import Method @@ -25,7 +25,7 @@ NUROPB_PROTOCOL_VERSION, NUROPB_VERSION, NuropbNotDeliveredError, - NuropbCallAgainReject, + NuropbCallAgainReject, RequestPayloadDict, ResponsePayloadDict, ) from nuropb.rmq_lib import ( rmq_api_url_from_amqp_url, @@ -114,21 +114,20 @@ class ServiceNotConfigured(Exception): """Raised when a service is not properly configured on the RabbitMQ broker. the leader will be expected to configure the Exchange and service queues """ - pass class RMQTransport: - """ - - If RabbitMQ closes the connection, this class will stop and indicate - that reconnection is necessary. You should look at the output, as - there are limited reasons why the connection may be closed, which - usually are tied to permission related issues or socket timeouts. + """ RMQTransport is the base class for the RabbitMQ transport. It wraps the + NuroPb service mesh patterns and rules over the AMQP protocol. - If the channel is closed, it will indicate a problem with one of the - commands that were issued and that should surface in the output as well. + When RabbitMQ closes the connection, this class will stop and alert that + reconnection is necessary, in some cases re-connection takes place automatically. + Disconnections should be continuously monitored, there are various reasons why a + connection or channel may be closed after being successfully opened, and usually + related to authentication, permissions, protocol violation or networking. + TODO: Configure the Pika client connection attributes in the pika client properties. """ _service_name: str _instance_id: str @@ -203,6 +202,16 @@ def __init__( - int prefetch_count: The number of messages to prefetch defaults to 1, unlimited is 0. Experiment with larger values for higher throughput in your user case. + When an existing transport initialised and connected, and a subsequent transport + instance is connected with the same service_name and instance_id as the first, the broker + will shut down the channel of subsequent instances when they attempt to configure their + response queue. This is because the response queue is opened in exclusive mode. The + exclusive mode is used to ensure that only one consumer (nuropb api connection) is + consuming from the response queue. + + Deliberately specifying a fixed instance_id, is a valid mechanism to ensure that a service + can only run in single instance mode. This is useful for services that are not designed to + be run in a distributed manner or where there is specific service configuration required. """ self._connected = False self._closing = False @@ -238,7 +247,7 @@ def __init__( self._service_queue = kwargs.get("service_queue", None) or f"nuropb-{self._service_name}-sq" self._response_queue = ( kwargs.get("response_queue", None) - or f"nuropb-{self._service_name}-{self._instance_id}-response" + or f"nuropb-{self._service_name}-{self._instance_id}-rq" ) self._rpc_bindings = rpc_bindings self._event_bindings = event_bindings @@ -343,7 +352,7 @@ def configure_rabbitmq( amqp_url = amqp_url or self._amqp_url if amqp_url is None: raise ValueError("amqp_url is not provided") - rmq_api_url = rmq_api_url or rmq_api_url_from_amqp_url(amqp_url) + # rmq_api_url = rmq_api_url or rmq_api_url_from_amqp_url(amqp_url) if rmq_api_url is None: raise ValueError("rmq_api_url is not provided") try: @@ -372,6 +381,12 @@ async def start(self) -> None: self._connected_future = self.connect() try: await self._connected_future + if self._connected_future.done(): + err = self._connected_future.exception() + if isinstance(err, NuropbTransportError) and err.close_connection: + await self.stop() + elif err is not None: + raise err except ProbableAccessDeniedError as err: if isinstance(self._amqp_url, dict): vhost = self._amqp_url.get("vhost", "") @@ -386,9 +401,18 @@ async def start(self) -> None: f"Access denied to RabbitMQ for the virtual host {vhost}: {err}" ) raise err + except NuropbTransportError as err: + """ Logging already captured, handle the error, likely a channel closed by broker + """ + if not self._connected_future.done(): + self._connected_future.set_exception(err) + if err.close_connection: + await self.stop() + except Exception as err: - logger.error(f"Failed to connect to RabbitMQ: {err}") - raise err + logger.exception("General failure connecting to RabbitMQ. %s: %s", type(err).__name__, err) + if not self._connected_future.done(): + self._connected_future.set_exception(err) async def stop(self) -> None: """Cleanly shutdown the connection to RabbitMQ by stopping the consumer @@ -403,12 +427,15 @@ async def stop(self) -> None: """ if not self._closing: logger.info("Stopping") - if self._consuming: - await self.stop_consuming() + try: + if self._consuming: + await self.stop_consuming() + except Exception as err: + logger.exception(f"Error stopping consuming: {err}") self._disconnected_future = self.disconnect() await self._disconnected_future - def connect(self) -> Awaitable[bool]: + def connect(self) -> asyncio.Future[bool]: """This method initiates a connection to RabbitMQ, returning the connection handle. When the connection is established, the on_connection_open method will be invoked by pika. @@ -447,10 +474,14 @@ def disconnect(self) -> Awaitable[bool]: :return: asyncio.Future """ if self._connection is None: - raise RuntimeError("RMQ transport is not connected") + logger.info("RMQ transport is not connected") + elif self._connection.is_closing or self._connection.is_closed: + logger.info("RMQ transport is already closing or closed") - if self._connection.is_closing or self._connection.is_closed: - raise RuntimeError("RMQ transport is already closing or closed") + if self._connection is None or not self._connected: + disconnected_future: asyncio.Future[bool] = asyncio.Future() + disconnected_future.set_result(True) + return disconnected_future if ( self._disconnected_future is not None @@ -476,23 +507,25 @@ def on_connection_open(self, _connection: AsyncioConnection) -> None: self.open_channel() - def on_connection_open_error( - self, _connection: AsyncioConnection, err: Exception - ) -> None: - """This method is called by pika if the connection to RabbitMQ - can't be established. + def on_connection_open_error(self, conn: AsyncioConnection, reason: Exception) -> None: + """This method is called by pika if the connection to RabbitMQ can't be established. - :param pika.adapters.asyncio_connection.AsyncioConnection _connection: - The connection - :param Exception err: The error + :param pika.adapters.asyncio_connection.AsyncioConnection conn: + :param Exception reason: The error """ - logger.error("Connection open failed: %s", err) + logger.error("Connection open Error. %s: %s", type(reason).__name__, reason) if self._connected_future is not None and not self._connected_future.done(): - self._connected_future.set_exception(err) - + if isinstance(reason, ProbableAccessDeniedError): + close_connection = True + else: + close_connection = False + self._connected_future.set_exception(NuropbTransportError( + description=f"Connection open Error. {type(reason).__name__}: {reason}", + exception=reason, + close_connection=close_connection, + )) if self._connecting: self._connecting = False - # self.connect() def on_connection_closed( self, _connection: AsyncioConnection, reason: Exception @@ -545,24 +578,87 @@ def on_channel_open(self, channel: Channel) -> None: self.declare_response_queue() def on_channel_closed(self, channel: Channel, reason: Exception) -> None: - """Invoked by pika when RabbitMQ unexpectedly closes the channel. Channels are usually closed - if you attempt to do something that violates the protocol, such as re-declare an exchange or - queue with different parameters. In this case, we'll close the connection to shut down the object. + """Invoked by pika when the channel is closed. Channels are at times close by the + broker for various reasons, the most common being protocol violations e.g. acknowledging + messages using an invalid message_tag. In most cases when the channel is closed by + the broker, nuropb will automatically open a new channel and re-declare the service queue. + + In the following cases the channel is not automatically re-opened: + * When the connection is closed by this transport API + * When the connection is close by the broker 403 (ACCESS_REFUSED): + Typically these examples are seen: + - "Provided JWT token has expired at timestamp". In this case the transport will + require a fresh JWT token before attempting to reconnect. + - queue 'XXXXXXX' in vhost 'YYYYY' in exclusive use. In this case there is another + response queue setup with the same name. Having a fixed response queue name is + a valid mechanism to enforce a single instance of a service. + * When the connection is close by the broker 403 (NOT_FOUND): + Typically where there is an exchange configuration issue. + + Always investigate the reasons why a channel is closed and introduce logic to handle + that scenario accordingly. It is important to continuously monitor for this condition. + + TODO: When there is message processing in flight, and particularly with a prefetch + count > 1, then those messages are now not able to be acknowledged. By doing so will + result in a forced channel closure by the broker and potentially a poison pill type + scenario. :param pika.channel.Channel channel: The closed channel :param Exception reason: why the channel was closed """ if isinstance(reason, ChannelClosedByBroker): - logger.critical( - f"RabbitMQ channel {channel} closed by broker with reply_code: {reason.reply_code} " - f"and reply_text: {reason.reply_text}" - ) + if reason.reply_code == 403 and self._response_queue in reason.reply_text: + reason_description = ( + f"RabbitMQ channel closed by the broker ({reason.reply_code})." + " There is already a response queue setup with the same name and instance_id," + " and hence this service is considered single instance only" + ) + elif reason.reply_code == 403 and "Provided JWT token has expired" in reason.reply_text: + reason_description = ( + f"RabbitMQ channel closed by the broker ({reason.reply_code})." + f" AuthenticationExpired: {reason.reply_text}" + ) + else: + reason_description = ( + f"RabbitMQ channel closed by the broker " + f"({reason.reply_code}): {reason.reply_text}" + ) + + logger.critical(reason_description) if self._connected_future and not self._connected_future.done(): - self._connected_future.set_exception(Exception) - self._connecting = False + """ the Connection is still in press and when the channel was closed by the broker + so treat as a serious error and close the connection + """ + self._connected_future.set_exception( + NuropbTransportError( + description=reason_description, + exception=reason, + close_connection=True, + ) + ) - # investigate reasons and methods automatically reopen the channel. - # until a solution is found it will be important to monitor for this condition + elif self._connected_future and self._connected_future.done(): + """ The channel was close after the connection was established + """ + if reason.reply_code in (403, 404): + """ There is no point in auto reconnecting when access is refused, so + shut the connection down. + """ + asyncio.create_task(self.stop()) + else: + """ It's ok to try and re-open the channel + """ + logging.info("Re-opening channel") + self.open_channel() + + elif not isinstance(reason, ChannelClosedByClient): + """ Log the reason for the channel close and allow the re-open process to continue + """ + reason_description = ( + f"RabbitMQ channel closed ({reason.reply_code})." + f"{type(reason).__name__}: {reason}" + ) + logger.warning(reason_description) def declare_service_queue(self, frame: pika.frame.Method) -> None: """Refresh the request queue on RabbitMQ by invoking the Queue.Declare RPC command. When it @@ -671,8 +767,10 @@ def on_bindok(self, _frame: pika.frame.Method, userdata: str) -> None: def on_basic_qos_ok(self, _frame: pika.frame.Method) -> None: """Invoked by pika when the Basic.QoS method has completed. At this - point we will start consuming messages by calling start_consuming - which will invoke the needed RPC commands to start the process. + point we will start consuming messages. + + A callback is added that will be invoked if RabbitMQ cancels the consumer for some reason. + If RabbitMQ does cancel the consumer, on_consumer_cancelled will be invoked by pika. :param pika.frame.Method _frame: The Basic.QosOk response frame @@ -680,15 +778,13 @@ def on_basic_qos_ok(self, _frame: pika.frame.Method) -> None: logger.info("QOS set to: %d", self._prefetch_count) logger.info("Configure message consumption") - """Add a callback that will be invoked if RabbitMQ cancels the consumer for some reason. - If RabbitMQ does cancel the consumer, on_consumer_cancelled will be invoked by pika. - """ if self._channel is None: raise RuntimeError("RMQ transport channel is not open") self._channel.add_on_cancel_callback(self.on_consumer_cancelled) - # Start consuming the response queue these need their own handler as the ack type is automatic + # Start consuming the response queue, message from this queue require their own flow, hence + # a dedicated handler is set up. The ack type is automatic. self._consumer_tags.add( self._channel.basic_consume( on_message_callback=functools.partial( @@ -700,7 +796,7 @@ def on_basic_qos_ok(self, _frame: pika.frame.Method) -> None: ) ) - # Start consuming the requests queue + # Start consuming the requests queue and handle the service message flow. ack type is manual. if not self._client_only: logger.info("Ready to consume requests, events and commands") self._consumer_tags.add( @@ -769,18 +865,21 @@ def on_message_returned( payload=body, correlation_id=correlation_id, ) - nuropb_payload = decode_payload(body, "json") - request_method = nuropb_payload["method"] - nuropb_payload["tag"] = nuropb_type - nuropb_payload["error"] = { - "error": "NuropbNotDeliveredError", - "description": f"Service {method.routing_key} not available. unable to call method {request_method}", - } - nuropb_payload["result"] = None - nuropb_payload.pop("service", None) - nuropb_payload.pop("method", None) - nuropb_payload.pop("params", None) - nuropb_payload.pop("reply_to", None) + request_payload = cast(RequestPayloadDict, decode_payload(body, "json")) + request_method = request_payload["method"] + nuropb_payload = ResponsePayloadDict( + tag="response", + correlation_id=correlation_id, + context=request_payload["context"], + trace_id=trace_id, + result=None, + error={ + "error": "NuropbNotDeliveredError", + "description": f"Service {method.routing_key} not available. unable to call method {request_method}", + }, + warning=None, + reply_to="", + ) message = TransportRespondPayload( nuropb_protocol=NUROPB_PROTOCOL_VERSION, correlation_id=correlation_id, @@ -810,15 +909,6 @@ def on_message_returned( ) self._message_callback(message, message_complete_callback, metadata) - if verbose: - raise NuropbNotDeliveredError( - ( - f"Could not route message {nuropb_type} with " - f"correlation_id: {correlation_id} " - f"trace_id: {trace_id} " - f": {method.reply_code}, {method.reply_text}" - ) - ) def send_message( self, @@ -830,11 +920,12 @@ def send_message( ) -> None: """Send a message to over the RabbitMQ Transport - TODO: Consider the alternative handling if the channel that's closed. - - Wait and retry on a new channel? - - setup a retry queue? - - should there be a high water mark for the number of retries? - - should new messages not be consumed until the channel is re-established and retry queue drained? + TODO: Consider alternative handling when the channel is closed. + also refer to the notes in the on_channel_closed method. + - Wait and retry on a new channel? + - setup a retry queue? + - should there be a high water mark for the number of retries? + - should new messages not be consumed until the channel is re-established and retry queue drained? :param Dict[str, Any] payload: The message contents :param expiry: The message expiry in milliseconds @@ -1401,10 +1492,10 @@ async def stop_consuming(self) -> None: """Tell RabbitMQ that you would like to stop consuming by sending the Basic.Cancel RPC command. """ - if self._channel: - logger.info("Sending a Basic.Cancel RPC command to RabbitMQ") - logger.info("Closing consumers %s", self._consumer_tags) - + if self._channel is None or self._channel.is_closed: + return + else: + logger.info("Stopping consumers and sending a Basic.Cancel command to RabbitMQ") all_consumers_closed: Awaitable[bool] = asyncio.Future() def _on_cancel_ok(frame: pika.frame.Method) -> None: @@ -1414,19 +1505,26 @@ def _on_cancel_ok(frame: pika.frame.Method) -> None: all_consumers_closed.set_result(True) # type: ignore[attr-defined] for consumer_tag in self._consumer_tags: - if self._channel: + if self._channel and self._channel.is_open: self._channel.basic_cancel(consumer_tag, _on_cancel_ok) - logger.info( - "Waiting for %ss for consumers to close", CONSUMER_CLOSED_WAIT_TIMEOUT - ) + try: + logger.info( + "Waiting for %ss for consumers to close", CONSUMER_CLOSED_WAIT_TIMEOUT + ) await asyncio.wait_for( - all_consumers_closed, timeout=CONSUMER_CLOSED_WAIT_TIMEOUT + all_consumers_closed, + timeout=CONSUMER_CLOSED_WAIT_TIMEOUT ) + logger.info("Consumers to gracefully closed") except asyncio.TimeoutError: logger.error( "Timed out while waiting for all consumers to gracefully close" ) + except Exception as err: + logger.exception( + "Error while waiting for all consumers to gracefully close: %s", err + ) if len(self._consumer_tags) != 0: logger.error( @@ -1434,7 +1532,6 @@ def _on_cancel_ok(frame: pika.frame.Method) -> None: ) self._consuming = False - logger.info("RabbitMQ acknowledged the cancellation of the consumer") self.close_channel() def close_channel(self) -> None: diff --git a/src/nuropb/service_handlers.py b/src/nuropb/service_handlers.py index 2f3ec7d..d05eea3 100644 --- a/src/nuropb/service_handlers.py +++ b/src/nuropb/service_handlers.py @@ -59,7 +59,7 @@ def error_dict_from_exception(exception: Exception | BaseException) -> Dict[str, :param exception: :return: """ - if hasattr(exception, "to_dict"): + if isinstance(exception, NuropbException): return exception.to_dict() if hasattr(exception, "description"): @@ -266,7 +266,7 @@ def handle_execution_result( :param message_complete_callback: :return: """ - error = None + error: BaseException | Dict[str, Any] | None = None acknowledgement: AcknowledgeAction = "ack" if asyncio.isfuture(result): error = result.exception() @@ -303,10 +303,15 @@ def handle_execution_result( Do not send a response for commands """ + if isinstance(error, (Exception, BaseException)): + pyload_error = error_dict_from_exception(error) + else: + pyload_error = error + payload = ResponsePayloadDict( tag="response", result=result, - error=error, + error=pyload_error, correlation_id=service_message["correlation_id"], trace_id=service_message["trace_id"], context=service_message["nuropb_payload"]["context"], @@ -386,11 +391,13 @@ def execute_request( or not hasattr(service_instance, method_name) or not callable(getattr(service_instance, method_name)) ): - raise NuropbHandlingError( + exception_result = NuropbHandlingError( description="Unknown method {}".format(method_name), payload=payload, exception=None, ) + handle_execution_result(service_message, exception_result, message_complete_callback) + return try: if method_name == "nuropb_describe": @@ -407,24 +414,30 @@ def execute_request( except NuropbException as err: if verbose: logger.exception(err) - raise + handle_execution_result(service_message, err, message_complete_callback) + return + except Exception as err: if verbose: logger.exception(err) error = f"{type(err).__name__}: {err}" - raise NuropbException( + exception_result = NuropbException( description=f"Runtime exception calling {service_name}.{method_name}: {error}", payload=payload, exception=err, ) + handle_execution_result(service_message, exception_result, message_complete_callback) + return if asyncio.isfuture(result) or asyncio.iscoroutine(result): # Asynchronous responses if is_future(result): - raise ValueError( + exception_result = ValueError( "Tornado Future detected, please use asyncio.Future instead" ) + handle_execution_result(service_message, exception_result, message_complete_callback) + return def future_done_callback(future: Awaitable[Any]) -> None: handle_execution_result( @@ -441,6 +454,7 @@ def future_done_callback(future: Awaitable[Any]) -> None: else: # Synchronous responses handle_execution_result(service_message, result, message_complete_callback) + except Exception as err: if verbose: logger.exception(err) diff --git a/src/nuropb/service_runner.py b/src/nuropb/service_runner.py index 6c55d32..a5c4cfc 100644 --- a/src/nuropb/service_runner.py +++ b/src/nuropb/service_runner.py @@ -82,9 +82,9 @@ class ServiceContainer(ServiceRunner): _is_leader: bool _leader_reference: str _etcd_config: Dict[str, Any] - _etcd_client: EtcdClient - _etcd_lease: Lease - _etcd_watcher: Watcher + _etcd_client: EtcdClient | None + _etcd_lease: Lease | None + _etcd_watcher: Watcher | None _etcd_prefix: str _container_running_future: Awaitable[bool] | None @@ -114,22 +114,29 @@ def __init__( self._leader_reference = "" self._etcd_prefix = f"/nuropb/{self._service_name}" self._etcd_config = etcd_config if etcd_config is not None else {} + self._etcd_client = None + self._etcd_lease = None + self._etcd_watcher = None + + + if not self._etcd_config: + logger.info("etcd features are disabled") + self.running_state = "running-standalone" logger.info( - "Starting the service container for {}:{}".format( + "ServiceRunner is wrapping {}:{}".format( self._service_name, self._instance_id ) ) - if not self._etcd_config: - logger.warning("etcd features are disabled") - self.running_state = "running-standalone" - else: - """asyncio NOTE: the etcd3 client is initialized as an asyncio task and will run - once __init__ has completed and there us a running asyncio event loop. - """ - task = asyncio.create_task(self.init_etcd(on_startup=True)) - task.add_done_callback(lambda _: None) + + # ***NOTE*** MOVED THIS CODE to self.start() + # if self._etcd_config: + # """asyncio NOTE: the etcd3 client is initialized as an asyncio task and will run + # once __init__ has completed and there us a running asyncio event loop. + # """ + # task = asyncio.create_task(self.init_etcd(on_startup=True)) + @property def running_state(self) -> ContainerRunningState: @@ -304,9 +311,6 @@ def nominate_as_leader(self) -> None: """ logger.error(f"Error during leader nomination: {e}") logger.exception(e) - logger.info("called init_etcd() to re-initiate the etcd connection") - task = asyncio.create_task(self.init_etcd(on_startup=False)) - task.add_done_callback(lambda _: None) def update_etcd_service_property(self, key: str, value: Any) -> bool: """update_etcd_service_property: updates the etcd3 service property. @@ -400,15 +404,24 @@ async def startup_steps(self) -> None: await self._instance.connect() async def start(self) -> bool: - """start: starts the container service instance. - - primary entry point to start the service container. + """ + - Starts the etcd client if configured + - Startup Steps: + - Leader election + - Configures the brokers nuropb service mesh configuration if not done + - Starts the container service instance. + :return: None """ started = False try: + if self._etcd_config: + await self.init_etcd(on_startup=True) + await self.startup_steps() logger.info("Container startup complete") started = True + except AMQPConnectionError as err: logger.error(f"Startup error connecting to RabbitMQ: {err}") except HTTPError as err: @@ -428,11 +441,16 @@ async def stop(self) -> None: :return: None """ self.running_state = "stopping" - self._container_shutdown_future = asyncio.Future() self._shutdown = True - self._etcd_watcher.stop() + try: - await self._container_shutdown_future + if self._etcd_watcher: + self._etcd_watcher.stop() + if self._etcd_lease: + self._etcd_lease.revoke() + if self._etcd_client: + self._etcd_client.close() + await self._instance.disconnect() self.running_state = "shutdown" logger.info("Container shutdown complete") except (asyncio.CancelledError, Exception) as err: @@ -440,7 +458,9 @@ async def stop(self) -> None: logger.info(f"container shutdown future cancelled: {err}") else: logger.exception(f"Container shutdown future runtime exception: {err}") - finally: - ioloop: AbstractEventLoop = asyncio.get_running_loop() - if ioloop.is_running(): - ioloop.stop() + + # NOTEs: event loop should be managed outside the scope of this class + # finally: + # ioloop: AbstractEventLoop = asyncio.get_running_loop() + # if ioloop.is_running(): + # ioloop.stop() diff --git a/src/nuropb/testing/stubs.py b/src/nuropb/testing/stubs.py index 03f496c..90b74e6 100644 --- a/src/nuropb/testing/stubs.py +++ b/src/nuropb/testing/stubs.py @@ -1,7 +1,9 @@ import logging -from typing import Any, Dict +from typing import Any, Dict, Optional from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa +from uuid import uuid4 +import os from nuropb.contexts.context_manager import NuropbContextManager from nuropb.contexts.context_manager_decorator import nuropb_context @@ -11,6 +13,9 @@ logger = logging.getLogger(__name__) +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + + def get_claims_from_token(bearer_token: str) -> Dict[str, Any] | None: """This is a stub for the authorise_func that is used in the tests""" _ = bearer_token @@ -22,20 +27,43 @@ def get_claims_from_token(bearer_token: str) -> Dict[str, Any] | None: } -class ServiceExample: +class ServiceStub: _service_name: str _instance_id: str _private_key: rsa.RSAPrivateKey - _method_call_count: int - def __init__(self, service_name: str, instance_id: str): + def __init__( + self, + service_name: str, + instance_id: Optional[str] = None, + private_key: Optional[rsa.RSAPrivateKey] = None, + ): self._service_name = service_name - self._instance_id = instance_id - self._private_key = private_key = rsa.generate_private_key( + self._instance_id = instance_id or uuid4().hex + self._private_key = private_key or rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() ) + @property + def service_name(self) -> str: + return self._service_name + + @property + def instance_id(self) -> str: + return self._instance_id + + @property + def private_key(self) -> rsa.RSAPrivateKey: + return self._private_key + + +class ServiceExample(ServiceStub): + _method_call_count: int + _raise_call_again_error: bool + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) self._method_call_count = 0 - self.raise_call_again_error = True + self._raise_call_again_error = True def test_method(self, **kwargs: Any) -> str: _ = kwargs @@ -57,7 +85,7 @@ def test_success_error(self, **kwargs: Any) -> None: @nuropb_context @publish_to_mesh(authorise_func=get_claims_from_token) - def test_requires_user_claims(self, ctx, **kwargs: Any) -> Any: + def test_requires_user_claims(self, ctx: NuropbContextManager, **kwargs: Any) -> Any: assert isinstance(self, ServiceExample) assert isinstance(ctx, NuropbContextManager) self._method_call_count += 1 @@ -66,7 +94,7 @@ def test_requires_user_claims(self, ctx, **kwargs: Any) -> Any: @nuropb_context @publish_to_mesh(authorise_func=get_claims_from_token, requires_encryption=True) - def test_requires_encryption(self, ctx, **kwargs: Any) -> Any: + def test_requires_encryption(self, ctx: NuropbContextManager, **kwargs: Any) -> Any: assert isinstance(self, ServiceExample) assert isinstance(ctx, NuropbContextManager) self._method_call_count += 1 @@ -77,9 +105,9 @@ def test_call_again_error(self, **kwargs: Any) -> Dict[str, Any]: self._method_call_count += 1 logger.debug(f"test_call_again_error: {kwargs}") success_result = f"response from {self._service_name}.test_call_again_error" - if self.raise_call_again_error: + if self._raise_call_again_error: """this is preventing the test from getting into an infinite loop""" - self.raise_call_again_error = False + self._raise_call_again_error = False raise NuropbCallAgain("Test Call Again") result = kwargs.copy() result.update( diff --git a/tests/conftest.py b/tests/conftest.py index c291f48..81bc2e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import os import pytest +import pytest_asyncio from nuropb.rmq_api import RMQAPI from nuropb.rmq_lib import ( @@ -15,12 +16,20 @@ configure_nuropb_rmq, ) from nuropb.rmq_transport import RMQTransport -from nuropb.testing.stubs import ServiceExample +from nuropb.testing.stubs import IN_GITHUB_ACTIONS, ServiceExample logging.getLogger("pika").setLevel(logging.WARNING) -IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +@pytest.fixture(scope="session") +def etcd_config(): + if IN_GITHUB_ACTIONS: + return None + else: + dict( + host="localhost", + port=2379, + ) @pytest.fixture(scope="session") def test_settings(): @@ -52,7 +61,7 @@ def test_settings(): } end_time = datetime.datetime.utcnow() logging.info( - f"Test summary:\n" + f"TEST SESSION SUMMARY:\n" f"start_time: {start_time}\n" f"end_time: {end_time}\n" f"duration: {end_time - start_time}" @@ -60,11 +69,12 @@ def test_settings(): @pytest.fixture(scope="session") -def test_rmq_url(test_settings): +def rmq_settings(test_settings): logging.debug("Setting up RabbitMQ test instance") vhost = f"pytest-{secrets.token_hex(8)}" + if IN_GITHUB_ACTIONS: - rmq_url = build_amqp_url( + settings = dict( host=test_settings["host"], port=test_settings["port"], username=test_settings["username"], @@ -72,15 +82,14 @@ def test_rmq_url(test_settings): vhost=vhost, ) else: - rmq_url = { - "cafile": "tls_connection/ca_cert.pem", - "username": "guest", - "password": "guest", - "host": "localhost", - "port": 5671, - "vhost": vhost, - "verify": False, - } + settings = dict( + username="guest", + password="guest", + host="localhost", + port=5672, + vhost=vhost, + verify=False, + ) api_url = build_rmq_api_url( scheme=test_settings["api_scheme"], @@ -90,7 +99,7 @@ def test_rmq_url(test_settings): password=test_settings["password"], ) - create_virtual_host(api_url, rmq_url) + create_virtual_host(api_url, settings) def message_callback(*args, **kwargs): # pragma: no cover pass @@ -98,7 +107,7 @@ def message_callback(*args, **kwargs): # pragma: no cover transport_settings = dict( service_name=test_settings["service_name"], instance_id=uuid4().hex, - amqp_url=rmq_url, + amqp_url=settings, rpc_exchange=test_settings["rpc_exchange"], events_exchange=test_settings["events_exchange"], dl_exchange=test_settings["dl_exchange"], @@ -111,15 +120,15 @@ def message_callback(*args, **kwargs): # pragma: no cover transport = RMQTransport(**transport_settings) configure_nuropb_rmq( - rmq_url=rmq_url, + rmq_url=settings, events_exchange=transport.events_exchange, rpc_exchange=transport.rpc_exchange, dl_exchange=transport._dl_exchange, dl_queue=transport._dl_queue, ) - yield rmq_url + yield settings logging.debug("Shutting down RabbitMQ test instance") - delete_virtual_host(api_url, rmq_url) + delete_virtual_host(api_url, settings) @pytest.fixture(scope="session") @@ -169,8 +178,8 @@ def message_callback(*args, **kwargs): # pragma: no cover dl_queue=transport._dl_queue, ) yield rmq_url - # logging.debug("Shutting down RabbitMQ test instance") - # delete_virtual_host(api_url, rmq_url) + logging.debug("Shutting down RabbitMQ test instance") + delete_virtual_host(api_url, rmq_url) @pytest.fixture(scope="session") @@ -193,8 +202,8 @@ def service_instance(): ) -@pytest.fixture(scope="function") -def test_mesh_service(test_settings, test_rmq_url, service_instance): +@pytest_asyncio.fixture(scope="function") +async def test_mesh_service(test_settings, rmq_settings, service_instance): service_name = test_settings["service_name"] instance_id = uuid4().hex transport_settings = dict( @@ -208,16 +217,18 @@ def test_mesh_service(test_settings, test_rmq_url, service_instance): service_name=service_name, instance_id=instance_id, service_instance=service_instance, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=transport_settings, ) yield service_api + await service_api.disconnect() -@pytest.fixture(scope="function") -def test_mesh_client(test_rmq_url, test_settings, test_mesh_service): + +@pytest_asyncio.fixture(scope="function") +async def test_mesh_client(rmq_settings, test_settings, test_mesh_service): instance_id = uuid4().hex settings = test_mesh_service.transport.rmq_configuration client_transport_settings = dict( @@ -227,9 +238,11 @@ def test_mesh_client(test_rmq_url, test_settings, test_mesh_service): ) client_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange=settings["rpc_exchange"], events_exchange=settings["events_exchange"], transport_settings=client_transport_settings, ) yield client_api + + await client_api.disconnect() diff --git a/tests/contexts/test_describe.py b/tests/contexts/test_describe.py index 5ef0be2..9416923 100644 --- a/tests/contexts/test_describe.py +++ b/tests/contexts/test_describe.py @@ -103,6 +103,7 @@ class OrderManagementService: Some useful documentation to describe the characteristic of the service and its purpose """ _service_name = "oms_v2" + _instance_id = uuid4().hex _version = "2.0.1" _private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() diff --git a/tests/mesh/test_service_discover.py b/tests/mesh/test_service_discover.py index 2d23a0c..f5acf8e 100644 --- a/tests/mesh/test_service_discover.py +++ b/tests/mesh/test_service_discover.py @@ -113,7 +113,7 @@ async def test_mesh_service_describe(test_mesh_client, test_mesh_service): logger.info(f"response: {pformat(rpc_response)}") -@pytest.mark.asyncio +@pytest.mark.asyncio(async_timeout=10) async def test_mesh_service_encrypt(test_mesh_client, test_mesh_service): """ user the service mesh api helper function to call the describe function for a service on the mesh. Test that service metta information is cached in the mesh client. diff --git a/tests/test_api_async_service_request.py b/tests/test_api_async_service_request.py index 9a7f93e..bc3528b 100644 --- a/tests/test_api_async_service_request.py +++ b/tests/test_api_async_service_request.py @@ -9,7 +9,7 @@ @pytest.mark.asyncio -async def test_async_service_methods(test_settings, test_rmq_url, service_instance): +async def test_async_service_methods(test_settings, rmq_settings, service_instance): service_name = test_settings["service_name"] instance_id = uuid4().hex transport_settings = dict( @@ -23,7 +23,7 @@ async def test_async_service_methods(test_settings, test_rmq_url, service_instan service_name=service_name, instance_id=instance_id, service_instance=service_instance, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=transport_settings, @@ -41,7 +41,7 @@ async def test_async_service_methods(test_settings, test_rmq_url, service_instan ) client_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=client_transport_settings, diff --git a/tests/test_api_service_request.py b/tests/test_api_service_request.py index b40a132..c80ef60 100644 --- a/tests/test_api_service_request.py +++ b/tests/test_api_service_request.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio -async def test_request_response_fail(test_settings, test_rmq_url, service_instance): +async def test_request_response_fail(test_settings, rmq_settings, service_instance): service_name = test_settings["service_name"] instance_id = uuid4().hex transport_settings = dict( @@ -24,7 +24,7 @@ async def test_request_response_fail(test_settings, test_rmq_url, service_instan service_name=service_name, instance_id=instance_id, service_instance=service_instance, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=transport_settings, @@ -42,7 +42,7 @@ async def test_request_response_fail(test_settings, test_rmq_url, service_instan ) client_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=client_transport_settings, @@ -56,8 +56,7 @@ async def test_request_response_fail(test_settings, test_rmq_url, service_instan ttl = 60 * 30 * 1000 trace_id = uuid4().hex logger.info(f"Requesting {service}.{method}") - # with pytest.raises(NuropbException) as error: - try: + with pytest.raises(NuropbMessageError) as error: result = await client_api.request( service=service, method=method, @@ -66,13 +65,7 @@ async def test_request_response_fail(test_settings, test_rmq_url, service_instan ttl=ttl, trace_id=trace_id, ) - logger.info(f"result: {result}") - except Exception as error: - logger.info(f"response: {error}") - - # assert ( - # error.value.payload["error"]["description"] == "Unknown method test_method_fail" - # ) + assert error.value.description == "Unknown method test_method_DOES_NOT_EXIST" method = "test_method" rpc_response = await client_api.request( @@ -93,7 +86,7 @@ async def test_request_response_fail(test_settings, test_rmq_url, service_instan @pytest.mark.asyncio -async def test_request_response_pass(test_settings, test_rmq_url, service_instance): +async def test_request_response_pass(test_settings, rmq_settings, service_instance): service_name = test_settings["service_name"] instance_id = uuid4().hex transport_settings = dict( @@ -107,7 +100,7 @@ async def test_request_response_pass(test_settings, test_rmq_url, service_instan service_name=service_name, instance_id=instance_id, service_instance=service_instance, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=transport_settings, @@ -124,7 +117,7 @@ async def test_request_response_pass(test_settings, test_rmq_url, service_instan ) client_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=client_transport_settings, @@ -157,7 +150,7 @@ async def test_request_response_pass(test_settings, test_rmq_url, service_instan @pytest.mark.asyncio -async def test_request_response_success(test_settings, test_rmq_url, service_instance): +async def test_request_response_success(test_settings, rmq_settings, service_instance): service_name = test_settings["service_name"] instance_id = uuid4().hex transport_settings = dict( @@ -171,7 +164,7 @@ async def test_request_response_success(test_settings, test_rmq_url, service_ins service_name=service_name, instance_id=instance_id, service_instance=service_instance, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=transport_settings, @@ -188,7 +181,7 @@ async def test_request_response_success(test_settings, test_rmq_url, service_ins ) client_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=client_transport_settings, @@ -222,7 +215,7 @@ async def test_request_response_success(test_settings, test_rmq_url, service_ins @pytest.mark.asyncio async def test_request_response_call_again( - test_settings, test_rmq_url, service_instance + test_settings, rmq_settings, service_instance ): service_name = test_settings["service_name"] instance_id = uuid4().hex @@ -237,7 +230,7 @@ async def test_request_response_call_again( service_name=service_name, instance_id=instance_id, service_instance=service_instance, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=transport_settings, @@ -254,7 +247,7 @@ async def test_request_response_call_again( ) client_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=client_transport_settings, @@ -288,7 +281,7 @@ async def test_request_response_call_again( @pytest.mark.asyncio async def test_request_response_call_again_loop_fail( - test_settings, test_rmq_url, service_instance + test_settings, rmq_settings, service_instance ): service_name = test_settings["service_name"] instance_id = uuid4().hex @@ -303,7 +296,7 @@ async def test_request_response_call_again_loop_fail( service_name=service_name, instance_id=instance_id, service_instance=service_instance, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=transport_settings, @@ -320,7 +313,7 @@ async def test_request_response_call_again_loop_fail( ) client_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=client_transport_settings, diff --git a/tests/test_client.py b/tests/test_client.py index 497d620..eeadf8e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio -async def test_request_response_pass(test_settings, test_rmq_url, service_instance): +async def test_request_response_pass(test_settings, rmq_settings, service_instance): instance_id = uuid4().hex client_transport_settings = dict( dl_exchange=test_settings["dl_exchange"], @@ -19,7 +19,7 @@ async def test_request_response_pass(test_settings, test_rmq_url, service_instan ) client_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange="test_rpc_exchange", events_exchange="test_events_exchange", transport_settings=client_transport_settings, diff --git a/tests/test_handlers.py b/tests/test_handlers.py index b522109..fd0d9e2 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -14,7 +14,6 @@ logger = logging.getLogger() -@pytest.mark.asyncio def test_sync_handler_call(service_instance): correlation_id = uuid4().hex trace_id = uuid4().hex diff --git a/tests/test_message_routing.py b/tests/test_message_routing.py new file mode 100644 index 0000000..6fd19ca --- /dev/null +++ b/tests/test_message_routing.py @@ -0,0 +1,102 @@ +import asyncio +import logging +from uuid import uuid4 + +import pytest + +from nuropb.rmq_api import RMQAPI + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_call_self(test_settings, rmq_settings, service_instance): + """ Currently this test passes, as there is no check for the service name in the request method. + Restricting the service name to be different from the service name of the service instance is + under consideration for a future release. + """ + transport_settings = dict( + dl_exchange=test_settings["dl_exchange"], + rpc_bindings=test_settings["rpc_bindings"], + event_bindings=test_settings["event_bindings"], + prefetch_count=test_settings["prefetch_count"], + default_ttl=test_settings["default_ttl"], + ) + service_api = RMQAPI( + service_name=service_instance.service_name, + instance_id=service_instance.instance_id, + service_instance=service_instance, + amqp_url=rmq_settings, + rpc_exchange=test_settings["rpc_exchange"], + events_exchange=test_settings["events_exchange"], + transport_settings=transport_settings, + ) + await service_api.connect() + assert service_api.connected is True + logger.info("SERVICE API CONNECTED") + + service = "test_service" + method = "test_method" + params = {"param1": "value1"} + context = {"context1": "value1"} + ttl = 60 * 5 * 1000 + trace_id = uuid4().hex + logger.info(f"Requesting {service}.{method}") + rpc_response = await service_api.request( + service=service, + method=method, + params=params, + context=context, + ttl=ttl, + trace_id=trace_id, + rpc_response=False, + ) + logger.info(f"response: {rpc_response}") + assert rpc_response["result"] == f"response from {service}.{method}" + + await service_api.disconnect() + assert service_api.connected is False + + +@pytest.mark.asyncio +async def test_subscribe_to_events_from_self(test_settings, rmq_settings, service_instance): + """ Currently this test passes, as there is no check to restrict binding to the service queue + for events that originate from the service. + This restriction is under consideration for a future release. + """ + test_topic = f"{service_instance.service_name}.test-event" + transport_settings = dict( + dl_exchange=test_settings["dl_exchange"], + rpc_bindings=test_settings["rpc_bindings"], + event_bindings=[test_topic], + prefetch_count=test_settings["prefetch_count"], + default_ttl=test_settings["default_ttl"], + ) + test_future = asyncio.Future() + + def handle_event(topic, event, target, context): + logger.info(f"Received event {topic} {target} {event} {context}") + assert topic == test_topic + test_future.set_result(True) + + setattr(service_instance, "_handle_event_", handle_event) + + service_api = RMQAPI( + service_name=service_instance.service_name, + instance_id=service_instance.instance_id, + service_instance=service_instance, + amqp_url=rmq_settings, + rpc_exchange=test_settings["rpc_exchange"], + events_exchange=test_settings["events_exchange"], + transport_settings=transport_settings, + ) + await service_api.connect() + assert service_api.connected is True + logger.info("SERVICE API CONNECTED") + + service_api.publish_event(test_topic, {"test": "event"}, {}) + + await test_future + + await service_api.disconnect() + assert service_api.connected is False diff --git a/tests/test_nuropb_interface.py b/tests/test_nuropb_interface.py index 2e7d3ef..b6a9bd1 100644 --- a/tests/test_nuropb_interface.py +++ b/tests/test_nuropb_interface.py @@ -12,7 +12,7 @@ ) -class TestTransport: +class TstTransport: def __init__(self, **kwargs): self.kwargs = kwargs self.started = False @@ -28,7 +28,7 @@ def connected(self): return self.started -class TestInterface(NuropbInterface): +class TstInterface(NuropbInterface): _leader: bool _transport: object @@ -145,10 +145,10 @@ def acknowledge_function(action: AcknowledgeAction) -> None: async def test_basic_interface(): service_name = "service_name" instance_id = uuid4().hex - interface = TestInterface( + interface = TstInterface( service_name=service_name, instance_id=instance_id, - transport_class=TestTransport, + transport_class=TstTransport, transport_settings={}, ) assert interface.service_name == service_name @@ -164,10 +164,10 @@ async def test_basic_interface(): async def test_interface_send_request(): service_name = "service_name" instance_id = uuid4().hex - interface = TestInterface( + interface = TstInterface( service_name=service_name, instance_id=instance_id, - transport_class=TestTransport, + transport_class=TstTransport, transport_settings={}, ) await interface.connect() diff --git a/tests/test_rqm_api.py b/tests/test_rqm_api.py index 602ec98..ba16073 100644 --- a/tests/test_rqm_api.py +++ b/tests/test_rqm_api.py @@ -10,16 +10,16 @@ logger = logging.getLogger() -def test_rmq_preparation(test_settings, test_rmq_url, test_api_url): +def test_rmq_preparation(test_settings, rmq_settings, test_api_url): """Test that the RMQ instance is and can be correctly configured - create virtual host must be idempotent - delete virtual host must be idempotent """ - if isinstance(test_rmq_url, str): - tmp_url = f"{test_rmq_url}-{secrets.token_hex(8)}" + if isinstance(rmq_settings, str): + tmp_url = f"{rmq_settings}-{secrets.token_hex(8)}" else: - tmp_url = test_rmq_url.copy() - tmp_url["vhost"] = f"{test_rmq_url['vhost']}-{secrets.token_hex(8)}" + tmp_url = rmq_settings.copy() + tmp_url["vhost"] = f"{rmq_settings['vhost']}-{secrets.token_hex(8)}" create_virtual_host(test_api_url, tmp_url) create_virtual_host(test_api_url, tmp_url) delete_virtual_host(test_api_url, tmp_url) @@ -27,24 +27,24 @@ def test_rmq_preparation(test_settings, test_rmq_url, test_api_url): @pytest.mark.asyncio -async def test_instantiate_api(test_settings, test_rmq_url): +async def test_instantiate_api(test_settings, rmq_settings): """Test that the RMQAPI instance can be instantiated""" - if isinstance(test_rmq_url, str): + if isinstance(rmq_settings, str): with pytest.raises(ValueError): - test_url = "/".join(test_rmq_url.split("/")[:-1]) + test_url = "/".join(rmq_settings.split("/")[:-1]) rmq_api = RMQAPI( amqp_url=test_url, ) else: with pytest.raises(AttributeError): - test_url = "/".join(test_rmq_url.split("/")[:-1]) + test_url = "/".join(rmq_settings.split("/")[:-1]) rmq_api = RMQAPI( amqp_url=test_url, ) rmq_api = RMQAPI( - amqp_url=test_rmq_url, + amqp_url=rmq_settings, ) await rmq_api.connect() await rmq_api.connect() @@ -54,7 +54,7 @@ async def test_instantiate_api(test_settings, test_rmq_url): @pytest.mark.asyncio -async def test_rmq_api_client_mode(test_settings, test_rmq_url): +async def test_rmq_api_client_mode(test_settings, rmq_settings): """Test client mode. this is a client only instance of RMQAPI and only established a connection to the RMQ server. It registers a response queue that is automatically associated with the default exchange, requires that RMQ is sufficiently setup. @@ -69,7 +69,7 @@ async def test_rmq_api_client_mode(test_settings, test_rmq_url): ) rmq_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange=test_settings["rpc_exchange"], events_exchange=test_settings["events_exchange"], transport_settings=transport_settings, @@ -84,7 +84,7 @@ async def test_rmq_api_client_mode(test_settings, test_rmq_url): @pytest.mark.asyncio -async def test_rmq_api_service_mode(test_settings, test_rmq_url, service_instance): +async def test_rmq_api_service_mode(test_settings, rmq_settings, service_instance): service_name = test_settings["service_name"] instance_id = uuid4().hex transport_settings = dict( @@ -98,7 +98,7 @@ async def test_rmq_api_service_mode(test_settings, test_rmq_url, service_instanc service_name=service_name, instance_id=instance_id, service_instance=service_instance, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange=test_settings["rpc_exchange"], events_exchange=test_settings["events_exchange"], transport_settings=transport_settings, diff --git a/tests/test_service_container.py b/tests/test_service_container.py index 198219b..eb14d89 100644 --- a/tests/test_service_container.py +++ b/tests/test_service_container.py @@ -1,3 +1,4 @@ +import asyncio import logging from uuid import uuid4 @@ -9,9 +10,8 @@ logger = logging.getLogger() -# @pytest.mark.skip @pytest.mark.asyncio -async def test_rmq_api_client_mode(test_settings, test_rmq_url, test_api_url): +async def test_rmq_api_client_mode(test_settings, rmq_settings, test_api_url, etcd_config): instance_id = uuid4().hex transport_settings = dict( dl_exchange=test_settings["dl_exchange"], @@ -22,26 +22,25 @@ async def test_rmq_api_client_mode(test_settings, test_rmq_url, test_api_url): ) rmq_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange=test_settings["rpc_exchange"], events_exchange=test_settings["events_exchange"], transport_settings=transport_settings, ) + + # FYI, Challenges with etcd features when tests run under Github Actions container = ServiceContainer( rmq_api_url=test_api_url, instance=rmq_api, - etcd_config=dict( - host="localhost", - port=2379, - ), + etcd_config=etcd_config, ) - # must resolved the testing issue on github actions - # await container.start() + await container.start() + await container.stop() @pytest.mark.asyncio async def test_rmq_api_service_mode( - test_settings, test_rmq_url, test_api_url, service_instance + test_settings, rmq_settings, test_api_url, service_instance, etcd_config ): instance_id = uuid4().hex transport_settings = dict( @@ -55,25 +54,24 @@ async def test_rmq_api_service_mode( service_name=test_settings["service_name"], service_instance=service_instance, instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange=test_settings["rpc_exchange"], events_exchange=test_settings["events_exchange"], transport_settings=transport_settings, ) + + # FYI, Challenges with etcd features when tests run under Github Actions container = ServiceContainer( rmq_api_url=test_api_url, instance=rmq_api, - etcd_config=dict( - host="localhost", - port=2379, - ), + etcd_config=etcd_config, ) - # must resolved the testing issue on github actions - # await container.start() + await container.start() + await container.stop() @pytest.mark.asyncio -async def test_rmq_api_service_mode_no_etcd(test_settings, test_rmq_url, test_api_url): +async def test_rmq_api_service_mode_no_etcd(test_settings, rmq_settings, test_api_url): instance_id = uuid4().hex transport_settings = dict( dl_exchange=test_settings["dl_exchange"], @@ -84,7 +82,7 @@ async def test_rmq_api_service_mode_no_etcd(test_settings, test_rmq_url, test_ap ) rmq_api = RMQAPI( instance_id=instance_id, - amqp_url=test_rmq_url, + amqp_url=rmq_settings, rpc_exchange=test_settings["rpc_exchange"], events_exchange=test_settings["events_exchange"], transport_settings=transport_settings, @@ -94,3 +92,4 @@ async def test_rmq_api_service_mode_no_etcd(test_settings, test_rmq_url, test_ap instance=rmq_api, ) await container.start() + await container.stop() diff --git a/tests/tls_connection/test_tls_pika.py b/tests/tls_connection/test_tls_pika.py deleted file mode 100644 index 54f5f12..0000000 --- a/tests/tls_connection/test_tls_pika.py +++ /dev/null @@ -1,46 +0,0 @@ -import os - -import pytest -from nuropb.rmq_transport import RMQTransport - -IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" -if IN_GITHUB_ACTIONS: - pytest.skip("Skipping model tests when run in Github Actions", allow_module_level=True) - - -@pytest.mark.asyncio -async def test_tls_connect(): - - def message_callback(message): - print(message) - - cacertfile = os.path.join(os.path.dirname(__file__), "ca_cert.pem") - certfile = os.path.join(os.path.dirname(__file__), "cert-2.pem") - keyfile = os.path.join(os.path.dirname(__file__), "key-2.pem") - - amqp_url = { - "cafile": cacertfile, - "host": "localhost", - "username": "guest", - "password": "guest", - "port": 5671, - "vhost": "nuropb", - "verify": False, - "certfile": certfile, - "keyfile": keyfile, - } - transport = RMQTransport( - service_name="test-service", - instance_id="test-service-instance", - amqp_url=amqp_url, - message_callback=message_callback, - ) - await transport.start() - assert transport.connected is True - # await asyncio.Event().wait() - await transport.stop() - assert transport.connected is False - - - - diff --git a/tests/tls_connection/ca_cert.pem b/tests/transport_connection/ca_cert.pem similarity index 100% rename from tests/tls_connection/ca_cert.pem rename to tests/transport_connection/ca_cert.pem diff --git a/tests/tls_connection/ca_key.pem b/tests/transport_connection/ca_key.pem similarity index 100% rename from tests/tls_connection/ca_key.pem rename to tests/transport_connection/ca_key.pem diff --git a/tests/tls_connection/cert-1.pem b/tests/transport_connection/cert-1.pem similarity index 100% rename from tests/tls_connection/cert-1.pem rename to tests/transport_connection/cert-1.pem diff --git a/tests/tls_connection/cert-2.pem b/tests/transport_connection/cert-2.pem similarity index 100% rename from tests/tls_connection/cert-2.pem rename to tests/transport_connection/cert-2.pem diff --git a/tests/tls_connection/key-1.pem b/tests/transport_connection/key-1.pem similarity index 100% rename from tests/tls_connection/key-1.pem rename to tests/transport_connection/key-1.pem diff --git a/tests/tls_connection/key-2.pem b/tests/transport_connection/key-2.pem similarity index 100% rename from tests/tls_connection/key-2.pem rename to tests/transport_connection/key-2.pem diff --git a/tests/transport_connection/test_connection_properties.py b/tests/transport_connection/test_connection_properties.py new file mode 100644 index 0000000..113d8c0 --- /dev/null +++ b/tests/transport_connection/test_connection_properties.py @@ -0,0 +1,198 @@ +import logging +from uuid import uuid4 +import asyncio + +import pytest + +from nuropb.rmq_api import RMQAPI +from nuropb.rmq_transport import RMQTransport +from nuropb.testing.stubs import ServiceStub + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_setting_connection_properties(rmq_settings, test_settings): + """The client connection properties can be set by the user. The user can set the connection + properties by passing a dictionary to the connection_properties argument of the RMQAPI + constructor. The connection properties are used to set the properties of the AMQP connection + that is established by the client. + """ + amqp_url = { + "host": "localhost", + "username": "guest", + "password": "guest", + "port": rmq_settings["port"], + "vhost": rmq_settings["vhost"], + "verify": False, + } + transport_settings = dict( + dl_exchange=test_settings["dl_exchange"], + rpc_bindings=test_settings["rpc_bindings"], + event_bindings=test_settings["event_bindings"], + prefetch_count=test_settings["prefetch_count"], + default_ttl=test_settings["default_ttl"], + rpc_exchange=test_settings["rpc_exchange"], + events_exchange=test_settings["events_exchange"], + ) + + def message_callback(*args, **kwargs): + pass + + transport1 = RMQTransport( + service_name="test-service", + instance_id=uuid4().hex, + amqp_url=amqp_url, + message_callback=message_callback, + **transport_settings, + ) + + await transport1.start() + assert transport1.connected is True + + service = ServiceStub( + service_name="test-service", + instance_id=uuid4().hex, + ) + api = RMQAPI( + service_name=service.service_name, + instance_id=service.instance_id, + service_instance=service, + amqp_url=amqp_url, + transport_settings=transport_settings, + ) + await api.connect() + assert api.connected is True + + await transport1.stop() + assert transport1.connected is False + await api.disconnect() + assert api.connected is False + + +@pytest.mark.asyncio +async def test_single_instance_connection(rmq_settings, test_settings): + """Test Single instance connections + """ + amqp_url = { + "host": "localhost", + "username": "guest", + "password": "guest", + "port": rmq_settings["port"], + "vhost": rmq_settings["vhost"], + } + transport_settings = dict( + dl_exchange=test_settings["dl_exchange"], + rpc_bindings=test_settings["rpc_bindings"], + event_bindings=test_settings["event_bindings"], + prefetch_count=test_settings["prefetch_count"], + default_ttl=test_settings["default_ttl"], + rpc_exchange=test_settings["rpc_exchange"], + events_exchange=test_settings["events_exchange"], + ) + + def message_callback(*args, **kwargs): + pass + + transport1 = RMQTransport( + service_name="test-service", + instance_id="Single_instance_id", + amqp_url=amqp_url, + message_callback=message_callback, + **transport_settings, + ) + + await transport1.start() + assert transport1.connected is True + + transport_settings["vhost"] = "bad" + service = ServiceStub( + service_name="test-service", + instance_id="Single_instance_id", + ) + api = RMQAPI( + service_name=service.service_name, + instance_id=service.instance_id, + service_instance=service, + amqp_url=amqp_url, + transport_settings=transport_settings, + ) + await api.connect() + logger.info("Connected : %s", api.connected) + await asyncio.sleep(3) + assert api.connected is False + await transport1.stop() + assert transport1.connected is False + await api.disconnect() + + +@pytest.mark.asyncio +async def test_bad_credentials(rmq_settings, test_settings): + + amqp_url = { + "host": rmq_settings["host"], + "username": rmq_settings["username"], + "password": "bad_guest", + "port": rmq_settings["port"], + "vhost": rmq_settings["vhost"], + } + transport_settings = dict( + dl_exchange=test_settings["dl_exchange"], + rpc_bindings=test_settings["rpc_bindings"], + event_bindings=test_settings["event_bindings"], + prefetch_count=test_settings["prefetch_count"], + default_ttl=test_settings["default_ttl"], + rpc_exchange=test_settings["rpc_exchange"], + events_exchange=test_settings["events_exchange"], + ) + + service = ServiceStub( + service_name="test-service", + instance_id="Single_instance_id", + ) + api = RMQAPI( + service_name=service.service_name, + instance_id=service.instance_id, + service_instance=service, + amqp_url=amqp_url, + transport_settings=transport_settings, + ) + await api.connect() + logger.info("Connected : %s", api.connected) + assert api.connected is False + + +@pytest.mark.asyncio +async def test_bad_vhost(rmq_settings, test_settings): + + amqp_url = { + "host": rmq_settings["host"], + "username": rmq_settings["username"], + "password": rmq_settings["password"], + "port": rmq_settings["port"], + "vhost": "bad_vhost", + } + transport_settings = dict( + dl_exchange=test_settings["dl_exchange"], + rpc_bindings=test_settings["rpc_bindings"], + event_bindings=test_settings["event_bindings"], + prefetch_count=test_settings["prefetch_count"], + default_ttl=test_settings["default_ttl"], + rpc_exchange=test_settings["rpc_exchange"], + events_exchange=test_settings["events_exchange"], + ) + + service = ServiceStub( + service_name="test-service", + instance_id="Single_instance_id", + ) + api = RMQAPI( + service_name=service.service_name, + instance_id=service.instance_id, + service_instance=service, + amqp_url=amqp_url, + transport_settings=transport_settings, + ) + await api.connect() + logger.info("Connected : %s", api.connected) + assert api.connected is False \ No newline at end of file diff --git a/tests/transport_connection/test_tls_connection.py b/tests/transport_connection/test_tls_connection.py new file mode 100644 index 0000000..3bc06b8 --- /dev/null +++ b/tests/transport_connection/test_tls_connection.py @@ -0,0 +1,75 @@ +import os +from uuid import uuid4 + +import pytest + +from nuropb.rmq_api import RMQAPI +from nuropb.rmq_transport import RMQTransport +from nuropb.testing.stubs import ServiceStub + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +if IN_GITHUB_ACTIONS: + pytest.skip("Skipping model tests when run in Github Actions", allow_module_level=True) + + +@pytest.mark.asyncio +async def test_tls_connect(rmq_settings, test_settings): + + def message_callback(message): + print(message) + + cacertfile = os.path.join(os.path.dirname(__file__), "ca_cert.pem") + certfile = os.path.join(os.path.dirname(__file__), "cert-2.pem") + keyfile = os.path.join(os.path.dirname(__file__), "key-2.pem") + + amqp_url = { + "cafile": cacertfile, + "host": "localhost", + "username": "guest", + "password": "guest", + "port": 5671, + "vhost": rmq_settings["vhost"], + "verify": False, + "certfile": certfile, + "keyfile": keyfile, + } + transport_settings = dict( + dl_exchange=test_settings["dl_exchange"], + rpc_bindings=test_settings["rpc_bindings"], + event_bindings=test_settings["event_bindings"], + prefetch_count=test_settings["prefetch_count"], + default_ttl=test_settings["default_ttl"], + rpc_exchange=test_settings["rpc_exchange"], + events_exchange=test_settings["events_exchange"], + ) + transport1 = RMQTransport( + service_name="test-service", + instance_id=uuid4().hex, + amqp_url=amqp_url, + message_callback=message_callback, + **transport_settings, + ) + await transport1.start() + assert transport1.connected is True + + service = ServiceStub( + service_name="test-service", + instance_id=uuid4().hex, + ) + api = RMQAPI( + service_name=service.service_name, + instance_id=service.instance_id, + service_instance=service, + amqp_url=amqp_url, + transport_settings=transport_settings + ) + await api.connect() + assert api.connected is True + + await transport1.stop() + assert transport1.connected is False + await api.disconnect() + assert api.connected is False + + +