diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 6cb8f79f09284..18bee4474e83d 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -243,3 +243,34 @@ def load_from_checkpoint( **kwargs, ) return cast(Self, loaded) + + def __str__(self) -> str: + """Return a string representation of the datasets that are setup. + + Returns: + A string representation of the datasets that are setup. + + """ + datasets_info = [] + + def len_implemented(obj): + try: + len(obj) + return True + except NotImplementedError: + return False + + for attr_name in dir(self): + attr = getattr(self, attr_name) + + # Get Dataset information + if isinstance(attr, Dataset): + if hasattr(attr, "__len__") and len_implemented(attr): + datasets_info.append(f"name={attr_name}, size={len(attr)}") + else: + datasets_info.append(f"name={attr_name}, size=Unavailable") + + if not datasets_info: + return "No datasets are set up." + + return "\n".join(datasets_info) + "\n" diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index fd2660228146e..637a1e131809a 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -187,6 +187,38 @@ def predict_dataloader(self) -> DataLoader: return DataLoader(self.random_predict) +class BoringDataModuleNoLen(LightningDataModule): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self) -> None: + super().__init__() + self.random_full = RandomIterableDataset(32, 64 * 4) + + +class BoringDataModuleLenNotImplemented(LightningDataModule): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self) -> None: + super().__init__() + + class DS(Dataset): + def __init__(self, size: int, length: int): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index: int) -> Tensor: + return self.data[index] + + def __len__(self) -> int: + raise NotImplementedError + + self.random_full = DS(32, 64 * 4) + + class ManualOptimBoringModel(BoringModel): """ .. warning:: This is meant for testing/debugging and is experimental. diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 65fccb691a33d..0739359d6b5ae 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -22,7 +22,12 @@ import torch from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel +from lightning.pytorch.demos.boring_classes import ( + BoringDataModule, + BoringDataModuleLenNotImplemented, + BoringDataModuleNoLen, + BoringModel, +) from lightning.pytorch.profilers.simple import SimpleProfiler from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities import AttributeDict @@ -510,3 +515,59 @@ def prepare_data(self): durations = profiler.recorded_durations[key] assert len(durations) == 1 assert durations[0] > 0 + + +def test_datamodule_string_no_datasets(): + dm = BoringDataModule() + del dm.random_full + expected_output = "No datasets are set up." + assert str(dm) == expected_output + + +def test_datamodule_string_no_length(): + dm = BoringDataModuleNoLen() + expected_output = "name=random_full, size=Unavailable\n" + assert str(dm) == expected_output + + +def test_datamodule_string_length_not_implemented(): + dm = BoringDataModuleLenNotImplemented() + expected_output = "name=random_full, size=Unavailable\n" + assert str(dm) == expected_output + + +def test_datamodule_string_fit_setup(): + dm = BoringDataModule() + dm.setup(stage="fit") + + expected_outputs = ["name=random_full, size=256\n", "name=random_train, size=64\n", "name=random_val, size=64\n"] + output = str(dm) + for expected_output in expected_outputs: + assert expected_output in output + + +def test_datamodule_string_validation_setup(): + dm = BoringDataModule() + dm.setup(stage="validate") + expected_outputs = ["name=random_full, size=256\n", "name=random_val, size=64\n"] + output = str(dm) + for expected_output in expected_outputs: + assert expected_output in output + + +def test_datamodule_string_test_setup(): + dm = BoringDataModule() + dm.setup(stage="test") + expected_outputs = ["name=random_full, size=256\n", "name=random_test, size=64\n"] + output = str(dm) + for expected_output in expected_outputs: + assert expected_output in output + + +def test_datamodule_string_predict_setup(): + dm = BoringDataModule() + dm.setup(stage="predict") + expected_outputs = ["name=random_full, size=256\n", "name=random_predict, size=64\n"] + output = str(dm) + for expected_output in expected_outputs: + assert expected_output in output