Skip to content

Commit

Permalink
more doc
Browse files Browse the repository at this point in the history
Signed-off-by: Melvin Strobl <[email protected]>
  • Loading branch information
Melvin Strobl committed Dec 2, 2024
1 parent 0a780f4 commit 9a250e3
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 34 deletions.
43 changes: 24 additions & 19 deletions src/fromhopetoheuristics/pipelines/hyperparameter_study/nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional, Callable
import subprocess
import pandas as pd
import os
Expand All @@ -10,11 +10,14 @@
log = logging.getLogger(__name__)


from typing import Dict, List


def create_hyperparam_optimizer(
n_trials: str,
n_trials: int,
timeout: int,
enabled_hyperparameters: List,
optimization_metric: List,
enabled_hyperparameters: List[str],
optimization_metric: List[str],
path: str,
sampler: str,
sampler_seed: int,
Expand All @@ -26,8 +29,8 @@ def create_hyperparam_optimizer(
resume_study: bool,
n_jobs: int,
run_id: str,
hyperparameters: Dict,
) -> Hyperparam_Optimizer:
hyperparameters: Dict[str, List[float]],
) -> Dict[str, Hyperparam_Optimizer]:

hyperparam_optimizer = Hyperparam_Optimizer(
name=run_id,
Expand All @@ -51,7 +54,10 @@ def create_hyperparam_optimizer(
hyperparam_optimizer.set_fixed_parameters({})

def objective(
trial, parameters, report_callback=None, early_stop_callback=None
trial: o.trial.Trial,
parameters: Dict[str, float],
report_callback: Optional[Callable[[Dict[str, float], int], None]] = None,
early_stop_callback: Optional[Callable[[], bool]] = None,
) -> float:
"""This function is the optimization target that is called by Optuna
for each trial. It runs the experiment with the given parameters
Expand Down Expand Up @@ -102,7 +108,7 @@ def objective(
]
)

def get_objective_for_trial(trial_id) -> float:
def get_objective_for_trial(trial_id: int) -> float:
tmp_file_name = f".hyperhyper{trial_id}.json"
results = pd.read_json(tmp_file_name)
os.remove(tmp_file_name)
Expand All @@ -125,17 +131,16 @@ def get_objective_for_trial(trial_id) -> float:
return {"hyperparam_optimizer": hyperparam_optimizer}


def run_optuna(
hyperparam_optimizer: Hyperparam_Optimizer,
):
hyperparam_optimizer.minimize(idx=0)
def run_optuna(hyperparam_optimizer: Hyperparam_Optimizer) -> None:
"""
Run the hyperparameter optimization study.
Args:
hyperparam_optimizer: The hyperparameter optimizer to run.
# try:
# hyperparam_optimizer.log_study(
# selected_parallel_params=optuna_selected_parallel_params,
# selected_slice_params=optuna_selected_slice_params,
# )
# except Exception as e:
# log.exception("Error while logging study")
Returns:
None
"""
hyperparam_optimizer.minimize(idx=0)

return {}
103 changes: 88 additions & 15 deletions src/fromhopetoheuristics/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
ExtendedTriplet,
)

from typing import Tuple
import json
from typing import Tuple, List
import logging

log = logging.getLogger(__name__)
Expand All @@ -20,12 +19,52 @@ class SplitConfig(D0Config):


class QallseSplit(QallseD0):
"""
A QallseSplit is a QallseD0 model that is split into multiple parts based on
the angle in the XY plane of the first hit of each doublet. This is done by
overriding the `_create_doublets` and `_create_triplets` methods to apply
early cuts to the doublets and triplets based on the angle part of the first
hit. The angle part is determined by the `geometric_index` parameter of
the `config` attribute.
Attributes
----------
config : SplitConfig
The configuration for the QallseSplit model.
Methods
-------
_create_doublets(initial_doublets)
Generate Doublet structures from the initial doublets, calling
`_is_invalid_doublet` to apply early cuts.
_create_triplets()
Generate Triplet structures from Doublets, calling
`_is_invalid_triplet` to apply early cuts.
_is_invalid_triplet(triplet)
Check if a Triplet is invalid based on the angle part of the first hit.
_get_base_config()
Return the base configuration for the QallseSplit model.
serialize()
Serialize the model and their associated xplets.
"""

config = SplitConfig()

def _create_doublets(self, initial_doublets):
# Generate Doublet structures from the initial doublets,
# calling _is_invalid_doublet to apply early cuts
doublets = []
def _create_doublets(self, initial_doublets: List[Tuple[int, int]]) -> None:
"""
Generate Doublet structures from the initial doublets, calling
_is_invalid_doublet to apply early cuts.
Parameters
----------
initial_doublets : List[Tuple[int, int]]
A list of tuples containing start and end hit IDs.
Returns
-------
None
"""
doublets: List[ExtendedDoublet] = []
for start_id, end_id in initial_doublets:
start, end = self.hits[start_id], self.hits[end_id]
d = ExtendedDoublet(start, end)
Expand All @@ -37,10 +76,16 @@ def _create_doublets(self, initial_doublets):
self.logger.info(f"created {len(doublets)} doublets.")
self.doublets = doublets

def _create_triplets(self):
# Generate Triplet structures from Doublets,
# calling _is_invalid_triplet to apply early cuts
triplets = []
def _create_triplets(self) -> None:
"""
Generate Triplet structures from Doublets, calling
_is_invalid_triplet to apply early cuts.
Returns
-------
None
"""
triplets: List[ExtendedTriplet] = []
for d1 in self.doublets:
for d2 in d1.h2.outer:
t = ExtendedTriplet(d1, d2)
Expand All @@ -51,7 +96,20 @@ def _create_triplets(self):
self.logger.info(f"created {len(triplets)} triplets.")
self.triplets = triplets

def _is_invalid_triplet(self, triplet: ExtendedTriplet):
def _is_invalid_triplet(self, triplet: ExtendedTriplet) -> bool:
"""
Check if a triplet is invalid with respect to the geometric index.
Parameters
----------
triplet : ExtendedTriplet
The triplet to check
Returns
-------
bool
True if invalid, False otherwise.
"""
if super()._is_invalid_triplet(triplet):
return True

Expand Down Expand Up @@ -95,9 +153,25 @@ def serialize(self) -> Tuple:
return qubo, xplet


def build_model(doublets, model, add_missing):

# prepare doublets
def build_model(
doublets: List[ExtendedDoublet], model: QallseSplit, add_missing: bool
) -> None:
"""
Prepares and builds the QUBO model from the given doublets.
Parameters
----------
doublets : List[ExtendedDoublet]
A list of extended doublets to be used in the model.
model : QallseSplit
The QallseSplit model to build the QUBO.
add_missing : bool
Flag indicating whether to add missing doublets.
Returns
-------
None
"""
if add_missing:
log.info("Cheat on, adding missing doublets.")
doublets = model.dataw.add_missing_doublets(doublets)
Expand All @@ -107,5 +181,4 @@ def build_model(doublets, model, add_missing):
f"Precision: {p * 100:.4f}%, Recall:{r * 100:.4f}%, Missing: {len(ms)}"
)

# build the qubo
model.build_model(doublets=doublets)

0 comments on commit 9a250e3

Please sign in to comment.