Skip to content

Commit

Permalink
Only load last observation of map data by default (#3403)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3403

This commit ensures that the `Adapter` only loads a single observation by default, even for map metrics. This makes sure that any method (not-necessarily map-data aware) can be applied by default.

Reviewed By: saitcakmak

Differential Revision: D69992533
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Feb 24, 2025
1 parent c64d122 commit 96ef4e4
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 4 deletions.
6 changes: 5 additions & 1 deletion ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,11 @@ def observations_from_data(
statuses_to_include_map_metric = NON_ABANDONED_STATUSES
is_map_data = isinstance(data, MapData)
map_keys = []
take_map_branch = is_map_data and not load_only_completed_map_metrics
take_map_branch = (
is_map_data
and not load_only_completed_map_metrics
and len(assert_is_instance(data, MapData).map_df) > 0
)
if take_map_branch:
data = assert_is_instance(data, MapData)
map_keys.extend(data.map_keys)
Expand Down
9 changes: 8 additions & 1 deletion ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
fit_only_completed_map_metrics: bool = True,
latest_rows_per_group: int | None = 1,
) -> None:
"""
Applies transforms and fits model.
Expand Down Expand Up @@ -160,6 +161,11 @@ def __init__(
fit_only_completed_map_metrics: Whether to fit a model to map metrics only
when the trial is completed. This is useful for applications like
modeling partially completed learning curves in AutoML.
latest_rows_per_group: If specified and data is an instance of MapData,
uses MapData.latest() with `rows_per_group=latest_rows_per_group` to
retrieve the most recent rows for each group. Useful in cases where
learning curves are frequently updated, preventing an excessive
number of Observation objects.
"""
t_fit_start = time.monotonic()
transforms = transforms or []
Expand Down Expand Up @@ -189,6 +195,7 @@ def __init__(
self._fit_abandoned = fit_abandoned
self._fit_tracking_metrics = fit_tracking_metrics
self._fit_only_completed_map_metrics = fit_only_completed_map_metrics
self._latest_rows_per_group = latest_rows_per_group
self.outcomes: list[str] = []
self._experiment_has_immutable_search_space_and_opt_config: bool = (
experiment is not None and experiment.immutable_search_space_and_opt_config
Expand Down Expand Up @@ -303,7 +310,7 @@ def _prepare_observations(
return observations_from_data(
experiment=experiment,
data=data,
latest_rows_per_group=None,
latest_rows_per_group=self._latest_rows_per_group,
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,8 +1077,8 @@ def test_fit_only_completed_map_metrics(
)
_, kwargs = mock_observations_from_data.call_args
self.assertTrue(kwargs["map_keys_as_parameters"])
# assert `latest_rows_per_group` is not specified or is None
self.assertIsNone(kwargs.get("latest_rows_per_group"))
# assert `latest_rows_per_group` is not specified or is 1
self.assertEqual(kwargs.get("latest_rows_per_group"), 1)
mock_observations_from_data.reset_mock()

# calling without map data calls observations_from_data with
Expand Down
1 change: 1 addition & 0 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def test_enum_sobol_legacy_GPEI(self) -> None:
"fit_on_init": True,
"default_model_gen_options": None,
"fit_only_completed_map_metrics": True,
"latest_rows_per_group": 1,
},
)
prior_kwargs = {"lengthscale_prior": GammaPrior(6.0, 6.0)}
Expand Down
2 changes: 2 additions & 0 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
fit_on_init: bool = True,
default_model_gen_options: TConfig | None = None,
fit_only_completed_map_metrics: bool = True,
latest_rows_per_group: int | None = 1,
) -> None:
self.device = torch_device
# pyre-ignore [4]: Attribute `_default_model_gen_options` of class
Expand Down Expand Up @@ -148,6 +149,7 @@ def __init__(
fit_tracking_metrics=fit_tracking_metrics,
fit_on_init=fit_on_init,
fit_only_completed_map_metrics=fit_only_completed_map_metrics,
latest_rows_per_group=latest_rows_per_group,
)

def feature_importances(self, metric_name: str) -> dict[str, float]:
Expand Down

0 comments on commit 96ef4e4

Please sign in to comment.