diff --git a/examples/mnist_lightning.py b/examples/mnist_lightning.py index 42d93a3d..2155feb0 100644 --- a/examples/mnist_lightning.py +++ b/examples/mnist_lightning.py @@ -99,11 +99,8 @@ def configure_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=0) if self.enable_dp: - data_loader = ( - # soon there will be a fancy way to access train dataloader, - # see https://github.com/PyTorchLightning/pytorch-lightning/issues/10430 - self.trainer._data_connector._train_dataloader_source.dataloader() - ) + self.trainer.fit_loop.setup_data() + dataloader = self.trainer.train_dataloader # transform (model, optimizer, dataloader) to DP-versions if hasattr(self, "dp"):