Skip to content

Commit

Permalink
Only load last observation of map data by default
Browse files Browse the repository at this point in the history
Summary: This commit ensures that the `Adapter` only loads a single observation by default, even for map metrics. This makes sure that any method (not-necessarily map-data aware) can be applied by default.

Differential Revision: D69992533
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Feb 21, 2025
1 parent 63a1eaf commit a78451e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def observations_from_data(
statuses_to_include_map_metric = NON_ABANDONED_STATUSES
is_map_data = isinstance(data, MapData)
map_keys = []
if is_map_data:
if is_map_data and len(assert_is_instance(data, MapData).map_df) > 0:
data = assert_is_instance(data, MapData)
map_keys.extend(data.map_keys)
if latest_rows_per_group is not None:
Expand Down
9 changes: 8 additions & 1 deletion ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
fit_only_completed_map_metrics: bool = True,
latest_rows_per_group: int | None = 1,
) -> None:
"""
Applies transforms and fits model.
Expand Down Expand Up @@ -159,6 +160,11 @@ def __init__(
fit_only_completed_map_metrics: Whether to fit a model to map metrics only
when the trial is completed. This is useful for applications like
modeling partially completed learning curves in AutoML.
latest_rows_per_group: If specified and data is an instance of MapData,
uses MapData.latest() with `rows_per_group=latest_rows_per_group` to
retrieve the most recent rows for each group. Useful in cases where
learning curves are frequently updated, preventing an excessive
number of Observation objects.
"""
t_fit_start = time.monotonic()
transforms = transforms or []
Expand Down Expand Up @@ -188,6 +194,7 @@ def __init__(
self._fit_abandoned = fit_abandoned
self._fit_tracking_metrics = fit_tracking_metrics
self._fit_only_completed_map_metrics = fit_only_completed_map_metrics
self._latest_rows_per_group = latest_rows_per_group
self.outcomes: list[str] = []
self._experiment_has_immutable_search_space_and_opt_config: bool = (
experiment is not None and experiment.immutable_search_space_and_opt_config
Expand Down Expand Up @@ -302,7 +309,7 @@ def _prepare_observations(
return observations_from_data(
experiment=experiment,
data=data,
latest_rows_per_group=None,
latest_rows_per_group=self._latest_rows_per_group,
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,8 +1062,8 @@ def test_fit_only_completed_map_metrics(
)
_, kwargs = mock_observations_from_data.call_args
self.assertTrue(kwargs["map_keys_as_parameters"])
# assert `latest_rows_per_group` is not specified or is None
self.assertIsNone(kwargs.get("latest_rows_per_group"))
# assert `latest_rows_per_group` is not specified or is 1
self.assertEqual(kwargs.get("latest_rows_per_group"), 1)
mock_observations_from_data.reset_mock()

# calling without map data calls observations_from_data with
Expand Down
2 changes: 2 additions & 0 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
fit_on_init: bool = True,
default_model_gen_options: TConfig | None = None,
fit_only_completed_map_metrics: bool = True,
latest_rows_per_group: int | None = 1,
) -> None:
# This warning is being added while we are on 0.4.3, so it will be
# released in 0.4.4 or 0.5.0. The `torch_dtype` argument can be removed
Expand Down Expand Up @@ -163,6 +164,7 @@ def __init__(
fit_tracking_metrics=fit_tracking_metrics,
fit_on_init=fit_on_init,
fit_only_completed_map_metrics=fit_only_completed_map_metrics,
latest_rows_per_group=latest_rows_per_group,
)

def feature_importances(self, metric_name: str) -> dict[str, float]:
Expand Down

0 comments on commit a78451e

Please sign in to comment.