Skip to content

Commit

Permalink
Merge pull request #21 from robertbetts/develop
Browse files Browse the repository at this point in the history
Improved connection error handling
  • Loading branch information
robertbetts authored Sep 28, 2023
2 parents fffeb9a + 2db5d92 commit 5855cab
Show file tree
Hide file tree
Showing 33 changed files with 1,000 additions and 382 deletions.
2 changes: 1 addition & 1 deletion examples/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def main():
total_sample_count = 0
total_request_time = 0

batch_size = 500
batch_size = 5000
number_of_batches = 5
ioloop = asyncio.get_event_loop()

Expand Down
14 changes: 10 additions & 4 deletions examples/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ async def main():
try:
logging.info("Service %s ready", service_example._service_name)
await asyncio.Event().wait()
except KeyboardInterrupt:
logging.info("Shutting down signal received")

logging.info("Service %s done", service_example._service_name)
await container.stop()
except BaseException as err:
logging.info("Shutting down. %s: %s", type(err).__name__, err)
await container.stop()
finally:
logging.info("Service %s done", service_example._service_name)


if __name__ == "__main__":
Expand All @@ -99,4 +102,7 @@ async def main():
logging.getLogger("pika").setLevel(logging.CRITICAL)
logging.getLogger("etcd3").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
asyncio.run(main())
try:
asyncio.run(main())
except KeyboardInterrupt:
pass
13 changes: 10 additions & 3 deletions examples/server_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ async def main():
transport_settings=transport_settings,
)
await mesh_api.connect()

try:
logging.info("Service %s ready", service_instance._service_name)
await asyncio.Event().wait()
except KeyboardInterrupt:
logging.info("Shutting down signal received")
await mesh_api.disconnect()
except BaseException as err:
logging.info("Shutting down. %s: %s", type(err).__name__, err)
await mesh_api.disconnect()
finally:
logging.info("Service %s done", service_instance._service_name)

logging.info("Service %s done", service_instance._service_name)

Expand All @@ -61,4 +65,7 @@ async def main():
logging.getLogger("pika").setLevel(logging.WARNING)
logging.getLogger("etcd3").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
asyncio.run(main())
try:
asyncio.run(main())
except KeyboardInterrupt:
pass
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "nuropb"
version = "0.1.6"
version = "0.1.7"
description = "NuroPb - A Distributed Event Driven Service Mesh"
authors = ["Robert Betts <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -47,6 +47,7 @@ env_files = [".env_test"]
testpaths = ["tests"]

[tool.pytest.ini_options]
asyncio_mode = "strict"
log_cli = true
log_level = "DEBUG"
log_cli_format = " %(levelname).1s %(asctime)s,%(msecs)d %(module)s %(lineno)s %(message)s"
Expand Down
24 changes: 18 additions & 6 deletions src/nuropb/contexts/describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

logger = logging.getLogger(__name__)

AuthoriseFunc = Callable[[str], Dict[str, Any]]
AuthoriseFunc = Callable[[str], Optional[Dict[str, Any]]]


def method_visible_on_mesh(method: Callable[..., Any]) -> bool:
Expand Down Expand Up @@ -87,10 +87,21 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:


def describe_service(class_instance: object) -> Dict[str, Any] | None:
"""Returns a description of the class methods that will be exposed to the service mesh"""
"""Returns a description of the class methods that will be exposed to the service mesh
"""
service_info = {
"service_name": "",
"service_version": "",
"description": "",
"encrypted_methods": [],
"methods": {},
"warnings": [],
}

if class_instance is None:
logger.warning("No service class base has been input")
return None
service_info["warnings"].append("No service class base")

else:
service_name = getattr(class_instance, "_service_name", None)
service_description = getattr(class_instance, "__doc__", None)
Expand Down Expand Up @@ -174,18 +185,19 @@ def map_argument(arg_props: Any) -> Tuple[str, Dict[str, Any]]:
}
methods.append((name, method_spec))

service_info = {
service_info.update({
"service_name": service_name,
"service_version": service_version,
"description": service_description,
"encrypted_methods": service_has_encrypted_methods,
"methods": dict(methods),
}
})

if service_has_encrypted_methods:
private_key = service_name = getattr(class_instance, "_private_key", None)
if private_key is None:
raise ValueError(
service_info["warnings"].append("Service has encrypted methods but no private key has been set.")
logger.debug(
f"Service {service_name} has encrypted methods but no private key has been set"
)

Expand Down
15 changes: 12 additions & 3 deletions src/nuropb/encodings/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def add_public_key(
"""
self._service_public_keys[service_name] = public_key

def get_public_key(self, service_name: str) -> rsa.RSAPublicKey:
def get_public_key(self, service_name: str) -> rsa.RSAPublicKey | None:
"""Get a public key for a service
:param service_name: str
:return: rsa.RSAPublicKey
"""
return self._service_public_keys.get[service_name]
return self._service_public_keys.get(service_name)

def has_public_key(self, service_name: str) -> bool:
"""Check if a service has a public key
Expand Down Expand Up @@ -142,7 +142,12 @@ def encrypt_payload(

if service_name is None:
# Mode 4, get public key from the private key
public_key = self._private_key.public_key()
if self._private_key is None:
raise ValueError(
f"Service public key not found for service: {self._service_name}"
) # pragma: no cover
else:
public_key = self._private_key.public_key()
else:
# Mode 1, get public key from the destination service's public key
public_key = self._service_public_keys[service_name]
Expand Down Expand Up @@ -175,6 +180,10 @@ def decrypt_payload(self, payload: bytes, correlation_id: str) -> bytes:
encrypted_key, encrypted_payload = payload.split(b".", 1)
if correlation_id not in self._correlation_id_symmetric_keys:
"""Mode 3, use public key from the private key to decrypt key"""
if self._private_key is None:
raise ValueError(
f"Service public key not found for service: {self._service_name}"
) # pragma: no cover
key = decrypt_key(encrypted_key, self._private_key)
# remember the key for this correlation_id to encrypt the response
self._correlation_id_symmetric_keys[correlation_id] = key
Expand Down
22 changes: 21 additions & 1 deletion src/nuropb/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

logger = logging.getLogger(__name__)

NUROPB_VERSION = "0.1.6"
NUROPB_VERSION = "0.1.7"
NUROPB_PROTOCOL_VERSION = "0.1.1"
NUROPB_PROTOCOL_VERSIONS_SUPPORTED = ("0.1.1",)
NUROPB_MESSAGE_TYPES = (
Expand Down Expand Up @@ -260,6 +260,26 @@ class NuropbTimeoutError(NuropbException):

class NuropbTransportError(NuropbException):
"""NuropbTransportError: represents an error that inside the plumbing."""
_close_connection: bool

def __init__(
self,
description: Optional[str] = None,
payload: Optional[PayloadDict] = None,
exception: Optional[BaseException] = None,
close_connection: bool = False,
):
super().__init__(
description=description,
payload=payload,
exception=exception,
)
self._close_connection = close_connection

@property
def close_connection(self) -> bool:
"""close_connection: returns True if the connection should be closed"""
return self._close_connection


class NuropbMessageError(NuropbException):
Expand Down
136 changes: 84 additions & 52 deletions src/nuropb/rmq_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,19 @@


class RMQAPI(NuropbInterface):
"""RMQAPI: A NuropbInterface implementation that uses RabbitMQ as the underlying transport."""
"""RMQAPI: A NuropbInterface implementation that uses RabbitMQ as the underlying transport.
When an existing transport initialised and connected, and a subsequent transport
instance is connected with the same service_name and instance_id as the first, the broker
will shut down the channel of subsequent instances when they attempt to configure their
response queue. This is because the response queue is opened in exclusive mode. The
exclusive mode is used to ensure that only one consumer (nuropb api connection) is
consuming from the response queue.
Deliberately specifying a fixed instance_id, is a valid mechanism to ensure that a service
can only run in single instance mode. This is useful for services that are not designed to
be run in a distributed manner or where there is specific service configuration required.
"""

_mesh_name: str
_connection_name: str
Expand All @@ -42,6 +54,17 @@ class RMQAPI(NuropbInterface):
_service_discovery: Dict[str, Any]
_service_public_keys: Dict[str, Any]

@classmethod
def _get_vhost(cls, amqp_url: str | Dict[str, Any]) -> str:
if isinstance(amqp_url, str):
parts = amqp_url.split("/")
vhost = amqp_url.split("/")[-1]
if len(parts) < 4:
raise ValueError("Invalid amqp_url, missing vhost")
else:
vhost = amqp_url["vhost"]
return vhost

def __init__(
self,
amqp_url: str | Dict[str, Any],
Expand All @@ -52,14 +75,12 @@ def __init__(
events_exchange: Optional[str] = None,
transport_settings: Optional[Dict[str, Any]] = None,
):
"""RMQAPI: A NuropbInterface implementation that uses RabbitMQ as the underlying transport."""
if isinstance(amqp_url, str):
parts = amqp_url.split("/")
vhost = amqp_url.split("/")[-1]
if len(parts) < 4:
raise ValueError("Invalid amqp_url, missing vhost")
else:
vhost = amqp_url["vhost"]
"""RMQAPI: A NuropbInterface implementation that uses RabbitMQ as the underlying transport.
Where exchange inputs are none, but they user present in transport_settings, then use the
values from transport_settings
"""
vhost = self._get_vhost(amqp_url)

self._mesh_name = vhost

Expand Down Expand Up @@ -120,6 +141,14 @@ def __init__(
"No service instance provided, service will not be able to handle requests"
) # pragma: no cover

""" where exchange inputs are none, but they user present in transport_settings,
then use the values from transport settings
"""
if rpc_exchange is None and transport_settings.get("rpc_exchange", None):
rpc_exchange = transport_settings["rpc_exchange"]
if events_exchange is None and transport_settings.get("events_exchange", None):
events_exchange = transport_settings["events_exchange"]

transport_settings.update(
{
"service_name": self._service_name,
Expand Down Expand Up @@ -246,6 +275,41 @@ def receive_transport_message(
service_message["nuropb_type"],
)

@classmethod
def _handle_immediate_request_error(
cls,
rpc_response: bool,
payload: RequestPayloadDict | ResponsePayloadDict,
error: Dict[str, Any] | BaseException
) -> ResponsePayloadDict:

if rpc_response and isinstance(error, BaseException):
raise error
elif rpc_response:
raise NuropbMessageError(
description=error["description"],
payload=payload,
)

if isinstance(error, NuropbException):
error = error.to_dict()
elif isinstance(error, BaseException):
error = {
"error": f"{type(error).__name__}",
"description": f"{type(error).__name__}: {error}",
}

return {
"tag": "response",
"context": payload["context"],
"correlation_id": payload["correlation_id"],
"trace_id": payload["trace_id"],
"result": None,
"error": error,
"warning": None,
"reply_to": "",
}

async def request(
self,
service: str,
Expand Down Expand Up @@ -317,51 +381,19 @@ async def request(
encrypted=encrypted,
)
except Exception as err:
if rpc_response is False:
return {
"tag": "response",
"context": context,
"correlation_id": correlation_id,
"trace_id": trace_id,
"result": None,
"error": {
"error": f"{type(err).__name__}",
"description": f"Error sending request message: {err}",
},
}
else:
raise err
return self._handle_immediate_request_error(rpc_response, message, err)

try:
response = await response_future
if rpc_response is True and response["error"] is not None:
raise NuropbMessageError(
description=response["error"]["description"],
payload=response,
)
elif rpc_response is True:
return response["result"]
else:
return response
except BaseException as err:
if rpc_response is True:
raise err
else:
if not isinstance(err, NuropbException):
error = {
"error": f"{type(err).__name__}",
"description": f"Error waiting for response: {err}",
}
else:
error = err.to_dict()
return {
"tag": "response",
"context": context,
"correlation_id": correlation_id,
"trace_id": trace_id,
"result": None,
"error": error,
}
except Exception as err:
return self._handle_immediate_request_error(rpc_response, message, err)

if response["error"] is not None:
return self._handle_immediate_request_error(rpc_response, response, response["error"])
if rpc_response is True:
return response["result"]
else:
return response

def command(
self,
Expand Down Expand Up @@ -501,7 +533,7 @@ async def describe_service(
)
return service_info
except Exception as err:
logger.error(f"error loading the public key for {service_name}: {err}")
raise ValueError(f"error loading the public key for {service_name}: {err}")

async def requires_encryption(self, service_name: str, method_name: str) -> bool:
"""requires_encryption: Queries the service discovery information for the service_name
Expand Down
Loading

0 comments on commit 5855cab

Please sign in to comment.