diff --git a/ax/analysis/healthcheck/can_generate_candidates.py b/ax/analysis/healthcheck/can_generate_candidates.py index 30eb127913e..624f5fabf57 100644 --- a/ax/analysis/healthcheck/can_generate_candidates.py +++ b/ax/analysis/healthcheck/can_generate_candidates.py @@ -7,7 +7,6 @@ import json from datetime import datetime -from typing import Optional import pandas as pd from ax.analysis.analysis import AnalysisCardLevel @@ -46,7 +45,7 @@ def __init__( def compute( self, - experiment: Optional[Experiment] = None, + experiment: Experiment | None = None, generation_strategy: GenerationStrategy | None = None, ) -> HealthcheckAnalysisCard: status = HealthcheckStatus.PASS diff --git a/ax/analysis/healthcheck/constraints_feasibility.py b/ax/analysis/healthcheck/constraints_feasibility.py index 82399c186b0..1d21fe823a3 100644 --- a/ax/analysis/healthcheck/constraints_feasibility.py +++ b/ax/analysis/healthcheck/constraints_feasibility.py @@ -6,7 +6,6 @@ # pyre-strict import json -from typing import Tuple import pandas as pd @@ -156,7 +155,7 @@ def constraints_feasibility( optimization_config: OptimizationConfig, model: Adapter, prob_threshold: float = 0.99, -) -> Tuple[bool, pd.DataFrame]: +) -> tuple[bool, pd.DataFrame]: r""" Check the feasibility of the constraints for the experiment. diff --git a/ax/analysis/healthcheck/regression_analysis.py b/ax/analysis/healthcheck/regression_analysis.py index 43f456e5a66..2c58c79e211 100644 --- a/ax/analysis/healthcheck/regression_analysis.py +++ b/ax/analysis/healthcheck/regression_analysis.py @@ -6,7 +6,6 @@ # pyre-strict import json -from typing import Tuple import pandas as pd from ax.analysis.analysis import AnalysisCardLevel @@ -111,7 +110,7 @@ def compute( def process_regression_dict( regressions_by_trial: dict[int, dict[str, dict[str, float]]], -) -> Tuple[pd.DataFrame, str]: +) -> tuple[pd.DataFrame, str]: r""" Process the dictionary of trial indices, regressing arms and metrics into a dataframe and a string. diff --git a/ax/analysis/healthcheck/tests/test_search_space_analysis.py b/ax/analysis/healthcheck/tests/test_search_space_analysis.py index 725057fa145..0c8b2f5d7aa 100644 --- a/ax/analysis/healthcheck/tests/test_search_space_analysis.py +++ b/ax/analysis/healthcheck/tests/test_search_space_analysis.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import List, Union import pandas as pd from ax.analysis.analysis import AnalysisCardLevel @@ -103,7 +102,7 @@ def test_search_space_boundary_proportions(self) -> None: ], ) - parametrizations: List[dict[str, Union[None, bool, float, int, str]]] = [ + parametrizations: list[dict[str, None | bool | float | int | str]] = [ { "float_range_1": 1.0, "float_range_2": 1.0, diff --git a/ax/analysis/plotly/scatter.py b/ax/analysis/plotly/scatter.py index e058b195202..2154bc6fa14 100644 --- a/ax/analysis/plotly/scatter.py +++ b/ax/analysis/plotly/scatter.py @@ -5,7 +5,6 @@ # pyre-strict -from typing import Optional import pandas as pd from ax.analysis.analysis import AnalysisCardLevel @@ -50,8 +49,8 @@ def __init__( def compute( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategy] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("ScatterPlot requires an Experiment") diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py index 216766c159b..60bb32a7ede 100644 --- a/ax/analysis/plotly/surface/contour.py +++ b/ax/analysis/plotly/surface/contour.py @@ -6,7 +6,6 @@ # pyre-strict import math -from typing import Optional import pandas as pd from ax.analysis.analysis import AnalysisCardLevel @@ -61,8 +60,8 @@ def __init__( def compute( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategy] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("ContourPlot requires an Experiment") diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py index b69754979be..d0cbc6a7a7c 100644 --- a/ax/analysis/plotly/surface/slice.py +++ b/ax/analysis/plotly/surface/slice.py @@ -6,7 +6,6 @@ # pyre-strict import math -from typing import Optional import pandas as pd from ax.analysis.analysis import AnalysisCardLevel @@ -55,8 +54,8 @@ def __init__( def compute( self, - experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategy] = None, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("SlicePlot requires an Experiment") diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index 264f3480a6f..22513d32dea 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -24,7 +24,6 @@ from itertools import product from logging import Logger, WARNING from time import monotonic, time -from typing import Set import numpy as np import numpy.typing as npt @@ -289,7 +288,7 @@ def benchmark_replication( best_params_by_trial: list[list[TParameterization]] = [] is_mf_or_mt = len(problem.target_fidelity_and_task) > 0 - trials_used_for_best_point: Set[int] = set() + trials_used_for_best_point: set[int] = set() # Run the optimization loop. timeout_hours = method.timeout_hours diff --git a/ax/benchmark/benchmark_runner.py b/ax/benchmark/benchmark_runner.py index 9b5868bd864..ad5ac39695b 100644 --- a/ax/benchmark/benchmark_runner.py +++ b/ax/benchmark/benchmark_runner.py @@ -127,7 +127,7 @@ def get_total_runtime( # By default, each step takes 1 virtual second. if step_runtime_function is not None: max_step_runtime = max( - (step_runtime_function(arm.parameters) for arm in trial.arms) + step_runtime_function(arm.parameters) for arm in trial.arms ) else: max_step_runtime = 1 diff --git a/ax/benchmark/tests/test_benchmark_metric.py b/ax/benchmark/tests/test_benchmark_metric.py index 9850c41bf2d..ed07f9b8b26 100644 --- a/ax/benchmark/tests/test_benchmark_metric.py +++ b/ax/benchmark/tests/test_benchmark_metric.py @@ -89,12 +89,10 @@ def get_test_trial( n_steps = 3 if multiple_time_steps else 1 dfs = { name: pd.concat( - ( - _get_one_step_df( - batch=batch, metric_name=name, step=i, observe_noise_sd=True - ) - for i in range(n_steps) + _get_one_step_df( + batch=batch, metric_name=name, step=i, observe_noise_sd=True ) + for i in range(n_steps) ) for name in ["test_metric1", "test_metric2"] } diff --git a/ax/benchmark/tests/test_benchmark_runner.py b/ax/benchmark/tests/test_benchmark_runner.py index c21bd5d37d7..f45cb180740 100644 --- a/ax/benchmark/tests/test_benchmark_runner.py +++ b/ax/benchmark/tests/test_benchmark_runner.py @@ -248,7 +248,7 @@ def test_runner(self) -> None: for i, df in enumerate(res.values()): if isinstance(noise_std, list): self.assertEqual(df["sem"].item(), noise_std[i]) - if all((n == 0 for n in noise_std)): + if all(n == 0 for n in noise_std): self.assertTrue(np.array_equal(df["mean"], Y[i, :])) else: # float self.assertEqual(df["sem"].item(), noise_std) diff --git a/ax/core/search_space.py b/ax/core/search_space.py index 9b214f27205..b75cf380596 100644 --- a/ax/core/search_space.py +++ b/ax/core/search_space.py @@ -10,12 +10,11 @@ import math import warnings -from collections.abc import Callable, Hashable, Mapping +from collections.abc import Callable, Hashable, Mapping, Sequence from dataclasses import dataclass, field from functools import reduce from logging import Logger from random import choice, uniform -from typing import Sequence import numpy.typing as npt import pandas as pd diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 597407556e6..e8aaac8e51c 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -5,7 +5,8 @@ # pyre-strict -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import numpy as np import numpy.typing as npt diff --git a/ax/modelbridge/modelbridge_utils.py b/ax/modelbridge/modelbridge_utils.py index 9a34bc084f4..547fd086e4c 100644 --- a/ax/modelbridge/modelbridge_utils.py +++ b/ax/modelbridge/modelbridge_utils.py @@ -9,11 +9,11 @@ from __future__ import annotations import warnings -from collections.abc import Callable, Iterable, Mapping, MutableMapping +from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence from copy import deepcopy from functools import partial from logging import Logger -from typing import Any, Sequence, SupportsFloat, TYPE_CHECKING +from typing import Any, SupportsFloat, TYPE_CHECKING import numpy as np import numpy.typing as npt diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index 6bf4e0f555a..54ee8d13811 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -9,10 +9,10 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Sequence from copy import deepcopy from logging import Logger -from typing import Any, Sequence +from typing import Any from warnings import warn import numpy as np diff --git a/ax/modelbridge/transforms/log_y.py b/ax/modelbridge/transforms/log_y.py index 1376b270262..500d4a99d37 100644 --- a/ax/modelbridge/transforms/log_y.py +++ b/ax/modelbridge/transforms/log_y.py @@ -8,8 +8,10 @@ from __future__ import annotations +from collections.abc import Callable + from logging import Logger -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING import numpy as np import numpy.typing as npt diff --git a/ax/modelbridge/transforms/metadata_to_float.py b/ax/modelbridge/transforms/metadata_to_float.py index 0e5ed2d585b..0d379c46c6f 100644 --- a/ax/modelbridge/transforms/metadata_to_float.py +++ b/ax/modelbridge/transforms/metadata_to_float.py @@ -8,8 +8,10 @@ from __future__ import annotations +from collections.abc import Iterable + from logging import Logger -from typing import Any, Iterable, Optional, SupportsFloat, TYPE_CHECKING +from typing import Any, SupportsFloat, TYPE_CHECKING from ax.core import ParameterType @@ -51,7 +53,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.Adapter"] = None, + modelbridge: modelbridge_module.base.Adapter | None = None, config: TConfig | None = None, ) -> None: if observations is None or not observations: diff --git a/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py index ee94c237af2..969cd08d7bc 100644 --- a/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py +++ b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py @@ -6,8 +6,8 @@ # pyre-strict +from collections.abc import Iterator from copy import deepcopy -from typing import Iterator import numpy as np from ax.core.observation import Observation, ObservationData, ObservationFeatures diff --git a/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py b/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py index 46a014c0ad8..a26c991162d 100644 --- a/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py +++ b/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py @@ -6,8 +6,8 @@ # pyre-strict +from collections.abc import Iterator from copy import deepcopy -from typing import Iterator import numpy as np from ax.core.observation import Observation, ObservationData, ObservationFeatures diff --git a/ax/modelbridge/transforms/time_as_feature.py b/ax/modelbridge/transforms/time_as_feature.py index cb1053dc51d..84d4d15b8cb 100644 --- a/ax/modelbridge/transforms/time_as_feature.py +++ b/ax/modelbridge/transforms/time_as_feature.py @@ -10,7 +10,7 @@ from logging import Logger from time import time -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import pandas as pd @@ -48,7 +48,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.Adapter"] = None, + modelbridge: modelbridge_module.base.Adapter | None = None, config: TConfig | None = None, ) -> None: assert observations is not None, "TimeAsFeature requires observations" diff --git a/ax/models/torch_base.py b/ax/models/torch_base.py index 7cc3ff23d00..3bd10d3792d 100644 --- a/ax/models/torch_base.py +++ b/ax/models/torch_base.py @@ -8,10 +8,10 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass, field -from typing import Any, Mapping, Sequence +from typing import Any import torch from ax.core.metric import Metric diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index 4053e30c8ef..99140177d05 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -6,8 +6,9 @@ # pyre-strict import json +from collections.abc import Sequence from logging import Logger -from typing import Any, Sequence +from typing import Any import numpy as np diff --git a/ax/preview/api/configs.py b/ax/preview/api/configs.py index d7585c55937..dcc29a08e2e 100644 --- a/ax/preview/api/configs.py +++ b/ax/preview/api/configs.py @@ -5,9 +5,10 @@ # pyre-strict +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, List, Mapping, Sequence +from typing import Any from ax.preview.api.types import TParameterValue from ax.storage.registry_bundle import RegistryBundleBase @@ -58,7 +59,7 @@ class ChoiceParameterConfig: """ name: str - values: List[float] | List[int] | List[str] | List[bool] + values: list[float] | list[int] | list[str] | list[bool] parameter_type: ParameterType is_ordered: bool | None = None dependent_parameters: Mapping[TParameterValue, Sequence[str]] | None = None diff --git a/ax/preview/api/protocols/metric.py b/ax/preview/api/protocols/metric.py index 4ec1d71b3d3..d354408a28f 100644 --- a/ax/preview/api/protocols/metric.py +++ b/ax/preview/api/protocols/metric.py @@ -6,7 +6,8 @@ # pyre-strict -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from ax.preview.api.protocols.utils import _APIMetric from pyre_extensions import override diff --git a/ax/preview/api/protocols/runner.py b/ax/preview/api/protocols/runner.py index 91219c0c3e8..b6e9e426314 100644 --- a/ax/preview/api/protocols/runner.py +++ b/ax/preview/api/protocols/runner.py @@ -6,7 +6,8 @@ # pyre-strict -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from ax.core.trial_status import TrialStatus from ax.preview.api.protocols.utils import _APIRunner diff --git a/ax/preview/api/protocols/utils.py b/ax/preview/api/protocols/utils.py index 878c9edc5b4..55615a3f364 100644 --- a/ax/preview/api/protocols/utils.py +++ b/ax/preview/api/protocols/utils.py @@ -9,7 +9,8 @@ import json from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Iterable, Mapping +from collections.abc import Iterable, Mapping +from typing import Any import pandas as pd diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 8e09030aa45..4a378dd549a 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -5,7 +5,8 @@ # pyre-strict -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import numpy as np diff --git a/ax/preview/api/types.py b/ax/preview/api/types.py index c3cdbc47487..bd1a3b84818 100644 --- a/ax/preview/api/types.py +++ b/ax/preview/api/types.py @@ -5,7 +5,7 @@ # pyre-strict -from typing import Mapping +from collections.abc import Mapping TParameterValue = int | float | str | bool TParameterization = Mapping[str, TParameterValue] diff --git a/ax/preview/api/utils/instantiation/from_string.py b/ax/preview/api/utils/instantiation/from_string.py index 474b561c646..f6e23969e80 100644 --- a/ax/preview/api/utils/instantiation/from_string.py +++ b/ax/preview/api/utils/instantiation/from_string.py @@ -6,7 +6,7 @@ # pyre-strict import re -from typing import Sequence +from collections.abc import Sequence from ax.core.map_metric import MapMetric diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 9f360c6463f..dbe966b36af 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -13,7 +13,7 @@ from functools import partial from logging import Logger -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar import ax.service.utils.early_stopping as early_stopping_utils import numpy as np @@ -182,7 +182,7 @@ class AxClient(AnalysisBase, BestPointMixin, InstantiationBase): def __init__( self, generation_strategy: GenerationStrategy | None = None, - db_settings: Optional[DBSettings] = None, + db_settings: DBSettings | None = None, enforce_sequential_optimization: bool = True, random_seed: int | None = None, torch_device: torch.device | None = None, diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index bd9f85c8c70..c24d83609fd 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -15,7 +15,7 @@ from enum import IntEnum from logging import LoggerAdapter from time import sleep -from typing import Any, cast, NamedTuple, Optional +from typing import Any, cast, NamedTuple import ax.service.utils.early_stopping as early_stopping_utils from ax.core.base_trial import BaseTrial, TrialStatus @@ -213,7 +213,7 @@ def __init__( experiment: Experiment, generation_strategy: GenerationStrategy, options: SchedulerOptions, - db_settings: Optional[DBSettings] = None, + db_settings: DBSettings | None = None, _skip_experiment_save: bool = False, ) -> None: self.experiment = experiment @@ -277,7 +277,7 @@ def from_stored_experiment( cls, experiment_name: str, options: SchedulerOptions, - db_settings: Optional[DBSettings] = None, + db_settings: DBSettings | None = None, generation_strategy: GenerationStrategy | None = None, reduced_state: bool = True, **kwargs: Any, @@ -574,7 +574,7 @@ def run_n_trials( max_trials: int, ignore_global_stopping_strategy: bool = False, timeout_hours: float | None = None, - idle_callback: Optional[Callable[[Scheduler], None]] = None, + idle_callback: Callable[[Scheduler], None] | None = None, ) -> OptimizationResult: """Run up to ``max_trials`` trials; will run all ``max_trials`` unless completion criterion is reached. For base ``Scheduler``, completion criterion @@ -623,7 +623,7 @@ def run_n_trials( def run_all_trials( self, timeout_hours: float | None = None, - idle_callback: Optional[Callable[[Scheduler], None]] = None, + idle_callback: Callable[[Scheduler], None] | None = None, ) -> OptimizationResult: """Run all trials until ``should_consider_optimization_complete`` yields true (by default, ``should_consider_optimization_complete`` will yield true when @@ -1540,7 +1540,7 @@ def _abort_optimization(self, num_preexisting_trials: int) -> dict[str, Any]: def _complete_optimization( self, num_preexisting_trials: int, - idle_callback: Optional[Callable[[Scheduler], None]] = None, + idle_callback: Callable[[Scheduler], None] | None = None, ) -> dict[str, Any]: """Conclude optimization with waiting for anymore running trials and return final results via `wait_for_completed_trials_and_report_results`. diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index d353b28346c..a51a5a2f5da 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -9,12 +9,12 @@ import logging import os import re -from collections.abc import Iterable +from collections.abc import Callable, Iterable from datetime import datetime, timedelta from math import ceil from random import randint from tempfile import NamedTemporaryFile -from typing import Any, Callable, cast, Optional +from typing import Any, cast from unittest.mock import call, Mock, patch import pandas as pd @@ -296,7 +296,7 @@ class AxSchedulerTestCase(TestCase): str, Callable[ [...], - Optional[dict[str, list[ObservationFeatures]]], + dict[str, list[ObservationFeatures]] | None, ], ] = ( f"{Scheduler.__module__}." @@ -307,7 +307,7 @@ class AxSchedulerTestCase(TestCase): str, Callable[ [...], - Optional[dict[str, list[ObservationFeatures]]], + dict[str, list[ObservationFeatures]] | None, ], ] = ( f"{GenerationStrategy.__module__}.extract_pending_observations", @@ -435,7 +435,7 @@ def db_settings(self) -> DBSettings: return DBSettings(encoder=encoder, decoder=decoder) @property - def db_settings_if_always_needed(self) -> Optional[DBSettings]: + def db_settings_if_always_needed(self) -> DBSettings | None: if self.ALWAYS_USE_DB: return self.db_settings return None @@ -1041,7 +1041,7 @@ def test_logging_file_stream(self) -> None: testScheduler.logger.debug(testDebugMessage) - with open(temp_file.name, "r") as f: + with open(temp_file.name) as f: log_contents = f.read() self.assertIn(testDebugMessage, log_contents) temp_file.close() diff --git a/ax/service/utils/analysis_base.py b/ax/service/utils/analysis_base.py index c2c0b612b22..fe7204a1d19 100644 --- a/ax/service/utils/analysis_base.py +++ b/ax/service/utils/analysis_base.py @@ -5,7 +5,7 @@ # pyre-strict import traceback -from typing import Iterable +from collections.abc import Iterable import pandas as pd diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index bbd1a5e8d66..a2d7ce48af6 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -7,11 +7,10 @@ # pyre-strict from collections import OrderedDict -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from functools import reduce from logging import Logger -from typing import Mapping import pandas as pd import torch diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index 92f2d7ccd94..94267d2f090 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -1186,8 +1186,8 @@ def _objective_vs_true_objective_scatter( # TODO: may want to have a way to do this with a plot_fn # that returns a list of plots, such as get_standard_plots def get_figure_and_callback( - plot_fn: Callable[["Scheduler"], go.Figure], -) -> tuple[go.Figure, Callable[["Scheduler"], None]]: + plot_fn: Callable[[Scheduler], go.Figure], +) -> tuple[go.Figure, Callable[[Scheduler], None]]: """ Produce a figure and a callback for updating the figure in place. @@ -1212,7 +1212,7 @@ def get_figure_and_callback( fig = go.FigureWidget(layout=go.Layout()) # pyre-fixme[53]: Captured variable `fig` is not annotated. - def _update_fig_in_place(scheduler: "Scheduler") -> None: + def _update_fig_in_place(scheduler: Scheduler) -> None: try: new_fig = plot_fn(scheduler) except RuntimeError as e: diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index f00a0ad25b2..ac5cf2c78e9 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -8,10 +8,9 @@ import re import time -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from logging import INFO, Logger -from typing import Optional, Sequence from ax.analysis.analysis import AnalysisCard @@ -94,11 +93,11 @@ class WithDBSettingsBase: if `db_settings` property is set to a non-None value on the instance. """ - _db_settings: Optional[DBSettings] = None + _db_settings: DBSettings | None = None def __init__( self, - db_settings: Optional[DBSettings] = None, + db_settings: DBSettings | None = None, logging_level: int = INFO, suppress_all_errors: bool = False, ) -> None: @@ -118,7 +117,7 @@ def __init__( logger.setLevel(logging_level) @staticmethod - def _get_default_db_settings() -> Optional[DBSettings]: + def _get_default_db_settings() -> DBSettings | None: """Overridable method to get default db_settings if none are passed in __init__ """ diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 560f286207d..da5ec6a2ef4 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -10,7 +10,7 @@ from datetime import datetime from decimal import Decimal -from typing import Any, List +from typing import Any from ax.core.batch_trial import LifecycleStage from ax.core.parameter import ParameterType @@ -80,10 +80,10 @@ class SQAParameter(Base): upper: Column[Decimal | None] = Column(Float) # Attributes for Choice Parameters - choice_values: Column[List[TParamValue] | None] = Column(JSONEncodedList) + choice_values: Column[list[TParamValue] | None] = Column(JSONEncodedList) is_ordered: Column[bool | None] = Column(Boolean) is_task: Column[bool | None] = Column(Boolean) - dependents: Column[dict[TParamValue, List[str]] | None] = Column(JSONEncodedObject) + dependents: Column[dict[TParamValue, list[str]] | None] = Column(JSONEncodedObject) # Attributes for Fixed Parameters fixed_value: Column[TParamValue | None] = Column(JSONEncodedObject) @@ -135,7 +135,7 @@ class SQAMetric(Base): # of Multi/Scalarized Objective contains all children of the parent metric # join_depth argument: used for loading self-referential relationships # https://docs.sqlalchemy.org/en/13/orm/self_referential.html#configuring-self-referential-eager-loading - scalarized_objective_children_metrics: List["SQAMetric"] = relationship( + scalarized_objective_children_metrics: list[SQAMetric] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy=True, @@ -147,7 +147,7 @@ class SQAMetric(Base): scalarized_outcome_constraint_id: Column[int | None] = Column( Integer, ForeignKey("metric_v2.id") ) - scalarized_outcome_constraint_children_metrics: List["SQAMetric"] = relationship( + scalarized_outcome_constraint_children_metrics: list[SQAMetric] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy=True, @@ -214,19 +214,19 @@ class SQAGeneratorRun(Base): # relationships # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) - arms: List[SQAArm] = relationship( + arms: list[SQAArm] = relationship( "SQAArm", cascade="all, delete-orphan", lazy="selectin", order_by=lambda: SQAArm.id, ) - metrics: List[SQAMetric] = relationship( + metrics: list[SQAMetric] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy="selectin" ) - parameters: List[SQAParameter] = relationship( + parameters: list[SQAParameter] = relationship( "SQAParameter", cascade="all, delete-orphan", lazy="selectin" ) - parameter_constraints: List[SQAParameterConstraint] = relationship( + parameter_constraints: list[SQAParameterConstraint] = relationship( "SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin" ) @@ -268,15 +268,15 @@ class SQAGenerationStrategy(Base): id: Column[int] = Column(Integer, primary_key=True) name: Column[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH), nullable=False) - steps: Column[List[dict[str, Any]]] = Column(JSONEncodedList, nullable=False) + steps: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=False) curr_index: Column[int | None] = Column(Integer, nullable=True) experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) - nodes: Column[List[dict[str, Any]]] = Column(JSONEncodedList, nullable=True) + nodes: Column[list[dict[str, Any]]] = Column(JSONEncodedList, nullable=True) curr_node_name: Column[str | None] = Column( String(NAME_OR_TYPE_FIELD_LENGTH), nullable=True ) - generator_runs: List[SQAGeneratorRun] = relationship( + generator_runs: list[SQAGeneratorRun] = relationship( "SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin", @@ -322,10 +322,10 @@ class SQATrial(Base): # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) - abandoned_arms: List[SQAAbandonedArm] = relationship( + abandoned_arms: list[SQAAbandonedArm] = relationship( "SQAAbandonedArm", cascade="all, delete-orphan", lazy="selectin" ) - generator_runs: List[SQAGeneratorRun] = relationship( + generator_runs: list[SQAGeneratorRun] = relationship( "SQAGeneratorRun", cascade="all, delete-orphan", lazy="selectin" ) runner: SQARunner = relationship( @@ -372,7 +372,7 @@ class SQAExperiment(Base): # pyre-fixme[8]: Incompatible attribute type [8]: Attribute # `auxiliary_experiments_by_purpose` declared in class `SQAExperiment` has # type `Optional[Dict[str, List[str]]]` but is used as type `Column[typing.Any]` - auxiliary_experiments_by_purpose: dict[str, List[dict[str, Any]]] | None = Column( + auxiliary_experiments_by_purpose: dict[str, list[dict[str, Any]]] | None = Column( JSONEncodedTextDict, nullable=True, default={} ) @@ -382,22 +382,22 @@ class SQAExperiment(Base): # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) - data: List[SQAData] = relationship( + data: list[SQAData] = relationship( "SQAData", cascade="all, delete-orphan", lazy="selectin" ) - metrics: List[SQAMetric] = relationship( + metrics: list[SQAMetric] = relationship( "SQAMetric", cascade="all, delete-orphan", lazy="selectin" ) - parameters: List[SQAParameter] = relationship( + parameters: list[SQAParameter] = relationship( "SQAParameter", cascade="all, delete-orphan", lazy="selectin" ) - parameter_constraints: List[SQAParameterConstraint] = relationship( + parameter_constraints: list[SQAParameterConstraint] = relationship( "SQAParameterConstraint", cascade="all, delete-orphan", lazy="selectin" ) - runners: List[SQARunner] = relationship( + runners: list[SQARunner] = relationship( "SQARunner", cascade="all, delete-orphan", lazy=False ) - trials: List[SQATrial] = relationship( + trials: list[SQATrial] = relationship( "SQATrial", cascade="all, delete-orphan", lazy="selectin" ) generation_strategy: SQAGenerationStrategy | None = relationship( @@ -406,6 +406,6 @@ class SQAExperiment(Base): uselist=False, lazy=True, ) - analysis_cards: List[SQAAnalysisCard] = relationship( + analysis_cards: list[SQAAnalysisCard] = relationship( "SQAAnalysisCard", cascade="all, delete-orphan", lazy="selectin" ) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 0c03cd49dd7..39765bd8719 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -7,11 +7,12 @@ # pyre-strict import logging +from collections.abc import Callable from datetime import datetime from decimal import Decimal from enum import Enum, unique from logging import Logger -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar from unittest import mock from unittest.mock import MagicMock, Mock, patch diff --git a/ax/utils/common/func_enum.py b/ax/utils/common/func_enum.py index 2a5d8c5d92d..125f93ed493 100644 --- a/ax/utils/common/func_enum.py +++ b/ax/utils/common/func_enum.py @@ -4,9 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +from collections.abc import Callable from enum import Enum, unique from importlib import import_module -from typing import Any, Callable +from typing import Any from ax.exceptions.core import UnsupportedError diff --git a/ax/utils/common/tests/test_testutils.py b/ax/utils/common/tests/test_testutils.py index 13e72920bf5..be09fe492e1 100644 --- a/ax/utils/common/tests/test_testutils.py +++ b/ax/utils/common/tests/test_testutils.py @@ -100,11 +100,6 @@ def test_silence_warning(self) -> None: sys.stderr = old_err self.assertTrue(new_stderr.getvalue().startswith("A message\n")) - def test_fail_deprecated(self) -> None: - self.assertEqual(1, 1) - with self.assertRaises(RuntimeError): - self.assertEquals(1, 1) - def test_ax_long_test_decorator(self) -> None: testReason: str = "testReason" diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index fdd52a05c1f..ca79ed0e714 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -5,8 +5,9 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +from collections.abc import Iterator from dataclasses import dataclass, field -from typing import Any, Iterator +from typing import Any import numpy as np import torch diff --git a/scripts/convert_ipynb_to_mdx.py b/scripts/convert_ipynb_to_mdx.py index b2d28cf6d54..86961987a43 100644 --- a/scripts/convert_ipynb_to_mdx.py +++ b/scripts/convert_ipynb_to_mdx.py @@ -12,7 +12,7 @@ import subprocess import uuid from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Union import mdformat import nbformat @@ -45,7 +45,7 @@ ] -def load_nb_metadata() -> Dict[str, Dict[str, str]]: +def load_nb_metadata() -> dict[str, dict[str, str]]: """ Load the metadata and list of notebooks that are to be converted to MDX. @@ -83,7 +83,7 @@ def load_notebook(path: Path) -> NotebookNode: return nb -def create_folders(path: Path) -> Tuple[str, Path]: +def create_folders(path: Path) -> tuple[str, Path]: """ Create asset folders for the tutorial. @@ -109,7 +109,7 @@ def create_folders(path: Path) -> Tuple[str, Path]: return filename, assets_folder -def create_frontmatter(path: Path, nb_metadata: Dict[str, Dict[str, str]]) -> str: +def create_frontmatter(path: Path, nb_metadata: dict[str, dict[str, str]]) -> str: """ Create frontmatter for the resulting MDX file. @@ -154,7 +154,7 @@ def create_imports() -> str: return f"{imports}\n" -def get_current_git_tag() -> Optional[str]: +def get_current_git_tag() -> str | None: """ Retrieve the current Git tag if the current commit is tagged. @@ -175,7 +175,7 @@ def get_current_git_tag() -> Optional[str]: def create_buttons( - nb_metadata: Dict[str, Dict[str, str]], + nb_metadata: dict[str, dict[str, str]], ) -> str: """ Create buttons that link to Colab and GitHub for the tutorial. @@ -397,8 +397,8 @@ def handle_cell_input(cell: NotebookNode, language: str) -> str: def handle_image( - values: List[Dict[str, Union[int, str, NotebookNode]]], -) -> List[Tuple[int, str]]: + values: list[dict[str, int | str | NotebookNode]], +) -> list[tuple[int, str]]: """ Convert embedded images to string MDX can consume. @@ -421,8 +421,8 @@ def handle_image( def handle_markdown( - values: List[Dict[str, Union[int, str, NotebookNode]]], -) -> List[Tuple[int, str]]: + values: list[dict[str, int | str | NotebookNode]], +) -> list[tuple[int, str]]: """ Convert and format Markdown for MDX. @@ -445,8 +445,8 @@ def handle_markdown( def handle_pandas( - values: List[Dict[str, Union[int, str, NotebookNode]]], -) -> List[Tuple[int, str]]: + values: list[dict[str, int | str | NotebookNode]], +) -> list[tuple[int, str]]: """ Handle how to display pandas DataFrames. @@ -488,8 +488,8 @@ def handle_pandas( def handle_plain( - values: List[Dict[str, Union[int, str, NotebookNode]]], -) -> List[Tuple[int, str]]: + values: list[dict[str, int | str | NotebookNode]], +) -> list[tuple[int, str]]: """ Handle how to plain cell output should be displayed in MDX. @@ -519,9 +519,9 @@ def handle_plain( def handle_plotly( - values: List[Dict[str, Union[int, str, NotebookNode]]], + values: list[dict[str, int | str | NotebookNode]], plot_data_folder: Path, -) -> List[Tuple[int, str]]: +) -> list[tuple[int, str]]: """ Convert Plotly outputs to MDX. @@ -552,8 +552,8 @@ def handle_plotly( def handle_tqdm( - values: List[Dict[str, Union[int, str, NotebookNode]]], -) -> List[Tuple[int, str]]: + values: list[dict[str, int | str | NotebookNode]], +) -> list[tuple[int, str]]: """ Handle the output of tqdm. @@ -575,9 +575,9 @@ def handle_tqdm( return [(index, f"\n{{\n `{md}`\n}}\n\n\n")] -CELL_OUTPUTS_TO_PROCESS = Dict[ +CELL_OUTPUTS_TO_PROCESS = dict[ str, - List[Dict[str, Union[int, str, NotebookNode]]], + list[dict[str, Union[int, str, NotebookNode]]], ] @@ -622,8 +622,8 @@ def aggregate_mdx( def prioritize_dtypes( - cell_outputs: List[NotebookNode], -) -> Tuple[List[List[str]], List[bool]]: + cell_outputs: list[NotebookNode], +) -> tuple[list[list[str]], list[bool]]: """ Prioritize cell output data types. @@ -664,7 +664,7 @@ def aggregate_images_and_plotly( prioritized_data_dtype: str, cell_output: NotebookNode, data: NotebookNode, - plotly_flags: List[bool], + plotly_flags: list[bool], cell_outputs_to_process: CELL_OUTPUTS_TO_PROCESS, i: int, ) -> None: @@ -727,7 +727,7 @@ def aggregate_plain_output( cell_outputs_to_process["plain"].append({"index": i, "data": data}) -def aggregate_output_types(cell_outputs: List[NotebookNode]) -> CELL_OUTPUTS_TO_PROCESS: +def aggregate_output_types(cell_outputs: list[NotebookNode]) -> CELL_OUTPUTS_TO_PROCESS: """ Aggregate cell outputs into a dictionary for further processing. diff --git a/scripts/validate_sphinx.py b/scripts/validate_sphinx.py index 8625830c64c..fc4d2963705 100755 --- a/scripts/validate_sphinx.py +++ b/scripts/validate_sphinx.py @@ -34,7 +34,7 @@ def parse_rst(rst_filename: str) -> Set[str]: """Extract automodule directives from rst.""" ret = set() - with open(rst_filename, "r") as f: + with open(rst_filename) as f: lines = f.readlines() for line in lines: line = line.strip()