diff --git a/optimas/generators/ax/service/base.py b/optimas/generators/ax/service/base.py index 3d2bf42b..13fb47bd 100644 --- a/optimas/generators/ax/service/base.py +++ b/optimas/generators/ax/service/base.py @@ -23,7 +23,7 @@ Trial, VaryingParameter, Parameter, - TrialStatus, + TrialParameter, ) from optimas.generators.ax.base import AxGenerator from optimas.utils.ax import AxModelManager @@ -109,6 +109,9 @@ def __init__( model_save_period: Optional[int] = 5, model_history_dir: Optional[str] = "model_history", ) -> None: + custom_trial_parameters = [ + TrialParameter("ax_trial_id", dtype=int), + ] super().__init__( varying_parameters=varying_parameters, objectives=objectives, @@ -119,6 +122,7 @@ def __init__( save_model=save_model, model_save_period=model_save_period, model_history_dir=model_history_dir, + custom_trial_parameters=custom_trial_parameters, allow_fixed_parameters=True, allow_updating_parameters=True, )