Skip to content

Commit

Permalink
use antares-study-version package to handle versions (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle authored Oct 10, 2024
1 parent 892436f commit d613377
Show file tree
Hide file tree
Showing 14 changed files with 66 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[test]
python -m pip install -r requirements-test.txt
- name: Test with pytest
run: |
pytest
5 changes: 4 additions & 1 deletion antareslauncher/data_repo/data_repo_tinydb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
import typing as t

Expand Down Expand Up @@ -83,7 +84,9 @@ def save_study(self, study: StudyDTO):
pk_name = self.db_primary_key
pk_value = getattr(study, pk_name)
old = self.db.get(tinydb.where(pk_name) == pk_value)
new = vars(study)
study_dict = vars(study)
new = copy.deepcopy(study_dict) # to avoid modifying the study object
new["antares_version"] = f"{new['antares_version']:2d}"
if old:
diff = _calc_diff(old, new)
logger.info(f"Updating study '{pk_value}' in database: {diff!r}")
Expand Down
5 changes: 3 additions & 2 deletions antareslauncher/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from antareslauncher.use_cases.retrieve.retrieve_controller import RetrieveController
from antareslauncher.use_cases.retrieve.state_updater import StateUpdater
from antareslauncher.use_cases.wait_loop_controller.wait_controller import WaitController
from antares.study.version import SolverMinorVersion


class NoJsonConfigFileError(Exception):
Expand Down Expand Up @@ -67,7 +68,7 @@ class MainParameters:
json_dir: Path
default_json_db_name: str
slurm_script_path: str
antares_versions_on_remote_server: t.Sequence[str]
antares_versions_on_remote_server: t.Sequence[SolverMinorVersion]
default_ssh_dict: t.Mapping[str, t.Any]
db_primary_key: str
partition: str = ""
Expand Down Expand Up @@ -120,7 +121,7 @@ def run_with(arguments: argparse.Namespace, parameters: MainParameters, show_ban
post_processing=arguments.post_processing,
antares_versions_on_remote_server=parameters.antares_versions_on_remote_server,
other_options=arguments.other_options or "",
antares_version=arguments.antares_version,
antares_version=SolverMinorVersion.parse(arguments.antares_version),
),
)
launch_controller = LaunchController(repo=data_repo, env=environment, display=display)
Expand Down
3 changes: 2 additions & 1 deletion antareslauncher/parameters_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from antareslauncher.main import MainParameters
from antareslauncher.main_option_parser import ParserParameters
from antares.study.version import SolverMinorVersion

ALT2_PARENT = Path.home() / "antares_launcher_settings"
ALT1_PARENT = Path.cwd()
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(self, json_ssh_conf: Path, yaml_filepath: Path):
self.remote_slurm_script_path = obj["SLURM_SCRIPT_PATH"]
self.partition = obj.get("PARTITION", "")
self.quality_of_service = obj.get("QUALITY_OF_SERVICE", "")
self.antares_versions = obj["ANTARES_VERSIONS_ON_REMOTE_SERVER"]
self.antares_versions = [SolverMinorVersion.parse(v) for v in obj["ANTARES_VERSIONS_ON_REMOTE_SERVER"]]
self.db_primary_key = obj["DB_PRIMARY_KEY"]
self.json_dir = Path(obj["JSON_DIR"]).expanduser()
self.json_db_name = obj.get("DEFAULT_JSON_DB_NAME", DEFAULT_JSON_DB_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from antareslauncher.remote_environnement.slurm_script_features import ScriptParametersDTO, SlurmScriptFeatures
from antareslauncher.remote_environnement.ssh_connection import SshConnection
from antareslauncher.study_dto import StudyDTO
from antares.study.version import SolverMinorVersion

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -206,7 +207,7 @@ def submit_job(self, my_study: StudyDTO):
input_zipfile_name=Path(my_study.zipfile_path).name,
time_limit=time_limit,
n_cpu=my_study.n_cpu,
antares_version=my_study.antares_version,
antares_version=SolverMinorVersion.parse(my_study.antares_version),
run_mode=my_study.run_mode,
post_processing=my_study.post_processing,
other_options=my_study.other_options or "",
Expand Down
5 changes: 3 additions & 2 deletions antareslauncher/remote_environnement/slurm_script_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shlex

from antareslauncher.study_dto import Modes
from antares.study.version import SolverMinorVersion


@dataclasses.dataclass
Expand All @@ -10,7 +11,7 @@ class ScriptParametersDTO:
input_zipfile_name: str
time_limit: int
n_cpu: int
antares_version: int
antares_version: SolverMinorVersion
run_mode: Modes
post_processing: bool
other_options: str
Expand Down Expand Up @@ -81,7 +82,7 @@ def compose_launch_command(
for arg in [
self.solver_script_path,
script_params.input_zipfile_name,
str(script_params.antares_version),
f"{script_params.antares_version:2d}",
_job_type,
str(script_params.post_processing),
script_params.other_options,
Expand Down
5 changes: 3 additions & 2 deletions antareslauncher/study_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path

from antares.study.version import StudyVersion

class Modes(IntEnum):
antares = 1
Expand Down Expand Up @@ -43,7 +43,7 @@ class StudyDTO:
# Simulation stage data
time_limit: t.Optional[int] = None
n_cpu: int = 1
antares_version: int = 0
antares_version: StudyVersion = StudyVersion.parse(0)
xpansion_mode: str = "" # "", "r", "cpp"
run_mode: Modes = Modes.antares
post_processing: bool = False
Expand All @@ -59,4 +59,5 @@ def from_dict(cls, doc: t.Mapping[str, t.Any]) -> "StudyDTO":
"""
attrs = dict(**doc)
attrs.pop("name", None) # calculated
attrs["antares_version"] = StudyVersion.parse(attrs["antares_version"])
return cls(**attrs)
18 changes: 10 additions & 8 deletions antareslauncher/use_cases/create_list/study_list_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb
from antareslauncher.display.display_terminal import DisplayTerminal
from antareslauncher.study_dto import Modes, StudyDTO
from antares.study.version import SolverMinorVersion, StudyVersion

DEFAULT_VERSION = SolverMinorVersion.parse(0)

def get_solver_version(study_dir: Path, *, default: int = 0) -> int:
def get_solver_version(study_dir: Path, *, default: SolverMinorVersion = DEFAULT_VERSION) -> SolverMinorVersion:
"""
Retrieve the solver version number or else the study version number
from the "study.antares" file.
Expand All @@ -28,7 +30,7 @@ def get_solver_version(study_dir: Path, *, default: int = 0) -> int:
section = config["antares"]
for key in "solver_version", "version":
if key in section:
return int(section[key])
return SolverMinorVersion.parse(section[key])
return default


Expand All @@ -41,9 +43,9 @@ class StudyListComposerParameters:
xpansion_mode: str # "", "r", "cpp"
output_dir: str
post_processing: bool
antares_versions_on_remote_server: t.Sequence[str]
antares_versions_on_remote_server: t.Sequence[SolverMinorVersion]
other_options: str
antares_version: int = 0
antares_version: SolverMinorVersion = DEFAULT_VERSION


class StudyListComposer:
Expand All @@ -66,7 +68,7 @@ def __init__(
self.antares_version = parameters.antares_version
self._new_study_added = False
self.DEFAULT_JOB_LOG_DIR_PATH = str(Path(self.log_dir) / "JOB_LOGS")
self.ANTARES_VERSIONS_ON_REMOTE_SERVER = [int(v) for v in parameters.antares_versions_on_remote_server]
self.ANTARES_VERSIONS_ON_REMOTE_SERVER = parameters.antares_versions_on_remote_server

def get_list_of_studies(self):
"""Retrieve the list of studies from the repo
Expand All @@ -76,7 +78,7 @@ def get_list_of_studies(self):
"""
return self._repo.get_list_of_studies()

def _create_study(self, path: Path, antares_version: int, xpansion_mode: str) -> StudyDTO:
def _create_study(self, path: Path, antares_version: SolverMinorVersion, xpansion_mode: str) -> StudyDTO:
run_mode = {
"": Modes.antares,
"r": Modes.xpansion_r,
Expand All @@ -86,7 +88,7 @@ def _create_study(self, path: Path, antares_version: int, xpansion_mode: str) ->
path=str(path),
n_cpu=self.n_cpu,
time_limit=self.time_limit,
antares_version=antares_version,
antares_version=StudyVersion.parse(antares_version),
job_log_dir=self.DEFAULT_JOB_LOG_DIR_PATH,
output_dir=str(self.output_dir),
xpansion_mode=xpansion_mode,
Expand Down Expand Up @@ -120,7 +122,7 @@ def update_study_database(self):

def _update_database_with_directory(self, directory_path: Path):
solver_version = get_solver_version(directory_path)
antares_version = self.antares_version or solver_version
antares_version = self.antares_version if self.antares_version != DEFAULT_VERSION else solver_version
if not antares_version:
self._display.show_message(
"... not a valid Antares study",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
antares-study-version~=1.0.7
bcrypt~=3.2.2
cffi~=1.15.1
cryptography~=39.0.1
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from antareslauncher.display.display_terminal import DisplayTerminal
from antareslauncher.use_cases.create_list.study_list_composer import StudyListComposer, StudyListComposerParameters
from tests.unit.assets import ASSETS_DIR
from antares.study.version import SolverMinorVersion


@pytest.fixture(name="studies_in_dir")
Expand Down Expand Up @@ -44,15 +45,16 @@ def study_list_composer_fixture(
xpansion_mode="",
output_dir=str(tmp_path.joinpath("FINISHED")),
post_processing=False,
antares_versions_on_remote_server=[
antares_versions_on_remote_server=[SolverMinorVersion.parse(v) for v in [
"800",
"810",
"820",
"830",
"840",
"850",
],
]],
other_options="",

),
)
return composer
3 changes: 0 additions & 3 deletions tests/unit/test_data_repo_tinydb.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import random
from pathlib import Path
from unittest import mock
from uuid import uuid4

import pytest
import tinydb

from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb
from antareslauncher.study_dto import StudyDTO
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_remote_environment_with_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from antareslauncher.remote_environnement.slurm_script_features import ScriptParametersDTO, SlurmScriptFeatures
from antareslauncher.study_dto import Modes, StudyDTO
from antares.study.version import StudyVersion


class TestRemoteEnvironmentWithSlurm:
Expand Down Expand Up @@ -50,7 +51,7 @@ def study(self) -> StudyDTO:
path="path/to/study/91f1f911-4f4a-426f-b127-d0c2a2465b5f",
n_cpu=42,
zipfile_path="path/to/study/91f1f911-4f4a-426f-b127-d0c2a2465b5f-foo.zip",
antares_version=700,
antares_version=StudyVersion.parse(700),
local_final_zipfile_path="local_final_zipfile_path",
run_mode=Modes.antares,
)
Expand Down Expand Up @@ -689,7 +690,7 @@ def test_compose_launch_command(
f" --cpus-per-task={study.n_cpu}"
f" {filename_launch_script}"
f" {Path(study.zipfile_path).name}"
f" {study.antares_version}"
f" {study.antares_version:2d}"
f" {job_type}"
f" {post_processing}"
f" ''"
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_study_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from antares.study.version import StudyVersion

from antareslauncher.study_dto import StudyDTO


def test_study_dto_from_dict_old_version_syntax():

study_dict = {
"path": "/path/to/study",
"antares_version": 880
}

study_dto = StudyDTO.from_dict(study_dict)
assert study_dto.antares_version == StudyVersion.parse("8.8")


def test_study_dto_from_dict():
study_dict = {
"path": "/path/to/study",
"antares_version": "9.0"
}
study_dto = StudyDTO.from_dict(study_dict)
assert study_dto.antares_version == StudyVersion.parse("9.0")
12 changes: 7 additions & 5 deletions tests/unit/test_study_list_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from antareslauncher.use_cases.create_list.study_list_composer import StudyListComposer, get_solver_version
from antares.study.version import SolverMinorVersion

CONFIG_NOMINAL_VERSION = """\
[antares]
Expand Down Expand Up @@ -93,17 +94,18 @@ def test_update_study_database__antares_version(
study_list_composer: StudyListComposer,
antares_version: int,
):
study_list_composer.antares_version = antares_version
parsed_version = SolverMinorVersion.parse(antares_version)
study_list_composer.antares_version = parsed_version
study_list_composer.update_study_database()
studies = study_list_composer.get_list_of_studies()

# check the versions
actual_versions = {s.name: s.antares_version for s in studies}
if antares_version == 0:
expected_versions = {
"013 TS Generation - Solar power": 850, # solver_version
"024 Hurdle costs - 1": 840, # versions
"SMTA-case": 810, # version
"013 TS Generation - Solar power": "8.5", # solver_version
"024 Hurdle costs - 1": "8.4", # versions
"SMTA-case": "8.1", # version
}
elif antares_version in study_list_composer.ANTARES_VERSIONS_ON_REMOTE_SERVER:
study_names = {
Expand All @@ -114,7 +116,7 @@ def test_update_study_database__antares_version(
"MISSING Study version",
"SMTA-case",
}
expected_versions = dict.fromkeys(study_names, antares_version)
expected_versions = dict.fromkeys(study_names, parsed_version)
else:
expected_versions = {}
assert actual_versions == {n: expected_versions[n] for n in actual_versions}

0 comments on commit d613377

Please sign in to comment.