Skip to content

Commit

Permalink
to_pytorch: enable prefetching (#664)
Browse files Browse the repository at this point in the history
# 653 did not really enable prefetching. Prefetch was only implemented for map(), so the example gave me a false impression that the prefetching was working, but it was not.

Now, to_pytorch uses AsyncMapper to prefetch the data. The number of workers is set to 2 by default, but it can be changed by setting the `prefetch` in the settings.

For me, this dropped the time to load the data by 90%, from ~300s to now ~35s.
  • Loading branch information
skshetry authored Dec 6, 2024
1 parent 911c22f commit 7c9d193
Showing 1 changed file with 52 additions and 38 deletions.
90 changes: 52 additions & 38 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tqdm import tqdm

from datachain import Session
from datachain.asyn import AsyncMapper
from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
from datachain.lib.settings import Settings
Expand All @@ -31,6 +32,8 @@ def label_to_int(value: str, classes: list) -> int:


class PytorchDataset(IterableDataset):
prefetch: int = 2

def __init__(
self,
name: str,
Expand Down Expand Up @@ -67,7 +70,11 @@ def __init__(
if catalog is None:
catalog = get_catalog()
self._init_catalog(catalog)
self._dc_settings = dc_settings if dc_settings else Settings()

dc_settings = dc_settings or Settings()
self.cache = dc_settings.cache
if (prefetch := dc_settings.prefetch) is not None:
self.prefetch = prefetch

def _init_catalog(self, catalog: "Catalog"):
# For compatibility with multiprocessing,
Expand All @@ -85,51 +92,58 @@ def _get_catalog(self) -> "Catalog":
wh = wh_cls(*wh_args, **wh_kwargs)
return Catalog(ms, wh, **self._catalog_params)

def __iter__(self) -> Iterator[Any]:
if self.catalog is None:
self.catalog = self._get_catalog()
session = Session.get(catalog=self.catalog)
total_rank, total_workers = self.get_rank_and_workers()
def _rows_iter(self, total_rank: int, total_workers: int):
catalog = self._get_catalog()
session = Session("PyTorch", catalog=catalog)
ds = DataChain.from_dataset(
name=self.name, version=self.version, session=session
).settings(cache=self._dc_settings.cache, prefetch=self._dc_settings.prefetch)
).settings(cache=self.cache, prefetch=self.prefetch)
ds = ds.remove_file_signals()

if self.num_samples > 0:
ds = ds.sample(self.num_samples)
ds = ds.chunk(total_rank, total_workers)
yield from ds.collect()

def __iter__(self) -> Iterator[Any]:
total_rank, total_workers = self.get_rank_and_workers()
rows = self._rows_iter(total_rank, total_workers)
if self.prefetch > 0:
from datachain.lib.udf import _prefetch_input

rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()

desc = f"Parsed PyTorch dataset for rank={total_rank} worker"
with tqdm(desc=desc, unit=" rows") as pbar:
for row_features in ds.collect():
row = []
for fr in row_features:
if hasattr(fr, "read"):
row.append(fr.read()) # type: ignore[unreachable]
else:
row.append(fr)
# Apply transforms
if self.transform:
try:
if isinstance(self.transform, v2.Transform):
row = self.transform(row)
for i, val in enumerate(row):
if isinstance(val, Image.Image):
row[i] = self.transform(val)
except ValueError:
logger.warning(
"Skipping transform due to unsupported data types."
)
self.transform = None
if self.tokenizer:
for i, val in enumerate(row):
if isinstance(val, str) or (
isinstance(val, list) and isinstance(val[0], str)
):
row[i] = convert_text(
val, self.tokenizer, self.tokenizer_kwargs
).squeeze(0) # type: ignore[union-attr]
yield row
pbar.update(1)
with tqdm(rows, desc=desc, unit=" rows", position=total_rank) as rows_it:
yield from map(self._process_row, rows_it)

def _process_row(self, row_features):
row = []
for fr in row_features:
if hasattr(fr, "read"):
row.append(fr.read()) # type: ignore[unreachable]
else:
row.append(fr)
# Apply transforms
if self.transform:
try:
if isinstance(self.transform, v2.Transform):
row = self.transform(row)
for i, val in enumerate(row):
if isinstance(val, Image.Image):
row[i] = self.transform(val)
except ValueError:
logger.warning("Skipping transform due to unsupported data types.")
self.transform = None
if self.tokenizer:
for i, val in enumerate(row):
if isinstance(val, str) or (
isinstance(val, list) and isinstance(val[0], str)
):
row[i] = convert_text(
val, self.tokenizer, self.tokenizer_kwargs
).squeeze(0) # type: ignore[union-attr]
return row

@staticmethod
def get_rank_and_workers() -> tuple[int, int]:
Expand Down

0 comments on commit 7c9d193

Please sign in to comment.