diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index dc1e918896f..e2c6825e0a5 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -8,13 +8,11 @@ import json import time -import warnings from abc import ABC from collections import OrderedDict from collections.abc import MutableMapping from copy import deepcopy from dataclasses import dataclass, field - from logging import Logger from typing import Any @@ -50,7 +48,7 @@ from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters from ax.models.types import TConfig from ax.utils.common.logger import get_logger -from botorch.exceptions.warnings import InputDataWarning +from botorch.settings import validate_input_scaling from pyre_extensions import assert_is_instance, none_throws logger: Logger = get_logger(__name__) @@ -1004,15 +1002,10 @@ def cross_validate( ) # Apply terminal transform, and get predictions. - with warnings.catch_warnings(): - # Since each CV fold removes points from the training data, the remaining - # observations will not pass the standardization test. To avoid confusing - # users with this warning, we filter it out. - warnings.filterwarnings( - "ignore", - message=r"Data \(outcome observations\) is not standardized", - category=InputDataWarning, - ) + # Since each CV fold removes points from the training data, the + # remaining observations will not pass the input scaling checks. + # To avoid confusing users with warnings, we disable these checks. + with validate_input_scaling(False): cv_predictions = self._cross_validate( search_space=search_space, cv_training_data=cv_training_data, diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index 4ffaffcadbf..058910b80e1 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -8,7 +8,6 @@ from __future__ import annotations -import warnings from collections import defaultdict from collections.abc import Callable, Iterable from copy import deepcopy @@ -31,7 +30,7 @@ ModelFitMetricProtocol, std_of_the_standardized_error, ) -from botorch.exceptions.warnings import InputDataWarning +from botorch.settings import validate_input_scaling logger: Logger = get_logger(__name__) @@ -161,15 +160,10 @@ def cross_validate( ) = model._transform_inputs_for_cv( cv_training_data=cv_training_data, cv_test_points=cv_test_points ) - with warnings.catch_warnings(): - # Since each CV fold removes points from the training data, the - # remaining observations will not pass the standardization test. - # To avoid confusing users with this warning, we filter it out. - warnings.filterwarnings( - "ignore", - message=r"Data \(outcome observations\) is not standardized", - category=InputDataWarning, - ) + # Since each CV fold removes points from the training data, the + # remaining observations will not pass the input scaling checks. + # To avoid confusing users with warnings, we disable these checks. + with validate_input_scaling(False): cv_test_predictions = model._cross_validate( search_space=search_space, cv_training_data=cv_training_data, diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index aa78a0a51dc..f96909bcf31 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -35,6 +35,7 @@ from ax.utils.common.docutils import copy_doc from botorch.acquisition.acquisition import AcquisitionFunction from botorch.models.deterministic import FixedSingleSampleModel +from botorch.settings import validate_input_scaling from botorch.utils.datasets import SupervisedDataset from pyre_extensions import assert_is_instance from torch import Tensor @@ -362,15 +363,19 @@ def cross_validate( ) try: - self.fit( - datasets=datasets, - search_space_digest=search_space_digest, - # pyre-fixme [6]: state_dict() has a generic dict[str, Any] return type - # but it is actually an OrderedDict[str, Tensor]. - state_dict=state_dict, - refit=self.refit_on_cv, - **additional_model_inputs, - ) + # Since each CV fold removes points from the training data, the + # remaining observations will not pass the input scaling checks. + # To avoid confusing users with warnings, we disable these checks. + with validate_input_scaling(False): + self.fit( + datasets=datasets, + search_space_digest=search_space_digest, + # pyre-fixme [6]: state_dict() has a generic dict[str, Any] + # return type but it is actually an OrderedDict[str, Tensor]. + state_dict=state_dict, + refit=self.refit_on_cv, + **additional_model_inputs, + ) X_test_prediction = self.predict( X=X_test, use_posterior_predictive=use_posterior_predictive, diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 55aa37019cb..3a49f59609f 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -62,7 +62,6 @@ RANK_CORRELATION, ) from botorch.exceptions.errors import ModelFittingError -from botorch.exceptions.warnings import InputDataWarning from botorch.models.model import Model from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP @@ -74,6 +73,7 @@ ) from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform from botorch.posteriors.gpytorch import GPyTorchPosterior +from botorch.settings import validate_input_scaling from botorch.utils.containers import SliceContainer from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset from botorch.utils.dispatcher import Dispatcher @@ -1153,14 +1153,10 @@ def cross_validate( test_X = X[i : i + 1] # fit model to all but one data point # TODO: consider batchifying - with warnings.catch_warnings(): - # Suppress BoTorch input standardization warnings here, since they're - # expected to be triggered due to subsetting of the data. - warnings.filterwarnings( - "ignore", - message=r"Data \(outcome observations\) is not standardized", - category=InputDataWarning, - ) + # Since each CV fold removes points from the training data, the + # remaining observations will not pass the input scaling checks. + # To avoid confusing users with warnings, we disable these checks. + with validate_input_scaling(False): loo_model = self._construct_model( dataset=train_dataset, search_space_digest=search_space_digest,