diff --git a/poetry.lock b/poetry.lock index d8581e5..b495225 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -3721,4 +3721,4 @@ docs = [] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "07e41c2f66c37b647a6d4c45158259f24b249c60c4d489dc28f98408800707e0" +content-hash = "4cb94599570ba60ecaa704cb7044bb0663d86c9ca0111ac57454203f88a98028" diff --git a/pyproject.toml b/pyproject.toml index 3043008..5203864 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ poetry = "^1.5.1" pytest = "^7.3.1" certifi = "^2023.7.22" coverage = {extras = ["toml"], version = "^7.2.7"} -pytest-dotenv = "^0.5.2" +pytest-dotenv = "0.5.2" black = "^23.3.0" mypy = "^1.4.1" pytest-asyncio = "^0.21.1" @@ -67,10 +67,10 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.pytest] -env_files = [".env_test"] -testpaths = ["tests"] [tool.pytest.ini_options] +env_files = [".env"] +testpaths = ["tests"] asyncio_mode = "strict" log_cli = true log_level = "DEBUG" diff --git a/src/nuropb/nuropb_api.py b/src/nuropb/nuropb_api.py new file mode 100644 index 0000000..9c9f501 --- /dev/null +++ b/src/nuropb/nuropb_api.py @@ -0,0 +1,248 @@ +""" +Factory functions for instantiating nuropb api's. +""" +import logging +from typing import Optional, Dict, Any, Callable +from uuid import uuid4 + +from nuropb.rmq_api import RMQAPI +from nuropb.rmq_lib import configure_nuropb_rmq, create_virtual_host, build_amqp_url, build_rmq_api_url, \ + rmq_api_url_from_amqp_url + +logger = logging.getLogger(__name__) + + +def default_connection_properties(connection_properties: Dict[str, Any]) -> Dict[str, Any]: + if "host" not in connection_properties: + connection_properties["host"] = "localhost" + if "username" not in connection_properties: + connection_properties["username"] = "guest" + if "password" not in connection_properties: + connection_properties["password"] = "guest" + if "vhost" not in connection_properties: + connection_properties["vhost"] = "nuropb" + if "verify" not in connection_properties: + connection_properties["verify"] = False + if "ssl" not in connection_properties: + connection_properties["ssl"] = False + if "port" not in connection_properties and connection_properties["ssl"]: + connection_properties["port"] = 5671 + elif "port" not in connection_properties: + connection_properties["port"] = 5672 + + return connection_properties + + +def create_client( + name: Optional[str] = None, + instance_id: Optional[str] = None, + connection_properties: Optional[Dict[str, Any]] = None, + transport_settings: Optional[str | Dict[str, Any]] = None, + transport: Optional[RMQAPI] = RMQAPI, +) -> RMQAPI: + """ Create a client api instance for the nuropb service mesh. This caller of this function + will have to implement the asyncio call to connect to the service mesh: + await client_api.connect() + + :param name: used to identify the api connection to the service mesh + :param instance_id: used to create the service mesh response queue for this api connection + :param connection_properties: str or dict with values as required for the chosen transport api client + :param transport_settings: dict with values as required for the underlying transport api + :param transport: the class of the transport api client to use + :return: + """ + + if connection_properties is None: + connection_properties = default_connection_properties({ + "vhost": "nuropb", + "ssl": False, + "verify": False, + }) + elif isinstance(connection_properties, dict): + connection_properties = default_connection_properties(connection_properties) + + if transport is None: + transport = RMQAPI + if transport_settings is None: + transport_settings = {} + + client_api: RMQAPI = transport( + amqp_url=connection_properties, + service_name=name, + instance_id=instance_id, + transport_settings=transport_settings, + ) + return client_api + + +async def connect(instance_id: Optional[str] = None): + client_api = create_client( + instance_id=instance_id, + ) + await client_api.connect() + return client_api + + +def configure_mesh( + mesh_name: Optional[str] = None, + connection_properties: Optional[Dict[str, Any]] = None, + transport_settings: Optional[str | Dict[str, Any]] = None, +): + if mesh_name is None: + mesh_name = "nuropb" + + if connection_properties is None: + connection_properties = default_connection_properties({ + "vhost": mesh_name, + "ssl": False, + "verify": False, + }) + + if isinstance(connection_properties, str): + amqp_url = connection_properties + + elif isinstance(connection_properties, dict): + connection_properties = default_connection_properties(connection_properties) + + if connection_properties["ssl"]: + rmq_scheme = "amqps" + else: + rmq_scheme = "amqp" + + host = connection_properties["host"] + port = connection_properties["port"] + username = connection_properties["username"] + password = connection_properties["password"] + vhost = connection_properties["vhost"] + + amqp_url = build_amqp_url( + host, port, username, password, vhost, rmq_scheme + ) + else: + raise ValueError("connection_properties must be a str or dict") + + rmq_api_url = rmq_api_url_from_amqp_url(amqp_url) + create_virtual_host( + api_url=rmq_api_url, + vhost_url=amqp_url, + ) + + if transport_settings is None: + transport_settings = {} + if "rpc_exchange" not in transport_settings: + transport_settings["rpc_exchange"] = "nuropb-rpc-exchange" + if "events_exchange" not in transport_settings: + transport_settings["events_exchange"] = "nuropb-events-exchange" + if "dl_exchange" not in transport_settings: + transport_settings["dl_exchange"] = "nuropb-dl-exchange" + if "dl_queue" not in transport_settings: + transport_settings["dl_queue"] = "nuropb-dl-queue" + + configure_nuropb_rmq( + rmq_url=connection_properties, + events_exchange=transport_settings["events_exchange"], + rpc_exchange=transport_settings["rpc_exchange"], + dl_exchange=transport_settings["dl_exchange"], + dl_queue=transport_settings["dl_queue"], + ) + + +class MeshService: + """ A generic service class that can be used to create a connection only service instance for the + nuropb service mesh. This class could also be used as a template or to define a subclass for + creating a service instance. + """ + _service_name: str + _instance_id: str + _event_bindings: list[str] + _event_callback: Optional[Callable] + + def __init__( + self, + service_name: str, + instance_id: Optional[str] = None, + event_bindings: Optional[list[str]] = None, + event_callback: Optional[Callable] = None, + ): + self._service_name = service_name + self._instance_id = instance_id or uuid4().hex + self._event_bindings = event_bindings or [] + self._event_callback = event_callback + + async def _handle_event_( + self, + topic: str, + event: dict, + target: list[str] | None = None, + context: dict | None = None, + trace_id: str | None = None, + ): + _ = self + if self._event_callback is not None: + await self._event_callback(topic, event, target, context, trace_id) + + +def create_service( + name: str, + instance_id: Optional[str] = None, + service_instance: Optional[object] = None, + connection_properties: Optional[Dict[str, Any]] = None, + transport_settings: Optional[str | Dict[str, Any]] = None, + transport: Optional[RMQAPI] = RMQAPI, + event_bindings: Optional[list[str]] = None, + event_callback: Optional[Callable] = None, +) -> RMQAPI: + """ Create a client api instance for the nuropb service mesh. This caller of this function + will have to implement the asyncio call to connect to the service mesh: + await client_api.connect() + + :param name: used to identify this service to the service mesh + :param instance_id: used to create the service mesh response queue for this individual api + connection + :param service_instance: the instance of the service class that is intended to be exposed + to the service mesh + :param connection_properties: str or dict with values as required for the chosen transport + api client + :param transport_settings: dict with values as required for the underlying transport api + :param transport: the class of the transport api client to use + :param event_bindings: when service_instance is None, a list of event topics that this + service will subscribe to. + when service_instance is not None, the list will override the event_bindings of the + transport_settings if any are defined. + :param event_callback: when service_instance is None, a callback function that will be + called when an event is received + :return: + """ + + if connection_properties is None: + connection_properties = default_connection_properties({ + "vhost": "nuropb", + "ssl": False, + "verify": False, + }) + elif isinstance(connection_properties, dict): + connection_properties = default_connection_properties(connection_properties) + + if transport is None: + transport = RMQAPI + if transport_settings is None: + transport_settings = {} + + if service_instance is None: + service_instance = MeshService( + service_name=name, + instance_id=instance_id, + event_bindings=event_bindings, + event_callback=event_callback, + ) + elif event_bindings is not None: + transport_settings["event_bindings"] = event_bindings + + service_api: RMQAPI = transport( + amqp_url=connection_properties, + service_name=name, + service_instance=service_instance, + instance_id=instance_id, + transport_settings=transport_settings, + ) + return service_api diff --git a/src/nuropb/rmq_lib.py b/src/nuropb/rmq_lib.py index 8e56b07..4cfd42b 100644 --- a/src/nuropb/rmq_lib.py +++ b/src/nuropb/rmq_lib.py @@ -18,19 +18,28 @@ def build_amqp_url( - host: str, port: str | int, username: str, password: str, vhost: str + host: str, port: str | int, username: str, password: str, vhost: str, scheme: str = "amqp" ) -> str: """Creates an AMQP URL for connecting to RabbitMQ""" - return f"amqp://{username}:{password}@{host}:{port}/{vhost}" + if username: + password = f":{password}" if password.strip() else "" + return f"{scheme}://{username}{password}@{host}:{port}/{vhost}" + else: + return f"{scheme}://{host}:{port}/{vhost}" def build_rmq_api_url( scheme: str, host: str, port: str | int, username: str | None, password: str | None ) -> str: """Creates an HTTP URL for connecting to RabbitMQ management API""" - if username is None or password is None: + if username: + if password: + password = f":{password}" + else: + password = "" + return f"{scheme}://{username}{password}@{host}:{port}/api" + else: return f"{scheme}://{host}:{port}/api" - return f"{scheme}://{username}:{password}@{host}:{port}/api" def rmq_api_url_from_amqp_url( @@ -43,11 +52,19 @@ def rmq_api_url_from_amqp_url( :return: the RabbitMQ management API URL """ url_parts = urlparse(amqp_url) + scheme = scheme or url_parts.scheme + scheme = "https" if scheme == "amqps" else "http" username = url_parts.username password = url_parts.password + port = port or url_parts.port host = url_parts.hostname if url_parts.hostname else "localhost" - port = 15672 if port is None else port - scheme = "http" if scheme is None else scheme + if port: + port = int(port) + 10000 + elif not port and scheme == "https": + port = 15671 + elif not port: + port = 15672 + return build_rmq_api_url(scheme, host, port, username, password) @@ -96,7 +113,7 @@ def get_connection_parameters( """ if isinstance(amqp_url, dict): # create TLS connection parameters - + use_ssl = amqp_url.get("ssl", False) host = amqp_url.get("host", None) port = amqp_url.get("port", None) pika_parameters = { @@ -113,7 +130,10 @@ def get_connection_parameters( if vhost: pika_parameters["virtual_host"] = vhost - if amqp_url.get("cafile", None) or amqp_url.get("certfile"): + """ By specifying cafile, it is assumed that the connection will be over SSL/TLS + """ + if use_ssl or amqp_url.get("cafile", None): + use_ssl = True cafile = amqp_url.get("cafile", None) if cafile: # pragma: no cover context = ssl.create_default_context( @@ -145,6 +165,12 @@ def get_connection_parameters( ) pika_parameters["ssl_options"] = ssl_options + if pika_parameters["port"] is None and use_ssl: + pika_parameters["port"] = 5671 + elif pika_parameters["port"] is None: + pika_parameters["port"] = 5672 + + if amqp_url.get("username", None): credentials = PlainCredentials(amqp_url["username"], amqp_url["password"]) pika_parameters["credentials"] = credentials diff --git a/src/nuropb/rmq_transport.py b/src/nuropb/rmq_transport.py index f8f1bb7..9733908 100644 --- a/src/nuropb/rmq_transport.py +++ b/src/nuropb/rmq_transport.py @@ -200,7 +200,7 @@ def __init__( - int prefetch_count: The number of messages to prefetch defaults to 1, unlimited is 0. Experiment with larger values for higher throughput in your user case. - When an existing transport initialised and connected, and a subsequent transport + When an existing transport is initialised and connected, and a subsequent transport instance is connected with the same service_name and instance_id as the first, the broker will shut down the channel of subsequent instances when they attempt to configure their response queue. This is because the response queue is opened in exclusive mode. The @@ -456,7 +456,7 @@ def connect(self) -> asyncio.Future[bool]: self._connected_future = asyncio.Future() connection_parameters = get_connection_parameters( - amqp_url = self._amqp_url, + amqp_url=self._amqp_url, name=self._service_name, instance_id=self._instance_id, client_only=self._client_only, diff --git a/src/nuropb/utils.py b/src/nuropb/utils.py index bcdf5a3..fb8e27e 100644 --- a/src/nuropb/utils.py +++ b/src/nuropb/utils.py @@ -8,7 +8,20 @@ def obfuscate_credentials(url_with_credentials: str | Dict[str, Any]) -> str: :return: str """ if isinstance(url_with_credentials, dict): - return "tls-amqp://{username}:@{host}:{port}/{vhost}".format(**url_with_credentials) + port = url_with_credentials.get("port", "") + if port: + port = f":{port}" + else: + port = "" + + if url_with_credentials.get("use_ssl", False) or url_with_credentials.get( + "cafile", None + ): + scheme = "amqps" + else: + scheme = "amqp" + + return "{scheme}://{username}:@{host}{port}/{vhost}".format(scheme=scheme, **url_with_credentials) pattern = r"(:.*?@)" result = re.sub( diff --git a/tests/conftest.py b/tests/conftest.py index 8728439..f9a2130 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ from nuropb.testing.stubs import IN_GITHUB_ACTIONS, ServiceExample logging.getLogger("pika").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) @pytest.fixture(scope="session") @@ -31,6 +32,7 @@ def etcd_config(): port=2379, ) + @pytest.fixture(scope="session") def test_settings(): start_time = datetime.datetime.utcnow() @@ -40,12 +42,14 @@ def test_settings(): RMQ_AMQP_PORT: ${{ job.services.rabbitmq.ports['5672'] }} RMQ_API_PORT: ${{ job.services.rabbitmq.ports['15672'] }} """ - api_port = os.environ.get("RMQ_API_PORT", 15672) - amqp_port = os.environ.get("RMQ_AMQP_PORT", 5672) + logger.info(os.environ) + api_port = os.environ.get("RMQ_API_PORT", "15672") + amqp_port = os.environ.get("RMQ_AMQP_PORT", "5672") yield { "api_scheme": "http", "api_port": api_port, + "scheme": "amqp", "port": amqp_port, "host": "127.0.0.1", "username": "guest", @@ -58,6 +62,8 @@ def test_settings(): "event_bindings": [], "prefetch_count": 1, "default_ttl": 60 * 30 * 1000, # 30 minutes + "verify": False, + "ssl": False, } end_time = datetime.datetime.utcnow() logging.info( @@ -73,23 +79,15 @@ def rmq_settings(test_settings): logging.debug("Setting up RabbitMQ test instance") vhost = f"pytest-{secrets.token_hex(8)}" - if IN_GITHUB_ACTIONS: - settings = dict( - host=test_settings["host"], - port=test_settings["port"], - username=test_settings["username"], - password=test_settings["password"], - vhost=vhost, - ) - else: - settings = dict( - username="guest", - password="guest", - host="localhost", - port=5672, - vhost=vhost, - verify=False, - ) + settings = dict( + host=test_settings["host"], + port=test_settings["port"], + username=test_settings["username"], + password=test_settings["password"], + vhost=vhost, + verify=test_settings["verify"], + ssl=test_settings["ssl"], + ) api_url = build_rmq_api_url( scheme=test_settings["api_scheme"], @@ -141,6 +139,7 @@ def test_rmq_url_static(test_settings): username=test_settings["username"], password=test_settings["password"], vhost=vhost, + scheme=test_settings["scheme"], ) api_url = build_rmq_api_url( scheme=test_settings["api_scheme"], diff --git a/tests/test_nuropb_api.py b/tests/test_nuropb_api.py new file mode 100644 index 0000000..b9af0e5 --- /dev/null +++ b/tests/test_nuropb_api.py @@ -0,0 +1,76 @@ +import os + +from nuropb.nuropb_api import create_service, create_client, configure_mesh + +import pytest + + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +if IN_GITHUB_ACTIONS: + pytest.skip("Skipping model tests when run in Github Actions", allow_module_level=True) + + +@pytest.mark.asyncio +async def test_client_and_service_api_quick_setup(test_settings, rmq_settings): + + transport_settings = dict( + dl_exchange=test_settings["dl_exchange"], + prefetch_count=test_settings["prefetch_count"], + default_ttl=test_settings["default_ttl"], + ) + connection_properties = rmq_settings + + configure_mesh( + mesh_name=connection_properties["vhost"], + connection_properties=connection_properties, + transport_settings=transport_settings, + ) + + service_api = create_service( + name="test_service", + connection_properties=connection_properties, + transport_settings=transport_settings, + ) + await service_api.connect() + client_api = create_client( + connection_properties={ + "vhost": connection_properties["vhost"], + "port": rmq_settings["port"], + "host": rmq_settings["host"], + } + ) + await client_api.connect() + + await client_api.disconnect() + assert client_api.connected is False + await service_api.disconnect() + assert service_api.connected is False + + +@pytest.mark.asyncio +async def test_client_and_service_api_quick_setup_raw_defaults(rmq_settings): + + configure_mesh( + connection_properties={ + "port": rmq_settings["port"], + "host": rmq_settings["host"], + } + ) + service_api = create_service( + name="test_service", + connection_properties={ + "port": rmq_settings["port"], + "host": rmq_settings["host"], + } + ) + await service_api.connect() + client_api = create_client(connection_properties={ + "port": rmq_settings["port"], + "host": rmq_settings["host"], + }) + await client_api.connect() + + await client_api.disconnect() + assert client_api.connected is False + await service_api.disconnect() + assert service_api.connected is False diff --git a/tests/transport_connection/test_channel_state.py b/tests/transport_connection/test_channel_state.py index dd82860..278ba4e 100644 --- a/tests/transport_connection/test_channel_state.py +++ b/tests/transport_connection/test_channel_state.py @@ -37,7 +37,7 @@ async def close_channel(): await asyncio.sleep(0.001) asyncio.create_task(close_channel()) - asyncio.sleep(0.001) + await asyncio.sleep(0.001) result = await mesh_client.request( mesh_service.service_name, diff --git a/tests/transport_connection/test_connection_properties.py b/tests/transport_connection/test_connection_properties.py index 113d8c0..e163b7c 100644 --- a/tests/transport_connection/test_connection_properties.py +++ b/tests/transport_connection/test_connection_properties.py @@ -5,12 +5,30 @@ import pytest from nuropb.rmq_api import RMQAPI +from nuropb.rmq_lib import rmq_api_url_from_amqp_url from nuropb.rmq_transport import RMQTransport from nuropb.testing.stubs import ServiceStub logger = logging.getLogger(__name__) +def test_ampq_url_to_api_url(): + api_url = rmq_api_url_from_amqp_url("amqp://guest:guest@localhost:5672/nuropb-example") + assert api_url == "http://guest:guest@localhost:15672/api" + + api_url = rmq_api_url_from_amqp_url("amqp://guest@localhost:5672/nuropb-example") + assert api_url == "http://guest@localhost:15672/api" + + api_url = rmq_api_url_from_amqp_url("amqp:///nuropb-example") + assert api_url == "http://localhost:15672/api" + + api_url = rmq_api_url_from_amqp_url("amqps://guest:guest@localhost:5672/nuropb-example") + assert api_url == "https://guest:guest@localhost:15672/api" + + api_url = rmq_api_url_from_amqp_url("amqps://guest:guest@localhost/nuropb-example") + assert api_url == "https://guest:guest@localhost:15671/api" + + @pytest.mark.asyncio async def test_setting_connection_properties(rmq_settings, test_settings): """The client connection properties can be set by the user. The user can set the connection @@ -18,14 +36,7 @@ async def test_setting_connection_properties(rmq_settings, test_settings): constructor. The connection properties are used to set the properties of the AMQP connection that is established by the client. """ - amqp_url = { - "host": "localhost", - "username": "guest", - "password": "guest", - "port": rmq_settings["port"], - "vhost": rmq_settings["vhost"], - "verify": False, - } + amqp_url = rmq_settings.copy() transport_settings = dict( dl_exchange=test_settings["dl_exchange"], rpc_bindings=test_settings["rpc_bindings"], @@ -74,13 +85,7 @@ def message_callback(*args, **kwargs): async def test_single_instance_connection(rmq_settings, test_settings): """Test Single instance connections """ - amqp_url = { - "host": "localhost", - "username": "guest", - "password": "guest", - "port": rmq_settings["port"], - "vhost": rmq_settings["vhost"], - } + amqp_url = rmq_settings.copy() transport_settings = dict( dl_exchange=test_settings["dl_exchange"], rpc_bindings=test_settings["rpc_bindings"], @@ -129,13 +134,8 @@ def message_callback(*args, **kwargs): @pytest.mark.asyncio async def test_bad_credentials(rmq_settings, test_settings): - amqp_url = { - "host": rmq_settings["host"], - "username": rmq_settings["username"], - "password": "bad_guest", - "port": rmq_settings["port"], - "vhost": rmq_settings["vhost"], - } + amqp_url = rmq_settings.copy() + amqp_url["username"] = "bad-username" transport_settings = dict( dl_exchange=test_settings["dl_exchange"], rpc_bindings=test_settings["rpc_bindings"], @@ -165,13 +165,8 @@ async def test_bad_credentials(rmq_settings, test_settings): @pytest.mark.asyncio async def test_bad_vhost(rmq_settings, test_settings): - amqp_url = { - "host": rmq_settings["host"], - "username": rmq_settings["username"], - "password": rmq_settings["password"], - "port": rmq_settings["port"], - "vhost": "bad_vhost", - } + amqp_url = rmq_settings.copy() + amqp_url["vhost"] = "bad-vhost" transport_settings = dict( dl_exchange=test_settings["dl_exchange"], rpc_bindings=test_settings["rpc_bindings"], diff --git a/tests/test_rqm_api.py b/tests/transport_connection/test_rqm_api.py similarity index 99% rename from tests/test_rqm_api.py rename to tests/transport_connection/test_rqm_api.py index ba16073..90f05b2 100644 --- a/tests/test_rqm_api.py +++ b/tests/transport_connection/test_rqm_api.py @@ -20,6 +20,7 @@ def test_rmq_preparation(test_settings, rmq_settings, test_api_url): else: tmp_url = rmq_settings.copy() tmp_url["vhost"] = f"{rmq_settings['vhost']}-{secrets.token_hex(8)}" + create_virtual_host(test_api_url, tmp_url) create_virtual_host(test_api_url, tmp_url) delete_virtual_host(test_api_url, tmp_url) diff --git a/tests/transport_connection/test_tls_connection.py b/tests/transport_connection/test_tls_connection.py index 3bc06b8..934d838 100644 --- a/tests/transport_connection/test_tls_connection.py +++ b/tests/transport_connection/test_tls_connection.py @@ -15,6 +15,58 @@ @pytest.mark.asyncio async def test_tls_connect(rmq_settings, test_settings): + def message_callback(message): + print(message) + amqp_url = rmq_settings.copy() + amqp_url.update({ + "ssl": True, + "port": 5671, + "verify": False, + }) + transport_settings = dict( + dl_exchange=test_settings["dl_exchange"], + rpc_bindings=test_settings["rpc_bindings"], + event_bindings=test_settings["event_bindings"], + prefetch_count=test_settings["prefetch_count"], + default_ttl=test_settings["default_ttl"], + rpc_exchange=test_settings["rpc_exchange"], + events_exchange=test_settings["events_exchange"], + ) + transport1 = RMQTransport( + service_name="test-service", + instance_id=uuid4().hex, + amqp_url=amqp_url, + message_callback=message_callback, + **transport_settings, + ) + await transport1.start() + from pika.adapters.utils.io_services_utils import _AsyncSSLTransport + assert isinstance(transport1._connection._transport, _AsyncSSLTransport) + assert transport1.connected is True + + service = ServiceStub( + service_name="test-service", + instance_id=uuid4().hex, + ) + api = RMQAPI( + service_name=service.service_name, + instance_id=service.instance_id, + service_instance=service, + amqp_url=amqp_url, + transport_settings=transport_settings + ) + await api.connect() + assert api.connected is True + + await transport1.stop() + assert transport1.connected is False + await api.disconnect() + assert api.connected is False + + +@pytest.mark.asyncio +async def test_tls_connect_with_cafile(rmq_settings, test_settings): + def message_callback(message): print(message) @@ -22,17 +74,14 @@ def message_callback(message): certfile = os.path.join(os.path.dirname(__file__), "cert-2.pem") keyfile = os.path.join(os.path.dirname(__file__), "key-2.pem") - amqp_url = { + amqp_url = rmq_settings.copy() + amqp_url.update({ "cafile": cacertfile, - "host": "localhost", - "username": "guest", - "password": "guest", "port": 5671, - "vhost": rmq_settings["vhost"], "verify": False, "certfile": certfile, "keyfile": keyfile, - } + }) transport_settings = dict( dl_exchange=test_settings["dl_exchange"], rpc_bindings=test_settings["rpc_bindings"], @@ -50,6 +99,8 @@ def message_callback(message): **transport_settings, ) await transport1.start() + from pika.adapters.utils.io_services_utils import _AsyncSSLTransport + assert isinstance(transport1._connection._transport, _AsyncSSLTransport) assert transport1.connected is True service = ServiceStub(