Skip to content

Commit

Permalink
DataLoaderConfig - Only load last observation of map data by default (
Browse files Browse the repository at this point in the history
#3403)

Summary:

The original purpose of this commit was to ensure that the `Adapter` only loads a single observation by default, even for running and completed map metrics. This makes sure that any method (not-necessarily map-data aware) can be applied by default.

To this end, we need control of the `latest_rows_per_group` of `observations_from_data`. 

In order to avoid increasing the number of arguments to the adapter, we package the new argument and other parameters controlling what data the model is fit to in a new `DataLoaderConfig` dataclass.

Reviewed By: saitcakmak

Differential Revision: D69992533
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Mar 3, 2025
1 parent b3c32a0 commit 7a5fcc3
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 53 deletions.
1 change: 1 addition & 0 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class MapData(Data):
`experiment.attach_data()` (this requires a description to be set.)
"""

REQUIRED_COLUMNS = {"trial_index", "arm_name", "metric_name"}
DEDUPLICATE_BY_COLUMNS = ["trial_index", "arm_name", "metric_name"]

_map_df: pd.DataFrame
Expand Down
8 changes: 2 additions & 6 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ def observations_from_data(
latest_rows_per_group: int | None = None,
limit_rows_per_metric: int | None = None,
limit_rows_per_group: int | None = None,
load_only_completed_map_metrics: bool = True,
) -> list[Observation]:
"""Convert Data (or MapData) to observations.
Expand Down Expand Up @@ -502,8 +501,6 @@ def observations_from_data(
uses MapData.subsample() with `limit_rows_per_group` on the first
map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if
`latest_rows_per_group` is specified.
load_only_completed_map_metrics: If True, only loads the last observation
for each completed MapMetric.
Returns:
List of Observation objects.
Expand All @@ -514,8 +511,7 @@ def observations_from_data(
statuses_to_include_map_metric = NON_ABANDONED_STATUSES
is_map_data = isinstance(data, MapData)
map_keys = []
take_map_branch = is_map_data and not load_only_completed_map_metrics
if take_map_branch:
if is_map_data:
data = assert_is_instance(data, MapData)
map_keys.extend(data.map_keys)
if latest_rows_per_group is not None:
Expand All @@ -530,7 +526,7 @@ def observations_from_data(
df = data.map_df
else:
df = data.df
feature_cols = get_feature_cols(data, is_map_data=take_map_branch)
feature_cols = get_feature_cols(data, is_map_data=is_map_data)
return _observations_from_dataframe(
experiment=experiment,
df=df,
Expand Down
12 changes: 9 additions & 3 deletions ax/core/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,7 @@ def test_ObservationsFromMapData(self) -> None:
MapKeyInfo(key="timestamp", default_value=0.0),
],
)
observations = observations_from_data(
experiment, data, load_only_completed_map_metrics=False
)
observations = observations_from_data(experiment, data)

self.assertEqual(len(observations), 3)

Expand All @@ -494,6 +492,14 @@ def test_ObservationsFromMapData(self) -> None:
self.assertEqual(obs.arm_name, t["arm_name"])
self.assertEqual(obs.features.metadata, {"timestamp": t["timestamp"]})

# testing that we can handle empty data with latest_rows_per_group
empty_data = MapData()
observations = observations_from_data(
experiment,
empty_data,
latest_rows_per_group=1,
)

def test_ObservationsFromDataAbandoned(self) -> None:
truth = [
{
Expand Down
1 change: 0 additions & 1 deletion ax/core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ def test_get_pending_observation_features_multi_trial(self) -> None:
),
):
pending = get_pending_observation_features(self.experiment)
print(pending)
self.assertEqual(
pending,
{
Expand Down
135 changes: 111 additions & 24 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ class GenResults:
gen_metadata: dict[str, Any] = field(default_factory=dict)


@dataclass(frozen=True)
class DataLoaderConfig:
"""This dataclass contains parameters that control the `Adapter._set_training_data`.
Args:
fit_out_of_design: If specified, all training data are used.
Otherwise, only in design points are used.
fit_abandoned: Whether data for abandoned arms or trials should be included in
model training data. If `False`, only non-abandoned points are returned.
fit_only_completed_map_metrics: Whether to fit a model to map metrics only when
the trial is completed. This is useful for applications like modeling
partially completed learning curves in AutoML.
latest_rows_per_group: If specified and data is an instance of MapData, uses
MapData.latest() with `latest_rows_per_group` to retrieve the most recent
rows for each group. Useful in cases where learning curves are frequently
updated, preventing an excessive number of Observation objects.
limit_rows_per_metric: Subsample the map data so that the total number of
rows per metric is limited by this value.
limit_rows_per_group: Subsample the map data so that the number of rows
in the `map_key` column for each (arm, metric) is limited by this value.
"""

fit_out_of_design: bool = False
fit_abandoned: bool = False
fit_only_completed_map_metrics: bool = True
latest_rows_per_group: int | None = 1
limit_rows_per_metric: int | None = None
limit_rows_per_group: int | None = None


class Adapter:
"""The main object for using models in Ax.
Expand Down Expand Up @@ -105,11 +135,12 @@ def __init__(
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
expand_model_space: bool = True,
fit_out_of_design: bool = False,
fit_abandoned: bool = False,
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
fit_only_completed_map_metrics: bool = True,
data_loader_config: DataLoaderConfig | None = None,
fit_out_of_design: bool | None = None,
fit_abandoned: bool | None = None,
fit_only_completed_map_metrics: bool | None = None,
) -> None:
"""
Applies transforms and fits model.
Expand Down Expand Up @@ -145,11 +176,6 @@ def __init__(
space larger than the search space if training data fall outside
the search space. Will also include training points that violate
parameter constraints in the modeling.
fit_out_of_design: If specified, all training data are used.
Otherwise, only in design points are used.
fit_abandoned: Whether data for abandoned arms or trials should be
included in model training data. If ``False``, only
non-abandoned points are returned.
fit_tracking_metrics: Whether to fit a model for tracking metrics.
Setting this to False will improve runtime at the expense of
models not being available for predicting tracking metrics.
Expand All @@ -160,10 +186,26 @@ def __init__(
To fit the model afterwards, use `_process_and_transform_data`
to get the transformed inputs and call `_fit_if_implemented` with
the transformed inputs.
fit_only_completed_map_metrics: Whether to fit a model to map metrics only
when the trial is completed. This is useful for applications like
modeling partially completed learning curves in AutoML.
data_loader_config: A DataLoaderConfig of options for loading data. See the
docstring of DataLoaderConfig for more details.
fit_out_of_design: Deprecation warning: `fit_out_of_design` is deprecated.
Overwrites `data_loader_config.fit_out_of_design` if not None.
fit_abandoned: Deprecation warning: `fit_out_of_design` is deprecated.
Overwrites `data_loader_config.fit_abandoned` if not None.
fit_only_completed_map_metrics: Deprecation warning: `fit_out_of_design`
is deprecated. If not None, overwrites
`data_loader_config.fit_only_completed_map_metrics`.
"""
if data_loader_config is None:
data_loader_config = DataLoaderConfig()

data_loader_config = _legacy_overwrite_data_loader_options(
data_loader_config=data_loader_config,
fit_out_of_design=fit_out_of_design,
fit_abandoned=fit_abandoned,
fit_only_completed_map_metrics=fit_only_completed_map_metrics,
)

t_fit_start = time.monotonic()
transforms = transforms or []
transforms = [Cast] + list(transforms)
Expand All @@ -189,10 +231,8 @@ def __init__(
self._model_space: SearchSpace = search_space.clone()
self._raw_transforms = transforms
self._transform_configs: Mapping[str, TConfig] | None = transform_configs
self._fit_out_of_design = fit_out_of_design
self._fit_abandoned = fit_abandoned
self._data_loader_config: DataLoaderConfig = data_loader_config
self._fit_tracking_metrics = fit_tracking_metrics
self._fit_only_completed_map_metrics = fit_only_completed_map_metrics
self.outcomes: list[str] = []
self._experiment_has_immutable_search_space_and_opt_config: bool = (
experiment is not None and experiment.immutable_search_space_and_opt_config
Expand Down Expand Up @@ -295,17 +335,17 @@ def _prepare_observations(
) -> list[Observation]:
if experiment is None or data is None:
return []
map_keys_as_parameters = (
not self._fit_only_completed_map_metrics and isinstance(data, MapData)
)
fit_only_completed = self._data_loader_config.fit_only_completed_map_metrics
map_keys_as_parameters = not fit_only_completed and isinstance(data, MapData)
return observations_from_data(
experiment=experiment,
data=data,
latest_rows_per_group=None,
latest_rows_per_group=self._data_loader_config.latest_rows_per_group,
limit_rows_per_metric=self._data_loader_config.limit_rows_per_metric,
limit_rows_per_group=self._data_loader_config.limit_rows_per_group,
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
load_only_completed_map_metrics=self._fit_only_completed_map_metrics,
)

def _transform_data(
Expand Down Expand Up @@ -373,7 +413,7 @@ def _process_in_design(
) -> list[Observation]:
"""Set training_in_design, and decide whether to filter out of design points."""
# Don't filter points.
if self._fit_out_of_design:
if self._data_loader_config.fit_out_of_design:
# Use all data for training
# Set training_in_design to True for all observations so that
# all observations are used in CV and plotting
Expand Down Expand Up @@ -489,7 +529,10 @@ def _set_status_quo(
# observations of the status quo.
# This is useful for getting status_quo_data_by_trial
self._status_quo_name = status_quo_name
if len(sq_obs) > 1 and self._fit_only_completed_map_metrics:
if (
len(sq_obs) > 1
and self._data_loader_config.fit_only_completed_map_metrics
):
# it is expected to have multiple obserations for map data
logger.warning(
f"Status quo {status_quo_name} found in data with multiple "
Expand Down Expand Up @@ -554,7 +597,7 @@ def training_in_design(self) -> list[bool]:
@property
def statuses_to_fit(self) -> set[TrialStatus]:
"""Statuses to fit the model on."""
if self._fit_abandoned:
if self._data_loader_config.fit_abandoned:
return set(TrialStatus)
return NON_ABANDONED_STATUSES

Expand All @@ -563,7 +606,7 @@ def statuses_to_fit_map_metric(self) -> set[TrialStatus]:
"""Statuses to fit the model on."""
return (
{TrialStatus.COMPLETED}
if self._fit_only_completed_map_metrics
if self._data_loader_config.fit_only_completed_map_metrics
else self.statuses_to_fit
)

Expand Down Expand Up @@ -871,7 +914,7 @@ def gen(

# Clamp the untransformed data to the original search space if
# we don't fit/gen OOD points
if not self._fit_out_of_design:
if not self._data_loader_config.fit_out_of_design:
observation_features = clamp_observation_features(
observation_features, orig_search_space
)
Expand Down Expand Up @@ -1214,3 +1257,47 @@ def clamp_observation_features(
)
obsf.parameters[p.name] = p.upper
return observation_features


def _legacy_overwrite_data_loader_options(
data_loader_config: DataLoaderConfig,
fit_out_of_design: bool | None = None,
fit_abandoned: bool | None = None,
fit_only_completed_map_metrics: bool | None = None,
warn_if_legacy: bool = True,
) -> DataLoaderConfig:
"""Overwrites data loader config with legacy keyword arguments.
Args:
data_loader_config: Data loader config.
fit_out_of_design: Whether to fit out-of-design points.
fit_abandoned: Whether to fit abandoned arms.
fit_only_completed_map_metrics: Whether to fit only completed map metrics.
warn_if_legacy: Whether to warn if legacy keyword arguments are used.
Returns:
Updated data loader config.
"""
data_loader_config_dict = {}
for var_name, deprecated_var in (
("fit_out_of_design", fit_out_of_design),
("fit_abandoned", fit_abandoned),
("fit_only_completed_map_metrics", fit_only_completed_map_metrics),
):
if deprecated_var is not None:
if warn_if_legacy:
logger.warning(
f"`{var_name}` is deprecated. Please pass as "
f"`data_loader_options.{var_name}` instead."
)
data_loader_config_dict[var_name] = deprecated_var
else:
data_loader_config_dict[var_name] = getattr(data_loader_config, var_name)

data_loader_config = DataLoaderConfig(
latest_rows_per_group=data_loader_config.latest_rows_per_group,
limit_rows_per_metric=data_loader_config.limit_rows_per_metric,
limit_rows_per_group=data_loader_config.limit_rows_per_group,
**data_loader_config_dict,
)
return data_loader_config
25 changes: 19 additions & 6 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
observation_features_to_array,
parse_observation_features,
)
from ax.modelbridge.torch import FIT_MODEL_ERROR, TorchAdapter
from ax.modelbridge.torch import DataLoaderConfig, FIT_MODEL_ERROR, TorchAdapter
from ax.modelbridge.transforms.base import Transform
from ax.models.torch_base import TorchGenerator
from ax.models.types import TConfig
Expand Down Expand Up @@ -65,30 +65,42 @@ def __init__(
transform_configs: Mapping[str, TConfig] | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
fit_out_of_design: bool = False,
fit_on_init: bool = True,
fit_abandoned: bool = False,
default_model_gen_options: TConfig | None = None,
torch_device: torch.device | None = None,
map_data_limit_rows_per_metric: int | None = None,
map_data_limit_rows_per_group: int | None = None,
data_loader_config: DataLoaderConfig | None = None,
fit_out_of_design: bool = False,
fit_abandoned: bool = False,
) -> None:
"""In addition to common arguments documented in the ``Adapter`` and
``TorchAdapter`` classes, ``MapTorchAdapter`` accepts the following arguments.
Args:
map_data_limit_rows_per_metric: Subsample the map data so that the
total number of rows per metric is limited by this value.
total number of rows per metric is limited by this value. Used in place
`limit_rows_per_metric` in `data_loader_config` for MapTorchAdapter.
map_data_limit_rows_per_group: Subsample the map data so that the
number of rows in the `map_key` column for each (arm, metric)
is limited by this value.
is limited by this value. Is used in place `limit_rows_per_group` in
`data_loader_config` for MapTorchAdapter.
data_loader_options: A dictionary of options for loading data.
fit_out_of_design: Overwrites `data_loader_config.fit_out_of_design` if
not None.
fit_abandoned: Overwrites `data_loader_config.fit_abandoned` if not None.
"""
data = data or experiment.lookup_data()

if data_loader_config is None:
data_loader_config = DataLoaderConfig(latest_rows_per_group=None)

if not isinstance(data, MapData):
raise ValueError("`MapTorchAdapter expects `MapData` instead of `Data`.")

if any(isinstance(t, BatchTrial) for t in experiment.trials.values()):
raise ValueError("MapTorchAdapter does not support batch trials.")

self._map_key_features: list[str] = data.map_keys
self._map_data_limit_rows_per_metric = map_data_limit_rows_per_metric
self._map_data_limit_rows_per_group = map_data_limit_rows_per_group
Expand All @@ -107,6 +119,7 @@ def __init__(
fit_abandoned=fit_abandoned,
fit_on_init=fit_on_init,
default_model_gen_options=default_model_gen_options,
data_loader_config=data_loader_config,
)

@property
Expand Down Expand Up @@ -224,10 +237,10 @@ def _prepare_observations(
data=data,
limit_rows_per_metric=self._map_data_limit_rows_per_metric,
limit_rows_per_group=self._map_data_limit_rows_per_group,
latest_rows_per_group=self._data_loader_config.latest_rows_per_group,
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=True,
load_only_completed_map_metrics=False,
)

def _compute_in_design(
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,8 @@ def test_fit_only_completed_map_metrics(
)
kwargs = mock_observations_from_data.call_args.kwargs
self.assertTrue(kwargs["map_keys_as_parameters"])
# assert `latest_rows_per_group` is not specified or is None
self.assertIsNone(kwargs.get("latest_rows_per_group"))
# assert `latest_rows_per_group` is 1
self.assertEqual(kwargs["latest_rows_per_group"], 1)
mock_observations_from_data.reset_mock()

# calling without map data calls observations_from_data with
Expand Down
Loading

0 comments on commit 7a5fcc3

Please sign in to comment.