Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/v2 patch 2 #409

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions tests/apps/test_hpo.py

This file was deleted.

110 changes: 110 additions & 0 deletions tests/hpo/apps/main/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Literal, TypeAlias

from collections.abc import Callable, Generator
import os
from pathlib import Path
import shutil
import sqlite3
import tempfile

import pytest

# Define type aliases for the database utility functions
TrialCountFunc: TypeAlias = Callable[[Path, str], int]
TrialValuesFunc: TypeAlias = Callable[[Path, str], list[float]]
DBUtils: TypeAlias = dict[Literal["get_trial_count", "get_trial_values"], TrialCountFunc | TrialValuesFunc]
ConfigModFunc: TypeAlias = Callable[[Path, str, int, str], Path]


@pytest.fixture
def temp_dir() -> Generator[Path, None, None]:
"""Create a temporary directory with test files and clean up afterwards"""
with tempfile.TemporaryDirectory() as tmp_dir:
temp_path = Path(tmp_dir)
original_dir = os.getcwd()
os.chdir(tmp_dir)

source_dir = Path(__file__).parent
test_files = ["config.yaml", "objective.sh", "objective_for_test.py"]

for file_name in test_files:
source_file = source_dir / file_name
target_file = temp_path / file_name
if source_file.exists():
shutil.copy2(source_file, target_file)
os.chmod(target_file, 0o755)
print(f"\n=== Content of {file_name} ===")
print((target_file).read_text())
print("=" * 40)
else:
pytest.skip(f"Required test file {file_name} not found in {source_dir}")

os.chmod(temp_path / "objective.sh", 0o755)
yield temp_path
os.chdir(original_dir)


@pytest.fixture
def db_utils() -> DBUtils:
"""Fixture providing database utility functions"""

def get_trial_count(db_path: Path, study_name: str) -> int:
"""Get the number of trials from the SQLite database for a specific study"""
if not db_path.exists():
raise AssertionError("Database file does not exist")
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT COUNT(*) FROM trials
WHERE study_id = (SELECT study_id FROM studies WHERE study_name = ?)
""",
(study_name,),
)
count = cursor.fetchone()[0]
conn.close()
return int(count) if count is not None else 0

def get_trial_values(db_path: Path, study_name: str) -> list[float]:
"""Get the values of all completed trials for a specific study"""
if not db_path.exists():
raise AssertionError("Database file does not exist")
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT trial_values.value
FROM trials
JOIN studies ON trials.study_id = studies.study_id
JOIN trial_values ON trials.trial_id = trial_values.trial_id
WHERE studies.study_name = ?
AND trials.state = 'COMPLETE'
ORDER BY trials.number
""",
(study_name,),
)
values = [row[0] for row in cursor.fetchall()]
conn.close()
return values

return {"get_trial_count": get_trial_count, "get_trial_values": get_trial_values}


@pytest.fixture
def config_modifier() -> ConfigModFunc:
"""Fixture providing configuration file modification functionality"""

def modify_config(config_path: Path, study_name: str, n_trials: int, db_name: str) -> Path:
"""Modify config file with new study name and number of trials"""
with open(config_path) as f:
content = f.read()
content = content.replace("study_name: my_study", f"study_name: {study_name}")
content = content.replace("n_trials: 30", f"n_trials: {n_trials}")
if db_name:
content = content.replace("url: sqlite:///aiaccel_storage.db", f"url: sqlite:///{db_name}")
new_config_path = config_path.parent / f"config_{study_name}.yaml"
with open(new_config_path, "w") as f:
f.write(content)
return new_config_path

return modify_config
20 changes: 20 additions & 0 deletions tests/hpo/apps/main/test_optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pathlib import Path
from unittest.mock import patch
import uuid

from conftest import ConfigModFunc, DBUtils


def test_normal_execution(temp_dir: Path, db_utils: DBUtils, config_modifier: ConfigModFunc) -> None:
"""Test normal execution without resume functionality"""
from aiaccel.hpo.apps.optimize import main

db_name = "normal_test.db"
study_name = f"test_study_{uuid.uuid4().hex[:8]}"
config_path = config_modifier(temp_dir / "config.yaml", study_name, 30, db_name)

with patch("sys.argv", ["optimize.py", "objective.sh", "--config", str(config_path)]):
main()

trial_count = db_utils["get_trial_count"](temp_dir / db_name, study_name)
assert trial_count == 30
Loading
Loading