From 6a003fa78182f6b302fd4b491d87c649e69c2c25 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 27 Feb 2025 20:50:21 -0800 Subject: [PATCH] Changes to Adapter constructors, require experiment (#3415) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3415 The motivation for this change is to bring us closer to relying on the `experiment` as the source of truth about the experiment state & attributes within the modeling layer. We currently support many inputs that are extracted from the `experiment`, just to be passed in alongside it. By making `experiment` a required input, we open the possibility of removing these extra inputs and extracting them directly from `experiment` where they're needed. Makes the following changes to Adapter constructors: - Requires keyword-only arguments. Positional inputs are no-longer supported. - Makes `experiment` required and `search_space` optional. - Re-orders inputs for consistency across sub-classes. In addition: - Removes `model` input to `Adapter._fit`. This is a private method that is only called through `fit_if_implemented` (with `self.model`). Accepting multiple inputs for the same argument only makes the code harder to reason about. - Removes class level attributes, some of which weren't initialized in `__init__`, leading to pyre complaints. All attributes are now initialized in `__init__`. This also eliminates misleading "optional" type hints with `None` default for `model`, which is never `None` in practice. - Removes `Adapter.update`, which has been deprecated for quite some time. - Initializing a Generator from registry with only `search_space` is being deprecated. It is temporarily supported using a dummy experiment for random & discrete adapters, which previously did not require an `experiment`. Reviewed By: mgarrard Differential Revision: D70103442 fbshipit-source-id: 1cf4cf9138e482e166bf7b7e74d5505cc8bb227b --- .../surrogate/lcbench/transfer_learning.py | 6 +- .../tests/test_dispatch_utils.py | 7 +- .../tests/test_generation_strategy.py | 18 +- ax/modelbridge/base.py | 83 ++-- ax/modelbridge/discrete.py | 58 ++- ax/modelbridge/factory.py | 6 +- ax/modelbridge/map_torch.py | 50 +-- ax/modelbridge/pairwise.py | 2 +- ax/modelbridge/random.py | 71 +-- ax/modelbridge/registry.py | 31 +- ax/modelbridge/tests/test_base_modelbridge.py | 407 +++++++----------- .../tests/test_discrete_modelbridge.py | 63 +-- ax/modelbridge/tests/test_factory.py | 21 +- .../tests/test_hierarchical_search_space.py | 2 +- .../tests/test_model_fit_metrics.py | 2 +- .../tests/test_random_modelbridge.py | 59 +-- ax/modelbridge/tests/test_registry.py | 71 +-- .../tests/test_torch_modelbridge.py | 36 +- ax/modelbridge/tests/test_transform_utils.py | 9 +- ax/modelbridge/tests/test_utils.py | 2 +- ax/modelbridge/torch.py | 87 ++-- .../tests/test_derelativize_transform.py | 45 +- .../tests/test_relativize_transform.py | 5 +- .../tests/test_winsorize_transform.py | 19 +- ax/plot/tests/test_feature_importances.py | 11 +- 25 files changed, 518 insertions(+), 653 deletions(-) diff --git a/ax/benchmark/problems/surrogate/lcbench/transfer_learning.py b/ax/benchmark/problems/surrogate/lcbench/transfer_learning.py index efbcea94074..a8800d4e19a 100644 --- a/ax/benchmark/problems/surrogate/lcbench/transfer_learning.py +++ b/ax/benchmark/problems/surrogate/lcbench/transfer_learning.py @@ -7,7 +7,6 @@ import os from collections.abc import Mapping - from typing import Any import torch @@ -23,6 +22,7 @@ from ax.modelbridge.registry import Cont_X_trans, Generators, Y_trans from ax.modelbridge.torch import TorchAdapter from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel +from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.testing.mock import skip_fit_gpytorch_mll_context_manager from botorch.models import SingleTaskGP @@ -133,7 +133,9 @@ def get_surrogate() -> TorchAdapter: data=obj["data"], transforms=Cont_X_trans + Y_trans, ) - mb.model.surrogate.model.load_state_dict(obj["state_dict"]) + assert_is_instance(mb.model, BoTorchGenerator).surrogate.model.load_state_dict( + obj["state_dict"] + ) return assert_is_instance(mb, TorchAdapter) name = f"LCBench_Surrogate_{dataset_name}:v1" diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index b77d12f3db9..9e3bd2bc2af 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -22,6 +22,7 @@ from ax.modelbridge.registry import Generators, MBM_X_trans, Mixed_transforms, Y_trans from ax.modelbridge.transforms.log_y import LogY from ax.modelbridge.transforms.winsorize import Winsorize +from ax.models.random.sobol import SobolGenerator from ax.models.winsorization_config import WinsorizationConfig from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -35,7 +36,7 @@ run_branin_experiment_with_generation_strategy, ) from ax.utils.testing.mock import mock_botorch_optimize -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws class TestDispatchUtils(TestCase): @@ -406,7 +407,9 @@ def test_setting_random_seed(self) -> None: ) sobol.gen(experiment=get_experiment(), n=1) # First model is actually a bridge, second is the Sobol engine. - self.assertEqual(none_throws(sobol.model).model.seed, 9) + self.assertEqual( + assert_is_instance(none_throws(sobol.model).model, SobolGenerator).seed, 9 + ) with self.subTest("warns if use_saasbo is true"): with self.assertLogs( diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 9792e2bc6ab..cdacfbd56c8 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -580,10 +580,13 @@ def test_sobol_MBM_strategy(self) -> None: ) ms = none_throws(g._model_state_after_gen).copy() # Compare the model state to Sobol state. - sobol_model = none_throws(gs.model).model + sobol_model = assert_is_instance( + none_throws(gs.model).model, SobolGenerator + ) self.assertTrue( np.array_equal( - ms.pop("generated_points"), sobol_model.generated_points + ms.pop("generated_points"), + none_throws(sobol_model.generated_points), ) ) # Replace expected seed with the one generated in __init__. @@ -714,9 +717,9 @@ def test_with_factory_function(self) -> None: """Checks that generation strategy works with custom factory functions. No information about the model should be saved on generator run.""" - def get_sobol(search_space: SearchSpace) -> RandomAdapter: + def get_sobol(experiment: Experiment) -> RandomAdapter: return RandomAdapter( - search_space=search_space, + experiment=experiment, model=SobolGenerator(), transforms=Cont_X_trans, ) @@ -1551,10 +1554,13 @@ def test_gs_with_generation_nodes(self) -> None: ) ms = none_throws(g._model_state_after_gen).copy() # Compare the model state to Sobol state. - sobol_model = none_throws(self.sobol_MBM_GS_nodes.model).model + sobol_model = assert_is_instance( + none_throws(self.sobol_MBM_GS_nodes.model).model, SobolGenerator + ) self.assertTrue( np.array_equal( - ms.pop("generated_points"), sobol_model.generated_points + ms.pop("generated_points"), + none_throws(sobol_model.generated_points), ) ) # Replace expected seed with the one generated in __init__. diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 5c4f9a32568..97e3d3cedea 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -44,6 +44,7 @@ from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.cast import Cast from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters +from ax.models.base import Generator from ax.models.types import TConfig from ax.utils.common.logger import get_logger from botorch.settings import validate_input_scaling @@ -89,18 +90,18 @@ class Adapter: receives appropriate inputs. Subclasses will implement what is here referred to as the "terminal - transform," which is a transform that changes types of the data and problem + transform", which is a transform that changes types of the data and problem specification. """ def __init__( self, - search_space: SearchSpace, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - model: Any, - transforms: Sequence[type[Transform]] | None = None, - experiment: Experiment | None = None, + *, + experiment: Experiment, + model: Generator, + search_space: SearchSpace | None = None, data: Data | None = None, + transforms: Sequence[type[Transform]] | None = None, transform_configs: Mapping[str, TConfig] | None = None, status_quo_name: str | None = None, status_quo_features: ObservationFeatures | None = None, @@ -116,14 +117,22 @@ def __init__( Applies transforms and fits model. Args: - experiment: Is used to get arm parameters. Is not mutated. - search_space: Search space for fitting the model. Constraints need - not be the same ones used in gen. RangeParameter bounds are - considered soft and will be expanded to match the range of the - data sent in for fitting, if expand_model_space is True. - data: Ax Data. - model: Interface will be specified in subclass. If model requires + experiment: An ``Experiment`` object representing the setup and the + current state of the experiment, including the search space, + trials and observation data. It is used to extract various + attributes, and is not mutated. + model: A ``Generator`` that is used for generating candidates. + Its interface will be specified in subclasses. If model requires initialization, that should be done prior to its use here. + search_space: An optional ``SearchSpace`` for fitting the model. + If not provided, `experiment.search_space` is used. + The search space may be modified during ``Adapter.gen``, e.g., + to try out a different set of parameter bounds or constraints. + The bounds of the ``RangeParameter``s are considered soft and + will be expanded to match the range of the data sent in for fitting, + if `expand_model_space` is True. + data: An optional ``Data`` object, containing mean and SEM observations. + If `None`, extracted using `experiment.lookup_data()`. transforms: List of uninitialized transform classes. Forward transforms will be applied in this order, and untransforms in the reverse order. @@ -134,8 +143,8 @@ def __init__( that arm. status_quo_features: ObservationFeatures to use as status quo. Either this or status_quo_name should be specified, not both. - optimization_config: Optimization config defining how to optimize - the model. + optimization_config: An optional ``OptimizationConfig`` defining how to + optimize the model. Defaults to `experiment.optimization_config`. expand_model_space: If True, expand range parameter bounds in model space to cover given training data. This will make the modeling space larger than the search space if training data fall outside @@ -178,6 +187,7 @@ def __init__( self._model_kwargs: dict[str, Any] | None = None self._bridge_kwargs: dict[str, Any] | None = None # The space used for optimization. + search_space = search_space or experiment.search_space self._search_space: SearchSpace = search_space.clone() # The space used for modeling. Might be larger than the optimization # space to cover training data. @@ -193,13 +203,12 @@ def __init__( experiment is not None and experiment.immutable_search_space_and_opt_config ) self._experiment_properties: dict[str, Any] = {} - self._experiment: Experiment | None = experiment + self._experiment: Experiment = experiment - if experiment is not None: - if self._optimization_config is None: - self._optimization_config = experiment.optimization_config - self._arms_by_signature = experiment.arms_by_signature - self._experiment_properties = experiment._properties + if self._optimization_config is None: + self._optimization_config = experiment.optimization_config + self._arms_by_signature = experiment.arms_by_signature + self._experiment_properties = experiment._properties if self._fit_tracking_metrics is False: if self._optimization_config is None: @@ -211,6 +220,7 @@ def __init__( # Set training data (in the raw / untransformed space). This also omits # out-of-design and abandoned observations depending on the corresponding flags. + data = data if data is not None else experiment.lookup_data() observations_raw = self._prepare_observations(experiment=experiment, data=data) if expand_model_space: self._set_model_space(observations=observations_raw) @@ -258,11 +268,7 @@ def _fit_if_implemented( """ try: t_fit_start = time.monotonic() - self._fit( - model=self.model, - search_space=search_space, - observations=observations, - ) + self._fit(search_space=search_space, observations=observations) increment = time.monotonic() - t_fit_start + time_so_far self.fit_time += increment self.fit_time_since_gen += increment @@ -476,7 +482,7 @@ def _set_status_quo( if status_quo_name is not None: if status_quo_features is not None: - raise ValueError( + raise UserInputError( "Specify either status_quo_name or status_quo_features, not both." ) sq_obs = [ @@ -595,8 +601,6 @@ def training_in_design(self, training_in_design: list[bool]) -> None: def _fit( self, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - model: Any, search_space: SearchSpace, observations: list[Observation], ) -> None: @@ -735,21 +739,6 @@ def _predict( f"{self.__class__.__name__} does not implement `_predict`." ) - def update(self, new_data: Data, experiment: Experiment) -> None: - """Update the model bridge and the underlying model with new data. This - method should be used instead of `fit`, in cases where the underlying - model does not need to be re-fit from scratch, but rather updated. - - Note: `update` expects only new data (obtained since the model initialization - or last update) to be passed in, not all data in the experiment. - - Args: - new_data: Data from the experiment obtained since the last call to - `update`. - experiment: Experiment, in which this data was obtained. - """ - raise DeprecationWarning("Adapter.update is deprecated. Use `fit` instead.") - def _get_transformed_gen_args( self, search_space: SearchSpace, @@ -1079,14 +1068,12 @@ def _get_serialized_model_state(self) -> dict[str, Any]: """Obtains the state of the underlying model (if using a stateful one) in a readily JSON-serializable form. """ - model = none_throws(self.model) - return model.serialize_state(raw_state=model._get_state()) + return self.model.serialize_state(raw_state=self.model._get_state()) def _deserialize_model_state( self, serialized_state: dict[str, Any] ) -> dict[str, Any]: - model = none_throws(self.model) - return model.deserialize_state(serialized_state=serialized_state) + return self.model.deserialize_state(serialized_state=serialized_state) def feature_importances(self, metric_name: str) -> dict[str, float]: """Computes feature importances for a single metric. diff --git a/ax/modelbridge/discrete.py b/ax/modelbridge/discrete.py index a0f3f656bb4..5517db0683e 100644 --- a/ax/modelbridge/discrete.py +++ b/ax/modelbridge/discrete.py @@ -7,6 +7,10 @@ # pyre-strict +from typing import Mapping, Sequence + +from ax.core.data import Data +from ax.core.experiment import Experiment from ax.core.observation import ( Observation, ObservationData, @@ -28,6 +32,7 @@ extract_outcome_constraints, validate_transformed_optimization_config, ) +from ax.modelbridge.transforms.base import Transform from ax.models.discrete_base import DiscreteGenerator from ax.models.types import TConfig @@ -38,25 +43,56 @@ class DiscreteAdapter(Adapter): """A model bridge for using models based on discrete parameters. - Requires that all parameters have been transformed to ChoiceParameters. + Requires that all parameters to have been transformed to ChoiceParameters. """ - # pyre-fixme[13]: Attribute `model` is never initialized. - model: DiscreteGenerator - # pyre-fixme[13]: Attribute `outcomes` is never initialized. - outcomes: list[str] - # pyre-fixme[13]: Attribute `parameters` is never initialized. - parameters: list[str] - # pyre-fixme[13]: Attribute `search_space` is never initialized. - search_space: SearchSpace | None + def __init__( + self, + *, + experiment: Experiment, + model: DiscreteGenerator, + search_space: SearchSpace | None = None, + data: Data | None = None, + transforms: Sequence[type[Transform]] | None = None, + transform_configs: Mapping[str, TConfig] | None = None, + status_quo_name: str | None = None, + status_quo_features: ObservationFeatures | None = None, + optimization_config: OptimizationConfig | None = None, + expand_model_space: bool = True, + fit_out_of_design: bool = False, + fit_abandoned: bool = False, + fit_tracking_metrics: bool = True, + fit_on_init: bool = True, + fit_only_completed_map_metrics: bool = True, + ) -> None: + # These are set in _fit. + self.parameters: list[str] = [] + self.outcomes: list[str] = [] + super().__init__( + experiment=experiment, + model=model, + search_space=search_space, + data=data, + transforms=transforms, + transform_configs=transform_configs, + status_quo_name=status_quo_name, + status_quo_features=status_quo_features, + optimization_config=optimization_config, + expand_model_space=expand_model_space, + fit_out_of_design=fit_out_of_design, + fit_abandoned=fit_abandoned, + fit_tracking_metrics=fit_tracking_metrics, + fit_on_init=fit_on_init, + fit_only_completed_map_metrics=fit_only_completed_map_metrics, + ) + # Re-assing for more precise typing. + self.model: DiscreteGenerator = model def _fit( self, - model: DiscreteGenerator, search_space: SearchSpace, observations: list[Observation], ) -> None: - self.model = model # Convert observations to arrays self.parameters = list(search_space.parameters.keys()) all_metric_names: set[str] = set() diff --git a/ax/modelbridge/factory.py b/ax/modelbridge/factory.py index 387775481bb..8df2c2efef7 100644 --- a/ax/modelbridge/factory.py +++ b/ax/modelbridge/factory.py @@ -74,7 +74,7 @@ def get_sobol( """ return assert_is_instance( Generators.SOBOL( - search_space=search_space, + experiment=Experiment(search_space=search_space), seed=seed, deduplicate=deduplicate, init_position=init_position, @@ -99,7 +99,9 @@ def get_uniform( """ return assert_is_instance( Generators.UNIFORM( - search_space=search_space, seed=seed, deduplicate=deduplicate + experiment=Experiment(search_space=search_space), + seed=seed, + deduplicate=deduplicate, ), RandomAdapter, ) diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 068c85804cd..2d4d4241bd9 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -56,13 +56,13 @@ class should be used in the case where `model` makes use of map_key values. def __init__( self, + *, experiment: Experiment, - search_space: SearchSpace, - data: Data, model: TorchGenerator, - transforms: Sequence[type[Transform]], + search_space: SearchSpace | None = None, + data: Data | None = None, + transforms: Sequence[type[Transform]] | None = None, transform_configs: Mapping[str, TConfig] | None = None, - torch_device: torch.device | None = None, status_quo_name: str | None = None, status_quo_features: ObservationFeatures | None = None, optimization_config: OptimizationConfig | None = None, @@ -70,57 +70,27 @@ def __init__( fit_on_init: bool = True, fit_abandoned: bool = False, default_model_gen_options: TConfig | None = None, + torch_device: torch.device | None = None, map_data_limit_rows_per_metric: int | None = None, map_data_limit_rows_per_group: int | None = None, ) -> None: - """ - Applies transforms and fits model. + """In addition to common arguments documented in the ``Adapter`` and + ``TorchAdapter`` classes, ``MapTorchAdapter`` accepts the following arguments. Args: - experiment: Is used to get arm parameters. Is not mutated. - search_space: Search space for fitting the model. Constraints need - not be the same ones used in gen. - data: Ax Data. - model: Interface will be specified in subclass. If model requires - initialization, that should be done prior to its use here. - transforms: List of uninitialized transform classes. Forward - transforms will be applied in this order, and untransforms in - the reverse order. - transform_configs: A dictionary from transform name to the - transform config dictionary. - torch_device: Torch device. - status_quo_name: Name of the status quo arm. Can only be used if - Data has a single set of ObservationFeatures corresponding to - that arm. - status_quo_features: ObservationFeatures to use as status quo. - Either this or status_quo_name should be specified, not both. - optimization_config: Optimization config defining how to optimize - the model. - fit_out_of_design: If specified, all training data is returned. - Otherwise, only in design points are returned. - fit_on_init: Whether to fit the model on initialization. This can - be used to skip model fitting when a fitted model is not needed. - To fit the model afterwards, use `_process_and_transform_data` - to get the transformed inputs and call `_fit_if_implemented` with - the transformed inputs. - fit_abandoned: Whether data for abandoned arms or trials should be - included in model training data. If ``False``, only - non-abandoned points are returned. - default_model_gen_options: Options passed down to `model.gen(...)`. map_data_limit_rows_per_metric: Subsample the map data so that the total number of rows per metric is limited by this value. map_data_limit_rows_per_group: Subsample the map data so that the number of rows in the `map_key` column for each (arm, metric) is limited by this value. """ - + data = data or experiment.lookup_data() if not isinstance(data, MapData): raise ValueError("`MapTorchAdapter expects `MapData` instead of `Data`.") if any(isinstance(t, BatchTrial) for t in experiment.trials.values()): raise ValueError("MapTorchAdapter does not support batch trials.") - # pyre-fixme[4]: Attribute must be annotated. - self._map_key_features = data.map_keys + self._map_key_features: list[str] = data.map_keys self._map_data_limit_rows_per_metric = map_data_limit_rows_per_metric self._map_data_limit_rows_per_group = map_data_limit_rows_per_group @@ -187,7 +157,6 @@ def _predict( def _fit( self, - model: TorchGenerator, search_space: SearchSpace, observations: list[Observation], parameters: list[str] | None = None, @@ -200,7 +169,6 @@ def _fit( if parameters is None: parameters = self.parameters_with_map_keys super()._fit( - model=model, search_space=search_space, observations=observations, parameters=parameters, diff --git a/ax/modelbridge/pairwise.py b/ax/modelbridge/pairwise.py index 74936f6d9a1..e3fbaf0c0c7 100644 --- a/ax/modelbridge/pairwise.py +++ b/ax/modelbridge/pairwise.py @@ -45,7 +45,7 @@ def _convert_observations( ( Xs, Ys, - Yvars, + _, # Yvars is not used here. candidate_metadata_dict, any_candidate_metadata_is_not_none, trial_indices, diff --git a/ax/modelbridge/random.py b/ax/modelbridge/random.py index 55a1c8cbda5..ccf4a93397f 100644 --- a/ax/modelbridge/random.py +++ b/ax/modelbridge/random.py @@ -8,7 +8,6 @@ from collections.abc import Mapping, Sequence -from typing import Any from ax.core.data import Data from ax.core.experiment import Experiment @@ -28,70 +27,21 @@ from ax.models.types import TConfig -FIT_MODEL_ERROR = "Model must be fit before {action}." - - class RandomAdapter(Adapter): - """A model bridge for using purely random 'models'. + """An adaptor for using purely random ``RandomGenerator``s. Data and optimization configs are not required. - This model bridge interfaces with RandomGenerator. - - Attributes: - model: A RandomGenerator used to generate candidates - (note: this an awkward use of the word 'model'). - parameters: Params found in search space on modelbridge init. - - Args: - experiment: Is used to get arm parameters. Is not mutated. - search_space: Search space for fitting the model. Constraints need - not be the same ones used in gen. RangeParameter bounds are - considered soft and will be expanded to match the range of the - data sent in for fitting, if expand_model_space is True. - data: Ax Data. - model: Interface will be specified in subclass. If model requires - initialization, that should be done prior to its use here. - transforms: List of uninitialized transform classes. Forward - transforms will be applied in this order, and untransforms in - the reverse order. - transform_configs: A dictionary from transform name to the - transform config dictionary. - status_quo_name: Name of the status quo arm. Can only be used if - Data has a single set of ObservationFeatures corresponding to - that arm. - status_quo_features: ObservationFeatures to use as status quo. - Either this or status_quo_name should be specified, not both. - optimization_config: Optimization config defining how to optimize - the model. - fit_out_of_design: If specified, all training data are used. - Otherwise, only in design points are used. - fit_abandoned: Whether data for abandoned arms or trials should be - included in model training data. If ``False``, only - non-abandoned points are returned. - fit_tracking_metrics: Whether to fit a model for tracking metrics. - Setting this to False will improve runtime at the expense of - models not being available for predicting tracking metrics. - NOTE: This can only be set to False when the optimization config - is provided. - fit_on_init: Whether to fit the model on initialization. This can - be used to skip model fitting when a fitted model is not needed. - To fit the model afterwards, use `_process_and_transform_data` - to get the transformed inputs and call `_fit_if_implemented` with - the transformed inputs. + Please refer to base ``Adapter`` class for documentation of constructor arguments. """ - model: RandomGenerator - # pyre-fixme[13]: Attribute `parameters` is never initialized. - parameters: list[str] - def __init__( self, - search_space: SearchSpace, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - model: Any, - transforms: Sequence[type[Transform]] | None = None, - experiment: Experiment | None = None, + *, + experiment: Experiment, + model: RandomGenerator, + search_space: SearchSpace | None = None, data: Data | None = None, + transforms: Sequence[type[Transform]] | None = None, transform_configs: Mapping[str, TConfig] | None = None, status_quo_name: str | None = None, status_quo_features: ObservationFeatures | None = None, @@ -101,6 +51,7 @@ def __init__( fit_tracking_metrics: bool = True, fit_on_init: bool = True, ) -> None: + self.parameters: list[str] = [] super().__init__( search_space=search_space, model=model, @@ -117,15 +68,15 @@ def __init__( fit_tracking_metrics=fit_tracking_metrics, fit_on_init=fit_on_init, ) + # Re-assign for more precise typing. + self.model: RandomGenerator = model def _fit( self, - model: RandomGenerator, search_space: SearchSpace, observations: list[Observation] | None = None, ) -> None: - self.model = model - # Extract and fix parameters from initial search space. + """Extracts the list of parameters from the search space.""" self.parameters = list(search_space.parameters.keys()) def _gen( diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 4f3afe860b2..21a2d285170 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -17,6 +17,7 @@ from __future__ import annotations +import warnings from collections.abc import Mapping, Sequence from enum import Enum from inspect import isfunction, signature @@ -27,6 +28,7 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.search_space import SearchSpace +from ax.exceptions.core import UserInputError from ax.modelbridge.base import Adapter from ax.modelbridge.discrete import DiscreteAdapter from ax.modelbridge.random import RandomAdapter @@ -289,13 +291,34 @@ def __call__( silently_filter_kwargs: bool = False, **kwargs: Any, ) -> Adapter: - assert self.value in MODEL_KEY_TO_MODEL_SETUP, f"Unknown model {self.value}" - # All model bridges require either a search space or an experiment. - assert search_space or experiment, "Search space or experiment required." - search_space = search_space or none_throws(experiment).search_space + if self.value not in MODEL_KEY_TO_MODEL_SETUP: + raise UserInputError(f"Unknown model {self.value}") model_setup_info = MODEL_KEY_TO_MODEL_SETUP[self.value] model_class = model_setup_info.model_class bridge_class = model_setup_info.bridge_class + if experiment is None: + # Some Adapters used to accept search_space as the only input. + # Temporarily support it with a deprecation warning. + if ( + issubclass(bridge_class, (RandomAdapter, DiscreteAdapter)) + and search_space is not None + ): + warnings.warn( + "Passing in a `search_space` to initialize a generator from a " + "registry is being deprecated. `experiment` is now a required " + "input for initializing `Adapters`. Please use `experiment` " + "when initializing generators going forward. " + "Support for `search_space` will be removed in Ax 0.7.0.", + DeprecationWarning, + stacklevel=2, + ) + # Construct a dummy experiment for temporary support. + experiment = Experiment(search_space=search_space) + else: + raise UserInputError( + "`experiment` is required to initialize a model from registry." + ) + search_space = search_space or none_throws(experiment).search_space if not silently_filter_kwargs: # Check correct kwargs are present diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index 55bc9b7d1ea..0805e6167b6 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -14,7 +14,6 @@ import numpy as np import torch from ax.core.arm import Arm -from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.map_data import MapData from ax.core.metric import Metric @@ -86,38 +85,25 @@ def test_Adapter( # Test that on init transforms are stored and applied in the correct order transforms = [transform_1, transform_2] exp = get_experiment_for_value() - ss = get_search_space_for_value() - modelbridge = Adapter( - search_space=ss, - model=Generator(), - transforms=transforms, - experiment=exp, - data=Data(), - ) - self.assertFalse( - modelbridge._experiment_has_immutable_search_space_and_opt_config - ) + adapter = Adapter(experiment=exp, model=Generator(), transforms=transforms) + self.assertFalse(adapter._experiment_has_immutable_search_space_and_opt_config) self.assertEqual( - list(modelbridge.transforms.keys()), ["Cast", "transform_1", "transform_2"] + list(adapter.transforms.keys()), ["Cast", "transform_1", "transform_2"] ) fit_args = mock_fit.mock_calls[0][2] self.assertTrue(fit_args["search_space"] == get_search_space_for_value(8.0)) self.assertTrue(fit_args["observations"] == []) self.assertTrue(mock_observations_from_data.called) - # Test deprecation error on update. - with self.assertRaisesRegex(DeprecationWarning, "Adapter.update"): - modelbridge.update(Mock(), Mock()) - # Test prediction with arms. with self.assertRaisesRegex( UserInputError, "Input to predict must be a list of `ObservationFeatures`." ): # pyre-ignore[6]: Intentionally wrong argument type. - modelbridge.predict([Arm(parameters={"x": 1.0})]) + adapter.predict([Arm(parameters={"x": 1.0})]) # Test prediction on out of design features. - modelbridge._predict = mock.MagicMock( + adapter._predict = mock.MagicMock( "ax.modelbridge.base.Adapter._predict", autospec=True, side_effect=ValueError("Out of Design"), @@ -127,11 +113,11 @@ def test_Adapter( Adapter, "model_space", return_value=get_search_space_for_range_values ): with self.assertRaises(ValueError): - modelbridge.predict([get_observation2().features]) + adapter.predict([get_observation2().features]) # This point is out of design, and not in training data. with self.assertRaises(ValueError): - modelbridge.predict([get_observation_status_quo0().features]) + adapter.predict([get_observation_status_quo0().features]) # Now it's in the training data. with mock.patch.object( @@ -141,7 +127,7 @@ def test_Adapter( ): # Return raw training value. self.assertEqual( - modelbridge.predict([get_observation_status_quo0().features]), + adapter.predict([get_observation_status_quo0().features]), unwrap_observation_data([get_observation_status_quo0().data]), ) @@ -151,18 +137,18 @@ def test_Adapter( autospec=True, return_value=[get_observation2trans().data], ) - modelbridge._predict = mock_predict - modelbridge.predict([get_observation2().features]) + adapter._predict = mock_predict + adapter.predict([get_observation2().features]) # Observation features sent to _predict are un-transformed afterwards mock_predict.assert_called_with([get_observation2().features]) # Check that _single_predict is equivalent here. - modelbridge._single_predict([get_observation2().features]) + adapter._single_predict([get_observation2().features]) # Observation features sent to _predict are un-transformed afterwards mock_predict.assert_called_with([get_observation2().features]) # Test transforms applied on gen - modelbridge._gen = mock.MagicMock( + adapter._gen = mock.MagicMock( "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults( @@ -170,13 +156,13 @@ def test_Adapter( ), ) oc = get_optimization_config_no_constraints() - modelbridge._set_kwargs_to_save( + adapter._set_kwargs_to_save( model_key="TestModel", model_kwargs={}, bridge_kwargs={} ) # Test input error when generating 0 candidates. with self.assertRaisesRegex(UserInputError, "Attempted to generate"): - modelbridge.gen(n=0) - gr = modelbridge.gen( + adapter.gen(n=0) + gr = adapter.gen( n=1, search_space=get_search_space_for_value(), optimization_config=oc, @@ -185,7 +171,7 @@ def test_Adapter( ) self.assertEqual(gr._model_key, "TestModel") # pyre-fixme[16]: Callable `_gen` has no attribute `assert_called_with`. - modelbridge._gen.assert_called_with( + adapter._gen.assert_called_with( n=1, search_space=SearchSpace([FixedParameter("x", ParameterType.FLOAT, 8.0)]), optimization_config=oc, @@ -198,10 +184,10 @@ def test_Adapter( ) # Gen with no pending observations and no fixed features - modelbridge.gen( + adapter.gen( n=1, search_space=get_search_space_for_value(), optimization_config=None ) - modelbridge._gen.assert_called_with( + adapter._gen.assert_called_with( n=1, search_space=SearchSpace([FixedParameter("x", ParameterType.FLOAT, 8.0)]), optimization_config=None, @@ -216,10 +202,10 @@ def test_Adapter( metrics=[Metric(name="test_metric"), Metric(name="test_metric_2")] ) ) - modelbridge.gen( + adapter.gen( n=1, search_space=get_search_space_for_value(), optimization_config=oc2 ) - modelbridge._gen.assert_called_with( + adapter._gen.assert_called_with( n=1, search_space=SearchSpace([FixedParameter("x", ParameterType.FLOAT, 8.0)]), optimization_config=oc2, @@ -248,7 +234,7 @@ def warn_and_return_mock_obs( autospec=True, side_effect=warn_and_return_mock_obs, ) - modelbridge._cross_validate = mock_cv + adapter._cross_validate = mock_cv cv_training_data = [get_observation2()] cv_test_points = [get_observation1().features] @@ -257,7 +243,7 @@ def warn_and_return_mock_obs( transformed_cv_training_data, transformed_cv_test_points, transformed_ss, - ) = modelbridge._transform_inputs_for_cv( + ) = adapter._transform_inputs_for_cv( cv_training_data=cv_training_data, cv_test_points=cv_test_points ) self.assertEqual(transformed_cv_training_data, [get_observation2trans()]) @@ -267,7 +253,7 @@ def warn_and_return_mock_obs( ) with warnings.catch_warnings(record=True) as ws: - cv_predictions = modelbridge.cross_validate( + cv_predictions = adapter.cross_validate( cv_training_data=cv_training_data, cv_test_points=cv_test_points ) self.assertTrue(called) @@ -282,7 +268,7 @@ def warn_and_return_mock_obs( self.assertTrue(cv_predictions == [get_observation1().data]) # Test use_posterior_predictive in CV - modelbridge.cross_validate( + adapter.cross_validate( cv_training_data=cv_training_data, cv_test_points=cv_test_points, use_posterior_predictive=True, @@ -296,53 +282,45 @@ def warn_and_return_mock_obs( ) # Test stored training data - obs = modelbridge.get_training_data() + obs = adapter.get_training_data() self.assertTrue(obs == [get_observation1(), get_observation2()]) - self.assertEqual(modelbridge.metric_names, {"a", "b"}) - self.assertIsNone(modelbridge.status_quo) - self.assertTrue(modelbridge.model_space == get_search_space_for_value()) - self.assertEqual(modelbridge.training_in_design, [False, False]) + self.assertEqual(adapter.metric_names, {"a", "b"}) + self.assertIsNone(adapter.status_quo) + self.assertTrue(adapter.model_space == get_search_space_for_value()) + self.assertEqual(adapter.training_in_design, [False, False]) with self.assertRaises(ValueError): - modelbridge.training_in_design = [True, True, False] + adapter.training_in_design = [True, True, False] with self.assertRaises(ValueError): - modelbridge.training_in_design = [True, True, False] + adapter.training_in_design = [True, True, False] # Test feature_importances with self.assertRaises(NotImplementedError): - modelbridge.feature_importances("a") + adapter.feature_importances("a") # Test transform observation features with mock.patch( "ax.modelbridge.base.Adapter._transform_observation_features", autospec=True, ) as mock_tr: - modelbridge.transform_observation_features([get_observation2().features]) - mock_tr.assert_called_with(modelbridge, [get_observation2trans().features]) + adapter.transform_observation_features([get_observation2().features]) + mock_tr.assert_called_with(adapter, [get_observation2trans().features]) # Test that fit is not called when fit_on_init = False. mock_fit.reset_mock() - modelbridge = Adapter( - search_space=ss, - model=Generator(), - fit_on_init=False, - ) + adapter = Adapter(experiment=exp, model=Generator(), fit_on_init=False) self.assertEqual(mock_fit.call_count, 0) # Test error when fit_tracking_metrics is False and optimization # config is not specified. with self.assertRaisesRegex(UserInputError, "fit_tracking_metrics"): - Adapter( - search_space=ss, - model=Generator(), - fit_tracking_metrics=False, - ) + Adapter(experiment=exp, model=Generator(), fit_tracking_metrics=False) # Test error when fit_tracking_metrics is False and optimization # config is updated to include new metrics. - modelbridge = Adapter( - search_space=ss, + adapter = Adapter( + experiment=exp, model=Generator(), optimization_config=oc, fit_tracking_metrics=False, @@ -351,7 +329,7 @@ def warn_and_return_mock_obs( objective=Objective(metric=Metric(name="test_metric2"), minimize=False), ) with self.assertRaisesRegex(UnsupportedError, "fit_tracking_metrics"): - modelbridge.gen(n=1, optimization_config=new_oc) + adapter.gen(n=1, optimization_config=new_oc) @mock.patch( "ax.modelbridge.base.observations_from_data", @@ -364,27 +342,24 @@ def warn_and_return_mock_obs( return_value=([Arm(parameters={})], None), ) @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) - def test_repeat_candidates( - self, mock_fit: Mock, mock_gen_arms: Mock, mock_observations_from_data: Mock - ) -> None: - modelbridge = Adapter( - search_space=get_search_space_for_value(), - model=Generator(), + def test_repeat_candidates(self, _: Mock, __: Mock, ___: Mock) -> None: + adapter = Adapter( experiment=get_experiment_for_value(), + model=Generator(), ) # mock _gen to return 1 result - modelbridge._gen = mock.MagicMock( + adapter._gen = mock.MagicMock( "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults( observation_features=[get_observation1trans().features], weights=[2] ), ) - modelbridge._set_kwargs_to_save( + adapter._set_kwargs_to_save( model_key="TestModel", model_kwargs={}, bridge_kwargs={} ) with self.assertLogs("ax", level="INFO") as cm: - modelbridge.gen( + adapter.gen( n=2, ) self.assertTrue( @@ -396,7 +371,7 @@ def test_repeat_candidates( ) with self.assertLogs("ax", level="INFO") as cm: - modelbridge.gen( + adapter.gen( n=1, ) get_logger("ax").info("log to prevent error if there are no other logs") @@ -421,40 +396,35 @@ def test_with_status_quo(self, mock_fit: Mock, mock_gen_arms: Mock) -> None: with_status_quo=True, with_completed_trial=True, ) - modelbridge = Adapter( - search_space=exp.search_space, + adapter = Adapter( + experiment=exp, model=Generator(), transforms=Y_trans, - experiment=exp, - data=exp.lookup_data(), - ) - self.assertIsNotNone(modelbridge.status_quo) - self.assertEqual( - modelbridge.status_quo.features.parameters, {"x1": 0.0, "x2": 0.0} ) + self.assertIsNotNone(adapter.status_quo) + self.assertEqual(adapter.status_quo.features.parameters, {"x1": 0.0, "x2": 0.0}) @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) @mock.patch("ax.modelbridge.base.Adapter._gen", autospec=True) - def test_timing(self, mock_fit: Mock, mock_gen: Mock) -> None: + def test_timing(self, _: Mock, __: Mock) -> None: search_space = get_search_space_for_value() - modelbridge = Adapter( - search_space=search_space, model=Generator(), fit_on_init=False - ) - self.assertEqual(modelbridge.fit_time, 0.0) - modelbridge._fit_if_implemented( + experiment = Experiment(search_space=search_space) + adapter = Adapter(experiment=experiment, model=Generator(), fit_on_init=False) + self.assertEqual(adapter.fit_time, 0.0) + adapter._fit_if_implemented( search_space=search_space, observations=[], time_so_far=3.0 ) - modelbridge._fit_if_implemented( + adapter._fit_if_implemented( search_space=search_space, observations=[], time_so_far=2.0 ) - modelbridge._fit_if_implemented( + adapter._fit_if_implemented( search_space=search_space, observations=[], time_so_far=1.0 ) - self.assertAlmostEqual(modelbridge.fit_time, 6.0, places=1) - self.assertAlmostEqual(modelbridge.fit_time_since_gen, 6.0, places=1) - modelbridge.gen(1) - self.assertAlmostEqual(modelbridge.fit_time, 6.0, places=1) - self.assertAlmostEqual(modelbridge.fit_time_since_gen, 0.0, places=1) + self.assertAlmostEqual(adapter.fit_time, 6.0, places=1) + self.assertAlmostEqual(adapter.fit_time_since_gen, 6.0, places=1) + adapter.gen(1) + self.assertAlmostEqual(adapter.fit_time, 6.0, places=1) + self.assertAlmostEqual(adapter.fit_time_since_gen, 0.0, places=1) @mock.patch( "ax.modelbridge.base.observations_from_data", @@ -463,43 +433,35 @@ def test_timing(self, mock_fit: Mock, mock_gen: Mock) -> None: ) def test_ood_gen(self, _) -> None: # Test fit_out_of_design by returning OOD candidats - exp = get_experiment_for_value() ss = SearchSpace([RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0)]) - modelbridge = Adapter( - search_space=ss, + experiment = Experiment(search_space=ss) + adapter = Adapter( + experiment=experiment, model=Generator(), - transforms=[], - experiment=exp, - # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - data=0, fit_out_of_design=True, ) obs = ObservationFeatures(parameters={"x": 3.0}) - modelbridge._gen = mock.MagicMock( + adapter._gen = mock.MagicMock( "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults(observation_features=[obs], weights=[2]), ) - gr = modelbridge.gen(n=1) + gr = adapter.gen(n=1) self.assertEqual(gr.arms[0].parameters, obs.parameters) # Test clamping arms by setting fit_out_of_design=False - modelbridge = Adapter( - search_space=ss, + adapter = Adapter( + experiment=experiment, model=Generator(), - transforms=[], - experiment=exp, - # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - data=0, fit_out_of_design=False, ) obs = ObservationFeatures(parameters={"x": 3.0}) - modelbridge._gen = mock.MagicMock( + adapter._gen = mock.MagicMock( "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults(observation_features=[obs], weights=[2]), ) - gr = modelbridge.gen(n=1) + gr = adapter.gen(n=1) self.assertEqual(gr.arms[0].parameters, {"x": 1.0}) @mock.patch( @@ -508,88 +470,58 @@ def test_ood_gen(self, _) -> None: return_value=([get_observation1()]), ) @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def test_SetStatusQuo(self, mock_fit, mock_observations_from_data): - # NOTE: If empty data object is not passed, observations are not - # extracted, even with mock. - modelbridge = Adapter( - search_space=get_search_space_for_value(), - model=0, - experiment=get_experiment_for_value(), - data=Data(), - status_quo_name="1_1", - ) - self.assertEqual(modelbridge.status_quo, get_observation1()) - self.assertEqual(modelbridge.status_quo_name, "1_1") + def test_SetStatusQuo(self, _, __) -> None: + exp = get_experiment_for_value() + adapter = Adapter(experiment=exp, model=Generator(), status_quo_name="1_1") + self.assertEqual(adapter.status_quo, get_observation1()) + self.assertEqual(adapter.status_quo_name, "1_1") # Alternatively, we can specify by features - modelbridge = Adapter( - get_search_space_for_value(), - 0, - [], - get_experiment_for_value(), - # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - 0, + adapter = Adapter( + experiment=exp, + model=Generator(), status_quo_features=get_observation1().features, ) - self.assertEqual(modelbridge.status_quo, get_observation1()) - self.assertEqual(modelbridge.status_quo_name, "1_1") + self.assertEqual(adapter.status_quo, get_observation1()) + self.assertEqual(adapter.status_quo_name, "1_1") - # Alternatively, we can specify on experiment + # Alternatively, we can specify on experiment. # Put a dummy arm with SQ name 1_1 on the dummy experiment. - exp = get_experiment_for_value() sq = Arm(name="1_1", parameters={"x": 3.0}) exp._status_quo = sq # Check that we set SQ to arm 1_1 - # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - modelbridge = Adapter(get_search_space_for_value(), 0, [], exp, 0) - self.assertEqual(modelbridge.status_quo, get_observation1()) - self.assertEqual(modelbridge.status_quo_name, "1_1") + adapter = Adapter(experiment=exp, model=Generator()) + self.assertEqual(adapter.status_quo, get_observation1()) + self.assertEqual(adapter.status_quo_name, "1_1") # Errors if features and name both specified - with self.assertRaises(ValueError): - modelbridge = Adapter( - get_search_space_for_value(), - 0, - [], - exp, - # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - 0, + with self.assertRaisesRegex( + UserInputError, + "Specify either status_quo_name or status_quo_features, not both.", + ): + adapter = Adapter( + experiment=exp, + model=Generator(), status_quo_features=get_observation1().features, status_quo_name="1_1", ) # Left as None if features or name don't exist - modelbridge = Adapter( - get_search_space_for_value(), - 0, - [], - exp, - # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - 0, - status_quo_name="1_0", - ) - self.assertIsNone(modelbridge.status_quo) - self.assertIsNone(modelbridge.status_quo_name) - modelbridge = Adapter( - get_search_space_for_value(), - 0, - [], - get_experiment_for_value(), - # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - 0, + adapter = Adapter(experiment=exp, model=Generator(), status_quo_name="1_0") + self.assertIsNone(adapter.status_quo) + self.assertIsNone(adapter.status_quo_name) + adapter = Adapter( + experiment=exp, + model=Generator(), status_quo_features=ObservationFeatures(parameters={"x": 3.0, "y": 10.0}), ) - self.assertIsNone(modelbridge.status_quo) + self.assertIsNone(adapter.status_quo) @mock.patch( "ax.modelbridge.base.Adapter._gen", autospec=True, ) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def test_status_quo_for_non_monolithic_data(self, mock_gen): + def test_status_quo_for_non_monolithic_data(self, mock_gen: Mock) -> None: mock_gen.return_value = GenResults( observation_features=[ ObservationFeatures( @@ -600,7 +532,7 @@ def test_status_quo_for_non_monolithic_data(self, mock_gen): weights=[1] * 5, ) exp = get_branin_experiment_with_multi_objective(with_status_quo=True) - sobol = Generators.SOBOL(search_space=exp.search_space) + sobol = Generators.SOBOL(experiment=exp) exp.new_batch_trial(sobol.gen(5)).set_status_quo_and_optimize_power( status_quo=exp.status_quo ).run() @@ -610,16 +542,15 @@ def test_status_quo_for_non_monolithic_data(self, mock_gen): with warnings.catch_warnings(record=True) as ws: bridge = Adapter( experiment=exp, - data=data, model=Generator(), + data=data, search_space=exp.search_space, ) # just testing it doesn't error bridge.gen(5) self.assertTrue(any("start_time" in str(w.message) for w in ws)) self.assertTrue(any("end_time" in str(w.message) for w in ws)) - # pyre-fixme[16]: Optional type has no attribute `arm_name`. - self.assertEqual(bridge.status_quo.arm_name, "status_quo") + self.assertEqual(none_throws(bridge.status_quo).arm_name, "status_quo") @mock.patch( "ax.modelbridge.base.observations_from_data", @@ -634,9 +565,7 @@ def test_status_quo_for_non_monolithic_data(self, mock_gen): ), ) @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def test_SetStatusQuoMultipleObs(self, mock_fit, mock_observations_from_data): + def test_SetStatusQuoMultipleObs(self, _, __) -> None: exp = get_experiment_with_repeated_arms(2) trial_index = 1 @@ -645,30 +574,25 @@ def test_SetStatusQuoMultipleObs(self, mock_fit, mock_observations_from_data): parameters=exp.trials[trial_index].status_quo.parameters, trial_index=trial_index, ) - modelbridge = Adapter( - get_search_space_for_value(), - 0, - [], - exp, - # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - 0, + adapter = Adapter( + experiment=exp, + model=Generator(), status_quo_features=status_quo_features, ) # Check that for experiments with many trials the status quo is set # to the value of the status quo of the last trial. if len(exp.trials) >= 1: - self.assertEqual(modelbridge.status_quo, get_observation_status_quo1()) + self.assertEqual(adapter.status_quo, get_observation_status_quo1()) def test_transform_observations(self) -> None: """ This functionality is unused, even in the subclass where it is implemented. """ - ss = get_search_space_for_value() - modelbridge = Adapter(search_space=ss, model=Generator()) + adapter = Adapter(experiment=get_experiment_for_value(), model=Generator()) with self.assertRaises(NotImplementedError): - modelbridge.transform_observations([]) + adapter.transform_observations([]) with self.assertRaises(NotImplementedError): - modelbridge.transform_observations([]) + adapter.transform_observations([]) @mock.patch( "ax.modelbridge.base.observations_from_data", @@ -676,18 +600,12 @@ def test_transform_observations(self) -> None: return_value=([get_observation1(), get_observation1()]), ) @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) - def test_SetTrainingDataDupFeatures( - self, mock_fit: Mock, mock_observations_from_data: Mock - ) -> None: + def test_SetTrainingDataDupFeatures(self, _: Mock, __: Mock) -> None: # Throws an error if repeated features in observations. with self.assertRaises(ValueError): Adapter( - get_search_space_for_value(), - 0, - [], - get_experiment_for_value(), - # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - 0, + experiment=get_experiment_for_value(), + model=Generator(), status_quo_name="1_1", ) @@ -751,12 +669,10 @@ def test_GenWithDefaults(self, _, mock_gen: Mock) -> None: exp = get_experiment_for_value() exp.optimization_config = get_optimization_config_no_constraints() ss = get_search_space_for_range_value() - modelbridge = Adapter( - search_space=ss, model=Generator(), transforms=[], experiment=exp - ) - modelbridge.gen(1) + adapter = Adapter(experiment=exp, model=Generator(), search_space=ss) + adapter.gen(1) mock_gen.assert_called_with( - modelbridge, + adapter, n=1, search_space=ss, fixed_features=None, @@ -776,19 +692,13 @@ def test_GenWithDefaults(self, _, mock_gen: Mock) -> None: ), ) @mock.patch("ax.modelbridge.base.Adapter.predict", autospec=True, return_value=None) - # pyre-fixme[3]: Return type must be annotated. - def test_gen_on_experiment_with_imm_ss_and_opt_conf(self, _, __): + def test_gen_on_experiment_with_imm_ss_and_opt_conf(self, _, __) -> None: exp = get_experiment_for_value() exp._properties[Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF] = True exp.optimization_config = get_optimization_config_no_constraints() - ss = get_search_space_for_range_value() - modelbridge = Adapter( - search_space=ss, model=Generator(), transforms=[], experiment=exp - ) - self.assertTrue( - modelbridge._experiment_has_immutable_search_space_and_opt_config - ) - gr = modelbridge.gen(1) + adapter = Adapter(experiment=exp, model=Generator()) + self.assertTrue(adapter._experiment_has_immutable_search_space_and_opt_config) + gr = adapter.gen(1) self.assertIsNone(gr.optimization_config) self.assertIsNone(gr.search_space) @@ -800,20 +710,15 @@ def test_set_status_quo(self) -> None: num_batch_trial=1, with_completed_batch=True, ) - modelbridge = Adapter( - search_space=exp.search_space, - experiment=exp, - model=Generator, - data=exp.lookup_data(), - ) + adapter = Adapter(experiment=exp, model=Generator()) # we are able to set status_quo_data_by_trial when multiple # status_quos present in each trial - self.assertIsNotNone(modelbridge.status_quo_data_by_trial) + self.assertIsNotNone(adapter.status_quo_data_by_trial) # status_quo is set - self.assertIsNotNone(modelbridge.status_quo) + self.assertIsNotNone(adapter.status_quo) # Status quo name is logged - self.assertEqual(modelbridge._status_quo_name, none_throws(exp.status_quo).name) + self.assertEqual(adapter._status_quo_name, none_throws(exp.status_quo).name) # experiment with multiple status quos in different trials exp = get_branin_experiment( @@ -822,19 +727,14 @@ def test_set_status_quo(self) -> None: num_batch_trial=2, with_completed_batch=True, ) - modelbridge = Adapter( - search_space=exp.search_space, - experiment=exp, - model=Generator, - data=exp.lookup_data(), - ) + adapter = Adapter(experiment=exp, model=Generator()) # we are able to set status_quo_data_by_trial when multiple # status_quos present in each trial - self.assertIsNotNone(modelbridge.status_quo_data_by_trial) + self.assertIsNotNone(adapter.status_quo_data_by_trial) # status_quo is not set - self.assertIsNone(modelbridge.status_quo) + self.assertIsNone(adapter.status_quo) # Status quo name can still be logged - self.assertEqual(modelbridge._status_quo_name, none_throws(exp.status_quo).name) + self.assertEqual(adapter._status_quo_name, none_throws(exp.status_quo).name) # a unique status_quo can be identified (by trial index) # if status_quo_features is specified @@ -842,14 +742,12 @@ def test_set_status_quo(self) -> None: parameters=none_throws(exp.status_quo).parameters, trial_index=0, ) - modelbridge = Adapter( - search_space=exp.search_space, + adapter = Adapter( experiment=exp, - model=Generator, - data=exp.lookup_data(), + model=Generator(), status_quo_features=status_quo_features, ) - self.assertIsNotNone(modelbridge.status_quo) + self.assertIsNotNone(adapter.status_quo) class testClampObservationFeatures(TestCase): @@ -943,12 +841,7 @@ def test_FillMissingParameters(self, mock_fit: Mock) -> None: get_branin_data_batch(batch=trial, fill_vals=sq_vals) ) # Fit model without filling missing parameters - m = Adapter( - search_space=ss1, - model=None, - experiment=experiment, - data=experiment.lookup_data(), - ) + m = Adapter(experiment=experiment, model=Generator()) self.assertEqual( [t.__name__ for t in m._raw_transforms], # pyre-ignore[16] ["Cast"], @@ -961,10 +854,9 @@ def test_FillMissingParameters(self, mock_fit: Mock) -> None: ) # Fit with filling missing parameters m = Adapter( - search_space=ss2, - model=None, experiment=experiment, - data=experiment.lookup_data(), + model=Generator(), + search_space=ss2, transforms=[FillMissingParameters], transform_configs={"FillMissingParameters": {"fill_values": sq_vals}}, ) @@ -999,7 +891,6 @@ def test_SetModelSpace(self) -> None: trial.mark_running(no_runner_required=True) experiment.attach_data(get_branin_data_batch(batch=trial, fill_vals=sq_vals)) trial.mark_completed() - data = experiment.lookup_data() # Make search space with a parameter constraint ss = experiment.search_space.clone() ss.set_parameter_constraints( @@ -1014,10 +905,9 @@ def test_SetModelSpace(self) -> None: # Check that SQ and custom are OOD m = Adapter( - search_space=ss, - model=None, experiment=experiment, - data=data, + model=Generator(), + search_space=ss, expand_model_space=False, ) arm_names = [obs.arm_name for obs in m.get_training_data()] @@ -1029,10 +919,9 @@ def test_SetModelSpace(self) -> None: # With expand model space, custom is not OOD, and model space is expanded m = Adapter( - search_space=ss, - model=None, experiment=experiment, - data=data, + model=Generator(), + search_space=ss, ) arm_names = [obs.arm_name for obs in m.get_training_data()] ood_arms = [a for i, a in enumerate(arm_names) if not m.training_in_design[i]] @@ -1043,10 +932,9 @@ def test_SetModelSpace(self) -> None: # With fill values, SQ is also in design, and x2 is further expanded m = Adapter( - search_space=ss, - model=None, experiment=experiment, - data=data, + model=Generator(), + search_space=ss, transforms=[FillMissingParameters], transform_configs={"FillMissingParameters": {"fill_values": sq_vals}}, ) @@ -1062,19 +950,16 @@ def test_SetModelSpace(self) -> None: def test_fit_only_completed_map_metrics( self, mock_observations_from_data: Mock ) -> None: - # NOTE: If empty data object is not passed, observations are not - # extracted, even with mock. # _prepare_observations is called in the constructor and itself calls # observations_from_data with map_keys_as_parameters=True Adapter( - search_space=get_search_space_for_value(), - model=0, experiment=get_experiment_for_value(), + model=Generator(), data=MapData(), status_quo_name="1_1", fit_only_completed_map_metrics=False, ) - _, kwargs = mock_observations_from_data.call_args + kwargs = mock_observations_from_data.call_args.kwargs self.assertTrue(kwargs["map_keys_as_parameters"]) # assert `latest_rows_per_group` is not specified or is None self.assertIsNone(kwargs.get("latest_rows_per_group")) @@ -1083,12 +968,10 @@ def test_fit_only_completed_map_metrics( # calling without map data calls observations_from_data with # map_keys_as_parameters=False even if fit_only_completed_map_metrics is False Adapter( - search_space=get_search_space_for_value(), - model=0, experiment=get_experiment_for_value(), - data=Data(), + model=Generator(), status_quo_name="1_1", fit_only_completed_map_metrics=False, ) - _, kwargs = mock_observations_from_data.call_args + kwargs = mock_observations_from_data.call_args.kwargs self.assertFalse(kwargs["map_keys_as_parameters"]) diff --git a/ax/modelbridge/tests/test_discrete_modelbridge.py b/ax/modelbridge/tests/test_discrete_modelbridge.py index b8bb01ad2ad..0a0aa135d9a 100644 --- a/ax/modelbridge/tests/test_discrete_modelbridge.py +++ b/ax/modelbridge/tests/test_discrete_modelbridge.py @@ -74,14 +74,15 @@ def setUp(self) -> None: self.model_gen_options = {"option": "yes"} @mock.patch("ax.modelbridge.discrete.DiscreteAdapter.__init__", return_value=None) - def test_fit(self, mock_init: Mock) -> None: + def test_fit(self, _: Mock) -> None: # pyre-fixme[20]: Argument `model` expected. - ma = DiscreteAdapter() - ma._training_data = self.observations + adapter = DiscreteAdapter() + adapter._training_data = self.observations model = mock.create_autospec(DiscreteGenerator, instance=True) - ma._fit(model, self.search_space, self.observations) - self.assertEqual(ma.parameters, ["x", "y", "z"]) - self.assertEqual(sorted(ma.outcomes), ["a", "b"]) + adapter.model = model + adapter._fit(self.search_space, self.observations) + self.assertEqual(adapter.parameters, ["x", "y", "z"]) + self.assertEqual(sorted(adapter.outcomes), ["a", "b"]) Xs = { "a": [[0, "foo", True], [1, "foo", True], [1, "bar", True]], "b": [[0, "foo", True], [1, "foo", True]], @@ -91,23 +92,23 @@ def test_fit(self, mock_init: Mock) -> None: parameter_values = [[0.0, 1.0], ["foo", "bar"], [True]] model_fit_args = model.fit.mock_calls[0][2] for i, x in enumerate(model_fit_args["Xs"]): - self.assertEqual(x, Xs[ma.outcomes[i]]) + self.assertEqual(x, Xs[adapter.outcomes[i]]) for i, y in enumerate(model_fit_args["Ys"]): - self.assertEqual(y, Ys[ma.outcomes[i]]) + self.assertEqual(y, Ys[adapter.outcomes[i]]) for i, v in enumerate(model_fit_args["Yvars"]): - self.assertEqual(v, Yvars[ma.outcomes[i]]) + self.assertEqual(v, Yvars[adapter.outcomes[i]]) self.assertEqual(model_fit_args["parameter_values"], parameter_values) sq_obs = Observation( features=ObservationFeatures({}), data=self.observation_data[0] ) with self.assertRaises(ValueError): - ma._fit(model, self.search_space, self.observations + [sq_obs]) + adapter._fit(self.search_space, self.observations + [sq_obs]) @mock.patch("ax.modelbridge.discrete.DiscreteAdapter.__init__", return_value=None) def test_predict(self, mock_init: Mock) -> None: # pyre-fixme[20]: Argument `model` expected. - ma = DiscreteAdapter() + adapter = DiscreteAdapter() model = mock.MagicMock(DiscreteGenerator, autospec=True, instance=True) model.predict.return_value = ( np.array([[1.0, -1], [2.0, -2]]), @@ -115,10 +116,10 @@ def test_predict(self, mock_init: Mock) -> None: (np.array([[1.0, 4.0], [4.0, 6]]), np.array([[2.0, 5.0], [5.0, 7]])) ), ) - ma.model = model - ma.parameters = ["x", "y", "z"] - ma.outcomes = ["a", "b"] - observation_data = ma._predict(self.observation_features) + adapter.model = model + adapter.parameters = ["x", "y", "z"] + adapter.outcomes = ["a", "b"] + observation_data = adapter._predict(self.observation_features) X = [[0, "foo", True], [1, "foo", True], [1, "bar", True]] self.assertTrue(model.predict.mock_calls[0][2]["X"], X) for i, od in enumerate(observation_data): @@ -134,11 +135,11 @@ def test_gen(self, mock_init: Mock) -> None: ], ) # pyre-fixme[20]: Argument `model` expected. - ma = DiscreteAdapter() + adapter = DiscreteAdapter() # Test validation. with self.assertRaisesRegex(UserInputError, "positive integer or -1."): - ma._validate_gen_inputs(n=0) - ma._validate_gen_inputs(n=-1) + adapter._validate_gen_inputs(n=0) + adapter._validate_gen_inputs(n=-1) # Test rest of gen. model = mock.MagicMock(DiscreteGenerator, autospec=True, instance=True) best_x = [0.0, 2.0, 1.0] @@ -147,10 +148,10 @@ def test_gen(self, mock_init: Mock) -> None: [1.0, 2.0], {"best_x": best_x}, ) - ma.model = model - ma.parameters = ["x", "y", "z"] - ma.outcomes = ["a", "b"] - gen_results = ma._gen( + adapter.model = model + adapter.parameters = ["x", "y", "z"] + adapter.outcomes = ["a", "b"] + gen_results = adapter._gen( n=3, search_space=self.search_space, optimization_config=optimization_config, @@ -189,14 +190,14 @@ def test_gen(self, mock_init: Mock) -> None: self.assertEqual(gen_results.weights, [1.0, 2.0]) self.assertEqual( gen_results.best_observation_features, - ObservationFeatures(parameters=dict(zip(ma.parameters, best_x))), + ObservationFeatures(parameters=dict(zip(adapter.parameters, best_x))), ) # Test with no constraints, no fixed feature, no pending observations search_space = SearchSpace(self.parameters[:2]) optimization_config.outcome_constraints = [] - ma.parameters = ["x", "y"] - ma._gen( + adapter.parameters = ["x", "y"] + adapter._gen( n=3, search_space=search_space, optimization_config=optimization_config, @@ -217,7 +218,7 @@ def test_gen(self, mock_init: Mock) -> None: ], ) with self.assertRaises(ValueError): - ma._gen( + adapter._gen( n=3, search_space=search_space, optimization_config=optimization_config, @@ -229,7 +230,7 @@ def test_gen(self, mock_init: Mock) -> None: @mock.patch("ax.modelbridge.discrete.DiscreteAdapter.__init__", return_value=None) def test_cross_validate(self, mock_init: Mock) -> None: # pyre-fixme[20]: Argument `model` expected. - ma = DiscreteAdapter() + adapter = DiscreteAdapter() model = mock.MagicMock(DiscreteGenerator, autospec=True, instance=True) model.cross_validate.return_value = ( np.array([[1.0, -1], [2.0, -2]]), @@ -237,10 +238,10 @@ def test_cross_validate(self, mock_init: Mock) -> None: (np.array([[1.0, 4.0], [4.0, 6]]), np.array([[2.0, 5.0], [5.0, 7]])) ), ) - ma.model = model - ma.parameters = ["x", "y", "z"] - ma.outcomes = ["a", "b"] - observation_data = ma._cross_validate( + adapter.model = model + adapter.parameters = ["x", "y", "z"] + adapter.outcomes = ["a", "b"] + observation_data = adapter._cross_validate( search_space=self.search_space, cv_training_data=self.observations, cv_test_points=self.observation_features, diff --git a/ax/modelbridge/tests/test_factory.py b/ax/modelbridge/tests/test_factory.py index 6b0375eed5e..ff86b942af5 100644 --- a/ax/modelbridge/tests/test_factory.py +++ b/ax/modelbridge/tests/test_factory.py @@ -6,6 +6,11 @@ # pyre-strict +from ax.core.experiment import Experiment +from ax.core.optimization_config import ( + MultiObjectiveOptimizationConfig, + OptimizationConfig, +) from ax.core.outcome_constraint import ComparisonOp, ObjectiveThreshold from ax.modelbridge.discrete import DiscreteAdapter from ax.modelbridge.factory import ( @@ -24,13 +29,12 @@ get_branin_experiment_with_multi_objective, get_factorial_experiment, ) +from pyre_extensions import assert_is_instance, none_throws -# pyre-fixme[3]: Return type must be annotated. -def get_multi_obj_exp_and_opt_config(): +def get_multi_obj_exp_and_opt_config() -> tuple[Experiment, OptimizationConfig]: multi_obj_exp = get_branin_experiment_with_multi_objective(with_batch=True) - # pyre-fixme[16]: Optional type has no attribute `objective`. - metrics = multi_obj_exp.optimization_config.objective.metrics + metrics = none_throws(multi_obj_exp.optimization_config).objective.metrics multi_objective_thresholds = [ ObjectiveThreshold( metric=metrics[0], bound=5.0, relative=False, op=ComparisonOp.LEQ @@ -39,14 +43,13 @@ def get_multi_obj_exp_and_opt_config(): metric=metrics[1], bound=10.0, relative=False, op=ComparisonOp.LEQ ), ] - # pyre-fixme[16]: Optional type has no attribute `clone_with_args`. - optimization_config = multi_obj_exp.optimization_config.clone_with_args( - objective_thresholds=multi_objective_thresholds - ) + optimization_config = assert_is_instance( + multi_obj_exp.optimization_config, MultiObjectiveOptimizationConfig + ).clone_with_args(objective_thresholds=multi_objective_thresholds) return multi_obj_exp, optimization_config -class AdapterFactoryTestSingleObjective(TestCase): +class TestAdapterFactorySingleObjective(TestCase): def test_model_kwargs(self) -> None: """Tests that model kwargs are passed correctly.""" exp = get_branin_experiment() diff --git a/ax/modelbridge/tests/test_hierarchical_search_space.py b/ax/modelbridge/tests/test_hierarchical_search_space.py index 137798c5a26..618ad9bcdf9 100644 --- a/ax/modelbridge/tests/test_hierarchical_search_space.py +++ b/ax/modelbridge/tests/test_hierarchical_search_space.py @@ -152,7 +152,7 @@ def _test_gen_base( runner=SyntheticRunner(), ) - sobol = Generators.SOBOL(search_space=hss) + sobol = Generators.SOBOL(experiment=experiment) for _ in range(num_sobol_trials): trial = experiment.new_trial(generator_run=sobol.gen(n=1)) trial.run().mark_completed() diff --git a/ax/modelbridge/tests/test_model_fit_metrics.py b/ax/modelbridge/tests/test_model_fit_metrics.py index 659fd28e183..487dfe4a569 100644 --- a/ax/modelbridge/tests/test_model_fit_metrics.py +++ b/ax/modelbridge/tests/test_model_fit_metrics.py @@ -153,7 +153,7 @@ class TestGetFitAndStdQualityAndGeneralizationDict(TestCase): def setUp(self) -> None: super().setUp() self.experiment = get_branin_experiment() - self.sobol = Generators.SOBOL(search_space=self.experiment.search_space) + self.sobol = Generators.SOBOL(experiment=self.experiment) def test_it_returns_empty_data_for_sobol(self) -> None: results = get_fit_and_std_quality_and_generalization_dict( diff --git a/ax/modelbridge/tests/test_random_modelbridge.py b/ax/modelbridge/tests/test_random_modelbridge.py index c8ba5af13dc..18deef9cdd0 100644 --- a/ax/modelbridge/tests/test_random_modelbridge.py +++ b/ax/modelbridge/tests/test_random_modelbridge.py @@ -6,7 +6,6 @@ # pyre-strict -from collections import OrderedDict from unittest import mock import numpy as np @@ -42,52 +41,35 @@ def setUp(self) -> None: SumConstraint([x, z], False, 3.5), ] self.search_space = SearchSpace(self.parameters, parameter_constraints) + self.experiment = Experiment(search_space=self.search_space) self.model_gen_options = {"option": "yes"} - @mock.patch("ax.modelbridge.random.RandomAdapter.__init__", return_value=None) - def test_Fit(self, mock_init: mock.Mock) -> None: - # pyre-fixme[20]: Argument `model` expected. - modelbridge = RandomAdapter() - model = mock.create_autospec(RandomGenerator, instance=True) - modelbridge._fit(model, self.search_space, None) - self.assertEqual(modelbridge.parameters, ["x", "y", "z"]) - self.assertTrue(isinstance(modelbridge.model, RandomGenerator)) + def test_fit(self) -> None: + adapter = RandomAdapter(experiment=self.experiment, model=RandomGenerator()) + self.assertEqual(adapter.parameters, ["x", "y", "z"]) + self.assertTrue(isinstance(adapter.model, RandomGenerator)) - @mock.patch("ax.modelbridge.random.RandomAdapter.__init__", return_value=None) - def test_Predict(self, mock_init: mock.Mock) -> None: - # pyre-fixme[20]: Argument `model` expected. - modelbridge = RandomAdapter() - modelbridge.transforms = OrderedDict() - modelbridge.parameters = ["x", "y", "z"] + def test_predict(self) -> None: + adapter = RandomAdapter(experiment=self.experiment, model=RandomGenerator()) with self.assertRaises(NotImplementedError): - modelbridge._predict([]) + adapter._predict([]) - @mock.patch("ax.modelbridge.random.RandomAdapter.__init__", return_value=None) - def test_CrossValidate(self, mock_init: mock.Mock) -> None: - # pyre-fixme[20]: Argument `model` expected. - modelbridge = RandomAdapter() - modelbridge.transforms = OrderedDict() - modelbridge.parameters = ["x", "y", "z"] + def test_cross_validate(self) -> None: + adapter = RandomAdapter(experiment=self.experiment, model=RandomGenerator()) with self.assertRaises(NotImplementedError): - modelbridge._cross_validate(self.search_space, [], []) + adapter._cross_validate(self.search_space, [], []) - @mock.patch("ax.modelbridge.random.RandomAdapter.__init__", return_value=None) - def test_Gen(self, mock_init: mock.Mock) -> None: - # Test with constraints - # pyre-fixme[20]: Argument `model` expected. - modelbridge = RandomAdapter(model=RandomGenerator()) - modelbridge.parameters = ["x", "y", "z"] - modelbridge.transforms = OrderedDict() - modelbridge.model = RandomGenerator() + def test_gen_w_constraints(self) -> None: + adapter = RandomAdapter(experiment=self.experiment, model=RandomGenerator()) with mock.patch.object( - modelbridge.model, + adapter.model, "gen", return_value=( np.array([[1.0, 2.0, 3.0], [3.0, 4.0, 3.0]]), np.array([1.0, 2.0]), ), ) as mock_gen: - gen_results = modelbridge._gen( + gen_results = adapter._gen( n=3, search_space=self.search_space, pending_observations={}, @@ -118,15 +100,18 @@ def test_Gen(self, mock_init: mock.Mock) -> None: self.assertEqual(obsf[1].parameters, {"x": 3.0, "y": 4.0, "z": 3.0}) self.assertTrue(np.array_equal(gen_results.weights, np.array([1.0, 2.0]))) + def test_gen_simple(self) -> None: # Test with no constraints, no fixed feature, no pending observations search_space = SearchSpace(self.parameters[:2]) - modelbridge.parameters = ["x", "y"] + adapter = RandomAdapter( + experiment=Experiment(search_space=search_space), model=RandomGenerator() + ) with mock.patch.object( - modelbridge.model, + adapter.model, "gen", return_value=(np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([1.0, 2.0])), ) as mock_gen: - modelbridge._gen( + adapter._gen( n=3, search_space=search_space, pending_observations={}, @@ -145,7 +130,7 @@ def test_Gen(self, mock_init: mock.Mock) -> None: def test_deduplicate(self) -> None: sobol = RandomAdapter( - search_space=get_small_discrete_search_space(), + experiment=Experiment(search_space=get_small_discrete_search_space()), model=SobolGenerator(deduplicate=True), transforms=Cont_X_trans, ) diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 01bc6c7b5e0..e9642b04373 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -10,6 +10,7 @@ from ax.core.observation import ObservationFeatures from ax.core.optimization_config import MultiObjectiveOptimizationConfig +from ax.exceptions.core import UserInputError from ax.modelbridge.discrete import DiscreteAdapter from ax.modelbridge.random import RandomAdapter from ax.modelbridge.registry import ( @@ -37,6 +38,7 @@ get_branin_experiment, get_branin_experiment_with_status_quo_trials, get_branin_optimization_config, + get_branin_search_space, get_factorial_experiment, ) from ax.utils.testing.mock import mock_botorch_optimize @@ -52,6 +54,7 @@ from gpytorch.kernels.scale_kernel import ScaleKernel from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior +from pyre_extensions import assert_is_instance class ModelRegistryTest(TestCase): @@ -73,13 +76,13 @@ def test_botorch_modular(self) -> None: data=exp.fetch_data(), ) self.assertIsInstance(gpei, TorchAdapter) - self.assertIsInstance(gpei.model, BoTorchGenerator) - self.assertEqual(gpei.model.botorch_acqf_class, qExpectedImprovement) - self.assertEqual(gpei.model.acquisition_class, Acquisition) - self.assertEqual(gpei.model.acquisition_options, {"best_f": 0.0}) - self.assertIsInstance(gpei.model.surrogate, Surrogate) + generator = assert_is_instance(gpei.model, BoTorchGenerator) + self.assertEqual(generator.botorch_acqf_class, qExpectedImprovement) + self.assertEqual(generator.acquisition_class, Acquisition) + self.assertEqual(generator.acquisition_options, {"best_f": 0.0}) + self.assertIsInstance(generator.surrogate, Surrogate) # SingleTaskGP should be picked. - self.assertIsInstance(gpei.model.surrogate.model, SingleTaskGP) + self.assertIsInstance(generator.surrogate.model, SingleTaskGP) gr = gpei.gen(n=1) self.assertIsNotNone(gr.best_arm_predictions) @@ -87,7 +90,7 @@ def test_botorch_modular(self) -> None: @mock_botorch_optimize def test_SAASBO(self) -> None: exp = get_branin_experiment() - sobol = Generators.SOBOL(search_space=exp.search_space) + sobol = Generators.SOBOL(experiment=exp) self.assertIsInstance(sobol, RandomAdapter) for _ in range(5): sobol_run = sobol.gen(n=1) @@ -96,14 +99,14 @@ def test_SAASBO(self) -> None: saasbo = Generators.SAASBO(experiment=exp, data=exp.fetch_data()) self.assertIsInstance(saasbo, TorchAdapter) self.assertEqual(saasbo._model_key, "SAASBO") - self.assertIsInstance(saasbo.model, BoTorchGenerator) - surrogate_spec = saasbo.model.surrogate_spec + generator = assert_is_instance(saasbo.model, BoTorchGenerator) + surrogate_spec = generator.surrogate_spec self.assertEqual( surrogate_spec, SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP), ) self.assertEqual( - saasbo.model.surrogate.surrogate_spec.model_configs[0].botorch_model_class, + generator.surrogate.surrogate_spec.model_configs[0].botorch_model_class, SaasFullyBayesianSingleTaskGP, ) @@ -112,7 +115,7 @@ def test_enum_sobol_legacy_GPEI(self) -> None: """Tests Sobol and Legacy GPEI instantiation through the Generators enum.""" exp = get_branin_experiment() # Check that factory generates a valid sobol modelbridge. - sobol = Generators.SOBOL(search_space=exp.search_space) + sobol = Generators.SOBOL(experiment=exp) self.assertIsInstance(sobol, RandomAdapter) for _ in range(5): sobol_run = sobol.gen(n=1) @@ -191,7 +194,7 @@ def test_enum_model_kwargs(self) -> None: Generators enum.""" exp = get_branin_experiment() sobol = Generators.SOBOL( - search_space=exp.search_space, init_position=2, scramble=False, seed=239 + experiment=exp, init_position=2, scramble=False, seed=239 ) self.assertIsInstance(sobol, RandomAdapter) for _ in range(5): @@ -201,7 +204,7 @@ def test_enum_model_kwargs(self) -> None: def test_enum_factorial(self) -> None: """Tests factorial instantiation through the Generators enum.""" exp = get_factorial_experiment() - factorial = Generators.FACTORIAL(exp.search_space) + factorial = Generators.FACTORIAL(experiment=exp) self.assertIsInstance(factorial, DiscreteAdapter) factorial_run = factorial.gen(n=-1) self.assertEqual(len(factorial_run.arms), 24) @@ -209,7 +212,7 @@ def test_enum_factorial(self) -> None: def test_enum_empirical_bayes_thompson(self) -> None: """Tests EB/TS instantiation through the Generators enum.""" exp = get_factorial_experiment() - factorial = Generators.FACTORIAL(exp.search_space) + factorial = Generators.FACTORIAL(experiment=exp) self.assertIsInstance(factorial, DiscreteAdapter) factorial_run = factorial.gen(n=-1) exp.new_batch_trial().add_generator_run(factorial_run).run().mark_completed() @@ -225,7 +228,7 @@ def test_enum_empirical_bayes_thompson(self) -> None: def test_enum_thompson(self) -> None: """Tests TS instantiation through the Generators enum.""" exp = get_factorial_experiment() - factorial = Generators.FACTORIAL(exp.search_space) + factorial = Generators.FACTORIAL(experiment=exp) self.assertIsInstance(factorial, DiscreteAdapter) factorial_run = factorial.gen(n=-1) exp.new_batch_trial().add_generator_run(factorial_run).run().mark_completed() @@ -236,7 +239,7 @@ def test_enum_thompson(self) -> None: def test_enum_uniform(self) -> None: """Tests uniform random instantiation through the Generators enum.""" exp = get_branin_experiment() - uniform = Generators.UNIFORM(exp.search_space) + uniform = Generators.UNIFORM(experiment=exp) self.assertIsInstance(uniform, RandomAdapter) uniform_run = uniform.gen(n=5) self.assertEqual(len(uniform_run.arms), 5) @@ -309,8 +312,9 @@ def test_get_model_from_generator_run(self) -> None: models_enum=Generators, after_gen=False, ) - self.assertEqual(sobol.model.init_position, 0) - self.assertEqual(sobol.model.seed, 239) + generator = assert_is_instance(sobol.model, SobolGenerator) + self.assertEqual(generator.init_position, 0) + self.assertEqual(generator.seed, 239) # Restore the model as it was after generation (to resume generation). sobol_after_gen = get_model_from_generator_run( generator_run=gr, @@ -318,8 +322,9 @@ def test_get_model_from_generator_run(self) -> None: data=exp.fetch_data(), models_enum=Generators, ) - self.assertEqual(sobol_after_gen.model.init_position, 1) - self.assertEqual(sobol_after_gen.model.seed, 239) + generator = assert_is_instance(sobol_after_gen.model, SobolGenerator) + self.assertEqual(generator.init_position, 1) + self.assertEqual(generator.seed, 239) self.assertEqual(initial_sobol.gen(n=1).arms, sobol_after_gen.gen(n=1).arms) exp.new_trial(generator_run=gr) # Check restoration of GPEI, to ensure proper restoration of callable kwargs @@ -410,16 +415,16 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: surrogate=surrogate, ) self.assertIsInstance(mtgp, TorchAdapter) - self.assertIsInstance(mtgp.model, BoTorchGenerator) - self.assertEqual(mtgp.model.acquisition_class, Acquisition) + generator = assert_is_instance(mtgp.model, BoTorchGenerator) + self.assertEqual(generator.acquisition_class, Acquisition) is_moo = isinstance( exp.optimization_config, MultiObjectiveOptimizationConfig ) if is_moo: - self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP) - models = mtgp.model.surrogate.model.models + self.assertIsInstance(generator.surrogate.model, ModelListGP) + models = generator.surrogate.model.models else: - models = [mtgp.model.surrogate.model] + models = [generator.surrogate.model] for model in models: self.assertIsInstance( @@ -452,7 +457,7 @@ def test_SAAS_MTGP(self) -> None: def test_extract_model_state_after_gen(self) -> None: # Test with actual state. exp = get_branin_experiment() - sobol = Generators.SOBOL(search_space=exp.search_space) + sobol = Generators.SOBOL(experiment=exp) gr = sobol.gen(n=1) expected_state = sobol.model._get_state() self.assertEqual(gr._model_state_after_gen, expected_state) @@ -475,3 +480,17 @@ def test_deprecation_warning(self) -> None: r" instead.", ): Models.BOTORCH_MODULAR + + def test_initialize_from_search_space(self) -> None: + search_space = get_branin_search_space() + with self.assertWarnsRegex( + DeprecationWarning, "Passing in a `search_space` to initialize" + ): + adapter = Generators.SOBOL(search_space=search_space) + self.assertEqual(adapter._model_space, search_space) + self.assertIsNotNone(adapter._experiment) + with self.assertRaisesRegex( + UserInputError, + "`experiment` is required to initialize a model from registry.", + ): + Generators.BOTORCH_MODULAR(search_space=search_space) diff --git a/ax/modelbridge/tests/test_torch_modelbridge.py b/ax/modelbridge/tests/test_torch_modelbridge.py index 0a7cd1b889f..717de284795 100644 --- a/ax/modelbridge/tests/test_torch_modelbridge.py +++ b/ax/modelbridge/tests/test_torch_modelbridge.py @@ -67,8 +67,6 @@ def _get_modelbridge_from_experiment( ) -> TorchAdapter: return TorchAdapter( experiment=experiment, - search_space=experiment.search_space, - data=experiment.lookup_data(), model=BoTorchGenerator(), transforms=transforms or [], torch_device=device, @@ -84,13 +82,13 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None: min=0.0, max=5.0, parameter_names=feature_names ) experiment = Experiment(search_space=search_space, name="test") - model_bridge = _get_modelbridge_from_experiment( + adapter = _get_modelbridge_from_experiment( experiment=experiment, device=device, fit_on_init=False, ) - self.assertEqual(model_bridge.device, device) - self.assertIsNone(model_bridge._last_observations) + self.assertEqual(adapter.device, device) + self.assertIsNone(adapter._last_observations) tkwargs: dict[str, Any] = {"dtype": torch.double, "device": device} # Test `_fit`. X = torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], **tkwargs) @@ -129,12 +127,10 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None: ] observations = recombine_observations(observation_features, observation_data) - model = BoTorchGenerator() + model = adapter.model with mock.patch.object(model, "fit", wraps=model.fit) as mock_fit: - model_bridge._fit( - model=model, search_space=search_space, observations=observations - ) - model_fit_args = mock_fit.mock_calls[0][2] + adapter._fit(search_space=search_space, observations=observations) + model_fit_args = mock_fit.call_args.kwargs self.assertEqual(model_fit_args["datasets"], list(datasets.values())) expected_ssd = SearchSpaceDigest( @@ -142,14 +138,10 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None: ) self.assertEqual(model_fit_args["search_space_digest"], expected_ssd) self.assertIsNone(model_fit_args["candidate_metadata"]) - self.assertEqual(model_bridge._last_observations, observations) + self.assertEqual(adapter._last_observations, observations) with mock.patch(f"{TorchAdapter.__module__}.logger.debug") as mock_logger: - model_bridge._fit( - model=model, - search_space=search_space, - observations=observations, - ) + adapter._fit(search_space=search_space, observations=observations) mock_logger.assert_called_once_with( "The observations are identical to the last set of observations " "used to fit the model. Skipping model fitting." @@ -170,7 +162,7 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None: with mock.patch.object( model, "predict", return_value=predict_return_value ) as mock_predict: - pr_obs_data = model_bridge._predict( + pr_obs_data = adapter._predict( observation_features=observation_features[:1] ) self.assertTrue(torch.equal(mock_predict.mock_calls[0][2]["X"], X[:1])) @@ -209,7 +201,7 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None: # silence a warning about inability to generate unique candidates mock.patch(f"{Adapter.__module__}.logger.warning") ) - gen_run = model_bridge.gen( + gen_run = adapter.gen( n=3, search_space=search_space, optimization_config=opt_config, @@ -262,7 +254,7 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None: "cross_validate", return_value=predict_return_value, ) as mock_cross_validate: - cv_obs_data = model_bridge._cross_validate( + cv_obs_data = adapter._cross_validate( search_space=search_space, cv_training_data=observations, cv_test_points=cv_test_points, @@ -276,12 +268,12 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None: # Transform observations # This functionality is likely to be deprecated (T134940274) # so this is not a thorough test. - model_bridge.transform_observations(observations=observations) + adapter.transform_observations(observations=observations) # Transform observation features obsf = [ObservationFeatures(parameters={"x": 1.0, "y": 2.0})] - model_bridge.parameters = ["x", "y"] - X = model_bridge._transform_observation_features(observation_features=obsf) + adapter.parameters = ["x", "y"] + X = adapter._transform_observation_features(observation_features=obsf) self.assertTrue(torch.equal(X, torch.tensor([[1.0, 2.0]], **tkwargs))) def test_TorchAdapter_cuda(self) -> None: diff --git a/ax/modelbridge/tests/test_transform_utils.py b/ax/modelbridge/tests/test_transform_utils.py index 7f104cfd0fc..64c613d32ce 100644 --- a/ax/modelbridge/tests/test_transform_utils.py +++ b/ax/modelbridge/tests/test_transform_utils.py @@ -9,7 +9,6 @@ from unittest import mock import numpy as np -from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.parameter import ParameterType, RangeParameter @@ -19,6 +18,7 @@ ClosestLookupDict, derelativize_optimization_config_with_raw_status_quo, ) +from ax.models.base import Generator from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_multi_objective_optimization_config @@ -68,11 +68,8 @@ def test_derelativize_optimization_config_with_raw_status_quo(self, _) -> None: ] ) modelbridge = Adapter( - search_space=dummy_search_space, - model=None, - transforms=[], - experiment=Experiment(dummy_search_space, "test"), - data=Data(), + experiment=Experiment(search_space=dummy_search_space), + model=Generator(), optimization_config=optimization_config, status_quo_name="1_1", ) diff --git a/ax/modelbridge/tests/test_utils.py b/ax/modelbridge/tests/test_utils.py index 640ece86460..97e1f56ccc3 100644 --- a/ax/modelbridge/tests/test_utils.py +++ b/ax/modelbridge/tests/test_utils.py @@ -56,7 +56,7 @@ def setUp(self) -> None: arm=self.trial.arm, trial_index=self.trial.index ) self.hss_exp = get_hierarchical_search_space_experiment() - self.hss_sobol = Generators.SOBOL(search_space=self.hss_exp.search_space) + self.hss_sobol = Generators.SOBOL(experiment=self.hss_exp) self.hss_gr = self.hss_sobol.gen(n=1) self.hss_trial = self.hss_exp.new_trial(self.hss_gr) self.hss_arm = none_throws(self.hss_trial.arm) diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index e4651c27d91..9194287e7e4 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -91,23 +91,15 @@ class TorchAdapter(Adapter): them to the model. """ - model: TorchGenerator | None = None - # pyre-fixme[13]: Attribute `outcomes` is never initialized. - outcomes: list[str] # pyre-ignore[13]: These are initialized in _fit. - # pyre-fixme[13]: Attribute `parameters` is never initialized. - parameters: list[str] # pyre-ignore[13]: These are initialized in _fit. - _default_model_gen_options: TConfig - _last_observations: list[Observation] | None = None - def __init__( self, + *, experiment: Experiment, - search_space: SearchSpace, - data: Data, model: TorchGenerator, - transforms: Sequence[type[Transform]], + search_space: SearchSpace | None = None, + data: Data | None = None, + transforms: Sequence[type[Transform]] | None = None, transform_configs: Mapping[str, TConfig] | None = None, - torch_device: torch.device | None = None, status_quo_name: str | None = None, status_quo_features: ObservationFeatures | None = None, optimization_config: OptimizationConfig | None = None, @@ -116,13 +108,22 @@ def __init__( fit_abandoned: bool = False, fit_tracking_metrics: bool = True, fit_on_init: bool = True, - default_model_gen_options: TConfig | None = None, fit_only_completed_map_metrics: bool = True, + default_model_gen_options: TConfig | None = None, + torch_device: torch.device | None = None, ) -> None: - self.device = torch_device - # pyre-ignore [4]: Attribute `_default_model_gen_options` of class - # `TorchAdapter` must have a type that does not contain `Any`. - self._default_model_gen_options = default_model_gen_options or {} + """In addition to common arguments documented in the base ``Adapter`` class, + ``TorchAdapter`` accepts the following arguments. + + Args: + default_model_gen_options: A dictionary of default options to use + during candidate generation. These will be overridden by any + `model_gen_options` passed to the `Adapter.gen` method. + torch_device: The device to use for any torch tensors and operations + on these tensors. + """ + self.device: torch.device | None = torch_device + self._default_model_gen_options: TConfig = default_model_gen_options or {} # Handle init for multi-objective optimization. self.is_moo_problem: bool = False @@ -132,6 +133,14 @@ def __init__( ) self.is_moo_problem = optimization_config.is_moo_problem + # Tracks last set of observations used to fit the model, to skip + # model fitting when it's not necessary. + self._last_observations: list[Observation] | None = None + + # These are set during model fitting. + self.parameters: list[str] = [] + self.outcomes: list[str] = [] + super().__init__( experiment=experiment, search_space=search_space, @@ -150,11 +159,14 @@ def __init__( fit_only_completed_map_metrics=fit_only_completed_map_metrics, ) + # Re-assign self.model for more precise typing. + self.model: TorchGenerator = model + def feature_importances(self, metric_name: str) -> dict[str, float]: - importances_tensor = none_throws(self.model).feature_importances() - importances_dict = dict(zip(self.outcomes, importances_tensor)) + importances_tensor = self.model.feature_importances() + importances_dict = dict(zip(self.outcomes, importances_tensor, strict=True)) importances_arr = importances_dict[metric_name].flatten() - return dict(zip(self.parameters, importances_arr)) + return dict(zip(self.parameters, importances_arr, strict=True)) def infer_objective_thresholds( self, @@ -250,11 +262,11 @@ def model_best_point( search_space=base_gen_args.search_space, pending_observations=base_gen_args.pending_observations, fixed_features=base_gen_args.fixed_features, - model_gen_options=None, + model_gen_options=model_gen_options, optimization_config=base_gen_args.optimization_config, ) try: - xbest = none_throws(self.model).best_point( + xbest = self.model.best_point( search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, ) @@ -441,9 +453,9 @@ def _cross_validate( """Make predictions at cv_test_points using only the data in obs_feats and obs_data. """ - if self.model is None: + if not self.parameters: raise ValueError(FIT_MODEL_ERROR.format(action="_cross_validate")) - datasets, candidate_metadata, search_space_digest = self._get_fit_args( + datasets, _, search_space_digest = self._get_fit_args( search_space=search_space, observations=cv_training_data, parameters=parameters, @@ -457,7 +469,7 @@ def _cross_validate( device=self.device, ) # Use the model to do the cross validation - f_test, cov_test = none_throws(self.model).cross_validate( + f_test, cov_test = self.model.cross_validate( datasets=datasets, X_test=torch.as_tensor(X_test, dtype=torch.double, device=self.device), search_space_digest=search_space_digest, @@ -546,7 +558,7 @@ def _evaluate_acquisition_function( fixed_features: ObservationFeatures | None = None, acq_options: dict[str, Any] | None = None, ) -> list[float]: - if self.model is None: + if not self.parameters: raise RuntimeError( FIT_MODEL_ERROR.format(action="_evaluate_acquisition_function") ) @@ -562,7 +574,7 @@ def _evaluate_acquisition_function( for obsf in observation_features ] ) - evals = none_throws(self.model).evaluate_acquisition_function( + evals = self.model.evaluate_acquisition_function( X=self._array_to_tensor(X), search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, @@ -635,13 +647,12 @@ def _get_fit_args( def _fit( self, - model: TorchGenerator, search_space: SearchSpace, observations: list[Observation], parameters: list[str] | None = None, **kwargs: Any, ) -> None: - if self.model is not None and observations == self._last_observations: + if observations == self._last_observations: logger.debug( "The observations are identical to the last set of observations " "used to fit the model. Skipping model fitting." @@ -653,9 +664,7 @@ def _fit( parameters=parameters, update_outcomes_and_parameters=True, ) - # Fit - self.model = model - none_throws(self.model).fit( + self.model.fit( datasets=datasets, search_space_digest=search_space_digest, candidate_metadata=candidate_metadata, @@ -676,7 +685,7 @@ def _gen( The outcome constraints should be transformed to no longer be relative. """ - if self.model is None: + if not self.parameters: raise ValueError(FIT_MODEL_ERROR.format(action="_gen")) augmented_model_gen_options = { @@ -692,7 +701,7 @@ def _gen( ) # Generate the candidates - gen_results = none_throws(self.model).gen( + gen_results = self.model.gen( n=n, search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, @@ -722,7 +731,7 @@ def _gen( candidate_metadata=gen_results.candidate_metadata, ) try: - xbest = none_throws(self.model).best_point( + xbest = self.model.best_point( search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, ) @@ -747,11 +756,11 @@ def _gen( def _predict( self, observation_features: list[ObservationFeatures] ) -> list[ObservationData]: - if not self.model: + if not self.parameters: raise ValueError(FIT_MODEL_ERROR.format(action="_model_predict")) # Convert observation features to array X = observation_features_to_array(self.parameters, observation_features) - f, cov = none_throws(self.model).predict(X=self._array_to_tensor(X)) + f, cov = self.model.predict(X=self._array_to_tensor(X)) f = f.detach().cpu().clone().numpy() cov = cov.detach().cpu().clone().numpy() if f.shape[-2] != X.shape[-2]: @@ -857,10 +866,10 @@ def _get_transformed_model_gen_args( else None ) if risk_measure is not None: - if not none_throws(self.model)._supports_robust_optimization: + if not self.model._supports_robust_optimization: raise UnsupportedError( f"{self.model.__class__.__name__} does not support robust " - "optimization. Consider using modular BoTorch model instead." + "optimization. Consider using modular BoTorch generator instead." ) else: risk_measure = extract_risk_measure(risk_measure=risk_measure) diff --git a/ax/modelbridge/transforms/tests/test_derelativize_transform.py b/ax/modelbridge/transforms/tests/test_derelativize_transform.py index a8a85eca262..b57539cc636 100644 --- a/ax/modelbridge/transforms/tests/test_derelativize_transform.py +++ b/ax/modelbridge/transforms/tests/test_derelativize_transform.py @@ -12,7 +12,6 @@ from unittest.mock import Mock, patch import numpy as np -from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.metric import Metric from ax.core.objective import Objective @@ -25,6 +24,7 @@ from ax.exceptions.core import DataRequiredError from ax.modelbridge.base import Adapter from ax.modelbridge.transforms.derelativize import Derelativize +from ax.models.base import Generator from ax.utils.common.testutils import TestCase @@ -119,11 +119,8 @@ def _test_DerelativizeTransform( ] ) g = Adapter( - search_space=search_space, - model=None, - transforms=[], - experiment=Experiment(search_space, "test"), - data=Data(), + experiment=Experiment(search_space=search_space), + model=Generator(), status_quo_name="1_1", ) @@ -202,11 +199,8 @@ def _test_DerelativizeTransform( # Test with relative constraint, out-of-design status quo mock_predict.side_effect = RuntimeError() g = Adapter( - search_space=search_space, - model=None, - transforms=[], - experiment=Experiment(search_space, "test"), - data=Data(), + experiment=Experiment(search_space=search_space), + model=Generator(), status_quo_name="1_2", ) oc = OptimizationConfig( @@ -252,11 +246,8 @@ def _test_DerelativizeTransform( # Raises error if predict fails with in-design status quo g = Adapter( - search_space=search_space, - model=None, - transforms=[], - experiment=Experiment(search_space, "test"), - data=Data(), + experiment=Experiment(search_space=search_space), + model=Generator(), status_quo_name="1_1", ) oc = OptimizationConfig( @@ -311,13 +302,7 @@ def _test_DerelativizeTransform( t2.transform_optimization_config(deepcopy(oc_scalarized_only), g, None) # Raises error with relative constraint, no status quo. - g = Adapter( - search_space=search_space, - model=None, - transforms=[], - experiment=Experiment(search_space, "test"), - data=Data(), - ) + g = Adapter(experiment=Experiment(search_space=search_space), model=Generator()) with self.assertRaises(DataRequiredError): t.transform_optimization_config(deepcopy(oc), g, None) @@ -325,7 +310,7 @@ def _test_DerelativizeTransform( with self.assertRaises(ValueError): t.transform_optimization_config(deepcopy(oc), None, None) - def test_Errors(self) -> None: + def test_errors(self) -> None: t = Derelativize( search_space=None, observations=[], @@ -339,8 +324,14 @@ def test_Errors(self) -> None: search_space = SearchSpace( parameters=[RangeParameter("x", ParameterType.FLOAT, 0, 20)] ) - g = Adapter(search_space, None, []) + adapter = Adapter( + experiment=Experiment(search_space=search_space), model=Generator() + ) with self.assertRaises(ValueError): - t.transform_optimization_config(oc, None, None) + t.transform_optimization_config( + optimization_config=oc, modelbridge=None, fixed_features=None + ) with self.assertRaises(DataRequiredError): - t.transform_optimization_config(oc, g, None) + t.transform_optimization_config( + optimization_config=oc, modelbridge=adapter, fixed_features=None + ) diff --git a/ax/modelbridge/transforms/tests/test_relativize_transform.py b/ax/modelbridge/transforms/tests/test_relativize_transform.py index 59f712b6ac3..367a942700e 100644 --- a/ax/modelbridge/transforms/tests/test_relativize_transform.py +++ b/ax/modelbridge/transforms/tests/test_relativize_transform.py @@ -11,6 +11,7 @@ import numpy as np import numpy.typing as npt from ax.core import BatchTrial +from ax.core.experiment import Experiment from ax.core.observation import ( Observation, ObservationData, @@ -91,7 +92,7 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data( ) -> None: for relativize_cls in self.relativize_classes: # modelbridge has no status quo - sobol = Generators.SOBOL(search_space=get_search_space()) + sobol = Generators.SOBOL(experiment=get_branin_experiment()) self.assertIsNone(sobol.status_quo) with self.assertRaisesRegex( AssertionError, f"{relativize_cls.__name__} requires status quo data." @@ -457,7 +458,7 @@ class RelativizeDataOptConfigTest(TestCase): def setUp(self) -> None: super().setUp() search_space = get_search_space() - gr = Generators.SOBOL(search_space=search_space).gen(n=1) + gr = Generators.SOBOL(experiment=Experiment(search_space=search_space)).gen(n=1) self.model = Mock( search_space=search_space, status_quo=Mock( diff --git a/ax/modelbridge/transforms/tests/test_winsorize_transform.py b/ax/modelbridge/transforms/tests/test_winsorize_transform.py index 20045b8ed76..5103eee75ff 100644 --- a/ax/modelbridge/transforms/tests/test_winsorize_transform.py +++ b/ax/modelbridge/transforms/tests/test_winsorize_transform.py @@ -38,6 +38,7 @@ AUTO_WINS_QUANTILE, Winsorize, ) +from ax.models.base import Generator from ax.models.winsorization_config import WinsorizationConfig from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -601,10 +602,9 @@ def test_relative_constraints( ], ) modelbridge = Adapter( - search_space=search_space, - model=None, + experiment=Experiment(search_space=search_space), + model=Generator(), transforms=[], - experiment=Experiment(search_space, "test"), data=Data(), optimization_config=oc, ) @@ -619,10 +619,9 @@ def test_relative_constraints( ) modelbridge = Adapter( - search_space=search_space, - model=None, + experiment=Experiment(search_space=search_space), + model=Generator(), transforms=[], - experiment=Experiment(search_space, "test"), data=Data(), status_quo_name="1_1", optimization_config=oc, @@ -692,7 +691,9 @@ def get_default_transform_cutoffs( covariance=np.eye(obs_data_len), ) obs = Observation(features=ObservationFeatures({}), data=obsd) - modelbridge = _wrap_optimization_config_in_modelbridge(optimization_config) + modelbridge = _wrap_optimization_config_in_modelbridge( + optimization_config=optimization_config + ) transform = Winsorize( search_space=None, observations=[deepcopy(obs)], @@ -708,7 +709,7 @@ def _wrap_optimization_config_in_modelbridge( optimization_config: OptimizationConfig, ) -> Adapter: return Adapter( - search_space=SearchSpace(parameters=[]), - model=1, + experiment=Experiment(search_space=SearchSpace(parameters=[])), + model=Generator(), optimization_config=optimization_config, ) diff --git a/ax/plot/tests/test_feature_importances.py b/ax/plot/tests/test_feature_importances.py index 69aa5a7c508..51847626b42 100644 --- a/ax/plot/tests/test_feature_importances.py +++ b/ax/plot/tests/test_feature_importances.py @@ -11,6 +11,7 @@ import torch from ax.modelbridge.base import Adapter from ax.modelbridge.registry import Generators +from ax.models.torch.botorch import LegacyBoTorchGenerator from ax.plot.base import AxPlotConfig from ax.plot.feature_importances import ( plot_feature_importance_by_feature, @@ -24,6 +25,7 @@ from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import mock_botorch_optimize from plotly import graph_objects as go +from pyre_extensions import assert_is_instance DUMMY_CAPTION = "test_caption" @@ -46,10 +48,13 @@ def get_sensitivity_values(ax_model: Adapter) -> dict: Returns map {'metric_name': {'parameter_name': sensitivity_value}} """ - if hasattr(ax_model.model.model.covar_module, "outputscale"): - ls = ax_model.model.model.covar_module.base_kernel.lengthscale.squeeze() + generator = assert_is_instance(ax_model.model, LegacyBoTorchGenerator) + if hasattr(generator.model.covar_module, "outputscale"): + # pyre-ignore [16]: Covar modules are difficult to type. + ls = generator.model.covar_module.base_kernel.lengthscale.squeeze() else: - ls = ax_model.model.model.covar_module.lengthscale.squeeze() + # pyre-ignore [16]: Covar modules are difficult to type. + ls = generator.model.covar_module.lengthscale.squeeze() if len(ls.shape) > 1: ls = ls.mean(dim=0) # pyre-fixme[16]: `float` has no attribute `detach`.