Skip to content

Commit

Permalink
Use immutable type hints for transforms in Adapter & ModelSetup, make…
Browse files Browse the repository at this point in the history
… Adapter a normal (non-ABC) class (#3422)

Summary:
Pull Request resolved: #3422

- Adapter was an ABC class, but it didn't have any abstract methods. Not being abstracts allows us to use it in tests, which seems like a good enough reason to keep it that way. Removed ABC inheritance.
- Transforms & transform configs were typed as mutable list / dict, which necessitated various type casting in tests. Updated them to immutable Sequence & Mapping to improve typing experience.

Reviewed By: esantorella

Differential Revision: D70256202

fbshipit-source-id: f499fc1203529aad953daf9f63bfee953972509a
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 26, 2025
1 parent 927dcec commit c146114
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 40 deletions.
3 changes: 1 addition & 2 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ax.core.map_metric import MapMetric
from ax.core.objective import MultiObjective
from ax.core.trial_status import TrialStatus

from ax.early_stopping.utils import estimate_early_stopping_savings
from ax.modelbridge.map_torch import MapTorchAdapter
from ax.modelbridge.modelbridge_utils import (
Expand Down Expand Up @@ -515,7 +514,7 @@ def get_training_data(
def get_transform_helper_model(
experiment: Experiment,
data: Data,
transforms: list[type[Transform]] | None = None,
transforms: Sequence[type[Transform]] | None = None,
) -> MapTorchAdapter:
"""
Constructs a TorchAdapter, to be used as a helper for transforming parameters.
Expand Down
17 changes: 8 additions & 9 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

import json
import time
from abc import ABC
from collections import OrderedDict
from collections.abc import MutableMapping
from collections.abc import Mapping, MutableMapping, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from logging import Logger
Expand Down Expand Up @@ -69,7 +68,7 @@ class GenResults:
gen_metadata: dict[str, Any] = field(default_factory=dict)


class Adapter(ABC): # noqa: B024 -- Adapter doesn't have any abstract methods.
class Adapter:
"""The main object for using models in Ax.
Adapter specifies 3 methods for using models:
Expand Down Expand Up @@ -99,10 +98,10 @@ def __init__(
search_space: SearchSpace,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
model: Any,
transforms: list[type[Transform]] | None = None,
transforms: Sequence[type[Transform]] | None = None,
experiment: Experiment | None = None,
data: Data | None = None,
transform_configs: dict[str, TConfig] | None = None,
transform_configs: Mapping[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
Expand Down Expand Up @@ -163,7 +162,7 @@ def __init__(
"""
t_fit_start = time.monotonic()
transforms = transforms or []
transforms = [Cast] + transforms
transforms = [Cast] + list(transforms)

self.fit_time: float = 0.0
self.fit_time_since_gen: float = 0.0
Expand All @@ -184,7 +183,7 @@ def __init__(
# space to cover training data.
self._model_space: SearchSpace = search_space.clone()
self._raw_transforms = transforms
self._transform_configs: dict[str, TConfig] | None = transform_configs
self._transform_configs: Mapping[str, TConfig] | None = transform_configs
self._fit_out_of_design = fit_out_of_design
self._fit_abandoned = fit_abandoned
self._fit_tracking_metrics = fit_tracking_metrics
Expand Down Expand Up @@ -314,8 +313,8 @@ def _transform_data(
self,
observations: list[Observation],
search_space: SearchSpace,
transforms: list[type[Transform]] | None,
transform_configs: dict[str, TConfig] | None,
transforms: Sequence[type[Transform]] | None,
transform_configs: Mapping[str, TConfig] | None,
assign_transforms: bool = True,
) -> tuple[list[Observation], SearchSpace]:
"""Initialize transforms and apply them to provided data."""
Expand Down
5 changes: 3 additions & 2 deletions ax/modelbridge/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from collections.abc import Mapping, Sequence
from logging import Logger

import torch
Expand Down Expand Up @@ -109,8 +110,8 @@ def get_botorch(
data: Data,
search_space: SearchSpace | None = None,
device: torch.device = DEFAULT_TORCH_DEVICE,
transforms: list[type[Transform]] = Cont_X_trans + Y_trans,
transform_configs: dict[str, TConfig] | None = None,
transforms: Sequence[type[Transform]] = Cont_X_trans + Y_trans,
transform_configs: Mapping[str, TConfig] | None = None,
model_constructor: TModelConstructor = get_and_fit_model,
model_predictor: TModelPredictor = predict_from_model,
acqf_constructor: TAcqfConstructor = get_qLogNEI,
Expand Down
7 changes: 3 additions & 4 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

# pyre-strict

from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any

import numpy as np
import numpy.typing as npt

import torch
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
Expand Down Expand Up @@ -61,8 +60,8 @@ def __init__(
search_space: SearchSpace,
data: Data,
model: TorchGenerator,
transforms: list[type[Transform]],
transform_configs: dict[str, TConfig] | None = None,
transforms: Sequence[type[Transform]],
transform_configs: Mapping[str, TConfig] | None = None,
torch_device: torch.device | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
Expand Down
7 changes: 3 additions & 4 deletions ax/modelbridge/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
# pyre-strict


from collections.abc import Mapping, Sequence
from typing import Any

from ax.core.data import Data

from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
Expand Down Expand Up @@ -80,7 +80,6 @@ class RandomAdapter(Adapter):
the transformed inputs.
"""

# pyre-fixme[13]: Attribute `model` is never initialized.
model: RandomGenerator
# pyre-fixme[13]: Attribute `parameters` is never initialized.
parameters: list[str]
Expand All @@ -90,10 +89,10 @@ def __init__(
search_space: SearchSpace,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
model: Any,
transforms: list[type[Transform]] | None = None,
transforms: Sequence[type[Transform]] | None = None,
experiment: Experiment | None = None,
data: Data | None = None,
transform_configs: dict[str, TConfig] | None = None,
transform_configs: Mapping[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
Expand Down
9 changes: 5 additions & 4 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
from enum import Enum
from inspect import isfunction, signature
from logging import Logger
Expand Down Expand Up @@ -176,10 +177,10 @@ class ModelSetup(NamedTuple):

bridge_class: type[Adapter]
model_class: type[Generator]
transforms: list[type[Transform]]
default_model_kwargs: dict[str, Any] | None = None
standard_bridge_kwargs: dict[str, Any] | None = None
not_saved_model_kwargs: list[str] | None = None
transforms: Sequence[type[Transform]]
default_model_kwargs: Mapping[str, Any] | None = None
standard_bridge_kwargs: Mapping[str, Any] | None = None
not_saved_model_kwargs: Sequence[str] | None = None


"""A mapping of string keys that indicate a model, to the corresponding
Expand Down
3 changes: 1 addition & 2 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
)
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.registry import Generators, Y_trans
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
from ax.models.base import Generator
from ax.utils.common.constants import Keys
Expand Down Expand Up @@ -85,7 +84,7 @@ def test_Adapter(
self, mock_fit: Mock, mock_gen_arms: Mock, mock_observations_from_data: Mock
) -> None:
# Test that on init transforms are stored and applied in the correct order
transforms: list[type[Transform]] = [transform_1, transform_2]
transforms = [transform_1, transform_2]
exp = get_experiment_for_value()
ss = get_search_space_for_value()
modelbridge = Adapter(
Expand Down
3 changes: 2 additions & 1 deletion ax/modelbridge/tests/test_torch_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from collections.abc import Sequence
from contextlib import ExitStack
from typing import Any
from unittest import mock
Expand Down Expand Up @@ -60,7 +61,7 @@

def _get_modelbridge_from_experiment(
experiment: Experiment,
transforms: list[type[Transform]] | None = None,
transforms: Sequence[type[Transform]] | None = None,
device: torch.device | None = None,
fit_on_init: bool = True,
) -> TorchAdapter:
Expand Down
6 changes: 3 additions & 3 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable, Sequence
from collections.abc import Callable, Mapping, Sequence
from copy import deepcopy
from logging import Logger
from typing import Any
Expand Down Expand Up @@ -105,8 +105,8 @@ def __init__(
search_space: SearchSpace,
data: Data,
model: TorchGenerator,
transforms: list[type[Transform]],
transform_configs: dict[str, TConfig] | None = None,
transforms: Sequence[type[Transform]],
transform_configs: Mapping[str, TConfig] | None = None,
torch_device: torch.device | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@


class DerelativizeTransformTest(TestCase):
def setUp(self) -> None:
super().setUp()
m = mock.patch.object(Adapter, "__abstractmethods__", frozenset())
self.addCleanup(m.stop)
m.start()

def test_DerelativizeTransform(self) -> None:
for negative_metrics in [False, True]:
sq_sign = -1.0 if negative_metrics else 1.0
Expand Down
5 changes: 2 additions & 3 deletions ax/utils/common/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

# pyre-strict

from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Mapping
from inspect import Parameter, signature

from logging import Logger
from typing import Any

Expand All @@ -20,7 +19,7 @@


def consolidate_kwargs(
kwargs_iterable: Iterable[dict[str, Any] | None], keywords: Iterable[str]
kwargs_iterable: Iterable[Mapping[str, Any] | None], keywords: Iterable[str]
) -> dict[str, Any]:
"""Combine an iterable of kwargs into a single dict of kwargs, where kwargs
by duplicate keys that appear later in the iterable get priority over the
Expand Down

0 comments on commit c146114

Please sign in to comment.