From 38bd1003419a8e4bbe38fb2499726301aca1facf Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Fri, 28 Feb 2025 09:02:54 -0800 Subject: [PATCH] reap gen_multiple and replace with gen_for_multiple_with_multiple (#3436) Summary: This diff removes _gen_multiple and replaces it with calls to gen_for_multiple_with_multiple. We plan to replace gen() with gen_for_multiple_with_multiple() with the Ax1.0 release so will keep both around for now. Internal - Tldr: {F1973912988} See diff 1/n in the stack for context Reviewed By: saitcakmak Differential Revision: D67319697 --- ax/generation_strategy/generation_strategy.py | 304 ++++++------------ .../tests/test_generation_strategy.py | 68 ---- ax/service/scheduler.py | 6 +- ax/service/tests/scheduler_test_utils.py | 16 +- 4 files changed, 118 insertions(+), 276 deletions(-) diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index 641f340cdc5..d7a6e5d0d47 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -356,134 +356,6 @@ def gen( ) return gr[0] - def _gen_with_multiple_nodes( - self, - experiment: Experiment, - data: Data | None = None, - pending_observations: dict[str, list[ObservationFeatures]] | None = None, - n: int | None = None, - fixed_features: ObservationFeatures | None = None, - arms_per_node: dict[str, int] | None = None, - ) -> list[GeneratorRun]: - """Produces a List of GeneratorRuns for a single trial, either ``Trial`` or - ``BatchTrial``, and if producing a ``BatchTrial``, allows for multiple - ``GenerationNode``-s (and therefore models) to be used to generate - ``GeneratorRun``-s for that trial. - - - Args: - experiment: Experiment, for which the generation strategy is producing - a new generator run in the course of `gen`, and to which that - generator run will be added as trial(s). Information stored on the - experiment (e.g., trial statuses) is used to determine which model - will be used to produce the generator run returned from this method. - data: Optional data to be passed to the underlying model's `gen`, which - is called within this method and actually produces the resulting - generator run. By default, data is all data on the `experiment`. - pending_observations: A map from metric name to pending - observations for that metric, used by some models to avoid - resuggesting points that are currently being evaluated. - n: Integer representing how many arms should be in the generator run - produced by this method. NOTE: Some underlying models may ignore - the `n` and produce a model-determined number of arms. In that - case this method will also output a generator run with number of - arms that can differ from `n`. - fixed_features: An optional set of ``ObservationFeatures`` that will be - passed down to the underlying models. Note: if provided this will - override any algorithmically determined fixed features so it is - important to specify all necessary fixed features. - arms_per_node: An optional map from node name to the number of arms to - generate from that node. If not provided, will default to the number - of arms specified in the node's ``InputConstructors`` or n if no - ``InputConstructors`` are defined on the node. We expect either n or - arms_per_node to be provided, but not both, and this is an advanced - argument that should only be used by advanced users. - - Returns: - A list of ``GeneratorRuns`` for a single trial. - """ - grs = [] - continue_gen_for_trial = True - pending_observations = deepcopy(pending_observations) or {} - self.experiment = experiment - self._validate_arms_per_node(arms_per_node=arms_per_node) - pack_gs_gen_kwargs = self._initalize_gen_kwargs( - experiment=experiment, - grs_this_gen=grs, - data=data, - n=n, - fixed_features=fixed_features, - arms_per_node=arms_per_node, - pending_observations=pending_observations, - ) - if self.optimization_complete: - raise GenerationStrategyCompleted( - f"Generation strategy {self} generated all the trials as " - "specified in its nodes." - ) - - while continue_gen_for_trial: - pack_gs_gen_kwargs["grs_this_gen"] = grs - should_transition, node_to_gen_from_name = ( - self._curr.should_transition_to_next_node( - raise_data_required_error=False - ) - ) - node_to_gen_from = self.nodes_dict[node_to_gen_from_name] - if should_transition: - node_to_gen_from._previous_node_name = node_to_gen_from_name - # reset should skip as conditions may have changed, do not reset - # until now so node properites can be as up to date as possible - node_to_gen_from._should_skip = False - arms_from_node = self._determine_arms_from_node( - node_to_gen_from=node_to_gen_from, - n=n, - gen_kwargs=pack_gs_gen_kwargs, - ) - fixed_features_from_node = self._determine_fixed_features_from_node( - node_to_gen_from=node_to_gen_from, - gen_kwargs=pack_gs_gen_kwargs, - ) - sq_ft_from_node = self._determine_sq_features_from_node( - node_to_gen_from=node_to_gen_from, gen_kwargs=pack_gs_gen_kwargs - ) - self._maybe_transition_to_next_node() - if node_to_gen_from._should_skip: - continue - self._fit_current_model(data=data, status_quo_features=sq_ft_from_node) - self._curr.generator_run_limit(raise_generation_errors=True) - if arms_from_node != 0: - try: - curr_node_gr = self._curr.gen( - n=arms_from_node, - pending_observations=pending_observations, - arms_by_signature_for_deduplication=( - experiment.arms_by_signature_for_deduplication - ), - fixed_features=fixed_features_from_node, - ) - except DataRequiredError as err: - # Model needs more data, so we log the error and return - # as many generator runs as we were able to produce, unless - # no trials were produced at all (in which case its safe to raise). - if len(grs) == 0: - raise - logger.debug(f"Model required more data: {err}.") - break - self._generator_runs.append(curr_node_gr) - grs.append(curr_node_gr) - # ensure that the points generated from each node are marked as pending - # points for future calls to gen - pending_observations = extend_pending_observations( - experiment=experiment, - pending_observations=pending_observations, - # only pass in the most recent generator run to avoid unnecessary - # deduplication in extend_pending_observations - generator_runs=[grs[-1]], - ) - continue_gen_for_trial = self._should_continue_gen_for_trial() - return grs - def gen_for_multiple_trials_with_multiple_models( self, experiment: Experiment, @@ -750,7 +622,7 @@ def _step_repr(self, step_str_rep: str) -> str: for step in self._nodes: num_trials = remaining_trials for criterion in step.transition_criteria: - # backwards compatility of num_trials with MinTrials criterion + # backwards compatibility of num_trials with MinTrials criterion if ( criterion.criterion_class == "MinTrials" and isinstance(criterion, TrialBasedCriterion) @@ -819,27 +691,20 @@ def __repr__(self) -> str: return gs_str # ------------------------- Candidate generation helpers. ------------------------- - - def _gen_multiple( + def _gen_with_multiple_nodes( self, experiment: Experiment, - num_generator_runs: int, data: Data | None = None, - n: int = 1, pending_observations: dict[str, list[ObservationFeatures]] | None = None, - status_quo_features: ObservationFeatures | None = None, - **model_gen_kwargs: Any, + n: int | None = None, + fixed_features: ObservationFeatures | None = None, + arms_per_node: dict[str, int] | None = None, ) -> list[GeneratorRun]: - """Produce multiple generator runs at once, to be made into multiple - trials on the experiment. + """Produces a List of GeneratorRuns for a single trial, either ``Trial`` or + ``BatchTrial``, and if producing a ``BatchTrial``, allows for multiple + ``GenerationNode``-s (and therefore models) to be used to generate + ``GeneratorRun``-s for that trial. - NOTE: This is used to ensure that maximum parallelism and number - of trials per node are not violated when producing many generator - runs from this generation strategy in a row. Without this function, - if one generates multiple generator runs without first making any - of them into running trials, generation strategy cannot enforce that it only - produces as many generator runs as are allowed by the parallelism - limit and the limit on number of trials in current node. Args: experiment: Experiment, for which the generation strategy is producing @@ -850,64 +715,109 @@ def _gen_multiple( data: Optional data to be passed to the underlying model's `gen`, which is called within this method and actually produces the resulting generator run. By default, data is all data on the `experiment`. - n: Integer representing how many arms should be in the generator run - produced by this method. NOTE: Some underlying models may ignore - the ``n`` and produce a model-determined number of arms. In that - case this method will also output a generator run with number of - arms that can differ from ``n``. pending_observations: A map from metric name to pending observations for that metric, used by some models to avoid resuggesting points that are currently being evaluated. - model_gen_kwargs: Keyword arguments that are passed through to - ``GenerationNode.gen``, which will pass them through to - ``GeneratorSpec.gen``, which will pass them to ``Adapter.gen``. - status_quo_features: An ``ObservationFeature`` of the status quo arm, - needed by some models during fit to accomadate relative constraints. - Includes the status quo parameterization and target trial index. + n: Integer representing how many arms should be in the generator run + produced by this method. NOTE: Some underlying models may ignore + the `n` and produce a model-determined number of arms. In that + case this method will also output a generator run with number of + arms that can differ from `n`. + fixed_features: An optional set of ``ObservationFeatures`` that will be + passed down to the underlying models. Note: if provided this will + override any algorithmically determined fixed features so it is + important to specify all necessary fixed features. + arms_per_node: An optional map from node name to the number of arms to + generate from that node. If not provided, will default to the number + of arms specified in the node's ``InputConstructors`` or n if no + ``InputConstructors`` are defined on the node. We expect either n or + arms_per_node to be provided, but not both, and this is an advanced + argument that should only be used by advanced users. + + Returns: + A list of ``GeneratorRuns`` for a single trial. """ self.experiment = experiment - self._maybe_transition_to_next_node() - self._fit_current_model(data=data, status_quo_features=status_quo_features) - # Get GeneratorRun limit that respects the node's transition criterion that - # affect the number of generator runs that can be produced. - gr_limit = self._curr.generator_run_limit(raise_generation_errors=True) - if gr_limit == -1: - num_generator_runs = max(num_generator_runs, 1) - else: - num_generator_runs = max(min(num_generator_runs, gr_limit), 1) - generator_runs = [] + if self.optimization_complete: + raise GenerationStrategyCompleted( + f"Generation strategy {self} generated all the trials as " + "specified in its nodes." + ) + grs = [] + continue_gen_for_trial = True pending_observations = deepcopy(pending_observations) or {} - for _ in range(num_generator_runs): - try: - generator_run = self._curr.gen( - n=n, - pending_observations=pending_observations, - arms_by_signature_for_deduplication=( - experiment.arms_by_signature_for_deduplication - ), - **model_gen_kwargs, - ) + self._validate_arms_per_node(arms_per_node=arms_per_node) + pack_gs_gen_kwargs = self._initialize_gen_kwargs( + experiment=experiment, + grs_this_gen=grs, + data=data, + n=n, + fixed_features=fixed_features, + arms_per_node=arms_per_node, + pending_observations=pending_observations, + ) - except DataRequiredError as err: - # Model needs more data, so we log the error and return - # as many generator runs as we were able to produce, unless - # no trials were produced at all (in which case its safe to raise). - if len(generator_runs) == 0: - raise - logger.debug(f"Model required more data: {err}.") - break - - self._generator_runs.append(generator_run) - generator_runs.append(generator_run) - - # Extend the `pending_observation` with newly generated point(s) - # to avoid repeating them. - pending_observations = extend_pending_observations( - experiment=experiment, - pending_observations=pending_observations, - generator_runs=[generator_run], + while continue_gen_for_trial: + pack_gs_gen_kwargs["grs_this_gen"] = grs + should_transition, node_to_gen_from_name = ( + self._curr.should_transition_to_next_node( + raise_data_required_error=False + ) + ) + node_to_gen_from = self.nodes_dict[node_to_gen_from_name] + if should_transition: + node_to_gen_from._previous_node_name = node_to_gen_from_name + # reset should skip as conditions may have changed, do not reset + # until now so node properties can be as up to date as possible + node_to_gen_from._should_skip = False + arms_from_node = self._determine_arms_from_node( + node_to_gen_from=node_to_gen_from, + n=n, + gen_kwargs=pack_gs_gen_kwargs, + ) + fixed_features_from_node = self._determine_fixed_features_from_node( + node_to_gen_from=node_to_gen_from, + gen_kwargs=pack_gs_gen_kwargs, ) - return generator_runs + sq_ft_from_node = self._determine_sq_features_from_node( + node_to_gen_from=node_to_gen_from, gen_kwargs=pack_gs_gen_kwargs + ) + self._maybe_transition_to_next_node() + if node_to_gen_from._should_skip: + continue + self._fit_current_model(data=data, status_quo_features=sq_ft_from_node) + self._curr.generator_run_limit(raise_generation_errors=True) + if arms_from_node != 0: + try: + curr_node_gr = self._curr.gen( + n=arms_from_node, + pending_observations=pending_observations, + arms_by_signature_for_deduplication=( + experiment.arms_by_signature_for_deduplication + ), + fixed_features=fixed_features_from_node, + ) + except DataRequiredError as err: + # Model needs more data, so we log the error and return + # as many generator runs as we were able to produce, unless + # no trials were produced at all (in which case its safe to raise). + if len(grs) == 0: + raise + logger.debug(f"Model required more data: {err}.") + break + self._generator_runs.append(curr_node_gr) + grs.append(curr_node_gr) + # ensure that the points generated from each node are marked as pending + # points for future calls to gen + pending_observations = extend_pending_observations( + experiment=experiment, + pending_observations=pending_observations, + # only pass in the most recent generator run to avoid unnecessary + # deduplication in extend_pending_observations + generator_runs=[grs[-1]], + ) + continue_gen_for_trial = self._should_continue_gen_for_trial() + return grs def _should_continue_gen_for_trial(self) -> bool: """Determine if we should continue generating for the current trial, or end @@ -934,7 +844,7 @@ def _should_continue_gen_for_trial(self) -> bool: for tc in self._curr.transition_edges[next_node] ) - def _initalize_gen_kwargs( + def _initialize_gen_kwargs( self, experiment: Experiment, grs_this_gen: list[GeneratorRun], @@ -1059,7 +969,7 @@ def _determine_arms_from_node( gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s gen call, including arms_per_node: an optional map from node name to the number of arms to generate from that node. If not provided, will - default to the numberof arms specified in the node's + default to the number of arms specified in the node's ``InputConstructors`` or n if no``InputConstructors`` are defined on the node. @@ -1103,7 +1013,7 @@ def _fit_current_model( data: Optional ``Data`` to fit or update with; if not specified, generation strategy will obtain the data via ``experiment.lookup_data``. status_quo_features: An ``ObservationFeature`` of the status quo arm, - needed by some models during fit to accomadate relative constraints. + needed by some models during fit to accommodate relative constraints. Includes the status quo parameterization and target trial index. """ data = self.experiment.lookup_data() if data is None else data diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index cdacfbd56c8..d6630a1b391 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -959,74 +959,6 @@ def test_hierarchical_search_space(self) -> None: ) ) - def test_gen_multiple(self) -> None: - exp = get_experiment_with_multi_objective() - sobol_MBM_gs = self.sobol_MBM_step_GS - - with mock_patch_method_original( - mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen", - original_method=GeneratorSpec.gen, - ) as model_spec_gen_mock, mock_patch_method_original( - mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.fit", - original_method=GeneratorSpec.fit, - ) as model_spec_fit_mock: - # Generate first four Sobol GRs (one more to gen after that if - # first four become trials. - grs = sobol_MBM_gs._gen_multiple(experiment=exp, num_generator_runs=3) - self.assertEqual(len(grs), 3) - # We should only fit once for each model - # refitting for each `gen` would be wasteful as there is no new data. - self.assertEqual(model_spec_fit_mock.call_count, 1) - self.assertEqual(model_spec_gen_mock.call_count, 3) - pending_in_each_gen = enumerate( - args_and_kwargs.kwargs.get("pending_observations") - for args_and_kwargs in model_spec_gen_mock.call_args_list - ) - for gr, (idx, pending) in zip(grs, pending_in_each_gen): - exp.new_trial(generator_run=gr).mark_running(no_runner_required=True) - if idx > 0: - prev_gr = grs[idx - 1] - for arm in prev_gr.arms: - for m in pending: - self.assertIn(ObservationFeatures.from_arm(arm), pending[m]) - model_spec_gen_mock.reset_mock() - - # Check case with pending features initially specified; we should get two - # GRs now (remaining in Sobol step) even though we requested 3. - original_pending = none_throws(get_pending(experiment=exp)) - first_3_trials_obs_feats = [ - ObservationFeatures.from_arm(arm=a, trial_index=idx) - for idx, trial in exp.trials.items() - for a in trial.arms - ] - for m in original_pending: - self.assertTrue( - same_elements(original_pending[m], first_3_trials_obs_feats) - ) - - grs = sobol_MBM_gs._gen_multiple( - experiment=exp, - num_generator_runs=3, - pending_observations=get_pending(experiment=exp), - ) - self.assertEqual(len(grs), 2) - - pending_in_each_gen = enumerate( - args_and_kwargs[1].get("pending_observations") - for args_and_kwargs in model_spec_gen_mock.call_args_list - ) - for gr, (idx, pending) in zip(grs, pending_in_each_gen): - exp.new_trial(generator_run=gr).mark_running(no_runner_required=True) - if idx > 0: - prev_gr = grs[idx - 1] - for arm in prev_gr.arms: - for m in pending: - # In this case, we should see both the originally-pending - # and the new arms as pending observation features. - self.assertIn(ObservationFeatures.from_arm(arm), pending[m]) - for p in original_pending[m]: - self.assertIn(p, pending[m]) - def test_gen_for_multiple_uses_total_concurrent_arms_for_a_default( self, ) -> None: diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index c24d83609fd..f74b6dd7094 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -1779,16 +1779,16 @@ def _gen_new_trials_from_generation_strategy( pending = get_pending_observation_features_based_on_trial_status( experiment=self.experiment ) - grs = self.generation_strategy._gen_multiple( + grs = self.generation_strategy.gen_for_multiple_trials_with_multiple_models( experiment=self.experiment, - num_generator_runs=num_trials, + num_trials=num_trials, n=1, pending_observations=pending, fixed_features=get_fixed_features_from_experiment( experiment=self.experiment ), ) - return [[gr] for gr in grs] + return grs # TODO: pass self.trial_type to GS.gen for multi-type experiments def _update_and_save_trials( diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index a51a5a2f5da..6e76820b92e 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -580,9 +580,9 @@ def test_run_multi_arm_generator_run_error(self) -> None: ) with patch.object( type(branin_gs), - "_gen_multiple", - return_value=[get_generator_run()], - ) as patch_gen_multiple: + "gen_for_multiple_trials_with_multiple_models", + return_value=[[get_generator_run()]], + ) as patch_gen_for_multiple_trials_with_multiple_models: scheduler = Scheduler( experiment=self.branin_experiment, generation_strategy=branin_gs, @@ -596,7 +596,7 @@ def test_run_multi_arm_generator_run_error(self) -> None: SchedulerInternalError, ".* only one was expected" ): scheduler.run_all_trials() - patch_gen_multiple.assert_called_once() + patch_gen_for_multiple_trials_with_multiple_models.assert_called_once() def test_run_all_trials_using_runner_and_metrics(self) -> None: branin_gs = self._get_generation_strategy_strategy_for_test( @@ -1352,7 +1352,7 @@ def test_unknown_generation_errors_eventually_exit(self) -> None: scheduler.run_n_trials(max_trials=1) with patch.object( GenerationStrategy, - "_gen_multiple", + "_gen_with_multiple_nodes", side_effect=AxGenerationException("model error"), ): with self.assertRaises(SchedulerInternalError): @@ -1653,12 +1653,12 @@ def test_optimization_complete(self) -> None: ) with patch.object( GenerationStrategy, - "_gen_multiple", + "gen_for_multiple_trials_with_multiple_models", side_effect=OptimizationComplete("test error"), - ) as mock_gen_multiple: + ) as mock_gen_with_multiple_nodes: scheduler.run_n_trials(max_trials=1) # no trials should run if _gen_multiple throws an OptimizationComplete error - mock_gen_multiple.assert_called_once() + mock_gen_with_multiple_nodes.assert_called_once() self.assertEqual(len(scheduler.experiment.trials), 0) @patch(