From a96cc6f0ea6ab502158a02419dca37e13600a0f4 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 20 Feb 2025 17:58:07 -0800 Subject: [PATCH] Fix signatures of TorchModel.cross_validate (#3398) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3398 Some subclasses were missing SSD input, which leads to errors when called from `TorchAdapter` as ``` f_test, cov_test = none_throws(self.model).cross_validate( datasets=datasets, X_test=torch.as_tensor(X_test, dtype=self.dtype, device=self.device), search_space_digest=search_space_digest, use_posterior_predictive=use_posterior_predictive, ) ``` This was something raised by pyre but it was being suppressed. Reviewed By: esantorella Differential Revision: D69951061 fbshipit-source-id: 2199761ea28ae3fbab5e16b68486d8518ae209c7 --- ax/models/tests/test_botorch_model.py | 8 +++++++- ax/models/tests/test_randomforest.py | 15 ++++++++++----- ax/models/torch/botorch.py | 3 ++- ax/models/torch/randomforest.py | 3 ++- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/ax/models/tests/test_botorch_model.py b/ax/models/tests/test_botorch_model.py index 5c355ad0222..81820e7298e 100644 --- a/ax/models/tests/test_botorch_model.py +++ b/ax/models/tests/test_botorch_model.py @@ -557,6 +557,7 @@ def test_LegacyBoTorchGenerator( ] mean, variance = model.cross_validate( datasets=combined_datasets, + search_space_digest=search_space_digest, X_test=torch.tensor([[1.2, 3.2, 4.2], [2.4, 5.2, 3.2]], **tkwargs), ) self.assertEqual(mean.shape, torch.Size([2, 2])) @@ -566,6 +567,7 @@ def test_LegacyBoTorchGenerator( model.refit_on_cv = True mean, variance = model.cross_validate( datasets=combined_datasets, + search_space_digest=search_space_digest, X_test=torch.tensor([[1.2, 3.2, 4.2], [2.4, 5.2, 3.2]], **tkwargs), ) self.assertEqual(mean.shape, torch.Size([2, 2])) @@ -580,7 +582,11 @@ def test_LegacyBoTorchGenerator( with self.assertRaisesRegex( RuntimeError, r"Cannot cross-validate model that has not been fitted" ): - unfit_model.cross_validate(datasets=combined_datasets, X_test=Xs1[0]) + unfit_model.cross_validate( + datasets=combined_datasets, + search_space_digest=search_space_digest, + X_test=Xs1[0], + ) with self.assertRaisesRegex( RuntimeError, r"Cannot calculate feature_importances without a fitted model", diff --git a/ax/models/tests/test_randomforest.py b/ax/models/tests/test_randomforest.py index aac1f721266..2784a67da16 100644 --- a/ax/models/tests/test_randomforest.py +++ b/ax/models/tests/test_randomforest.py @@ -25,14 +25,15 @@ def test_RFModel(self) -> None: ) for i in range(2) ] + search_space_digest = SearchSpaceDigest( + feature_names=["x1", "x2"], + bounds=[(0, 1)] * 2, + ) m = RandomForest(num_trees=5) m.fit( datasets=datasets, - search_space_digest=SearchSpaceDigest( - feature_names=["x1", "x2"], - bounds=[(0, 1)] * 2, - ), + search_space_digest=search_space_digest, ) self.assertEqual(len(m.models), 2) # pyre-fixme[16]: `RandomForestRegressor` has no attribute `estimators_`. @@ -42,6 +43,10 @@ def test_RFModel(self) -> None: self.assertEqual(f.shape, torch.Size((5, 2))) self.assertEqual(cov.shape, torch.Size((5, 2, 2))) - f, cov = m.cross_validate(datasets=datasets, X_test=torch.rand(3, 2)) + f, cov = m.cross_validate( + datasets=datasets, + search_space_digest=search_space_digest, + X_test=torch.rand(3, 2), + ) self.assertEqual(f.shape, torch.Size((3, 2))) self.assertEqual(cov.shape, torch.Size((3, 2, 2))) diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 3a0536b266f..cb2a584330a 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -468,10 +468,11 @@ def best_point( ) @copy_doc(TorchGenerator.cross_validate) - def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed here + def cross_validate( self, datasets: list[SupervisedDataset], X_test: Tensor, + search_space_digest: SearchSpaceDigest, use_posterior_predictive: bool = False, ) -> tuple[Tensor, Tensor]: if self._model is None: diff --git a/ax/models/torch/randomforest.py b/ax/models/torch/randomforest.py index 233fbd7a815..908fe3fbdc7 100644 --- a/ax/models/torch/randomforest.py +++ b/ax/models/torch/randomforest.py @@ -66,10 +66,11 @@ def predict(self, X: Tensor) -> tuple[Tensor, Tensor]: return _rf_predict(self.models, X) @copy_doc(TorchGenerator.cross_validate) - def cross_validate( # pyre-ignore [14]: not using metric_names or ssd + def cross_validate( self, datasets: list[SupervisedDataset], X_test: Tensor, + search_space_digest: SearchSpaceDigest, use_posterior_predictive: bool = False, ) -> tuple[Tensor, Tensor]: Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets)