Skip to content

Commit

Permalink
Simplify observations_from_data/map_data (#3339)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3339

This diff cleans up a rarely used & untested part of `observations_from_data`. The code had logic for splitting the dataframe into two, and processing the dfs separately. However, this part of the code was not tested and it wasn't being executed in rare cases where the feature columns included NaN / NaT values. The `_observations_from_dataframe` helper is slightly modified to support NaN / NaT values, eliminating the need for splitting the dfs.

Differential Revision: D69419519

Reviewed By: lena-kashtelyan
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 11, 2025
1 parent f7227a2 commit 8028892
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 105 deletions.
116 changes: 19 additions & 97 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def _observations_from_dataframe(
return []
observations = []
abandoned_arms_dict = {}
for g, d in df.groupby(by=cols):
# NOTE: dropna is important to avoid dropping the whole or part of the df if
# a feature column is filled with NaN / NaT values.
for g, d in df.groupby(by=cols, dropna=False):
obs_kwargs = {}
features = dict(zip(cols, g, strict=True))
arm_name = features["arm_name"]
Expand Down Expand Up @@ -321,7 +323,7 @@ def _observations_from_dataframe(
if obs_parameters:
obs_kwargs["parameters"] = obs_parameters
for f, val in features.items():
if f in OBS_KWARGS:
if f in OBS_KWARGS and not pd.isna(val):
obs_kwargs[f] = val
# add start and end time of trial if the start and end time
# is the same for all metrics and arms
Expand Down Expand Up @@ -487,51 +489,15 @@ def observations_from_data(
if statuses_to_include_map_metric is None:
statuses_to_include_map_metric = {TrialStatus.COMPLETED}
feature_cols = get_feature_cols(data)
observations = []
# One DataFrame where all rows have all features.
isnull = data.df[feature_cols].isnull()
isnull_any = isnull.any(axis=1)
incomplete_df_cols = isnull[isnull_any].any()

# Get the incomplete_df columns that are complete, and usable as groupby keys.
complete_feature_cols = list(
OBS_COLS.intersection(incomplete_df_cols.index[~incomplete_df_cols])
return _observations_from_dataframe(
experiment=experiment,
df=data.df,
cols=feature_cols,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys=[],
)

if set(feature_cols) == set(complete_feature_cols):
complete_df = data.df
incomplete_df = None
else:
# The groupby and filter is expensive, so do it only if we have to.
grouped = data.df.groupby(by=complete_feature_cols)
complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any())
incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any())

# Get Observations from complete_df
observations.extend(
_observations_from_dataframe(
experiment=experiment,
df=complete_df,
cols=feature_cols,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys=[],
)
)
if incomplete_df is not None:
# Get Observations from incomplete_df
observations.extend(
_observations_from_dataframe(
experiment=experiment,
df=incomplete_df,
cols=complete_feature_cols,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys=[],
)
)
return observations


def observations_from_map_data(
experiment: experiment.Experiment,
Expand Down Expand Up @@ -584,60 +550,16 @@ def observations_from_map_data(
include_first_last=True,
)
feature_cols = get_feature_cols(map_data, is_map_data=True)
observations = []
# One DataFrame where all rows have all features.
isnull = map_data.map_df[feature_cols].isnull()
isnull_any = isnull.any(axis=1)
incomplete_df_cols = isnull[isnull_any].any()

# Get the incomplete_df columns that are complete, and usable as groupby keys.
obs_cols_and_map = OBS_COLS.union(map_data.map_keys)
complete_feature_cols = list(
obs_cols_and_map.intersection(incomplete_df_cols.index[~incomplete_df_cols])
return _observations_from_dataframe(
experiment=experiment,
df=map_data.map_df,
cols=feature_cols,
map_keys=map_data.map_keys,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
)

if set(feature_cols) == set(complete_feature_cols):
complete_df = map_data.map_df
incomplete_df = None
else:
# The groupby and filter is expensive, so do it only if we have to.
grouped = map_data.map_df.groupby(
by=(
complete_feature_cols
if len(complete_feature_cols) > 1
else complete_feature_cols[0]
)
)
complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any())
incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any())

# Get Observations from complete_df
observations.extend(
_observations_from_dataframe(
experiment=experiment,
df=complete_df,
cols=feature_cols,
map_keys=map_data.map_keys,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
)
)
if incomplete_df is not None:
# Get Observations from incomplete_df
observations.extend(
_observations_from_dataframe(
experiment=experiment,
df=incomplete_df,
cols=complete_feature_cols,
map_keys=map_data.map_keys,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
)
)
return observations


def separate_observations(
observations: list[Observation], copy: bool = False
Expand Down
25 changes: 17 additions & 8 deletions ax/core/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,9 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None:
)
self.assertEqual(obs.arm_name, cname_truth[i])

def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
def test_ObservationsFromDataWithDifferentTimesSingleTrial(
self, with_nat: bool = False
) -> None:
params0: TParameterization = {"x": 0, "y": "a"}
params1: TParameterization = {"x": 1, "y": "a"}
truth = [
Expand All @@ -754,7 +756,7 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
"trial_index": 0,
"metric_name": "a",
"start_time": "2024-03-20 08:45:00",
"end_time": "2024-03-20 08:47:00",
"end_time": pd.NaT if with_nat else "2024-03-20 08:47:00",
},
{
"arm_name": "0_0",
Expand All @@ -764,6 +766,7 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
"trial_index": 0,
"metric_name": "b",
"start_time": "2024-03-20 08:45:00",
"end_time": pd.NaT if with_nat else "2024-03-20 08:46:00",
},
{
"arm_name": "0_1",
Expand All @@ -773,7 +776,7 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
"trial_index": 0,
"metric_name": "a",
"start_time": "2024-03-20 08:43:00",
"end_time": "2024-03-20 08:46:00",
"end_time": pd.NaT if with_nat else "2024-03-20 08:46:00",
},
{
"arm_name": "0_1",
Expand All @@ -783,7 +786,7 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
"trial_index": 0,
"metric_name": "b",
"start_time": "2024-03-20 08:45:00",
"end_time": "2024-03-20 08:46:00",
"end_time": pd.NaT if with_nat else "2024-03-20 08:46:00",
},
]
arms_by_name = {
Expand Down Expand Up @@ -848,10 +851,16 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
self.assertIsNone(obs.features.end_time)
else:
self.assertIsNone(obs.features.start_time)
self.assertEqual(
none_throws(obs.features.end_time).strftime("%Y-%m-%d %X"),
"2024-03-20 08:46:00",
)
if with_nat:
self.assertIsNone(obs.features.end_time)
else:
self.assertEqual(
none_throws(obs.features.end_time).strftime("%Y-%m-%d %X"),
"2024-03-20 08:46:00",
)

def test_observations_from_dataframe_with_nat_timestamps(self) -> None:
self.test_ObservationsFromDataWithDifferentTimesSingleTrial(with_nat=True)

def test_SeparateObservations(self) -> None:
obs_arm_name = "0_0"
Expand Down

0 comments on commit 8028892

Please sign in to comment.