Skip to content

Commit

Permalink
Merge pull request #38 from nanxstats/torch-disk-dataset
Browse files Browse the repository at this point in the history
Implement `TorchDiskDataset` class
  • Loading branch information
nanxstats authored Jan 5, 2025
2 parents 4ad07b4 + f31a472 commit 0d216b3
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 9 deletions.
6 changes: 3 additions & 3 deletions docs/articles/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ A more general solution in PyTorch is to use map-style and
iterable-style datasets to stream data from disk on-demand, without
loading the entire tensor into system memory.

Starting from tinytopics 0.6.0, you can use the `NumpyDiskDataset` class
to load `.npy` datasets from disk as training data, supported by
`fit_model()`. Here is an example:
You can use the `NumpyDiskDataset` or `TorchDiskDataset` classes to load
`.npy` or `.pt` datasets from disk as training data, supported by both
`fit_model()` and `fit_model_distributed()`. Here is an example:

``` python
import numpy as np
Expand Down
6 changes: 3 additions & 3 deletions docs/articles/memory.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ A more general solution in PyTorch is to use map-style and iterable-style
datasets to stream data from disk on-demand, without loading the entire
tensor into system memory.

Starting from tinytopics 0.6.0, you can use the `NumpyDiskDataset` class to
load `.npy` datasets from disk as training data, supported by `fit_model()`.
Here is an example:
You can use the `NumpyDiskDataset` or `TorchDiskDataset` classes to load
`.npy` or `.pt` datasets from disk as training data, supported by both
`fit_model()` and `fit_model_distributed()`. Here is an example:

```{python}
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions docs/reference/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
options:
members:
- NumpyDiskDataset
- TorchDiskDataset
- IndexTrackingDataset
show_root_heading: true
show_source: false
2 changes: 1 addition & 1 deletion src/tinytopics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .fit_distributed import fit_model_distributed
from .models import NeuralPoissonNMF
from .loss import poisson_nmf_loss
from .data import NumpyDiskDataset
from .data import NumpyDiskDataset, TorchDiskDataset
from .utils import (
set_random_seed,
generate_synthetic_data,
Expand Down
90 changes: 89 additions & 1 deletion src/tinytopics/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:

class NumpyDiskDataset(Dataset):
"""
A PyTorch Dataset class for loading document-term matrices from disk.
A PyTorch Dataset class for loading document-term matrices from `.npy` files.
The dataset can be initialized with either a path to a `.npy` file or
a NumPy array. When a file path is provided, the data is accessed
Expand Down Expand Up @@ -77,3 +77,91 @@ def __getitem__(self, idx: int) -> torch.Tensor:
def num_terms(self) -> int:
"""Return vocabulary size (number of columns)."""
return self.shape[1]


class TorchDiskDataset(Dataset):
"""
A PyTorch Dataset class for loading document-term matrices from `.pt` files.
The dataset can be initialized with either a path to a `.pt` file or
a PyTorch tensor. When a file path is provided, the data is accessed
lazily using memory mapping, which is useful for handling large datasets
that do not fit entirely in (CPU) memory.
The input `.pt` file should contain a single tensor with document-term
matrix data.
"""

def _validate_tensor_data(self, tensor_data: any) -> torch.Tensor:
"""Validate that the loaded data is a single tensor and return it.
Args:
tensor_data: Data loaded from `.pt` file.
Returns:
Validated tensor data.
Raises:
ValueError: If data is not a tensor.
"""
if not isinstance(tensor_data, torch.Tensor):
raise ValueError(
f"File {self.data_path} must contain a single tensor, "
f"got {type(tensor_data)}"
)
return tensor_data

def __init__(
self,
data: str | Path,
indices: Sequence[int] | None = None,
) -> None:
"""
Args:
data: Path to `.pt` file (str or Path).
indices: Optional sequence of indices to use as valid indices.
"""
self.data_path = Path(data)
if not self.data_path.exists():
raise FileNotFoundError(f"Data file not found: {self.data_path}")

# Try loading with mmap first to get shape
try:
tensor_data = self._validate_tensor_data(
torch.load(
self.data_path, map_location="cpu", weights_only=True, mmap=True
)
)
self.shape = tuple(tensor_data.shape)
self.mmap_supported = True
self.mmap_data: torch.Tensor | None = None

except RuntimeError:
# Fallback to regular loading if mmap not supported
tensor_data = self._validate_tensor_data(
torch.load(self.data_path, map_location="cpu", weights_only=True)
)
self.shape = tuple(tensor_data.shape)
self.data = tensor_data
self.mmap_supported = False

self.indices = indices or range(self.shape[0])

def __len__(self) -> int:
return len(self.indices)

def __getitem__(self, idx: int) -> torch.Tensor:
real_idx = self.indices[idx]

if self.mmap_supported:
if self.mmap_data is None:
self.mmap_data = torch.load(
self.data_path, map_location="cpu", weights_only=True, mmap=True
)
return self.mmap_data[real_idx]
else:
return self.data[real_idx]

@property
def num_terms(self) -> int:
"""Return vocabulary size (number of columns)."""
return self.shape[1]
83 changes: 82 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import numpy as np

from tinytopics.data import NumpyDiskDataset
from tinytopics.data import NumpyDiskDataset, TorchDiskDataset


def test_numpy_disk_dataset_from_array():
Expand Down Expand Up @@ -88,3 +88,84 @@ def test_numpy_disk_dataset_memory_efficiency(tmp_path):

# Memory mapping should be initialized only after first access
assert dataset.mmap_data is not None


def test_torch_disk_dataset_from_file(tmp_path):
"""Test TorchDiskDataset with .pt file input."""
data = torch.rand(10, 5, dtype=torch.float32)
file_path = tmp_path / "test_data.pt"
torch.save(data, file_path)

dataset = TorchDiskDataset(file_path)

# Test basic properties
assert len(dataset) == 10
assert dataset.num_terms == 5
assert dataset.shape == (10, 5)

# Test data access
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, torch.Tensor)
assert item.shape == (5,)
assert torch.allclose(item, data[i])


def test_torch_disk_dataset_with_indices(tmp_path):
"""Test TorchDiskDataset with custom indices."""
data = torch.rand(10, 5, dtype=torch.float32)
file_path = tmp_path / "test_data.pt"
torch.save(data, file_path)
indices = [3, 1, 4]

dataset = TorchDiskDataset(file_path, indices=indices)

# Test basic properties
assert len(dataset) == len(indices)
assert dataset.num_terms == 5
assert dataset.shape == (10, 5)

# Test data access
for i, orig_idx in enumerate(indices):
item = dataset[i]
assert isinstance(item, torch.Tensor)
assert item.shape == (5,)
assert torch.allclose(item, data[orig_idx])


def test_torch_disk_dataset_file_not_found():
"""Test TorchDiskDataset with non-existent file."""
with pytest.raises(FileNotFoundError):
TorchDiskDataset("non_existent_file.pt")


def test_torch_disk_dataset_invalid_content(tmp_path):
"""Test TorchDiskDataset with invalid file content."""
file_path = tmp_path / "invalid_data.pt"
invalid_data = {"not_a_tensor": 42}
torch.save(invalid_data, file_path)

with pytest.raises(ValueError, match="must contain a single tensor"):
TorchDiskDataset(file_path)


def test_torch_disk_dataset_memory_efficiency(tmp_path):
"""Test that TorchDiskDataset uses memory mapping efficiently."""
shape = (1000, 500) # 500K elements
data = torch.rand(*shape, dtype=torch.float32)
file_path = tmp_path / "large_data.pt"
torch.save(data, file_path)

dataset = TorchDiskDataset(file_path)

# Access data in random order
indices = torch.randperm(shape[0])[:100] # Sample 100 random rows
for idx in indices:
item = dataset[idx]
assert torch.allclose(item, data[idx])

# Memory mapping should be initialized only after first access
if dataset.mmap_supported:
assert dataset.mmap_data is not None
else:
assert hasattr(dataset, "data")

0 comments on commit 0d216b3

Please sign in to comment.