diff --git a/cellarium/ml/models/logistic_regression.py b/cellarium/ml/models/logistic_regression.py index 41ea4f00..e2f022e5 100644 --- a/cellarium/ml/models/logistic_regression.py +++ b/cellarium/ml/models/logistic_regression.py @@ -72,7 +72,8 @@ def __init__( self.log_metrics = log_metrics def reset_parameters(self) -> None: - rng = torch.Generator() + rng_device = self.W_gc.device.type if self.W_gc.device.type != "meta" else "cpu" + rng = torch.Generator(device=rng_device) rng.manual_seed(self.seed) self.W_prior_scale.fill_(self._W_prior_scale) self.W_gc.data.normal_(0, self.W_init_scale, generator=rng) diff --git a/cellarium/ml/models/probabilistic_pca.py b/cellarium/ml/models/probabilistic_pca.py index 9ff50028..032e8b5e 100644 --- a/cellarium/ml/models/probabilistic_pca.py +++ b/cellarium/ml/models/probabilistic_pca.py @@ -95,7 +95,8 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - rng = torch.Generator() + rng_device = self.W_kg.device.type if self.W_kg.device.type != "meta" else "cpu" + rng = torch.Generator(device=rng_device) rng.manual_seed(self.seed) if isinstance(self.mean_g, torch.nn.Parameter): self.mean_g.data.zero_()