Skip to content

Commit

Permalink
Fix signatures of TorchModel.cross_validate
Browse files Browse the repository at this point in the history
Summary:
Some subclasses were missing SSD input, which leads to errors when called from `TorchAdapter` as
```
        f_test, cov_test = none_throws(self.model).cross_validate(
            datasets=datasets,
            X_test=torch.as_tensor(X_test, dtype=self.dtype, device=self.device),
            search_space_digest=search_space_digest,
            use_posterior_predictive=use_posterior_predictive,
        )
```
This was something raised by pyre but it was being suppressed.

Differential Revision: D69951061
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 20, 2025
1 parent f80caac commit 8c9c887
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,11 @@ def best_point(
)

@copy_doc(TorchGenerator.cross_validate)
def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed here
def cross_validate(
self,
datasets: list[SupervisedDataset],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
use_posterior_predictive: bool = False,
) -> tuple[Tensor, Tensor]:
if self._model is None:
Expand Down
3 changes: 2 additions & 1 deletion ax/models/torch/randomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ def predict(self, X: Tensor) -> tuple[Tensor, Tensor]:
return _rf_predict(self.models, X)

@copy_doc(TorchGenerator.cross_validate)
def cross_validate( # pyre-ignore [14]: not using metric_names or ssd
def cross_validate(
self,
datasets: list[SupervisedDataset],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
use_posterior_predictive: bool = False,
) -> tuple[Tensor, Tensor]:
Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets)
Expand Down

0 comments on commit 8c9c887

Please sign in to comment.