Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up multi-surrogate utilities #3370

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import (
check_outcome_dataset_match,
choose_botorch_acqf_class,
construct_acquisition_and_optimizer_options,
ModelConfig,
Expand Down Expand Up @@ -177,11 +176,6 @@ def fit(
additional_model_inputs: Additional kwargs to pass to the
model input constructor in ``Surrogate.fit``.
"""
outcome_names = sum((ds.outcome_names for ds in datasets), [])
check_outcome_dataset_match(
outcome_names=outcome_names, datasets=datasets, exact_match=True
) # Checks for duplicate outcome names

# Store search space info for later use (e.g. during generation)
self._search_space_digest = search_space_digest

Expand Down
108 changes: 1 addition & 107 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxError, AxWarning, UnsupportedError
from ax.exceptions.core import AxWarning, UnsupportedError
from ax.models.torch_base import TorchOptConfig
from ax.models.types import TConfig
from ax.utils.common.constants import Keys
Expand Down Expand Up @@ -408,112 +408,6 @@ def fit_botorch_model(
)


def _tensor_difference(A: Tensor, B: Tensor) -> Tensor:
"""Used to return B sans any Xs that also appear in A"""
C = torch.cat((A, B), dim=0)
D, inverse_ind = torch.unique(C, return_inverse=True, dim=0)
n = A.shape[0]
A_indices = inverse_ind[:n].tolist()
B_indices = inverse_ind[n:].tolist()
Bi_set = set(B_indices) - set(A_indices)
return D[list(Bi_set)]


def check_outcome_dataset_match(
outcome_names: Sequence[str],
datasets: Sequence[SupervisedDataset],
exact_match: bool,
) -> None:
"""Check that the given outcome names match those of datasets.

Based on `exact_match` we either require that outcome names are
a subset of all outcomes or require the them to be the same.

Also checks that there are no duplicates in outcome names.

Args:
outcome_names: A list of outcome names.
datasets: A list of `SupervisedDataset` objects.
exact_match: If True, outcome_names must be the same as the union of
outcome names of the datasets. Otherwise, we check that the
outcome_names are a subset of all outcomes.

Raises:
ValueError: If there is no match.
"""
all_outcomes = sum((ds.outcome_names for ds in datasets), [])
set_all_outcomes = set(all_outcomes)
set_all_spec_outcomes = set(outcome_names)
if len(set_all_outcomes) != len(all_outcomes):
raise AxError("Found duplicate outcomes in the datasets.")
if len(set_all_spec_outcomes) != len(outcome_names):
raise AxError("Found duplicate outcome names.")

if not exact_match:
if not set_all_spec_outcomes.issubset(set_all_outcomes):
raise AxError(
"Outcome names must be a subset of the outcome names of the datasets."
f"Got {outcome_names=} but the datasets model {set_all_outcomes}."
)
elif set_all_spec_outcomes != set_all_outcomes:
raise AxError(
"Each outcome name must correspond to an outcome in the datasets. "
f"Got {outcome_names=} but the datasets model {set_all_outcomes}."
)


def get_subset_datasets(
datasets: Sequence[SupervisedDataset],
subset_outcome_names: Sequence[str],
) -> list[SupervisedDataset]:
"""Get the list of datasets corresponding to the given subset of
outcome names. This is used to separate out datasets that are
used by one surrogate.

Args:
datasets: A list of `SupervisedDataset` objects.
subset_outcome_names: A list of outcome names to get datasets for.

Returns:
A list of `SupervisedDataset` objects corresponding to the given
subset of outcome names.
"""
check_outcome_dataset_match(
outcome_names=subset_outcome_names, datasets=datasets, exact_match=False
)
single_outcome_datasets = {
ds.outcome_names[0]: ds for ds in datasets if len(ds.outcome_names) == 1
}
multi_outcome_datasets = {
tuple(ds.outcome_names): ds for ds in datasets if len(ds.outcome_names) > 1
}
subset_datasets = []
outcomes_processed = []
for outcome_name in subset_outcome_names:
if outcome_name in outcomes_processed:
# This can happen if the outcome appears in a multi-outcome
# dataset that is already processed.
continue
if outcome_name in single_outcome_datasets:
# The default case of outcome with a corresponding dataset.
ds = single_outcome_datasets[outcome_name]
else:
# The case of outcome being part of a multi-outcome dataset.
for outcome_names in multi_outcome_datasets.keys():
if outcome_name in outcome_names:
ds = multi_outcome_datasets[outcome_names]
if not set(ds.outcome_names).issubset(subset_outcome_names):
raise UnsupportedError(
"Breaking up a multi-outcome dataset between "
"surrogates is not supported."
)
break
# Pyre-ignore [61]: `ds` may not be defined but it is guaranteed to be defined.
subset_datasets.append(ds)
outcomes_processed.extend(ds.outcome_names)
return subset_datasets


def subset_state_dict(
state_dict: OrderedDict[str, Tensor],
submodel_index: int,
Expand Down
116 changes: 1 addition & 115 deletions ax/models/torch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
import numpy as np
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxError, AxWarning, UnsupportedError
from ax.exceptions.core import AxWarning, UnsupportedError
from ax.models.torch.botorch_modular.utils import (
_get_shared_rows,
_tensor_difference,
check_outcome_dataset_match,
choose_botorch_acqf_class,
choose_model_class,
construct_acquisition_and_optimizer_options,
convert_to_block_design,
get_subset_datasets,
subset_state_dict,
use_model_list,
)
Expand Down Expand Up @@ -351,17 +348,6 @@ def test_use_model_list(self) -> None:
)
)

def test_tensor_difference(self) -> None:
n, m = 3, 2
A = torch.arange(n * m).reshape(n, m)
B = torch.cat((A[: n - 1], torch.randn(2, m)), dim=0)
# permute B
B = B[torch.randperm(len(B))]

C = _tensor_difference(A=A, B=B)

self.assertEqual(C.size(dim=0), 2)

def test_get_shared_rows(self) -> None:
X1 = torch.rand(4, 2)

Expand Down Expand Up @@ -518,106 +504,6 @@ def test_to_inequality_constraints(self) -> None:
self.assertTrue(torch.allclose(ineq_constraints[1][1], torch.tensor([-1])))
self.assertEqual(ineq_constraints[1][2], -2.0)

def test_check_check_outcome_dataset_match(self) -> None:
ds = self.fixed_noise_datasets[0]
# Simple test with one metric & dataset.
for exact_match in (True, False):
self.assertIsNone(
check_outcome_dataset_match(
outcome_names=ds.outcome_names,
datasets=[ds],
exact_match=exact_match,
)
)
# Error with duplicate outcome names.
with self.assertRaisesRegex(AxError, "duplicate outcome names"):
check_outcome_dataset_match(
outcome_names=["y", "y"], datasets=[ds], exact_match=False
)
ds2 = self.supervised_datasets[0]
# Error with duplicate outcomes in datasets.
with self.assertRaisesRegex(AxError, "duplicate outcomes"):
check_outcome_dataset_match(
outcome_names=["y", "y2"], datasets=[ds, ds2], exact_match=False
)
ds2.outcome_names = ["y2"]
# Simple test with two metrics & datasets.
for exact_match in (True, False):
self.assertIsNone(
check_outcome_dataset_match(
outcome_names=["y", "y2"],
datasets=[ds, ds2],
exact_match=exact_match,
)
)
# Exact match required but too many datasets provided.
with self.assertRaisesRegex(AxError, "must correspond to an outcome"):
check_outcome_dataset_match(
outcome_names=["y"],
datasets=[ds, ds2],
exact_match=True,
)
# The same check passes if we don't require exact match.
self.assertIsNone(
check_outcome_dataset_match(
outcome_names=["y"],
datasets=[ds, ds2],
exact_match=False,
)
)
# Error if metric doesn't exist in the datasets.
for exact_match in (True, False):
with self.assertRaisesRegex(AxError, "but the datasets model"):
check_outcome_dataset_match(
outcome_names=["z"],
datasets=[ds, ds2],
exact_match=exact_match,
)

def test_get_subset_datasets(self) -> None:
ds = self.fixed_noise_datasets[0]
ds2 = self.supervised_datasets[0]
ds2.outcome_names = ["y2"]
ds3 = SupervisedDataset(
X=torch.zeros(1, 2),
Y=torch.ones(1, 2),
feature_names=["x1", "x2"],
outcome_names=["y3", "y4"],
)
# Test with single dataset.
self.assertEqual(
[ds], get_subset_datasets(datasets=[ds], subset_outcome_names=["y"])
)
# Edge case of empty metric list.
self.assertEqual(
[], get_subset_datasets(datasets=[ds], subset_outcome_names=[])
)
# Multiple datasets, single metric.
self.assertEqual(
[ds],
get_subset_datasets(datasets=[ds, ds2, ds3], subset_outcome_names=["y"]),
)
self.assertEqual(
[ds2],
get_subset_datasets(datasets=[ds, ds2, ds3], subset_outcome_names=["y2"]),
)
# Multi-output dataset, 1 metric -- not allowed.
with self.assertRaisesRegex(UnsupportedError, "multi-outcome dataset"):
get_subset_datasets(datasets=[ds, ds2, ds3], subset_outcome_names=["y3"])
# Multiple datasets, multiple metrics -- datasets in the same order as metrics.
self.assertEqual(
[ds2, ds],
get_subset_datasets(
datasets=[ds, ds2, ds3], subset_outcome_names=["y2", "y"]
),
)
self.assertEqual(
[ds3, ds],
get_subset_datasets(
datasets=[ds, ds2, ds3], subset_outcome_names=["y3", "y", "y4"]
),
)

def test_subset_state_dict(self) -> None:
m0 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
m1 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))
Expand Down