Skip to content

Commit

Permalink
use separate cache for prefetch by default
Browse files Browse the repository at this point in the history
Unless `cache=True` is set, a separate temporary cache will be
used for prefetching. It will get removed after the iteration is closed.
  • Loading branch information
skshetry committed Jan 7, 2025
1 parent 6862726 commit 1ffd71b
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 37 deletions.
22 changes: 12 additions & 10 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import multiprocessing
import os
from contextlib import closing
from posixpath import basename

import torch
Expand Down Expand Up @@ -56,7 +57,7 @@ def forward(self, x):
if __name__ == "__main__":
ds = (
DataChain.from_storage(STORAGE, type="image")
.settings(cache=True, prefetch=25)
.settings(prefetch=25)
.filter(C("file.path").glob("*.jpg"))
.map(
label=lambda path: label_to_int(basename(path)[:3], CLASSES),
Expand All @@ -65,10 +66,11 @@ def forward(self, x):
)
)

dataset = ds.to_pytorch(transform=transform)
train_loader = DataLoader(
ds.to_pytorch(transform=transform),
dataset,
batch_size=25,
num_workers=max(4, os.cpu_count() or 2),
num_workers=min(4, os.cpu_count() or 2),
persistent_workers=True,
multiprocessing_context=multiprocessing.get_context("spawn"),
)
Expand All @@ -89,11 +91,11 @@ def forward(self, x):
inputs, labels = data
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()
loader.set_postfix(loss=loss.item())
# Backward pass and optimize
loss.backward()
optimizer.step()
loader.set_postfix(loss=loss.item())
30 changes: 29 additions & 1 deletion src/datachain/cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
from collections.abc import Iterator
from contextlib import contextmanager
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional

from dvc_data.hashfile.db.local import LocalHashFileDB
from dvc_objects.fs.local import LocalFileSystem
from dvc_objects.fs.utils import remove
from fsspec.callbacks import Callback, TqdmCallback

from .progress import Tqdm
Expand All @@ -20,6 +24,23 @@ def try_scandir(path):
pass


def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "DataChainCache":
cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir)
return DataChainCache(cache_dir, tmp_dir=tmp_dir)


@contextmanager
def temporary_cache(
tmp_dir: str, prefix: Optional[str] = None, delete: bool = True
) -> Iterator["DataChainCache"]:
cache = get_temp_cache(tmp_dir, prefix=prefix)
try:
yield cache
finally:
if delete:
cache.destroy()


class DataChainCache:
def __init__(self, cache_dir: str, tmp_dir: str):
self.odb = LocalHashFileDB(
Expand All @@ -28,6 +49,9 @@ def __init__(self, cache_dir: str, tmp_dir: str):
tmp_dir=tmp_dir,
)

def __eq__(self, other) -> bool:
return self.odb == other.odb

@property
def cache_dir(self):
return self.odb.path
Expand Down Expand Up @@ -90,12 +114,16 @@ def store_data(self, file: "File", contents: bytes) -> None:
with open(dst, mode="wb") as f:
f.write(contents)

def clear(self):
def clear(self) -> None:
"""
Completely clear the cache.
"""
self.odb.clear()

def destroy(self) -> None:
# `clear` leaves the prefix directory structure intact.
remove(self.cache_dir)

def get_total_size(self) -> int:
total = 0
for subdir in try_scandir(self.odb.path):
Expand Down
65 changes: 65 additions & 0 deletions src/datachain/lib/prefetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from collections.abc import Generator, Iterable, Sequence
from contextlib import nullcontext
from functools import partial
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast

from fsspec.callbacks import DEFAULT_CALLBACK, Callback

from datachain.asyn import AsyncMapper
from datachain.cache import temporary_cache
from datachain.lib.file import File

if TYPE_CHECKING:
from contextlib import AbstractContextManager

from datachain.cache import DataChainCache as Cache
from datachain.catalog.catalog import Catalog


T = TypeVar("T", bound=Sequence[Any])


def noop(*args, **kwargs):
pass


async def _prefetch_input(row: T, catalog: "Catalog", download_cb: Callback) -> T:
try:
callback = download_cb.increment_file_count
except AttributeError:
callback = noop

for obj in row:
if isinstance(obj, File):
obj._set_stream(catalog, True, download_cb)
await obj._prefetch()
callback()
return row


def clone_catalog_with_cache(catalog: "Catalog", cache: "Cache") -> "Catalog":
clone = catalog.copy()
clone.cache = cache
return clone


def rows_prefetcher(
catalog: "Catalog",
rows: Iterable[T],
prefetch: int,
cache: Optional["Cache"] = None,
download_cb: Callback = DEFAULT_CALLBACK,
) -> Generator[T, None, None]:
cache_ctx: AbstractContextManager[Cache]
if cache:
cache_ctx = nullcontext(cache)
else:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-")

with cache_ctx as prefetch_cache:
catalog = clone_catalog_with_cache(catalog, prefetch_cache)
func = partial(_prefetch_input, download_cb=download_cb, catalog=catalog)
mapper = AsyncMapper(func, rows, workers=prefetch)
yield from cast("Generator[T, None, None]", mapper.iterate())
48 changes: 40 additions & 8 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from collections.abc import Iterator
import os
from collections.abc import Generator, Iterator
from contextlib import closing
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
Expand All @@ -9,11 +11,14 @@
from torchvision.transforms import v2

from datachain import Session
from datachain.asyn import AsyncMapper
from datachain.cache import get_temp_cache
from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.settings import Settings
from datachain.lib.text import convert_text
from datachain.progress import CombinedDownloadCallback
from datachain.query.dataset import get_download_callback

if TYPE_CHECKING:
from torchvision.transforms.v2 import Transform
Expand Down Expand Up @@ -75,6 +80,17 @@ def __init__(
if (prefetch := dc_settings.prefetch) is not None:
self.prefetch = prefetch

if self.cache:
self._cache = catalog.cache
else:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
self._cache = get_temp_cache(tmp_dir, prefix="prefetch-")

def close(self) -> None:
if not self.cache:
self._cache.destroy()

def _init_catalog(self, catalog: "Catalog"):
# For compatibility with multiprocessing,
# we can only store params in __init__(), as Catalog isn't picklable
Expand All @@ -91,7 +107,9 @@ def _get_catalog(self) -> "Catalog":
wh = wh_cls(*wh_args, **wh_kwargs)
return Catalog(ms, wh, **self._catalog_params)

def _rows_iter(self, total_rank: int, total_workers: int):
def _row_iter(
self, total_rank: int, total_workers: int
) -> Generator[tuple[Any, ...], None, None]:
catalog = self._get_catalog()
session = Session("PyTorch", catalog=catalog)
ds = DataChain.from_dataset(
Expand All @@ -106,12 +124,26 @@ def _rows_iter(self, total_rank: int, total_workers: int):

def __iter__(self) -> Iterator[Any]:
total_rank, total_workers = self.get_rank_and_workers()
rows = self._rows_iter(total_rank, total_workers)
if self.prefetch > 0:
from datachain.lib.udf import _prefetch_input

rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
yield from map(self._process_row, rows)
download_cb = CombinedDownloadCallback()
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
download_cb = get_download_callback(
f"{total_rank}/{total_workers}", position=total_rank
)

rows = self._row_iter(total_rank, total_workers)
if self.prefetch > 0:
catalog = self._get_catalog()
rows = rows_prefetcher(
catalog,
rows,
self.prefetch,
cache=self._cache,
download_cb=download_cb,
)

with download_cb, closing(rows):
yield from map(self._process_row, rows)

def _process_row(self, row_features):
row = []
Expand Down
31 changes: 17 additions & 14 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from pydantic import BaseModel

from datachain.asyn import AsyncMapper
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
from datachain.lib.file import File
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.query.batch import (
Batch,
Expand Down Expand Up @@ -279,13 +279,6 @@ def process_safe(self, obj_rows):
return result_objs


async def _prefetch_input(row):
for obj in row:
if isinstance(obj, File):
await obj._prefetch()
return row


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

Expand All @@ -307,9 +300,14 @@ def run(
for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)

with contextlib.closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
Expand Down Expand Up @@ -384,9 +382,14 @@ def run(
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)

with contextlib.closing(prepared_inputs):
for row in prepared_inputs:
Expand Down
19 changes: 18 additions & 1 deletion src/datachain/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from threading import RLock
from typing import Any, ClassVar

from fsspec import Callback
from fsspec.callbacks import TqdmCallback
from tqdm import tqdm

Expand Down Expand Up @@ -132,8 +133,24 @@ def format_dict(self):
return d


class CombinedDownloadCallback(TqdmCallback):
class CombinedDownloadCallback(Callback):
def set_size(self, size):
# This is a no-op to prevent fsspec's .get_file() from setting the combined
# download size to the size of the current file.
pass

def increment_file_count(self, n: int = 1) -> None:
pass


class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
def __init__(self, tqdm_kwargs=None, *args, **kwargs):
self.files_count = 0
tqdm_kwargs = tqdm_kwargs or {}
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
super().__init__(tqdm_kwargs, *args, **kwargs)

def increment_file_count(self, n: int = 1) -> None:
self.files_count += n
if self.tqdm is not None:
self.tqdm.postfix = f"{self.files_count} files"
6 changes: 4 additions & 2 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
from datachain.dataset import DatasetStatus, RowDict
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.func.base import Function
from datachain.progress import CombinedDownloadCallback
from datachain.lib.udf import UDFAdapter
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
from datachain.query.schema import C, UDFParamSpec, normalize_param
from datachain.query.session import Session
from datachain.sql.functions.random import rand
Expand Down Expand Up @@ -357,7 +358,8 @@ def get_download_callback() -> Callback:
"unit_scale": True,
"unit_divisor": 1024,
"leave": False,
}
**kwargs,
},
)


Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_map_file(cloud_test_catalog, use_cache, prefetch):
ctc = cloud_test_catalog

def new_signal(file: File) -> str:
assert bool(file.get_local_path()) is (use_cache and prefetch > 0)
assert bool(file.get_local_path()) is (prefetch > 0)
with file.open() as f:
return file.name + " -> " + f.read().decode("utf-8")

Expand Down

0 comments on commit 1ffd71b

Please sign in to comment.