Skip to content

Commit

Permalink
Use validate_input_scaling context manager to skip validation / silen…
Browse files Browse the repository at this point in the history
…ce warnings during cross validation (#3371)

Summary:

Previously, we were letting these validations happen, knowing that they'd fail, just to filter the raised warnings based on the expected text. BoTorch has a context manager that can disable these warnings, so let's use that instead.

Differential Revision: D69663108
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 14, 2025
1 parent e2c96be commit a3cda5c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 41 deletions.
17 changes: 5 additions & 12 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@

import json
import time
import warnings
from abc import ABC
from collections import OrderedDict
from collections.abc import MutableMapping
from copy import deepcopy
from dataclasses import dataclass, field

from logging import Logger
from typing import Any

Expand Down Expand Up @@ -50,7 +48,7 @@
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from botorch.exceptions.warnings import InputDataWarning
from botorch.settings import validate_input_scaling
from pyre_extensions import assert_is_instance, none_throws

logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -1004,15 +1002,10 @@ def cross_validate(
)

# Apply terminal transform, and get predictions.
with warnings.catch_warnings():
# Since each CV fold removes points from the training data, the remaining
# observations will not pass the standardization test. To avoid confusing
# users with this warning, we filter it out.
warnings.filterwarnings(
"ignore",
message=r"Data \(outcome observations\) is not standardized",
category=InputDataWarning,
)
# Since each CV fold removes points from the training data, the
# remaining observations will not pass the input scaling checks.
# To avoid confusing users with warnings, we disable these checks.
with validate_input_scaling(False):
cv_predictions = self._cross_validate(
search_space=search_space,
cv_training_data=cv_training_data,
Expand Down
16 changes: 5 additions & 11 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from __future__ import annotations

import warnings
from collections import defaultdict
from collections.abc import Callable, Iterable
from copy import deepcopy
Expand All @@ -31,7 +30,7 @@
ModelFitMetricProtocol,
std_of_the_standardized_error,
)
from botorch.exceptions.warnings import InputDataWarning
from botorch.settings import validate_input_scaling

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -161,15 +160,10 @@ def cross_validate(
) = model._transform_inputs_for_cv(
cv_training_data=cv_training_data, cv_test_points=cv_test_points
)
with warnings.catch_warnings():
# Since each CV fold removes points from the training data, the
# remaining observations will not pass the standardization test.
# To avoid confusing users with this warning, we filter it out.
warnings.filterwarnings(
"ignore",
message=r"Data \(outcome observations\) is not standardized",
category=InputDataWarning,
)
# Since each CV fold removes points from the training data, the
# remaining observations will not pass the input scaling checks.
# To avoid confusing users with warnings, we disable these checks.
with validate_input_scaling(False):
cv_test_predictions = model._cross_validate(
search_space=search_space,
cv_training_data=cv_training_data,
Expand Down
23 changes: 14 additions & 9 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ax.utils.common.docutils import copy_doc
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.settings import validate_input_scaling
from botorch.utils.datasets import SupervisedDataset
from pyre_extensions import assert_is_instance
from torch import Tensor
Expand Down Expand Up @@ -362,15 +363,19 @@ def cross_validate(
)

try:
self.fit(
datasets=datasets,
search_space_digest=search_space_digest,
# pyre-fixme [6]: state_dict() has a generic dict[str, Any] return type
# but it is actually an OrderedDict[str, Tensor].
state_dict=state_dict,
refit=self.refit_on_cv,
**additional_model_inputs,
)
# Since each CV fold removes points from the training data, the
# remaining observations will not pass the input scaling checks.
# To avoid confusing users with warnings, we disable these checks.
with validate_input_scaling(False):
self.fit(
datasets=datasets,
search_space_digest=search_space_digest,
# pyre-fixme [6]: state_dict() has a generic dict[str, Any]
# return type but it is actually an OrderedDict[str, Tensor].
state_dict=state_dict,
refit=self.refit_on_cv,
**additional_model_inputs,
)
X_test_prediction = self.predict(
X=X_test,
use_posterior_predictive=use_posterior_predictive,
Expand Down
14 changes: 5 additions & 9 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
RANK_CORRELATION,
)
from botorch.exceptions.errors import ModelFittingError
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
Expand All @@ -74,6 +73,7 @@
)
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.settings import validate_input_scaling
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
Expand Down Expand Up @@ -1153,14 +1153,10 @@ def cross_validate(
test_X = X[i : i + 1]
# fit model to all but one data point
# TODO: consider batchifying
with warnings.catch_warnings():
# Suppress BoTorch input standardization warnings here, since they're
# expected to be triggered due to subsetting of the data.
warnings.filterwarnings(
"ignore",
message=r"Data \(outcome observations\) is not standardized",
category=InputDataWarning,
)
# Since each CV fold removes points from the training data, the
# remaining observations will not pass the input scaling checks.
# To avoid confusing users with warnings, we disable these checks.
with validate_input_scaling(False):
loo_model = self._construct_model(
dataset=train_dataset,
search_space_digest=search_space_digest,
Expand Down

0 comments on commit a3cda5c

Please sign in to comment.