Skip to content

Commit

Permalink
allow TrialAsTask with less than two trials
Browse files Browse the repository at this point in the history
Summary:
The transform is a no-op in that case.

This is necessary to enable STGP vs MTGP model selection without adding additional nodes to the GS, thus creating further complication of the GS DAG.

Reviewed By: mgarrard

Differential Revision: D69496509
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 14, 2025
1 parent ae6637a commit dd71b56
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 2 deletions.
4 changes: 3 additions & 1 deletion ax/modelbridge/transforms/stratified_standardize_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def __init__(
]
if len(task_parameters) == 0:
raise ValueError(
"Must specify parameter for stratified standardization"
"Must specify parameter for stratified standardization. This can "
"happen if TrialAsTask is a no-op, due to there only being a single"
" task level."
)
elif len(task_parameters) != 1:
raise ValueError(
Expand Down
56 changes: 56 additions & 0 deletions ax/modelbridge/transforms/tests/test_trial_as_task_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,59 @@ def test_w_robust_search_space(self) -> None:
observations=[],
modelbridge=self.modelbridge,
)

def test_less_than_two_trials(self) -> None:
# test transform is a no-op with less than two trials
exp = get_branin_experiment()
exp.new_trial().add_arm(Arm(parameters={"x": 1}))
modelbridge = Adapter(
search_space=exp.search_space,
model=Generator(),
experiment=exp,
)
training_obs = self.training_obs[:1]
t = TrialAsTask(
search_space=exp.search_space,
observations=training_obs,
modelbridge=modelbridge,
)
self.assertEqual(t.trial_level_map, {})
training_feats = [training_obs[0].features]
training_feats_clone = deepcopy(training_feats)
self.assertEqual(
t.transform_observation_features(training_feats_clone), training_feats
)
self.assertEqual(
t.untransform_observation_features(training_feats), training_feats_clone
)
ss2 = exp.search_space.clone()
self.assertEqual(t.transform_search_space(ss2), exp.search_space)

def test_less_than_two_levels(self) -> None:
# test transform is a no-op with less than two trials
exp = get_branin_experiment()
exp.new_trial().add_arm(Arm(parameters={"x": 1}))
exp.new_trial().add_arm(Arm(parameters={"x": 2}))
modelbridge = Adapter(
search_space=exp.search_space,
model=Generator(),
experiment=exp,
)
training_obs = self.training_obs[:1]
t = TrialAsTask(
search_space=exp.search_space,
observations=training_obs,
modelbridge=modelbridge,
config={"trial_level_map": {"t": {0: "v1", 1: "v1"}}},
)
self.assertEqual(t.trial_level_map, {})
training_feats = [training_obs[0].features]
training_feats_clone = deepcopy(training_feats)
self.assertEqual(
t.transform_observation_features(training_feats_clone), training_feats
)
self.assertEqual(
t.untransform_observation_features(training_feats), training_feats_clone
)
ss2 = exp.search_space.clone()
self.assertEqual(t.transform_search_space(ss2), exp.search_space)
30 changes: 29 additions & 1 deletion ax/modelbridge/transforms/trial_as_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class TrialAsTask(Transform):
Will raise if trial not specified for every point in the training data.
If there are fewer than 2 trials or levels, the transform is a no-op.
Transform is done in-place.
"""

Expand Down Expand Up @@ -109,7 +111,13 @@ def __init__(
self.inverse_map = None
# Compute target values
self.target_values: dict[str, int | str] = {}
for p_name, trial_map in self.trial_level_map.items():

for p_name, trial_map in list(self.trial_level_map.items()):
if len(set(trial_map.values())) < 2:
# If there are less than two distinct levels, then we don't need to
# create a task parameter and the transform is a no-op.
del self.trial_level_map[p_name]
continue
if config is not None and "target_trial" in config:
target_trial = int(config["target_trial"]) # pyre-ignore [6]
else:
Expand All @@ -133,6 +141,9 @@ def transform_observation_features(
trials. Trial indices set to None are probably pending points passed in by the
user.
"""
if len(self.trial_level_map) == 0:
# no-op
return observation_features
for obsf in observation_features:
for p_name, level_dict in self.trial_level_map.items():
if obsf.trial_index is not None and int(obsf.trial_index) in level_dict:
Expand All @@ -147,6 +158,9 @@ def transform_observation_features(
return observation_features

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
if len(self.trial_level_map) == 0:
# no-op
return search_space
for p_name, level_dict in self.trial_level_map.items():
level_values = sorted(set(level_dict.values()))
if len(level_values) < 2:
Expand Down Expand Up @@ -174,6 +188,20 @@ def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
def untransform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
"""If task parameters have been added to observation features by
this parameter, then remove those task parameters and restore
the trial index/
Args:
observation_features: List of observation features to untransform.
Returns:
List of observation features with task parameters removed and trial
index restored.
"""
if len(self.trial_level_map) == 0:
# no-op
return observation_features
for obsf in observation_features:
for p_name in self.trial_level_map:
pval = obsf.parameters.pop(p_name)
Expand Down

0 comments on commit dd71b56

Please sign in to comment.