Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebook#3385

Reviewed By: esantorella

Differential Revision: D69738468

fbshipit-source-id: cdf8316d0d9d09db0d0306b74c3509e575a1209a
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 18, 2025
1 parent 185cea1 commit ec07e1c
Show file tree
Hide file tree
Showing 43 changed files with 127 additions and 131 deletions.
3 changes: 1 addition & 2 deletions ax/analysis/healthcheck/can_generate_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import json
from datetime import datetime
from typing import Optional

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
Expand Down Expand Up @@ -46,7 +45,7 @@ def __init__(

def compute(
self,
experiment: Optional[Experiment] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> HealthcheckAnalysisCard:
status = HealthcheckStatus.PASS
Expand Down
3 changes: 1 addition & 2 deletions ax/analysis/healthcheck/constraints_feasibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# pyre-strict

import json
from typing import Tuple

import pandas as pd

Expand Down Expand Up @@ -156,7 +155,7 @@ def constraints_feasibility(
optimization_config: OptimizationConfig,
model: Adapter,
prob_threshold: float = 0.99,
) -> Tuple[bool, pd.DataFrame]:
) -> tuple[bool, pd.DataFrame]:
r"""
Check the feasibility of the constraints for the experiment.
Expand Down
3 changes: 1 addition & 2 deletions ax/analysis/healthcheck/regression_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# pyre-strict

import json
from typing import Tuple

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
Expand Down Expand Up @@ -111,7 +110,7 @@ def compute(

def process_regression_dict(
regressions_by_trial: dict[int, dict[str, dict[str, float]]],
) -> Tuple[pd.DataFrame, str]:
) -> tuple[pd.DataFrame, str]:
r"""
Process the dictionary of trial indices, regressing arms and metrics into
a dataframe and a string.
Expand Down
3 changes: 1 addition & 2 deletions ax/analysis/healthcheck/tests/test_search_space_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

from typing import List, Union

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
Expand Down Expand Up @@ -103,7 +102,7 @@ def test_search_space_boundary_proportions(self) -> None:
],
)

parametrizations: List[dict[str, Union[None, bool, float, int, str]]] = [
parametrizations: list[dict[str, None | bool | float | int | str]] = [
{
"float_range_1": 1.0,
"float_range_2": 1.0,
Expand Down
5 changes: 2 additions & 3 deletions ax/analysis/plotly/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

from typing import Optional

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
Expand Down Expand Up @@ -50,8 +49,8 @@ def __init__(

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("ScatterPlot requires an Experiment")
Expand Down
5 changes: 2 additions & 3 deletions ax/analysis/plotly/surface/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# pyre-strict

import math
from typing import Optional

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
Expand Down Expand Up @@ -61,8 +60,8 @@ def __init__(

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("ContourPlot requires an Experiment")
Expand Down
5 changes: 2 additions & 3 deletions ax/analysis/plotly/surface/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# pyre-strict

import math
from typing import Optional

import pandas as pd
from ax.analysis.analysis import AnalysisCardLevel
Expand Down Expand Up @@ -55,8 +54,8 @@ def __init__(

def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategy] = None,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("SlicePlot requires an Experiment")
Expand Down
3 changes: 1 addition & 2 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from itertools import product
from logging import Logger, WARNING
from time import monotonic, time
from typing import Set

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -289,7 +288,7 @@ def benchmark_replication(
best_params_by_trial: list[list[TParameterization]] = []

is_mf_or_mt = len(problem.target_fidelity_and_task) > 0
trials_used_for_best_point: Set[int] = set()
trials_used_for_best_point: set[int] = set()

# Run the optimization loop.
timeout_hours = method.timeout_hours
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_total_runtime(
# By default, each step takes 1 virtual second.
if step_runtime_function is not None:
max_step_runtime = max(
(step_runtime_function(arm.parameters) for arm in trial.arms)
step_runtime_function(arm.parameters) for arm in trial.arms
)
else:
max_step_runtime = 1
Expand Down
8 changes: 3 additions & 5 deletions ax/benchmark/tests/test_benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ def get_test_trial(
n_steps = 3 if multiple_time_steps else 1
dfs = {
name: pd.concat(
(
_get_one_step_df(
batch=batch, metric_name=name, step=i, observe_noise_sd=True
)
for i in range(n_steps)
_get_one_step_df(
batch=batch, metric_name=name, step=i, observe_noise_sd=True
)
for i in range(n_steps)
)
for name in ["test_metric1", "test_metric2"]
}
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/tests/test_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_runner(self) -> None:
for i, df in enumerate(res.values()):
if isinstance(noise_std, list):
self.assertEqual(df["sem"].item(), noise_std[i])
if all((n == 0 for n in noise_std)):
if all(n == 0 for n in noise_std):
self.assertTrue(np.array_equal(df["mean"], Y[i, :]))
else: # float
self.assertEqual(df["sem"].item(), noise_std)
Expand Down
3 changes: 1 addition & 2 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

import math
import warnings
from collections.abc import Callable, Hashable, Mapping
from collections.abc import Callable, Hashable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import reduce
from logging import Logger
from random import choice, uniform
from typing import Sequence

import numpy.typing as npt
import pandas as pd
Expand Down
3 changes: 2 additions & 1 deletion ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

# pyre-strict

from typing import Any, Sequence
from collections.abc import Sequence
from typing import Any

import numpy as np
import numpy.typing as npt
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/modelbridge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from __future__ import annotations

import warnings
from collections.abc import Callable, Iterable, Mapping, MutableMapping
from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence
from copy import deepcopy
from functools import partial
from logging import Logger
from typing import Any, Sequence, SupportsFloat, TYPE_CHECKING
from typing import Any, SupportsFloat, TYPE_CHECKING

import numpy as np
import numpy.typing as npt
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable
from collections.abc import Callable, Sequence
from copy import deepcopy
from logging import Logger
from typing import Any, Sequence
from typing import Any
from warnings import warn

import numpy as np
Expand Down
4 changes: 3 additions & 1 deletion ax/modelbridge/transforms/log_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

from __future__ import annotations

from collections.abc import Callable

from logging import Logger
from typing import Callable, TYPE_CHECKING
from typing import TYPE_CHECKING

import numpy as np
import numpy.typing as npt
Expand Down
6 changes: 4 additions & 2 deletions ax/modelbridge/transforms/metadata_to_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

from __future__ import annotations

from collections.abc import Iterable

from logging import Logger
from typing import Any, Iterable, Optional, SupportsFloat, TYPE_CHECKING
from typing import Any, SupportsFloat, TYPE_CHECKING

from ax.core import ParameterType

Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.Adapter"] = None,
modelbridge: modelbridge_module.base.Adapter | None = None,
config: TConfig | None = None,
) -> None:
if observations is None or not observations:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

# pyre-strict

from collections.abc import Iterator
from copy import deepcopy
from typing import Iterator

import numpy as np
from ax.core.observation import Observation, ObservationData, ObservationFeatures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

# pyre-strict

from collections.abc import Iterator
from copy import deepcopy
from typing import Iterator

import numpy as np
from ax.core.observation import Observation, ObservationData, ObservationFeatures
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/transforms/time_as_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from logging import Logger
from time import time
from typing import Optional, TYPE_CHECKING
from typing import TYPE_CHECKING

import pandas as pd

Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.Adapter"] = None,
modelbridge: modelbridge_module.base.Adapter | None = None,
config: TConfig | None = None,
) -> None:
assert observations is not None, "TimeAsFeature requires observations"
Expand Down
4 changes: 2 additions & 2 deletions ax/models/torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Mapping, Sequence

from dataclasses import dataclass, field
from typing import Any, Mapping, Sequence
from typing import Any

import torch
from ax.core.metric import Metric
Expand Down
3 changes: 2 additions & 1 deletion ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
# pyre-strict

import json
from collections.abc import Sequence
from logging import Logger
from typing import Any, Sequence
from typing import Any

import numpy as np

Expand Down
5 changes: 3 additions & 2 deletions ax/preview/api/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

# pyre-strict

from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, List, Mapping, Sequence
from typing import Any

from ax.preview.api.types import TParameterValue
from ax.storage.registry_bundle import RegistryBundleBase
Expand Down Expand Up @@ -58,7 +59,7 @@ class ChoiceParameterConfig:
"""

name: str
values: List[float] | List[int] | List[str] | List[bool]
values: list[float] | list[int] | list[str] | list[bool]
parameter_type: ParameterType
is_ordered: bool | None = None
dependent_parameters: Mapping[TParameterValue, Sequence[str]] | None = None
Expand Down
3 changes: 2 additions & 1 deletion ax/preview/api/protocols/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# pyre-strict


from typing import Any, Mapping
from collections.abc import Mapping
from typing import Any

from ax.preview.api.protocols.utils import _APIMetric
from pyre_extensions import override
Expand Down
3 changes: 2 additions & 1 deletion ax/preview/api/protocols/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# pyre-strict


from typing import Any, Mapping
from collections.abc import Mapping
from typing import Any

from ax.core.trial_status import TrialStatus
from ax.preview.api.protocols.utils import _APIRunner
Expand Down
3 changes: 2 additions & 1 deletion ax/preview/api/protocols/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import json
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Iterable, Mapping
from collections.abc import Iterable, Mapping
from typing import Any

import pandas as pd

Expand Down
3 changes: 2 additions & 1 deletion ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

# pyre-strict

from typing import Any, Mapping
from collections.abc import Mapping
from typing import Any

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion ax/preview/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-strict

from typing import Mapping
from collections.abc import Mapping

TParameterValue = int | float | str | bool
TParameterization = Mapping[str, TParameterValue]
Expand Down
Loading

0 comments on commit ec07e1c

Please sign in to comment.