Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/9947 Add str method to datamodule #20301

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/lightning/pytorch/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,34 @@ def load_from_checkpoint(
**kwargs,
)
return cast(Self, loaded)

def __str__(self) -> str:
"""Return a string representation of the datasets that are setup.

Returns:
A string representation of the datasets that are setup.

"""
datasets_info = []

def len_implemented(obj):
try:
len(obj)
return True
except NotImplementedError:
return False

for attr_name in dir(self):
attr = getattr(self, attr_name)

# Get Dataset information
if isinstance(attr, Dataset):
if hasattr(attr, "__len__") and len_implemented(attr):
datasets_info.append(f"name={attr_name}, size={len(attr)}")
else:
datasets_info.append(f"name={attr_name}, size=Unavailable")

if not datasets_info:
return "No datasets are set up."

return "\n".join(datasets_info) + "\n"
32 changes: 32 additions & 0 deletions src/lightning/pytorch/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,38 @@ def predict_dataloader(self) -> DataLoader:
return DataLoader(self.random_predict)


class BoringDataModuleNoLen(LightningDataModule):
"""
.. warning:: This is meant for testing/debugging and is experimental.
"""

def __init__(self) -> None:
super().__init__()
self.random_full = RandomIterableDataset(32, 64 * 4)


class BoringDataModuleLenNotImplemented(LightningDataModule):
"""
.. warning:: This is meant for testing/debugging and is experimental.
"""

def __init__(self) -> None:
super().__init__()

class DS(Dataset):
def __init__(self, size: int, length: int):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index: int) -> Tensor:
return self.data[index]

def __len__(self) -> int:
raise NotImplementedError

self.random_full = DS(32, 64 * 4)


class ManualOptimBoringModel(BoringModel):
"""
.. warning:: This is meant for testing/debugging and is experimental.
Expand Down
63 changes: 62 additions & 1 deletion tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import torch
from lightning.pytorch import LightningDataModule, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
from lightning.pytorch.demos.boring_classes import (
BoringDataModule,
BoringDataModuleLenNotImplemented,
BoringDataModuleNoLen,
BoringModel,
)
from lightning.pytorch.profilers.simple import SimpleProfiler
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities import AttributeDict
Expand Down Expand Up @@ -510,3 +515,59 @@ def prepare_data(self):
durations = profiler.recorded_durations[key]
assert len(durations) == 1
assert durations[0] > 0


def test_datamodule_string_no_datasets():
dm = BoringDataModule()
del dm.random_full
expected_output = "No datasets are set up."
assert str(dm) == expected_output


def test_datamodule_string_no_length():
dm = BoringDataModuleNoLen()
expected_output = "name=random_full, size=Unavailable\n"
assert str(dm) == expected_output


def test_datamodule_string_length_not_implemented():
dm = BoringDataModuleLenNotImplemented()
expected_output = "name=random_full, size=Unavailable\n"
assert str(dm) == expected_output


def test_datamodule_string_fit_setup():
dm = BoringDataModule()
dm.setup(stage="fit")

expected_outputs = ["name=random_full, size=256\n", "name=random_train, size=64\n", "name=random_val, size=64\n"]
output = str(dm)
for expected_output in expected_outputs:
assert expected_output in output


def test_datamodule_string_validation_setup():
dm = BoringDataModule()
dm.setup(stage="validate")
expected_outputs = ["name=random_full, size=256\n", "name=random_val, size=64\n"]
output = str(dm)
for expected_output in expected_outputs:
assert expected_output in output


def test_datamodule_string_test_setup():
dm = BoringDataModule()
dm.setup(stage="test")
expected_outputs = ["name=random_full, size=256\n", "name=random_test, size=64\n"]
output = str(dm)
for expected_output in expected_outputs:
assert expected_output in output


def test_datamodule_string_predict_setup():
dm = BoringDataModule()
dm.setup(stage="predict")
expected_outputs = ["name=random_full, size=256\n", "name=random_predict, size=64\n"]
output = str(dm)
for expected_output in expected_outputs:
assert expected_output in output