From efc957ed16e678a091c353c2bd686a303e9b4cb9 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 14 Feb 2025 07:27:22 -0800 Subject: [PATCH] Clean up multi-surrogate utilities Summary: Removes some used utilities that were originally introduced to ensure correct dataset splitting between multiple surrogates. Multiple surrogate support has since been deprecated and these are not needed. Reviewed By: sdaulton Differential Revision: D69632792 --- ax/models/torch/botorch_modular/model.py | 6 -- ax/models/torch/botorch_modular/utils.py | 108 +-------------------- ax/models/torch/tests/test_utils.py | 116 +---------------------- 3 files changed, 2 insertions(+), 228 deletions(-) diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index d7a8eb338b5..aa78a0a51dc 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -24,7 +24,6 @@ from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ( - check_outcome_dataset_match, choose_botorch_acqf_class, construct_acquisition_and_optimizer_options, ModelConfig, @@ -177,11 +176,6 @@ def fit( additional_model_inputs: Additional kwargs to pass to the model input constructor in ``Surrogate.fit``. """ - outcome_names = sum((ds.outcome_names for ds in datasets), []) - check_outcome_dataset_match( - outcome_names=outcome_names, datasets=datasets, exact_match=True - ) # Checks for duplicate outcome names - # Store search space info for later use (e.g. during generation) self._search_space_digest = search_space_digest diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 982d41cfaa0..343b982796c 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -15,7 +15,7 @@ import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxError, AxWarning, UnsupportedError +from ax.exceptions.core import AxWarning, UnsupportedError from ax.models.torch_base import TorchOptConfig from ax.models.types import TConfig from ax.utils.common.constants import Keys @@ -408,112 +408,6 @@ def fit_botorch_model( ) -def _tensor_difference(A: Tensor, B: Tensor) -> Tensor: - """Used to return B sans any Xs that also appear in A""" - C = torch.cat((A, B), dim=0) - D, inverse_ind = torch.unique(C, return_inverse=True, dim=0) - n = A.shape[0] - A_indices = inverse_ind[:n].tolist() - B_indices = inverse_ind[n:].tolist() - Bi_set = set(B_indices) - set(A_indices) - return D[list(Bi_set)] - - -def check_outcome_dataset_match( - outcome_names: Sequence[str], - datasets: Sequence[SupervisedDataset], - exact_match: bool, -) -> None: - """Check that the given outcome names match those of datasets. - - Based on `exact_match` we either require that outcome names are - a subset of all outcomes or require the them to be the same. - - Also checks that there are no duplicates in outcome names. - - Args: - outcome_names: A list of outcome names. - datasets: A list of `SupervisedDataset` objects. - exact_match: If True, outcome_names must be the same as the union of - outcome names of the datasets. Otherwise, we check that the - outcome_names are a subset of all outcomes. - - Raises: - ValueError: If there is no match. - """ - all_outcomes = sum((ds.outcome_names for ds in datasets), []) - set_all_outcomes = set(all_outcomes) - set_all_spec_outcomes = set(outcome_names) - if len(set_all_outcomes) != len(all_outcomes): - raise AxError("Found duplicate outcomes in the datasets.") - if len(set_all_spec_outcomes) != len(outcome_names): - raise AxError("Found duplicate outcome names.") - - if not exact_match: - if not set_all_spec_outcomes.issubset(set_all_outcomes): - raise AxError( - "Outcome names must be a subset of the outcome names of the datasets." - f"Got {outcome_names=} but the datasets model {set_all_outcomes}." - ) - elif set_all_spec_outcomes != set_all_outcomes: - raise AxError( - "Each outcome name must correspond to an outcome in the datasets. " - f"Got {outcome_names=} but the datasets model {set_all_outcomes}." - ) - - -def get_subset_datasets( - datasets: Sequence[SupervisedDataset], - subset_outcome_names: Sequence[str], -) -> list[SupervisedDataset]: - """Get the list of datasets corresponding to the given subset of - outcome names. This is used to separate out datasets that are - used by one surrogate. - - Args: - datasets: A list of `SupervisedDataset` objects. - subset_outcome_names: A list of outcome names to get datasets for. - - Returns: - A list of `SupervisedDataset` objects corresponding to the given - subset of outcome names. - """ - check_outcome_dataset_match( - outcome_names=subset_outcome_names, datasets=datasets, exact_match=False - ) - single_outcome_datasets = { - ds.outcome_names[0]: ds for ds in datasets if len(ds.outcome_names) == 1 - } - multi_outcome_datasets = { - tuple(ds.outcome_names): ds for ds in datasets if len(ds.outcome_names) > 1 - } - subset_datasets = [] - outcomes_processed = [] - for outcome_name in subset_outcome_names: - if outcome_name in outcomes_processed: - # This can happen if the outcome appears in a multi-outcome - # dataset that is already processed. - continue - if outcome_name in single_outcome_datasets: - # The default case of outcome with a corresponding dataset. - ds = single_outcome_datasets[outcome_name] - else: - # The case of outcome being part of a multi-outcome dataset. - for outcome_names in multi_outcome_datasets.keys(): - if outcome_name in outcome_names: - ds = multi_outcome_datasets[outcome_names] - if not set(ds.outcome_names).issubset(subset_outcome_names): - raise UnsupportedError( - "Breaking up a multi-outcome dataset between " - "surrogates is not supported." - ) - break - # Pyre-ignore [61]: `ds` may not be defined but it is guaranteed to be defined. - subset_datasets.append(ds) - outcomes_processed.extend(ds.outcome_names) - return subset_datasets - - def subset_state_dict( state_dict: OrderedDict[str, Tensor], submodel_index: int, diff --git a/ax/models/torch/tests/test_utils.py b/ax/models/torch/tests/test_utils.py index 035b5c52e34..d5f5c7e1e98 100644 --- a/ax/models/torch/tests/test_utils.py +++ b/ax/models/torch/tests/test_utils.py @@ -12,16 +12,13 @@ import numpy as np import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxError, AxWarning, UnsupportedError +from ax.exceptions.core import AxWarning, UnsupportedError from ax.models.torch.botorch_modular.utils import ( _get_shared_rows, - _tensor_difference, - check_outcome_dataset_match, choose_botorch_acqf_class, choose_model_class, construct_acquisition_and_optimizer_options, convert_to_block_design, - get_subset_datasets, subset_state_dict, use_model_list, ) @@ -351,17 +348,6 @@ def test_use_model_list(self) -> None: ) ) - def test_tensor_difference(self) -> None: - n, m = 3, 2 - A = torch.arange(n * m).reshape(n, m) - B = torch.cat((A[: n - 1], torch.randn(2, m)), dim=0) - # permute B - B = B[torch.randperm(len(B))] - - C = _tensor_difference(A=A, B=B) - - self.assertEqual(C.size(dim=0), 2) - def test_get_shared_rows(self) -> None: X1 = torch.rand(4, 2) @@ -518,106 +504,6 @@ def test_to_inequality_constraints(self) -> None: self.assertTrue(torch.allclose(ineq_constraints[1][1], torch.tensor([-1]))) self.assertEqual(ineq_constraints[1][2], -2.0) - def test_check_check_outcome_dataset_match(self) -> None: - ds = self.fixed_noise_datasets[0] - # Simple test with one metric & dataset. - for exact_match in (True, False): - self.assertIsNone( - check_outcome_dataset_match( - outcome_names=ds.outcome_names, - datasets=[ds], - exact_match=exact_match, - ) - ) - # Error with duplicate outcome names. - with self.assertRaisesRegex(AxError, "duplicate outcome names"): - check_outcome_dataset_match( - outcome_names=["y", "y"], datasets=[ds], exact_match=False - ) - ds2 = self.supervised_datasets[0] - # Error with duplicate outcomes in datasets. - with self.assertRaisesRegex(AxError, "duplicate outcomes"): - check_outcome_dataset_match( - outcome_names=["y", "y2"], datasets=[ds, ds2], exact_match=False - ) - ds2.outcome_names = ["y2"] - # Simple test with two metrics & datasets. - for exact_match in (True, False): - self.assertIsNone( - check_outcome_dataset_match( - outcome_names=["y", "y2"], - datasets=[ds, ds2], - exact_match=exact_match, - ) - ) - # Exact match required but too many datasets provided. - with self.assertRaisesRegex(AxError, "must correspond to an outcome"): - check_outcome_dataset_match( - outcome_names=["y"], - datasets=[ds, ds2], - exact_match=True, - ) - # The same check passes if we don't require exact match. - self.assertIsNone( - check_outcome_dataset_match( - outcome_names=["y"], - datasets=[ds, ds2], - exact_match=False, - ) - ) - # Error if metric doesn't exist in the datasets. - for exact_match in (True, False): - with self.assertRaisesRegex(AxError, "but the datasets model"): - check_outcome_dataset_match( - outcome_names=["z"], - datasets=[ds, ds2], - exact_match=exact_match, - ) - - def test_get_subset_datasets(self) -> None: - ds = self.fixed_noise_datasets[0] - ds2 = self.supervised_datasets[0] - ds2.outcome_names = ["y2"] - ds3 = SupervisedDataset( - X=torch.zeros(1, 2), - Y=torch.ones(1, 2), - feature_names=["x1", "x2"], - outcome_names=["y3", "y4"], - ) - # Test with single dataset. - self.assertEqual( - [ds], get_subset_datasets(datasets=[ds], subset_outcome_names=["y"]) - ) - # Edge case of empty metric list. - self.assertEqual( - [], get_subset_datasets(datasets=[ds], subset_outcome_names=[]) - ) - # Multiple datasets, single metric. - self.assertEqual( - [ds], - get_subset_datasets(datasets=[ds, ds2, ds3], subset_outcome_names=["y"]), - ) - self.assertEqual( - [ds2], - get_subset_datasets(datasets=[ds, ds2, ds3], subset_outcome_names=["y2"]), - ) - # Multi-output dataset, 1 metric -- not allowed. - with self.assertRaisesRegex(UnsupportedError, "multi-outcome dataset"): - get_subset_datasets(datasets=[ds, ds2, ds3], subset_outcome_names=["y3"]) - # Multiple datasets, multiple metrics -- datasets in the same order as metrics. - self.assertEqual( - [ds2, ds], - get_subset_datasets( - datasets=[ds, ds2, ds3], subset_outcome_names=["y2", "y"] - ), - ) - self.assertEqual( - [ds3, ds], - get_subset_datasets( - datasets=[ds, ds2, ds3], subset_outcome_names=["y3", "y", "y4"] - ), - ) - def test_subset_state_dict(self) -> None: m0 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1)) m1 = SingleTaskGP(train_X=torch.rand(5, 2), train_Y=torch.rand(5, 1))