diff --git a/optimas/diagnostics/ax_model_manager.py b/optimas/diagnostics/ax_model_manager.py index a0608216..32a77fec 100644 --- a/optimas/diagnostics/ax_model_manager.py +++ b/optimas/diagnostics/ax_model_manager.py @@ -127,7 +127,14 @@ def _build_ax_client_from_dataframe( # should work. ax_client = AxClient( generation_strategy=GenerationStrategy( - [GenerationStep(model=Models.GPEI if len(objectives) == 1 else Models.MOO, num_trials=-1)] + [ + GenerationStep( + model=( + Models.GPEI if len(objectives) == 1 else Models.MOO + ), + num_trials=-1, + ) + ] ), verbose_logging=False, ) @@ -248,8 +255,8 @@ def get_best_point( best_point = best_arm.parameters index = self.get_arm_index(best_arm.name) else: - # AxClient.get_best_parameters seems to always return the best point - # from the observed values, independently of the value of `use_model_predictions`. + # AxClient.get_best_parameters seems to always return the best point + # from the observed values, independently of the value of `use_model_predictions`. index, best_point, _ = self.ax_client.get_best_trial( use_model_predictions=use_model_predictions )