Skip to content

Commit

Permalink
Adds method to retain the N most recently observed values from MapData (
Browse files Browse the repository at this point in the history
facebook#3112)

Summary:

**Context:** Early-stopping will lead to learning curves that are truncated at varying progressions/lengths, not necessarily just near the highest progressions or in the asymptotic regime. Incorporating one or more of the last seen values for curves that were stopped both early on during the stages of steepest descent and also those that were stopped in the asymptotic regime will likely provide useful/informative datapoints with which to improve the model

This diff:

* Provide a new method `latest` for `MapData` to retrieve the *n* most recently observed values for each (arm, metric) group, determined by the `map_key` values, where higher implies more recent.
* Update `observations_from_data` to optionally utilize `latest` and retain only the most recently observed *n* values (the new option, if specified alongside the existing subsampling options, will now take precedence).
* Modify the "upcast" `MapData.df` property to leverage `latest`, which is a special case with *n=1*.
* Revise the docstring to reflect changes in the pertinent methods, as well as update related methods like `subsample` to ensure uniform and consistent writing.

Reviewed By: esantorella

Differential Revision: D66434621
  • Loading branch information
ltiao authored and facebook-github-bot committed Feb 19, 2025
1 parent fe7fd5a commit a529b8d
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 21 deletions.
55 changes: 48 additions & 7 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,13 @@ def df(self) -> pd.DataFrame:
if self._memo_df is not None:
return self._memo_df

# If map_keys is empty just return the df
if len(self.map_keys) == 0:
# If map_keys is empty just return the df
return self.map_df

self._memo_df = self.map_df.sort_values(self.map_keys).drop_duplicates(
MapData.DEDUPLICATE_BY_COLUMNS, keep="last"
self._memo_df = _tail(
map_df=self.map_df, map_keys=self.map_keys, n=1, sort=True
)

return self._memo_df
Expand Down Expand Up @@ -362,6 +363,32 @@ def clone(self) -> MapData:
description=self.description,
)

def latest(
self,
map_keys: list[str] | None = None,
rows_per_group: int = 1,
) -> MapData:
"""Return a new MapData with the most recently observed `rows_per_group`
rows for each (arm, metric) group, determined by the `map_key` values,
where higher implies more recent.
This function considers only the relative ordering of the `map_key` values,
making it most suitable when these values are equally spaced.
If `rows_per_group` is greater than the number of rows in a given
(arm, metric) group, then all rows are returned.
"""
if map_keys is None:
map_keys = self.map_keys

return MapData(
df=_tail(
map_df=self.map_df, map_keys=map_keys, n=rows_per_group, sort=True
),
map_key_infos=self.map_key_infos,
description=self.description,
)

def subsample(
self,
map_key: str | None = None,
Expand All @@ -370,11 +397,13 @@ def subsample(
limit_rows_per_metric: int | None = None,
include_first_last: bool = True,
) -> MapData:
"""Subsample the `map_key` column in an equally-spaced manner (if there is
a `self.map_keys` is length one, then `map_key` can be set to None). The
values of the `map_key` column are not taken into account, so this function
is most reasonable when those values are equally-spaced. There are three
ways that this can be done:
"""Return a new MapData that subsamples the `map_key` column in an
equally-spaced manner. If `self.map_keys` has a length of one, `map_key`
can be set to None. This function considers only the relative ordering
of the `map_key` values, making it most suitable when these values are
equally spaced.
There are three ways that this can be done:
1. If `keep_every = k` is set, then every kth row of the DataFrame in the
`map_key` column is kept after grouping by `DEDUPLICATE_BY_COLUMNS`.
In other words, every kth step of each (arm, metric) will be kept.
Expand Down Expand Up @@ -478,6 +507,18 @@ def _subsample_rate(
)


def _tail(
map_df: pd.DataFrame,
map_keys: list[str],
n: int = 1,
sort: bool = True,
) -> pd.DataFrame:
df = map_df.sort_values(map_keys).groupby(MapData.DEDUPLICATE_BY_COLUMNS).tail(n)
if sort:
df.sort_values(MapData.DEDUPLICATE_BY_COLUMNS, inplace=True)
return df


def _subsample_one_metric(
map_df: pd.DataFrame,
map_key: str | None = None,
Expand Down
33 changes: 20 additions & 13 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def observations_from_data(
statuses_to_include: set[TrialStatus] | None = None,
statuses_to_include_map_metric: set[TrialStatus] | None = None,
map_keys_as_parameters: bool = False,
latest_rows_per_group: int | None = None,
limit_rows_per_metric: int | None = None,
limit_rows_per_group: int | None = None,
) -> list[Observation]:
Expand All @@ -485,17 +486,21 @@ def observations_from_data(
trials with statuses in this set. Defaults to all statuses except abandoned.
map_keys_as_parameters: Whether map_keys should be returned as part of
the parameters of the Observation objects.
limit_rows_per_metric: If specified, and if data is an instance of MapData,
uses MapData.subsample() with
`limit_rows_per_metric` equal to the specified value on the first
map_key (map_data.map_keys[0]) to subsample the MapData. This is
useful in, e.g., cases where learning curves are frequently
updated, leading to an intractable number of Observation objects
created.
limit_rows_per_group: If specified, and if data is an instance of MapData,
uses MapData.subsample() with
`limit_rows_per_group` equal to the specified value on the first
map_key (map_data.map_keys[0]) to subsample the MapData.
latest_rows_per_group: If specified and data is an instance of MapData,
uses MapData.latest() with `rows_per_group=latest_rows_per_group` to
retrieve the most recent rows for each group. Useful in cases where
learning curves are frequently updated, preventing an excessive
number of Observation objects. Overrides `limit_rows_per_metric`
and `limit_rows_per_group`.
limit_rows_per_metric: If specified and data is an instance of MapData,
uses MapData.subsample() with `limit_rows_per_metric` on the first
map_key (map_data.map_keys[0]) to subsample the MapData. Useful for
managing the number of Observation objects when learning curves are
frequently updated. Ignored if `latest_rows_per_group` is specified.
limit_rows_per_group: If specified and data is an instance of MapData,
uses MapData.subsample() with `limit_rows_per_group` on the first
map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if
`latest_rows_per_group` is specified.
Returns:
List of Observation objects.
Expand All @@ -509,9 +514,11 @@ def observations_from_data(
if is_map_data:
data = assert_is_instance(data, MapData)
map_keys.extend(data.map_keys)
if limit_rows_per_metric is not None or limit_rows_per_group is not None:
if latest_rows_per_group is not None:
data = data.latest(map_keys=map_keys, rows_per_group=latest_rows_per_group)
elif limit_rows_per_metric is not None or limit_rows_per_group is not None:
data = data.subsample(
map_key=map_keys[0],
map_key=data.map_keys[0],
limit_rows_per_metric=limit_rows_per_metric,
limit_rows_per_group=limit_rows_per_group,
include_first_last=True,
Expand Down
75 changes: 74 additions & 1 deletion ax/core/tests/test_map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-strict


import numpy as np
import pandas as pd
from ax.core.data import Data
from ax.core.map_data import MapData, MapKeyInfo
Expand Down Expand Up @@ -236,7 +237,17 @@ def test_upcast(self) -> None:

self.assertIsNotNone(fresh._memo_df) # Assert df is cached after first call

def test_subsample(self) -> None:
self.assertTrue(
fresh.df.equals(
fresh.map_df.sort_values(fresh.map_keys).drop_duplicates(
MapData.DEDUPLICATE_BY_COLUMNS, keep="last"
)
)
)

def test_latest(self) -> None:
seed = 8888

arm_names = ["0_0", "1_0", "2_0", "3_0"]
max_epochs = [25, 50, 75, 100]
metric_names = ["a", "b"]
Expand All @@ -259,6 +270,68 @@ def test_subsample(self) -> None:
)
large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos)

shuffled_large_map_df = large_map_data.map_df.groupby(
MapData.DEDUPLICATE_BY_COLUMNS
).sample(frac=1, random_state=seed)
shuffled_large_map_data = MapData(
df=shuffled_large_map_df, map_key_infos=self.map_key_infos
)

for rows_per_group in [1, 40]:
large_map_data_latest = large_map_data.latest(rows_per_group=rows_per_group)

if rows_per_group == 1:
self.assertTrue(
large_map_data_latest.map_df.groupby("metric_name")
.epoch.transform(lambda col: set(col) == set(max_epochs))
.all()
)

# when rows_per_group is larger than the number of rows
# actually observed in a group
actual_rows_per_group = large_map_data_latest.map_df.groupby(
MapData.DEDUPLICATE_BY_COLUMNS
).size()
expected_rows_per_group = np.minimum(
large_map_data_latest.map_df.groupby(
MapData.DEDUPLICATE_BY_COLUMNS
).epoch.max(),
rows_per_group,
)
self.assertTrue(actual_rows_per_group.equals(expected_rows_per_group))

# behavior should be consistent even if map_keys are not in ascending order
shuffled_large_map_data_latest = shuffled_large_map_data.latest(
rows_per_group=rows_per_group
)
self.assertTrue(
shuffled_large_map_data_latest.map_df.equals(
large_map_data_latest.map_df
)
)

def test_subsample(self) -> None:
arm_names = ["0_0", "1_0", "2_0", "3_0"]
max_epochs = [25, 50, 75, 100]
metric_names = ["a", "b"]
large_map_df = pd.DataFrame(
[
{
"arm_name": arm_name,
"epoch": epoch + 1,
"mean": epoch * 0.1,
"sem": 0.1,
"trial_index": trial_index,
"metric_name": metric_name,
}
for metric_name in metric_names
for trial_index, (arm_name, max_epoch) in enumerate(
zip(arm_names, max_epochs)
)
for epoch in range(max_epoch)
]
)
large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos)
large_map_df_sparse_metric = pd.DataFrame(
[
{
Expand Down
1 change: 1 addition & 0 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def _prepare_observations(
return observations_from_data(
experiment=experiment,
data=data,
latest_rows_per_group=1,
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
Expand Down

0 comments on commit a529b8d

Please sign in to comment.