diff --git a/ax/core/observation.py b/ax/core/observation.py index 5b91f0b2774..c5834975091 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -514,7 +514,11 @@ def observations_from_data( statuses_to_include_map_metric = NON_ABANDONED_STATUSES is_map_data = isinstance(data, MapData) map_keys = [] - take_map_branch = is_map_data and not load_only_completed_map_metrics + take_map_branch = ( + is_map_data + and not load_only_completed_map_metrics + and len(assert_is_instance(data, MapData).map_df) > 0 + ) if take_map_branch: data = assert_is_instance(data, MapData) map_keys.extend(data.map_keys) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index cef84761cad..db5e923a6ad 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -112,6 +112,7 @@ def __init__( fit_tracking_metrics: bool = True, fit_on_init: bool = True, fit_only_completed_map_metrics: bool = True, + latest_rows_per_group: int | None = 1, ) -> None: """ Applies transforms and fits model. @@ -160,6 +161,11 @@ def __init__( fit_only_completed_map_metrics: Whether to fit a model to map metrics only when the trial is completed. This is useful for applications like modeling partially completed learning curves in AutoML. + latest_rows_per_group: If specified and data is an instance of MapData, + uses MapData.latest() with `rows_per_group=latest_rows_per_group` to + retrieve the most recent rows for each group. Useful in cases where + learning curves are frequently updated, preventing an excessive + number of Observation objects. """ t_fit_start = time.monotonic() transforms = transforms or [] @@ -189,6 +195,7 @@ def __init__( self._fit_abandoned = fit_abandoned self._fit_tracking_metrics = fit_tracking_metrics self._fit_only_completed_map_metrics = fit_only_completed_map_metrics + self._latest_rows_per_group = latest_rows_per_group self.outcomes: list[str] = [] self._experiment_has_immutable_search_space_and_opt_config: bool = ( experiment is not None and experiment.immutable_search_space_and_opt_config @@ -303,7 +310,7 @@ def _prepare_observations( return observations_from_data( experiment=experiment, data=data, - latest_rows_per_group=None, + latest_rows_per_group=self._latest_rows_per_group, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=map_keys_as_parameters, diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index c00acbeab51..d4a390bb756 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -73,6 +73,7 @@ def __init__( default_model_gen_options: TConfig | None = None, map_data_limit_rows_per_metric: int | None = None, map_data_limit_rows_per_group: int | None = None, + latest_rows_per_group: int | None = None, ) -> None: """ Applies transforms and fits model. @@ -113,6 +114,11 @@ def __init__( map_data_limit_rows_per_group: Subsample the map data so that the number of rows in the `map_key` column for each (arm, metric) is limited by this value. + latest_rows_per_group: If specified and data is an instance of MapData, + uses MapData.latest() with `rows_per_group=latest_rows_per_group` to + retrieve the most recent rows for each group. Useful in cases where + learning curves are frequently updated, preventing an excessive + number of Observation objects. """ if not isinstance(data, MapData): @@ -140,6 +146,7 @@ def __init__( fit_abandoned=fit_abandoned, fit_on_init=fit_on_init, default_model_gen_options=default_model_gen_options, + latest_rows_per_group=latest_rows_per_group, ) @property @@ -259,6 +266,7 @@ def _prepare_observations( data=data, limit_rows_per_metric=self._map_data_limit_rows_per_metric, limit_rows_per_group=self._map_data_limit_rows_per_group, + latest_rows_per_group=self._latest_rows_per_group, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=True, diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index 6232aae351b..ea3e25e16a9 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -1077,8 +1077,8 @@ def test_fit_only_completed_map_metrics( ) _, kwargs = mock_observations_from_data.call_args self.assertTrue(kwargs["map_keys_as_parameters"]) - # assert `latest_rows_per_group` is not specified or is None - self.assertIsNone(kwargs.get("latest_rows_per_group")) + # assert `latest_rows_per_group` is not specified or is 1 + self.assertEqual(kwargs.get("latest_rows_per_group"), 1) mock_observations_from_data.reset_mock() # calling without map data calls observations_from_data with diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 01bc6c7b5e0..41079ad221a 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -171,6 +171,7 @@ def test_enum_sobol_legacy_GPEI(self) -> None: "fit_on_init": True, "default_model_gen_options": None, "fit_only_completed_map_metrics": True, + "latest_rows_per_group": 1, }, ) prior_kwargs = {"lengthscale_prior": GammaPrior(6.0, 6.0)} diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index 02da786a8a4..1a398c2df3b 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -118,6 +118,7 @@ def __init__( fit_on_init: bool = True, default_model_gen_options: TConfig | None = None, fit_only_completed_map_metrics: bool = True, + latest_rows_per_group: int | None = 1, ) -> None: self.device = torch_device # pyre-ignore [4]: Attribute `_default_model_gen_options` of class @@ -148,6 +149,7 @@ def __init__( fit_tracking_metrics=fit_tracking_metrics, fit_on_init=fit_on_init, fit_only_completed_map_metrics=fit_only_completed_map_metrics, + latest_rows_per_group=latest_rows_per_group, ) def feature_importances(self, metric_name: str) -> dict[str, float]: