Skip to content

Commit

Permalink
Fix database connection initialization using parameters (#529)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandemeusy authored Jun 6, 2024
1 parent 603a81c commit b060e34
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 51 deletions.
46 changes: 6 additions & 40 deletions ct-app/database/database_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,18 @@ class DatabaseConnection:
Database connection class.
"""

def __init__(self):
def __init__(self, params: Parameters):
"""
Create a new DatabaseConnection based on environment variables setting user, password, host, port, database, sslmode, sslrootcert, sslcert and sslkey.
"""
self.params = Parameters()("PG")
self._assert_parameters()

url = URL(
drivername="postgresql+psycopg2",
username=self.params.pg.user,
password=self.params.pg.password,
host=self.params.pg.host,
port=self.params.pg.port,
database=self.params.pg.database,
username=params.user,
password=params.password,
host=params.host,
port=params.port,
database=params.database,
query={}
)

Expand All @@ -38,38 +36,6 @@ def __init__(self):

log.info("Database connection established.")

def _assert_parameters(self):
"""
Asserts that all required parameters are set.
"""
for group, values in self.required_parameters().items():
assert len(getattr(self.params, group).__dict__), (
f"Missing all '{group.upper()}' environment variables. "
+ "The following ones are required: "
+ f"{', '.join([(group+'(_)'+v).upper() for v in values])}"
)

for value in values:
assert hasattr(self.params.pg, value), (
"Environment variable "
+ f"'{group.upper()}(_){value.upper()}' missing"
)

@classmethod
def required_parameters(cls):
"""
Returns the required parameters for the DatabaseConnection.
"""
return {
"pg": [
"user",
"password",
"host",
"port",
"database"
]
}

def __enter__(self):
"""
Return the session (used by context manager)
Expand Down
8 changes: 3 additions & 5 deletions ct-app/postman/postman_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def send_1_hop_message(

attempts += 1 # send_status in [TaskStatus.SPLITTED, TaskStatus.SUCCESS]

if attempts >= params.maxIterations:
if attempts >= params.distribution.maxIterations:
send_status = TaskStatus.TIMEOUT

if send_status in [TaskStatus.RETRIED, TaskStatus.SPLIT]:
Expand All @@ -72,7 +72,7 @@ def send_1_hop_message(

# store results in database
if send_status != TaskStatus.RETRIED:
with DatabaseConnection() as session:
with DatabaseConnection(params.pg) as session:
entry = Reward(
peer_id=peer,
node_address=node_peer_id,
Expand Down Expand Up @@ -138,9 +138,7 @@ async def async_send_1_hop_message(
max_possible,
node_peer_id,
timestamp,
params.distribution.batchSize,
params.distribution.delayBetweenTwoMessages,
params.distribution.messageDeliveryDelay,
params,
)

status = TaskStatus.SPLIT if relayed < expected_count else TaskStatus.SUCCESS
Expand Down
16 changes: 10 additions & 6 deletions ct-app/postman/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

from core.components.hoprd_api import MESSAGE_TAG, HoprdAPI
from core.components.parameters import Parameters
from database import DatabaseConnection, Peer


Expand All @@ -16,8 +17,8 @@ def createBatches(cls, total_count: int, batch_size: int) -> list[int]:
return [batch_size] * full_batches + [remainder] * bool(remainder)

@classmethod
def peerIDToInt(cls, peer_id: str) -> int:
with DatabaseConnection() as session:
def peerIDToInt(cls, peer_id: str, parameters: Parameters) -> int:
with DatabaseConnection(parameters) as session:
existing_peer = session.query(Peer).filter_by(peer_id=peer_id).first()

if existing_peer:
Expand Down Expand Up @@ -50,14 +51,17 @@ async def send_messages_in_batches(
expected_count: int,
recipient: str,
timestamp: float,
batch_size: int,
delay_between_two_messages: float,
message_delivery_timeout: float,
params: Parameters,
):

batch_size = params.distribution.batchSize,
delay_between_two_messages = params.distribution.delayBetweenTwoMessages,
message_delivery_timeout = params.distribution.messageDeliveryDelay,

relayed_count = 0
issued_count = 0

tag = MESSAGE_TAG + cls.peerIDToInt(relayer)
tag = MESSAGE_TAG + cls.peerIDToInt(relayer, params.pg)

batches = cls.createBatches(expected_count, batch_size)

Expand Down

0 comments on commit b060e34

Please sign in to comment.