Skip to content

Commit

Permalink
Fix signatures of TorchModel.cross_validate (#3398)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3398

Some subclasses were missing SSD input, which leads to errors when called from `TorchAdapter` as
```
        f_test, cov_test = none_throws(self.model).cross_validate(
            datasets=datasets,
            X_test=torch.as_tensor(X_test, dtype=self.dtype, device=self.device),
            search_space_digest=search_space_digest,
            use_posterior_predictive=use_posterior_predictive,
        )
```
This was something raised by pyre but it was being suppressed.

Reviewed By: esantorella

Differential Revision: D69951061

fbshipit-source-id: 2199761ea28ae3fbab5e16b68486d8518ae209c7
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 21, 2025
1 parent f80caac commit a96cc6f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
8 changes: 7 additions & 1 deletion ax/models/tests/test_botorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def test_LegacyBoTorchGenerator(
]
mean, variance = model.cross_validate(
datasets=combined_datasets,
search_space_digest=search_space_digest,
X_test=torch.tensor([[1.2, 3.2, 4.2], [2.4, 5.2, 3.2]], **tkwargs),
)
self.assertEqual(mean.shape, torch.Size([2, 2]))
Expand All @@ -566,6 +567,7 @@ def test_LegacyBoTorchGenerator(
model.refit_on_cv = True
mean, variance = model.cross_validate(
datasets=combined_datasets,
search_space_digest=search_space_digest,
X_test=torch.tensor([[1.2, 3.2, 4.2], [2.4, 5.2, 3.2]], **tkwargs),
)
self.assertEqual(mean.shape, torch.Size([2, 2]))
Expand All @@ -580,7 +582,11 @@ def test_LegacyBoTorchGenerator(
with self.assertRaisesRegex(
RuntimeError, r"Cannot cross-validate model that has not been fitted"
):
unfit_model.cross_validate(datasets=combined_datasets, X_test=Xs1[0])
unfit_model.cross_validate(
datasets=combined_datasets,
search_space_digest=search_space_digest,
X_test=Xs1[0],
)
with self.assertRaisesRegex(
RuntimeError,
r"Cannot calculate feature_importances without a fitted model",
Expand Down
15 changes: 10 additions & 5 deletions ax/models/tests/test_randomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ def test_RFModel(self) -> None:
)
for i in range(2)
]
search_space_digest = SearchSpaceDigest(
feature_names=["x1", "x2"],
bounds=[(0, 1)] * 2,
)

m = RandomForest(num_trees=5)
m.fit(
datasets=datasets,
search_space_digest=SearchSpaceDigest(
feature_names=["x1", "x2"],
bounds=[(0, 1)] * 2,
),
search_space_digest=search_space_digest,
)
self.assertEqual(len(m.models), 2)
# pyre-fixme[16]: `RandomForestRegressor` has no attribute `estimators_`.
Expand All @@ -42,6 +43,10 @@ def test_RFModel(self) -> None:
self.assertEqual(f.shape, torch.Size((5, 2)))
self.assertEqual(cov.shape, torch.Size((5, 2, 2)))

f, cov = m.cross_validate(datasets=datasets, X_test=torch.rand(3, 2))
f, cov = m.cross_validate(
datasets=datasets,
search_space_digest=search_space_digest,
X_test=torch.rand(3, 2),
)
self.assertEqual(f.shape, torch.Size((3, 2)))
self.assertEqual(cov.shape, torch.Size((3, 2, 2)))
3 changes: 2 additions & 1 deletion ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,11 @@ def best_point(
)

@copy_doc(TorchGenerator.cross_validate)
def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed here
def cross_validate(
self,
datasets: list[SupervisedDataset],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
use_posterior_predictive: bool = False,
) -> tuple[Tensor, Tensor]:
if self._model is None:
Expand Down
3 changes: 2 additions & 1 deletion ax/models/torch/randomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ def predict(self, X: Tensor) -> tuple[Tensor, Tensor]:
return _rf_predict(self.models, X)

@copy_doc(TorchGenerator.cross_validate)
def cross_validate( # pyre-ignore [14]: not using metric_names or ssd
def cross_validate(
self,
datasets: list[SupervisedDataset],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
use_posterior_predictive: bool = False,
) -> tuple[Tensor, Tensor]:
Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets)
Expand Down

0 comments on commit a96cc6f

Please sign in to comment.