Skip to content

Commit

Permalink
Update pending points to be modified in place for efficiency (faceboo…
Browse files Browse the repository at this point in the history
…k#3437)

Summary:

NOTE: this must be landed at the same time as 4/n 

Previously (xxx) we updated the extend pending point method to not modify in place, however, that creates a pretty significant slow down which is unacceptable to waveguide. In order to merge the gen methods, and have all paths go through the new _gen_for_multi_with_multi method (including waveguide) we need to have this method modify in place again.

We include a note in the docstring for this. In follow up diffs we will also:
1. explore making pending points a set instead of a list of observation features
2. remove all external setting of pending points in calls to gen - now we set it inside the gen methods so it's redundant to do so in multiple places
3. try to unify the various pending point utils

Reviewed By: saitcakmak

Differential Revision: D68790218
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Feb 28, 2025
1 parent 38bd100 commit bded2a8
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 47 deletions.
8 changes: 4 additions & 4 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,10 @@ def _test_replication_async(self, map_data: bool) -> None:
"Complete out of order": [0, 0, 1, 2],
}
expected_pending_in_each_gen = {
"All complete at different times": [[], [0], [1], [2]],
"Trials complete immediately": [[], [0], [], [2]],
"Trials complete at same time": [[], [0], [], [2]],
"Complete out of order": [[], [0], [0], [2]],
"All complete at different times": [[None], [0], [1], [2]],
"Trials complete immediately": [[None], [0], [None], [2]],
"Trials complete at same time": [[None], [0], [None], [2]],
"Complete out of order": [[None], [0], [0], [2]],
}
# When two trials complete at the same time, the inference trace uses
# data from both to get the best point, and repeats it.
Expand Down
29 changes: 14 additions & 15 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,32 +397,31 @@ def get_pending_observation_features_based_on_trial_status(
def extend_pending_observations(
experiment: Experiment,
pending_observations: dict[str, list[ObservationFeatures]],
generator_runs: list[GeneratorRun],
) -> dict[str, list[ObservationFeatures]]:
generator_run: GeneratorRun,
) -> None:
"""Extend given pending observations dict (from metric name to observations
that are pending for that metric), with arms in a given generator run.
Note: This function performs this operation in-place for performance reasons.
It is only used within the ``GenerationStrategy`` class, and is not intended
for wide re-use. Please use caution when re-using this function.
Args:
experiment: Experiment, for which the generation strategy is producing
``GeneratorRun``s.
pending_observations: Dict from metric name to pending observations for
that metric, used to avoid resuggesting arms that will be explored soon.
generator_runs: List of ``GeneratorRun``s currently produced by the
``GenerationStrategy``.
generator_run: ``GeneratorRun`` currently produced by the
``GenerationStrategy`` to add to the pending points.
Returns:
A new dictionary of pending observations to avoid in-place modification
"""
pending_observations = deepcopy(pending_observations)
extended_observations: dict[str, list[ObservationFeatures]] = {}
for m in experiment.metrics:
extended_obs_set = set(pending_observations.get(m, []))
for generator_run in generator_runs:
for a in generator_run.arms:
ob_ft = ObservationFeatures.from_arm(a)
extended_obs_set.add(ob_ft)
extended_observations[m] = list(extended_obs_set)
return extended_observations
if m not in pending_observations:
pending_observations[m] = []
pending_observations[m].extend(
ObservationFeatures.from_arm(a) for a in generator_run.arms
)
return


# -------------------- Get target trial utils. ---------------------
Expand Down
21 changes: 7 additions & 14 deletions ax/generation_strategy/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,14 +427,6 @@ def gen_for_multiple_trials_with_multiple_models(
fixed_features=fixed_features,
)
)

extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
# pass in the most recently generated grs each time to avoid
# duplication
generator_runs=trial_grs[-1],
)
return trial_grs

def current_generator_run_limit(
Expand Down Expand Up @@ -745,7 +737,10 @@ def _gen_with_multiple_nodes(
)
grs = []
continue_gen_for_trial = True
pending_observations = deepcopy(pending_observations) or {}
pending_observations = (
pending_observations if pending_observations is not None else {}
)
self.experiment = experiment
self._validate_arms_per_node(arms_per_node=arms_per_node)
pack_gs_gen_kwargs = self._initialize_gen_kwargs(
experiment=experiment,
Expand Down Expand Up @@ -808,13 +803,11 @@ def _gen_with_multiple_nodes(
self._generator_runs.append(curr_node_gr)
grs.append(curr_node_gr)
# ensure that the points generated from each node are marked as pending
# points for future calls to gen
pending_observations = extend_pending_observations(
# points for future calls to gen, or further generation for this trial
extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
# only pass in the most recent generator run to avoid unnecessary
# deduplication in extend_pending_observations
generator_runs=[grs[-1]],
generator_run=curr_node_gr,
)
continue_gen_for_trial = self._should_continue_gen_for_trial()
return grs
Expand Down
105 changes: 91 additions & 14 deletions ax/generation_strategy/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.core.trial_status import TrialStatus
from ax.core.utils import (
extend_pending_observations,
extract_pending_observations,
get_pending_observation_features_based_on_trial_status as get_pending,
)
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
Expand Down Expand Up @@ -959,6 +961,82 @@ def test_hierarchical_search_space(self) -> None:
)
)

def test_gen_for_multiple_trials_with_multiple_models_bw_comp(self) -> None:
# This test initially tested _gen_multiple, however, this has
# been replaced with gen_for_multiple_trials_with_multiple_models
# ensure the original gen_multiple behavior is preserved.
exp = get_experiment_with_multi_objective()
sobol_MBM_gs = self.sobol_MBM_step_GS

with mock_patch_method_original(
mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen",
original_method=GeneratorSpec.gen,
) as gen_spec_gen_mock, mock_patch_method_original(
mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.fit",
original_method=GeneratorSpec.fit,
) as gen_spec_fit_mock:
# Generate first four Sobol GRs (one more to gen after that if
# first four become trials.
grs = sobol_MBM_gs.gen_for_multiple_trials_with_multiple_models(
experiment=exp, num_trials=3
)
self.assertEqual(len(grs), 3)
# We should only fit once; refitting for each `gen` would be
# wasteful as there is no new data.
# TODO[@drfreund]: This is currently regressed to fit three times, which is
# expected as of changeset 5/n; fixed in 8/n. Will change expected value to
# 1 in 8/n.
self.assertEqual(gen_spec_fit_mock.call_count, 3)
self.assertEqual(gen_spec_gen_mock.call_count, 3)
pending_in_each_gen = enumerate(
args_and_kwargs.kwargs.get("pending_observations")
for args_and_kwargs in gen_spec_gen_mock.call_args_list
)
for gr, (idx, pending) in zip(grs, pending_in_each_gen):
exp.new_trial(generator_run=gr[0]).mark_running(no_runner_required=True)
if idx > 0:
prev_gr = grs[idx - 1][0]
for arm in prev_gr.arms:
for m in pending:
self.assertIn(ObservationFeatures.from_arm(arm), pending[m])
gen_spec_gen_mock.reset_mock()

# Check case with pending features initially specified; we should get two
# GRs now (remaining in Sobol step) even though we requested 3.
original_pending = none_throws(get_pending(experiment=exp))
first_3_trials_obs_feats = [
ObservationFeatures.from_arm(arm=a, trial_index=idx)
for idx, trial in exp.trials.items()
for a in trial.arms
]
for m in original_pending:
self.assertTrue(
same_elements(original_pending[m], first_3_trials_obs_feats)
)

grs = sobol_MBM_gs.gen_for_multiple_trials_with_multiple_models(
experiment=exp,
num_trials=3,
pending_observations=get_pending(experiment=exp),
)
self.assertEqual(len(grs), 2)

pending_in_each_gen = enumerate(
args_and_kwargs[1].get("pending_observations")
for args_and_kwargs in gen_spec_gen_mock.call_args_list
)
for gr, (idx, pending) in zip(grs, pending_in_each_gen):
exp.new_trial(generator_run=gr[0]).mark_running(no_runner_required=True)
if idx > 0:
prev_gr = grs[idx - 1][0]
for arm in prev_gr.arms:
for m in pending:
# In this case, we should see both the originally-pending
# and the new arms as pending observation features.
self.assertIn(ObservationFeatures.from_arm(arm), pending[m])
for p in original_pending[m]:
self.assertIn(p, pending[m])

def test_gen_for_multiple_uses_total_concurrent_arms_for_a_default(
self,
) -> None:
Expand Down Expand Up @@ -999,6 +1077,9 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None:
args_and_kwargs.kwargs.get("pending_observations")
for args_and_kwargs in model_spec_gen_mock.call_args_list
)
# pending points is updated in plac so we can't check each intermediate
# call state, however we can confirm that all arms in the grs produced by
# _gen_with_multiple_nodes are present in the pending points
for gr, (idx, pending) in zip(grs, pending_in_each_gen):
exp.new_trial(generator_run=gr[0]).mark_running(no_runner_required=True)
if idx > 0:
Expand Down Expand Up @@ -1341,30 +1422,26 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None:
with mock_patch_method_original(
mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen",
original_method=GeneratorSpec.gen,
) as model_spec_gen_mock:
) as gen_spec_gen_mock:
# Generate a trial that should be composed of arms from 3 nodes
grs = gs._gen_with_multiple_nodes(
experiment=exp, arms_per_node=arms_per_node
)

self.assertEqual(len(grs), 3) # len == 3 due to 3 nodes contributing
self.assertEqual(gen_spec_gen_mock.call_count, 3)
pending_in_each_gen = enumerate(
call_kwargs.get("pending_observations")
for _, call_kwargs in model_spec_gen_mock.call_args_list
for _, call_kwargs in gen_spec_gen_mock.call_args_list
)

# for each call to gen after the first call to gen, which should have no
# pending points the number of pending points should be equal to the sum of
# the number of arms we suspect from the previous nodes
expected_pending_per_call = [2, 3]
# pending points is updated in plac so we can't check each intermediate
# call state, however we can confirm that all arms in the grs produced by
# _gen_with_multiple_nodes are present in the pending points
for idx, pending in pending_in_each_gen:
# the first pending call will be empty because we didn't pass in any
# additional points, start checking after the first position
# that the pending points we expect are present
if idx > 0:
self.assertEqual(
len(pending["m2"]), expected_pending_per_call[idx - 1]
)
prev_gr = grs[idx - 1]
for arm in prev_gr.arms:
for m in pending:
Expand All @@ -1373,7 +1450,7 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None:
exp.new_batch_trial(generator_runs=grs).mark_running(
no_runner_required=True
)
model_spec_gen_mock.reset_mock()
gen_spec_gen_mock.reset_mock()

# check that the pending points line up
original_pending = none_throws(get_pending(experiment=exp))
Expand All @@ -1396,10 +1473,10 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None:
self.assertEqual(len(grs), 3) # len == 3 due to 3 nodes contributing
pending_in_each_gen = enumerate(
call_kwargs.get("pending_observations")
for _, call_kwargs in model_spec_gen_mock.call_args_list
for _, call_kwargs in gen_spec_gen_mock.call_args_list
)
# check first call is 6 (from the previous trial having 6 arms)
self.assertEqual(len(list(pending_in_each_gen)[0][1]["m1"]), 6)
# check pending points is now 12 (from the previous trial having 6 arms)
self.assertEqual(len(list(pending_in_each_gen)[0][1]["m1"]), 12)

def test_gs_initializes_default_props_correctly(self) -> None:
"""Test that all previous nodes are initialized to None"""
Expand Down

0 comments on commit bded2a8

Please sign in to comment.