Skip to content

Commit

Permalink
add argparse for StratifiedStandardize (facebook#3343)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3343

See title.

Reviewed By: saitcakmak

Differential Revision: D69466499

fbshipit-source-id: ca58f6419ac52c36fc0a473126f361b6fb4ae134
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 12, 2025
1 parent bb6404a commit 4f0d86e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,17 @@

from typing import Any

import torch

from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.utils.datasets import SupervisedDataset
from botorch.models.transforms.outcome import (
OutcomeTransform,
Standardize,
StratifiedStandardize,
)
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from pyre_extensions import assert_is_instance

outcome_transform_argparse = Dispatcher(
name="outcome_transform_argparse", encoder=_argparse_type_encoder
Expand Down Expand Up @@ -70,3 +77,36 @@ def _outcome_transform_argparse_standardize(
outcome_transform_options.setdefault("m", m)

return outcome_transform_options


@outcome_transform_argparse.register(StratifiedStandardize)
def _outcome_transform_argparse_stratified_standardize(
outcome_transform_class: type[StratifiedStandardize],
dataset: SupervisedDataset,
outcome_transform_options: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Extract the outcome transform kwargs form the given arguments.
Args:
outcome_transform_class: Outcome transform class, which is Standardize in this
case.
dataset: Dataset containing feature matrix and the response.
outcome_transform_options: Outcome transform kwargs.
See botorch.models.transforms.outcome.Standardize for all available options
Returns:
A dictionary with outcome transform kwargs.
"""

outcome_transform_options = outcome_transform_options or {}
dataset = assert_is_instance(dataset, MultiTaskDataset)
if dataset.has_heterogeneous_features:
task_feature_index = dataset.task_feature_index or -1
task_values = torch.arange(len(dataset.datasets), dtype=torch.long)
else:
task_feature_index = dataset.task_feature_index
task_values = dataset.X[..., dataset.task_feature_index].unique().long()
outcome_transform_options.setdefault("stratification_idx", task_feature_index)
outcome_transform_options.setdefault("task_values", task_values)

return outcome_transform_options
50 changes: 48 additions & 2 deletions ax/models/torch/tests/test_outcome_transform_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@
outcome_transform_argparse,
)
from ax.utils.common.testutils import TestCase
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.utils.datasets import SupervisedDataset
from botorch.models.transforms.outcome import (
OutcomeTransform,
Standardize,
StratifiedStandardize,
)
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from pyre_extensions import assert_is_instance


class DummyOutcomeTransform(OutcomeTransform):
Expand Down Expand Up @@ -60,3 +65,44 @@ def test_argparse_standardize(self) -> None:
)
self.assertEqual(outcome_transform_kwargs_a, {"m": 1})
self.assertEqual(outcome_transform_kwargs_b, {"m": 10})

def test_argparse_stratified_standardize(self) -> None:
X = self.dataset.X
X[:5, 3] = 0
X[5:, 3] = 1
mt_dataset = MultiTaskDataset.from_joint_dataset(
dataset=self.dataset,
task_feature_index=3,
target_task_value=1,
)
outcome_transform_kwargs_a = outcome_transform_argparse(
StratifiedStandardize, dataset=mt_dataset
)
options_b = {
"stratification_idx": 2,
"task_values": torch.tensor([0, 3]),
}
outcome_transform_kwargs_b = outcome_transform_argparse(
StratifiedStandardize,
dataset=mt_dataset,
outcome_transform_options=options_b,
)
expected_options_a = {
"stratification_idx": 3,
"task_values": torch.tensor([0, 1]),
}
for expected_options, actual_options in zip(
(expected_options_a, options_b),
(outcome_transform_kwargs_a, outcome_transform_kwargs_b),
):
self.assertEqual(len(actual_options), 2)
self.assertEqual(
actual_options["stratification_idx"],
expected_options["stratification_idx"],
)
self.assertTrue(
torch.equal(
actual_options["task_values"],
assert_is_instance(expected_options["task_values"], torch.Tensor),
)
)

0 comments on commit 4f0d86e

Please sign in to comment.