Skip to content

Commit

Permalink
added component tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandemeusy committed Jan 23, 2024
1 parent 2deec55 commit e3939f2
Show file tree
Hide file tree
Showing 12 changed files with 308 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: Test with pytest
working-directory: ct-app
run: |
pytest test
pytest test -W ignore::DeprecationWarning
image:
name: Build and push container image
Expand Down
4 changes: 2 additions & 2 deletions ct-app/core/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from prometheus_client import start_http_server

from .components.parameters import Parameters
from .components.utils import EnvUtils, Utils
from .components.utils import EnvironmentUtils, Utils
from .core import Core
from .node import Node

Expand Down Expand Up @@ -51,5 +51,5 @@ def main():


if __name__ == "__main__":
if EnvUtils.checkRequiredEnvVar("core"):
if EnvironmentUtils.checkRequiredEnvVar("core"):
main()
45 changes: 45 additions & 0 deletions ct-app/core/components/environment_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import subprocess
from os import environ
from typing import Any

from .baseclass import Base


class EnvironmentUtils(Base):
def print_prefix(self) -> str:
return "EnvUtils"

@classmethod
def envvar(cls, var_name: str, default: Any = None, type: type = str):
if var_name in environ:
return type(environ[var_name])
else:
return default

@classmethod
def envvarWithPrefix(cls, prefix: str, type=str) -> dict[str, Any]:
var_dict = {
key: type(v) for key, v in environ.items() if key.startswith(prefix)
}

return dict(sorted(var_dict.items()))

@classmethod
def checkRequiredEnvVar(cls, folder: str):
result = subprocess.run(
f"sh ./scripts/list_required_parameters.sh {folder}".split(),
capture_output=True,
text=True,
).stdout

all_set_flag = True
for var in result.splitlines():
exists = var in environ
all_set_flag *= exists

# print var with a leading check mark if it exists or red X (emoji) if it doesn't
cls().info(f"{'✅' if exists else '❌'} {var}")

if not all_set_flag:
cls().error("Some required environment variables are not set.")
return all_set_flag
4 changes: 2 additions & 2 deletions ct-app/core/components/parameters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .baseclass import Base
from .utils import EnvUtils
from .environment_utils import EnvironmentUtils


class Parameters(Base):
Expand All @@ -16,7 +16,7 @@ def __call__(self, *prefixes: str or list[str]):
if subparams_name[-1] == "_":
subparams_name = subparams_name[:-1]

for key, value in EnvUtils.envvarWithPrefix(prefix).items():
for key, value in EnvironmentUtils.envvarWithPrefix(prefix).items():
k = key.replace(prefix, "").lower()

try:
Expand Down
123 changes: 55 additions & 68 deletions ct-app/core/components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import json
import os
import random
import subprocess
import time
from datetime import datetime, timedelta
from os import environ
from typing import Any

import aiohttp
Expand All @@ -18,69 +16,44 @@
from core.model.topology_entry import TopologyEntry

from .baseclass import Base


class EnvUtils(Base):
def print_prefix(self) -> str:
return "EnvUtils"

@classmethod
def envvar(cls, var_name: str, default: Any = None, type: type = str):
if var_name in environ:
return type(environ[var_name])
else:
return default

@classmethod
def envvarWithPrefix(cls, prefix: str, type=str) -> dict[str, Any]:
var_dict = {
key: type(v) for key, v in environ.items() if key.startswith(prefix)
}

return dict(sorted(var_dict.items()))

@classmethod
def checkRequiredEnvVar(cls, folder: str):
result = subprocess.run(
f"sh ./scripts/list_required_parameters.sh {folder}".split(),
capture_output=True,
text=True,
).stdout

all_set_flag = True
for var in result.splitlines():
exists = var in environ
all_set_flag *= exists

# print var with a leading check mark if it exists or red X (emoji) if it doesn't
cls().info(f"{'✅' if exists else '❌'} {var}")

if not all_set_flag:
cls().error("Some required environment variables are not set.")
return all_set_flag
from .channelstatus import ChannelStatus
from .environment_utils import EnvironmentUtils


class Utils(Base):
@classmethod
def nodesAddresses(
cls, address_prefix: str, keyenv: str
) -> tuple[list[str], list[str]]:
addresses = EnvUtils.envvarWithPrefix(address_prefix).values()
keys = EnvUtils.envvarWithPrefix(keyenv).values()
"""
Returns a tuple containing the addresses and keys of the nodes.
:param address_prefix: The prefix of the environment variables containing addresses.
:param keyenv: The prefix of the environment variables containing keys.
:returns: A tuple containing the addresses and keys.
"""
addresses = EnvironmentUtils.envvarWithPrefix(address_prefix).values()
keys = EnvironmentUtils.envvarWithPrefix(keyenv).values()

return list(addresses), list(keys)

@classmethod
async def httpPOST(cls, url, data) -> tuple[int, dict]:
async def post(session: ClientSession, url: str, data: dict):
async def httpPOST(cls, url: str, data: dict) -> tuple[int, dict]:
"""
Performs an HTTP POST request.
:param url: The URL to send the request to.
:param data: The data to be sent.
:returns: A tuple containing the status code and the response.
"""

async def _post(session: ClientSession, url: str, data: dict):
async with session.post(url, json=data) as response:
status = response.status
response = await response.json()
return status, response

async with aiohttp.ClientSession() as session:
try:
status, response = await post(session, url, data)
status, response = await _post(session, url, data)
except Exception:
return None, None
else:
Expand All @@ -96,13 +69,12 @@ def mergeTopologyPeersSubgraph(
"""
Merge metrics and subgraph data with the unique peer IDs, addresses,
balance links.
:param: topology_dict: A dict mapping peer IDs to node addresses.
:param: peers_list: A dict containing metrics with peer ID as the key.
:param: subgraph_dict: A dict containing subgraph data with safe address as key.
:param topology_dict: A dict mapping peer IDs to node addresses.
:param peers_list: A dict containing metrics with peer ID as the key.
:param subgraph_dict: A dict containing subgraph data with safe address as key.
:returns: A dict with peer ID as the key and the merged information.
"""
merged_result: list[Peer] = []

network_addresses = [p.address for p in peers_list]

# Merge based on peer ID with the channel topology as the baseline
Expand Down Expand Up @@ -133,7 +105,7 @@ def allowManyNodePerSafe(cls, peers: list[Peer]):
"""
Split the stake managed by a safe address equaly between the nodes
that the safe manages.
:param: peer: list of peers
:param peer: list of peers
:returns: nothing.
"""
safe_counts = {peer.safe_address: 0 for peer in peers}
Expand All @@ -152,9 +124,9 @@ def excludeElements(
) -> list[Peer]:
"""
Removes elements from a dictionary based on a blacklist.
:param: source_data (dict): The dictionary to be updated.
:param: blacklist (list): A list containing the keys to be removed.
:returns: nothing.
:param source_data (dict): The dictionary to be updated.
:param blacklist (list): A list containing the keys to be removed.
:returns: A list containing the removed elements.
"""

peer_addresses = [peer.address for peer in source_data]
Expand All @@ -176,8 +148,8 @@ def excludeElements(
def rewardProbability(cls, peers: list[Peer]) -> list[int]:
"""
Evaluate the function for each stake value in the eligible_peers dictionary.
:param eligible_peers: A dict containing the data.
:returns: nothing.
:param peers: A dict containing the data.
:returns: A list containing the excluded elements due to low stake.
"""

indexes_to_remove = [
Expand All @@ -198,13 +170,12 @@ def rewardProbability(cls, peers: list[Peer]) -> list[int]:
return excluded

@classmethod
def jsonFromGCP(cls, bucket_name, blob_name, schema=None):
def jsonFromGCP(cls, bucket_name: str, blob_name: str):
"""
Reads a JSON file and validates its contents using a schema.
:param: bucket_name: The name of the bucket
:param: blob_name: The name of the blob
;param: schema (opt): The validation schema
:returns: (dict): The contents of the JSON file.
:param bucket_name: The name of the bucket
:param blob_name: The name of the blob
:returns: The contents of the JSON file.
"""

storage_client = storage.Client()
Expand Down Expand Up @@ -234,6 +205,14 @@ def stringArrayToGCP(cls, bucket_name: str, blob_name: str, data: list[str]):

@classmethod
def generateFilename(cls, prefix: str, foldername: str, extension: str = "csv"):
"""
Generates a filename with the following format:
<prefix>_<timestamp>.<extension>
:param prefix: The prefix of the filename
:param foldername: The folder where the file will be stored
:param extension: The extension of the file
:returns: The filename
"""
timestamp = time.strftime("%Y%m%d%H%M%S")

if extension.startswith("."):
Expand All @@ -247,6 +226,7 @@ def nextEpoch(cls, seconds: int) -> datetime:
"""
Calculates the delay until the next whole `minutes`min and `seconds`sec.
:param seconds: next whole second to trigger the function
:returns: The next epoch
"""
if seconds == 0:
raise ValueError("'seconds' must be greater than 0")
Expand All @@ -261,6 +241,7 @@ def nextDelayInSeconds(cls, seconds: int) -> int:
"""
Calculates the delay until the next whole `minutes`min and `seconds`sec.
:param seconds: next whole second to trigger the function
:returns: The delay in seconds.
"""
if seconds == 0:
return 1
Expand All @@ -276,6 +257,8 @@ def nextDelayInSeconds(cls, seconds: int) -> int:
async def aggregatePeerBalanceInChannels(cls, channels: list) -> dict[str, dict]:
"""
Returns a dict containing all unique source_peerId-source_address links.
:param channels: The list of channels.
:returns: A dict containing all peerIds-balanceInChannels links.
"""

results: dict[str, dict] = {}
Expand All @@ -284,10 +267,11 @@ async def aggregatePeerBalanceInChannels(cls, channels: list) -> dict[str, dict]
hasattr(c, "source_peer_id")
and hasattr(c, "source_address")
and hasattr(c, "status")
and hasattr(c, "balance")
):
continue

if c.status != "Open":
if ChannelStatus(c.status) != ChannelStatus.Open:
continue

if c.source_peer_id not in results:
Expand All @@ -301,15 +285,18 @@ async def aggregatePeerBalanceInChannels(cls, channels: list) -> dict[str, dict]
return results

@classmethod
def splitDict(cls, peers: dict[str, int], bins: int) -> list[dict]:
def splitDict(cls, src: dict[str, Any], bins: int) -> list[dict[str, Any]]:
"""
Splits randomly a dict into multiple sub-dictionary.
Splits randomly a dict into multiple sub-dictionary of almost equal sizes.
:param src: The dict to be split.
:param bins: The number of sub-dictionaries.
:returns: A list containing the sub-dictionaries.
"""
# Split the dictionary into multiple sub-dictionaries
split = [{} for i in range(bins)]
split = [{} for _ in range(bins)]

# Assign a random number to each element in the dictionary
for peer_id, data in peers.items():
split[random.randint(0, bins - 1)][peer_id] = data
for idx, (key, value) in enumerate(random.sample(src.items(), len(src))):
split[idx % bins][key] = value

return split
4 changes: 2 additions & 2 deletions ct-app/postman/postman_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from celery import Celery
from core.components.hoprd_api import HoprdAPI
from core.components.parameters import Parameters
from core.components.utils import EnvUtils, Utils
from core.components.utils import EnvironmentUtils, Utils
from database import DatabaseConnection, Reward

from .task_status import TaskStatus
Expand All @@ -18,7 +18,7 @@

params = Parameters()("PARAM_", "RABBITMQ_")

if not EnvUtils.checkRequiredEnvVar("postman"):
if not EnvironmentUtils.checkRequiredEnvVar("postman"):
exit(1)

app = Celery(
Expand Down
13 changes: 10 additions & 3 deletions ct-app/test/components/test_channelstatus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@


def test_channelstatus():
assert ChannelStatus.isPending("PendingToClose")
assert not ChannelStatus.isPending("Open")
assert not ChannelStatus.isOpen("PendingToClose")
assert ChannelStatus.isOpen("Open")
assert not ChannelStatus.isOpen("PendingToClose")
assert not ChannelStatus.isOpen("Closed")

assert not ChannelStatus.isPending("Open")
assert ChannelStatus.isPending("PendingToClose")
assert not ChannelStatus.isPending("Closed")

assert not ChannelStatus.isClosed("Open")
assert not ChannelStatus.isClosed("PendingToClose")
assert ChannelStatus.isClosed("Closed")
Loading

0 comments on commit e3939f2

Please sign in to comment.