Skip to content

Commit

Permalink
Clean up multi-surrogate utilities (facebook#3370)
Browse files Browse the repository at this point in the history
Summary:

Removes some used utilities that were originally introduced to ensure correct dataset splitting between multiple surrogates. Multiple surrogate support has since been deprecated and these are not needed.

Reviewed By: sdaulton

Differential Revision: D69632792
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 14, 2025
1 parent c763b43 commit 5d53e36
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 228 deletions.
6 changes: 0 additions & 6 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import (
check_outcome_dataset_match,
choose_botorch_acqf_class,
construct_acquisition_and_optimizer_options,
ModelConfig,
Expand Down Expand Up @@ -177,11 +176,6 @@ def fit(
additional_model_inputs: Additional kwargs to pass to the
model input constructor in ``Surrogate.fit``.
"""
outcome_names = sum((ds.outcome_names for ds in datasets), [])
check_outcome_dataset_match(
outcome_names=outcome_names, datasets=datasets, exact_match=True
) # Checks for duplicate outcome names

# Store search space info for later use (e.g. during generation)
self._search_space_digest = search_space_digest

Expand Down
108 changes: 1 addition & 107 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxError, AxWarning, UnsupportedError
from ax.exceptions.core import AxWarning, UnsupportedError
from ax.models.torch_base import TorchOptConfig
from ax.models.types import TConfig
from ax.utils.common.constants import Keys
Expand Down Expand Up @@ -408,112 +408,6 @@ def fit_botorch_model(
)


def _tensor_difference(A: Tensor, B: Tensor) -> Tensor:
"""Used to return B sans any Xs that also appear in A"""
C = torch.cat((A, B), dim=0)
D, inverse_ind = torch.unique(C, return_inverse=True, dim=0)
n = A.shape[0]
A_indices = inverse_ind[:n].tolist()
B_indices = inverse_ind[n:].tolist()
Bi_set = set(B_indices) - set(A_indices)
return D[list(Bi_set)]


def check_outcome_dataset_match(
outcome_names: Sequence[str],
datasets: Sequence[SupervisedDataset],
exact_match: bool,
) -> None:
"""Check that the given outcome names match those of datasets.
Based on `exact_match` we either require that outcome names are
a subset of all outcomes or require the them to be the same.
Also checks that there are no duplicates in outcome names.
Args:
outcome_names: A list of outcome names.
datasets: A list of `SupervisedDataset` objects.
exact_match: If True, outcome_names must be the same as the union of
outcome names of the datasets. Otherwise, we check that the
outcome_names are a subset of all outcomes.
Raises:
ValueError: If there is no match.
"""
all_outcomes = sum((ds.outcome_names for ds in datasets), [])
set_all_outcomes = set(all_outcomes)
set_all_spec_outcomes = set(outcome_names)
if len(set_all_outcomes) != len(all_outcomes):
raise AxError("Found duplicate outcomes in the datasets.")
if len(set_all_spec_outcomes) != len(outcome_names):
raise AxError("Found duplicate outcome names.")

if not exact_match:
if not set_all_spec_outcomes.issubset(set_all_outcomes):
raise AxError(
"Outcome names must be a subset of the outcome names of the datasets."
f"Got {outcome_names=} but the datasets model {set_all_outcomes}."
)
elif set_all_spec_outcomes != set_all_outcomes:
raise AxError(
"Each outcome name must correspond to an outcome in the datasets. "
f"Got {outcome_names=} but the datasets model {set_all_outcomes}."
)


def get_subset_datasets(
datasets: Sequence[SupervisedDataset],
subset_outcome_names: Sequence[str],
) -> list[SupervisedDataset]:
"""Get the list of datasets corresponding to the given subset of
outcome names. This is used to separate out datasets that are
used by one surrogate.
Args:
datasets: A list of `SupervisedDataset` objects.
subset_outcome_names: A list of outcome names to get datasets for.
Returns:
A list of `SupervisedDataset` objects corresponding to the given
subset of outcome names.
"""
check_outcome_dataset_match(
outcome_names=subset_outcome_names, datasets=datasets, exact_match=False
)
single_outcome_datasets = {
ds.outcome_names[0]: ds for ds in datasets if len(ds.outcome_names) == 1
}
multi_outcome_datasets = {
tuple(ds.outcome_names): ds for ds in datasets if len(ds.outcome_names) > 1
}
subset_datasets = []
outcomes_processed = []
for outcome_name in subset_outcome_names:
if outcome_name in outcomes_processed:
# This can happen if the outcome appears in a multi-outcome
# dataset that is already processed.
continue
if outcome_name in single_outcome_datasets:
# The default case of outcome with a corresponding dataset.
ds = single_outcome_datasets[outcome_name]
else:
# The case of outcome being part of a multi-outcome dataset.
for outcome_names in multi_outcome_datasets.keys():
if outcome_name in outcome_names:
ds = multi_outcome_datasets[outcome_names]
if not set(ds.outcome_names).issubset(subset_outcome_names):
raise UnsupportedError(
"Breaking up a multi-outcome dataset between "
"surrogates is not supported."
)
break
# Pyre-ignore [61]: `ds` may not be defined but it is guaranteed to be defined.
subset_datasets.append(ds)
outcomes_processed.extend(ds.outcome_names)
return subset_datasets


def subset_state_dict(
state_dict: OrderedDict[str, Tensor],
submodel_index: int,
Expand Down
116 changes: 1 addition & 115 deletions ax/models/torch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
import numpy as np
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxError, AxWarning, UnsupportedError
from ax.exceptions.core import AxWarning, UnsupportedError
from ax.models.torch.botorch_modular.utils import (
_get_shared_rows,
_tensor_difference,
check_outcome_dataset_match,
choose_botorch_acqf_class,
choose_model_class,
construct_acquisition_and_optimizer_options,
convert_to_block_design,
get_subset_datasets,
subset_state_dict,
use_model_list,
)
Expand Down Expand Up @@ -351,17 +348,6 @@ def test_use_model_list(self) -> None:
)
)

def test_tensor_difference(self) -> None:
n, m = 3, 2
A = torch.arange(n * m).reshape(n, m)
B = torch.cat((A[: n - 1], torch.randn(2, m)), dim=0)
# permute B
B = B[torch.randperm(len(B))]

C = _tensor_difference(A=A, B=B)

self.assertEqual(C.size(dim=0), 2)

def test_get_shared_rows(self) -> None:
X1 = torch.rand(4, 2)

Expand Down Expand Up @@ -518,106 +504,6 @@ def test_to_inequality_constraints(self) -> None:
self.assertTrue(torch.allclose(ineq_constraints[1][1], torch.tensor([-1])))
self.assertEqual(ineq_constraints[1][2], -2.0)

def test_check_check_outcome_dataset_match(self) -> None:
ds = self.fixed_noise_datasets[0]
# Simple test with one metric & dataset.
for exact_match in (True, False):
self.assertIsNone(
check_outcome_dataset_match(
outcome_names=ds.outcome_names,
datasets=[ds],
exact_match=exact_match,
)
)
# Error with duplicate outcome names.
with self.assertRaisesRegex(AxError, "duplicate outcome names"):
check_outcome_dataset_match(
outcome_names=["y", "y"], datasets=[ds], exact_match=False
)
ds2 = self.supervised_datasets[0]
# Error with duplicate outcomes in datasets.
with self.assertRaisesRegex(AxError, "duplicate outcomes"):
check_outcome_dataset_match(
outcome_names=["y", "y2"], datasets=[ds, ds2], exact_match=False
)
ds2.outcome_names = ["y2"]
# Simple test with two metrics & datasets.
for exact_match in (True, False):
self.assertIsNone(
check_outcome_dataset_match(
outcome_names=["y", "y2"],
datasets=[ds, ds2],
exact_match=exact_match,
)
)
# Exact match required but too many datasets provided.
with self.assertRaisesRegex(AxError, "must correspond to an outcome"):
check_outcome_dataset_match(
outcome_names=["y"],
datasets=[ds, ds2],
exact_match=True,
)
# The same check passes if we don't require exact match.
self.assertIsNone(
check_outcome_dataset_match(
outcome_names=["y"],
datasets=[ds, ds2],
exact_match=False,
)
)
# Error if metric doesn't exist in the datasets.
for exact_match in (True, False):
with self.assertRaisesRegex(AxError, "but the datasets model"):
check_outcome_dataset_match(
outcome_names=["z"],
datasets=[ds, ds2],
exact_match=exact_match,
)

def test_get_subset_datasets(self) -> None:
ds = self.fixed_noise_datasets[0]
ds2 = self.supervised_datasets[0]
ds2.outcome_names = ["y2"]
ds3 = SupervisedDataset(
X=torch.zeros(1, 2),
Y=torch.ones(1, 2),
feature_names=["x1", "x2"],
outcome_names=["y3", "y4"],
)
# Test with single dataset.
self.assertEqual(
[ds], get_subset_datasets(datasets=[ds], subset_outcome_names=["y"])
)
# Edge case of empty metric list.
self.assertEqual(
[], get_subset_datasets(datasets=[ds], subset_outcome_names=[])
)
# Multiple datasets, single metric.
self.assertEqual(
[ds],
get_subset_datasets(datasets=[ds, ds2, ds3], subset_outcome_names=["y"]),
)
self.assertEqual(
[ds2],
get_subset_datasets(datasets=[ds, ds2, ds3], subset_outcome_names=["y2"]),
)
# Multi-output dataset, 1 metric -- not allowed.
with self.assertRaisesRegex(UnsupportedError, "multi-outcome dataset"):
get_subset_datasets(datasets=[ds, ds2, ds3], subset_outcome_names=["y3"])
# Multiple datasets, multiple metrics -- datasets in the same order as metrics.
self.assertEqual(
[ds2, ds],
get_subset_datasets(
datasets=[ds, ds2, ds3], subset_outcome_names=["y2", "y"]
),
)
self.assertEqual(
[ds3, ds],
get_subset_datasets(
datasets=[ds, ds2, ds3], subset_outcome_names=["y3", "y", "y4"]
),
)

def test_subset_state_dict(self) -> None:
m0 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
m1 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
Expand Down

0 comments on commit 5d53e36

Please sign in to comment.