Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix database connection initialization using parameters #529

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 6 additions & 40 deletions ct-app/database/database_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,18 @@ class DatabaseConnection:
Database connection class.
"""

def __init__(self):
def __init__(self, params: Parameters):
"""
Create a new DatabaseConnection based on environment variables setting user, password, host, port, database, sslmode, sslrootcert, sslcert and sslkey.
"""
self.params = Parameters()("PG")
self._assert_parameters()

url = URL(
drivername="postgresql+psycopg2",
username=self.params.pg.user,
password=self.params.pg.password,
host=self.params.pg.host,
port=self.params.pg.port,
database=self.params.pg.database,
username=params.user,
password=params.password,
host=params.host,
port=params.port,
database=params.database,
query={}
)

Expand All @@ -38,38 +36,6 @@ def __init__(self):

log.info("Database connection established.")

def _assert_parameters(self):
"""
Asserts that all required parameters are set.
"""
for group, values in self.required_parameters().items():
assert len(getattr(self.params, group).__dict__), (
f"Missing all '{group.upper()}' environment variables. "
+ "The following ones are required: "
+ f"{', '.join([(group+'(_)'+v).upper() for v in values])}"
)

for value in values:
assert hasattr(self.params.pg, value), (
"Environment variable "
+ f"'{group.upper()}(_){value.upper()}' missing"
)

@classmethod
def required_parameters(cls):
"""
Returns the required parameters for the DatabaseConnection.
"""
return {
"pg": [
"user",
"password",
"host",
"port",
"database"
]
}

def __enter__(self):
"""
Return the session (used by context manager)
Expand Down
8 changes: 3 additions & 5 deletions ct-app/postman/postman_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def send_1_hop_message(

attempts += 1 # send_status in [TaskStatus.SPLITTED, TaskStatus.SUCCESS]

if attempts >= params.maxIterations:
if attempts >= params.distribution.maxIterations:
send_status = TaskStatus.TIMEOUT

if send_status in [TaskStatus.RETRIED, TaskStatus.SPLIT]:
Expand All @@ -72,7 +72,7 @@ def send_1_hop_message(

# store results in database
if send_status != TaskStatus.RETRIED:
with DatabaseConnection() as session:
with DatabaseConnection(params.pg) as session:
entry = Reward(
peer_id=peer,
node_address=node_peer_id,
Expand Down Expand Up @@ -138,9 +138,7 @@ async def async_send_1_hop_message(
max_possible,
node_peer_id,
timestamp,
params.distribution.batchSize,
params.distribution.delayBetweenTwoMessages,
params.distribution.messageDeliveryDelay,
params,
)

status = TaskStatus.SPLIT if relayed < expected_count else TaskStatus.SUCCESS
Expand Down
16 changes: 10 additions & 6 deletions ct-app/postman/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

from core.components.hoprd_api import MESSAGE_TAG, HoprdAPI
from core.components.parameters import Parameters
from database import DatabaseConnection, Peer


Expand All @@ -16,8 +17,8 @@ def createBatches(cls, total_count: int, batch_size: int) -> list[int]:
return [batch_size] * full_batches + [remainder] * bool(remainder)

@classmethod
def peerIDToInt(cls, peer_id: str) -> int:
with DatabaseConnection() as session:
def peerIDToInt(cls, peer_id: str, parameters: Parameters) -> int:
with DatabaseConnection(parameters) as session:
existing_peer = session.query(Peer).filter_by(peer_id=peer_id).first()

if existing_peer:
Expand Down Expand Up @@ -50,14 +51,17 @@ async def send_messages_in_batches(
expected_count: int,
recipient: str,
timestamp: float,
batch_size: int,
delay_between_two_messages: float,
message_delivery_timeout: float,
params: Parameters,
):

batch_size = params.distribution.batchSize,
delay_between_two_messages = params.distribution.delayBetweenTwoMessages,
message_delivery_timeout = params.distribution.messageDeliveryDelay,

relayed_count = 0
issued_count = 0

tag = MESSAGE_TAG + cls.peerIDToInt(relayer)
tag = MESSAGE_TAG + cls.peerIDToInt(relayer, params.pg)

batches = cls.createBatches(expected_count, batch_size)

Expand Down
Loading