From a94c42b6b5d3d5293f128d0bad0637cdcc2784f5 Mon Sep 17 00:00:00 2001 From: Jean Demeusy Date: Thu, 6 Jun 2024 13:40:06 +0200 Subject: [PATCH 1/2] Fix database connection initialization using parameters --- ct-app/database/database_connection.py | 46 ++++---------------------- ct-app/postman/postman_tasks.py | 8 ++--- ct-app/postman/utils.py | 16 +++++---- 3 files changed, 19 insertions(+), 51 deletions(-) diff --git a/ct-app/database/database_connection.py b/ct-app/database/database_connection.py index 4047c2dc..4bb27295 100644 --- a/ct-app/database/database_connection.py +++ b/ct-app/database/database_connection.py @@ -15,20 +15,18 @@ class DatabaseConnection: Database connection class. """ - def __init__(self): + def __init__(self, params): """ Create a new DatabaseConnection based on environment variables setting user, password, host, port, database, sslmode, sslrootcert, sslcert and sslkey. """ - self.params = Parameters()("PG") - self._assert_parameters() url = URL( drivername="postgresql+psycopg2", - username=self.params.pg.user, - password=self.params.pg.password, - host=self.params.pg.host, - port=self.params.pg.port, - database=self.params.pg.database, + username=params.user, + password=params.password, + host=params.host, + port=params.port, + database=params.database, query={} ) @@ -38,38 +36,6 @@ def __init__(self): log.info("Database connection established.") - def _assert_parameters(self): - """ - Asserts that all required parameters are set. - """ - for group, values in self.required_parameters().items(): - assert len(getattr(self.params, group).__dict__), ( - f"Missing all '{group.upper()}' environment variables. " - + "The following ones are required: " - + f"{', '.join([(group+'(_)'+v).upper() for v in values])}" - ) - - for value in values: - assert hasattr(self.params.pg, value), ( - "Environment variable " - + f"'{group.upper()}(_){value.upper()}' missing" - ) - - @classmethod - def required_parameters(cls): - """ - Returns the required parameters for the DatabaseConnection. - """ - return { - "pg": [ - "user", - "password", - "host", - "port", - "database" - ] - } - def __enter__(self): """ Return the session (used by context manager) diff --git a/ct-app/postman/postman_tasks.py b/ct-app/postman/postman_tasks.py index fd1b801a..cd15f3bb 100644 --- a/ct-app/postman/postman_tasks.py +++ b/ct-app/postman/postman_tasks.py @@ -62,7 +62,7 @@ def send_1_hop_message( attempts += 1 # send_status in [TaskStatus.SPLITTED, TaskStatus.SUCCESS] - if attempts >= params.maxIterations: + if attempts >= params.distribution.maxIterations: send_status = TaskStatus.TIMEOUT if send_status in [TaskStatus.RETRIED, TaskStatus.SPLIT]: @@ -72,7 +72,7 @@ def send_1_hop_message( # store results in database if send_status != TaskStatus.RETRIED: - with DatabaseConnection() as session: + with DatabaseConnection(params.pg) as session: entry = Reward( peer_id=peer, node_address=node_peer_id, @@ -138,9 +138,7 @@ async def async_send_1_hop_message( max_possible, node_peer_id, timestamp, - params.distribution.batchSize, - params.distribution.delayBetweenTwoMessages, - params.distribution.messageDeliveryDelay, + params, ) status = TaskStatus.SPLIT if relayed < expected_count else TaskStatus.SUCCESS diff --git a/ct-app/postman/utils.py b/ct-app/postman/utils.py index 053d9e73..0e3937c2 100644 --- a/ct-app/postman/utils.py +++ b/ct-app/postman/utils.py @@ -1,6 +1,7 @@ import asyncio from core.components.hoprd_api import MESSAGE_TAG, HoprdAPI +from core.components.parameters import Parameters from database import DatabaseConnection, Peer @@ -16,8 +17,8 @@ def createBatches(cls, total_count: int, batch_size: int) -> list[int]: return [batch_size] * full_batches + [remainder] * bool(remainder) @classmethod - def peerIDToInt(cls, peer_id: str) -> int: - with DatabaseConnection() as session: + def peerIDToInt(cls, peer_id: str, parameters: Parameters) -> int: + with DatabaseConnection(parameters) as session: existing_peer = session.query(Peer).filter_by(peer_id=peer_id).first() if existing_peer: @@ -50,14 +51,17 @@ async def send_messages_in_batches( expected_count: int, recipient: str, timestamp: float, - batch_size: int, - delay_between_two_messages: float, - message_delivery_timeout: float, + params: Parameters, ): + + batch_size = params.distribution.batchSize, + delay_between_two_messages = params.distribution.delayBetweenTwoMessages, + message_delivery_timeout = params.distribution.messageDeliveryDelay, + relayed_count = 0 issued_count = 0 - tag = MESSAGE_TAG + cls.peerIDToInt(relayer) + tag = MESSAGE_TAG + cls.peerIDToInt(relayer, params.pg) batches = cls.createBatches(expected_count, batch_size) From c3c6af5c1049fb7c728b6620e41d464270a843ec Mon Sep 17 00:00:00 2001 From: Jean Demeusy Date: Thu, 6 Jun 2024 13:44:04 +0200 Subject: [PATCH 2/2] typehint --- ct-app/database/database_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ct-app/database/database_connection.py b/ct-app/database/database_connection.py index 4bb27295..3f981108 100644 --- a/ct-app/database/database_connection.py +++ b/ct-app/database/database_connection.py @@ -15,7 +15,7 @@ class DatabaseConnection: Database connection class. """ - def __init__(self, params): + def __init__(self, params: Parameters): """ Create a new DatabaseConnection based on environment variables setting user, password, host, port, database, sslmode, sslrootcert, sslcert and sslkey. """