Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prefetch: use a separate temporary cache for prefetching #730

Merged
merged 13 commits into from
Jan 7, 2025
Next Next commit
use separate cache for prefetch by default
Unless `cache=True` is set, a separate temporary cache will be
used for prefetching. It will get removed after the iteration is closed.
skshetry committed Jan 7, 2025

Verified

This commit was signed with the committer’s verified signature.
mike182uk Michael Barrett
commit 1ffd71bd8a5ab70618d5aa1b98cc9f3bd80158d6
22 changes: 12 additions & 10 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@

import multiprocessing
import os
from contextlib import closing
from posixpath import basename

import torch
@@ -56,7 +57,7 @@ def forward(self, x):
if __name__ == "__main__":
ds = (
DataChain.from_storage(STORAGE, type="image")
.settings(cache=True, prefetch=25)
.settings(prefetch=25)
.filter(C("file.path").glob("*.jpg"))
.map(
label=lambda path: label_to_int(basename(path)[:3], CLASSES),
@@ -65,10 +66,11 @@ def forward(self, x):
)
)

dataset = ds.to_pytorch(transform=transform)
train_loader = DataLoader(
ds.to_pytorch(transform=transform),
dataset,
batch_size=25,
num_workers=max(4, os.cpu_count() or 2),
num_workers=min(4, os.cpu_count() or 2),
persistent_workers=True,
multiprocessing_context=multiprocessing.get_context("spawn"),
)
@@ -89,11 +91,11 @@ def forward(self, x):
inputs, labels = data
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()
loader.set_postfix(loss=loss.item())
# Backward pass and optimize
loss.backward()
optimizer.step()
loader.set_postfix(loss=loss.item())
30 changes: 29 additions & 1 deletion src/datachain/cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
from collections.abc import Iterator
from contextlib import contextmanager
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional

from dvc_data.hashfile.db.local import LocalHashFileDB
from dvc_objects.fs.local import LocalFileSystem
from dvc_objects.fs.utils import remove
from fsspec.callbacks import Callback, TqdmCallback

from .progress import Tqdm
@@ -20,6 +24,23 @@ def try_scandir(path):
pass


def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "DataChainCache":
cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir)
return DataChainCache(cache_dir, tmp_dir=tmp_dir)


@contextmanager
def temporary_cache(
tmp_dir: str, prefix: Optional[str] = None, delete: bool = True
) -> Iterator["DataChainCache"]:
cache = get_temp_cache(tmp_dir, prefix=prefix)
try:
yield cache
finally:
if delete:
cache.destroy()


class DataChainCache:
def __init__(self, cache_dir: str, tmp_dir: str):
self.odb = LocalHashFileDB(
@@ -28,6 +49,9 @@ def __init__(self, cache_dir: str, tmp_dir: str):
tmp_dir=tmp_dir,
)

def __eq__(self, other) -> bool:
return self.odb == other.odb

@property
def cache_dir(self):
return self.odb.path
@@ -90,12 +114,16 @@ def store_data(self, file: "File", contents: bytes) -> None:
with open(dst, mode="wb") as f:
f.write(contents)

def clear(self):
def clear(self) -> None:
"""
Completely clear the cache.
"""
self.odb.clear()

def destroy(self) -> None:
# `clear` leaves the prefix directory structure intact.
remove(self.cache_dir)

def get_total_size(self) -> int:
total = 0
for subdir in try_scandir(self.odb.path):
65 changes: 65 additions & 0 deletions src/datachain/lib/prefetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from collections.abc import Generator, Iterable, Sequence
from contextlib import nullcontext
from functools import partial
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast

from fsspec.callbacks import DEFAULT_CALLBACK, Callback

from datachain.asyn import AsyncMapper
from datachain.cache import temporary_cache
from datachain.lib.file import File

if TYPE_CHECKING:
from contextlib import AbstractContextManager

from datachain.cache import DataChainCache as Cache
from datachain.catalog.catalog import Catalog


T = TypeVar("T", bound=Sequence[Any])


def noop(*args, **kwargs):
pass


async def _prefetch_input(row: T, catalog: "Catalog", download_cb: Callback) -> T:
try:
callback = download_cb.increment_file_count
except AttributeError:
callback = noop

for obj in row:
if isinstance(obj, File):
obj._set_stream(catalog, True, download_cb)
await obj._prefetch()
callback()
return row


def clone_catalog_with_cache(catalog: "Catalog", cache: "Cache") -> "Catalog":
clone = catalog.copy()
clone.cache = cache
return clone


def rows_prefetcher(
catalog: "Catalog",
rows: Iterable[T],
prefetch: int,
cache: Optional["Cache"] = None,
download_cb: Callback = DEFAULT_CALLBACK,
) -> Generator[T, None, None]:
cache_ctx: AbstractContextManager[Cache]
if cache:
cache_ctx = nullcontext(cache)
else:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
cache_ctx = temporary_cache(tmp_dir, prefix="prefetch-")

with cache_ctx as prefetch_cache:
catalog = clone_catalog_with_cache(catalog, prefetch_cache)
func = partial(_prefetch_input, download_cb=download_cb, catalog=catalog)
mapper = AsyncMapper(func, rows, workers=prefetch)
yield from cast("Generator[T, None, None]", mapper.iterate())
48 changes: 40 additions & 8 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from collections.abc import Iterator
import os
from collections.abc import Generator, Iterator
from contextlib import closing
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
@@ -9,11 +11,14 @@
from torchvision.transforms import v2

from datachain import Session
from datachain.asyn import AsyncMapper
from datachain.cache import get_temp_cache
from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.settings import Settings
from datachain.lib.text import convert_text
from datachain.progress import CombinedDownloadCallback
from datachain.query.dataset import get_download_callback

if TYPE_CHECKING:
from torchvision.transforms.v2 import Transform
@@ -75,6 +80,17 @@ def __init__(
if (prefetch := dc_settings.prefetch) is not None:
self.prefetch = prefetch

if self.cache:
self._cache = catalog.cache
else:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
self._cache = get_temp_cache(tmp_dir, prefix="prefetch-")

def close(self) -> None:
if not self.cache:
self._cache.destroy()

def _init_catalog(self, catalog: "Catalog"):
# For compatibility with multiprocessing,
# we can only store params in __init__(), as Catalog isn't picklable
@@ -91,7 +107,9 @@ def _get_catalog(self) -> "Catalog":
wh = wh_cls(*wh_args, **wh_kwargs)
return Catalog(ms, wh, **self._catalog_params)

def _rows_iter(self, total_rank: int, total_workers: int):
def _row_iter(
self, total_rank: int, total_workers: int
) -> Generator[tuple[Any, ...], None, None]:
catalog = self._get_catalog()
session = Session("PyTorch", catalog=catalog)
ds = DataChain.from_dataset(
@@ -106,12 +124,26 @@ def _rows_iter(self, total_rank: int, total_workers: int):

def __iter__(self) -> Iterator[Any]:
total_rank, total_workers = self.get_rank_and_workers()
rows = self._rows_iter(total_rank, total_workers)
if self.prefetch > 0:
from datachain.lib.udf import _prefetch_input

rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
yield from map(self._process_row, rows)
download_cb = CombinedDownloadCallback()
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
download_cb = get_download_callback(
f"{total_rank}/{total_workers}", position=total_rank
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shows a prefetch download progressbar for each worker which will be useful for debugging.

We cannot enable this by default, as this will mess up user's progressbar due to multiprocessing.


rows = self._row_iter(total_rank, total_workers)
if self.prefetch > 0:
catalog = self._get_catalog()
rows = rows_prefetcher(
catalog,
rows,
self.prefetch,
cache=self._cache,
download_cb=download_cb,
)

with download_cb, closing(rows):
yield from map(self._process_row, rows)

def _process_row(self, row_features):
row = []
31 changes: 17 additions & 14 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
@@ -8,11 +8,11 @@
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from pydantic import BaseModel

from datachain.asyn import AsyncMapper
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
from datachain.lib.file import File
from datachain.lib.prefetcher import rows_prefetcher
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.query.batch import (
Batch,
@@ -279,13 +279,6 @@ def process_safe(self, obj_rows):
return result_objs


async def _prefetch_input(row):
for obj in row:
if isinstance(obj, File):
await obj._prefetch()
return row


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

@@ -307,9 +300,14 @@ def run(
for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)

with contextlib.closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
@@ -384,9 +382,14 @@ def run(
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()
_cache = self.catalog.cache if cache else None
prepared_inputs = rows_prefetcher(
self.catalog,
prepared_inputs,
self.prefetch,
cache=_cache,
download_cb=download_cb,
)

with contextlib.closing(prepared_inputs):
for row in prepared_inputs:
19 changes: 18 additions & 1 deletion src/datachain/progress.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from threading import RLock
from typing import Any, ClassVar

from fsspec import Callback
from fsspec.callbacks import TqdmCallback
from tqdm import tqdm

@@ -132,8 +133,24 @@ def format_dict(self):
return d


class CombinedDownloadCallback(TqdmCallback):
class CombinedDownloadCallback(Callback):
def set_size(self, size):
# This is a no-op to prevent fsspec's .get_file() from setting the combined
# download size to the size of the current file.
pass

def increment_file_count(self, n: int = 1) -> None:
pass


class TqdmCombinedDownloadCallback(CombinedDownloadCallback, TqdmCallback):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have modified the callback to also show file counts on prefetching.
This will not show up on pytorch however.

Eg:

Download: 1.03MB [00:01, 605kB/s, 50 files]

def __init__(self, tqdm_kwargs=None, *args, **kwargs):
self.files_count = 0
tqdm_kwargs = tqdm_kwargs or {}
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
super().__init__(tqdm_kwargs, *args, **kwargs)

def increment_file_count(self, n: int = 1) -> None:
self.files_count += n
if self.tqdm is not None:
self.tqdm.postfix = f"{self.files_count} files"
6 changes: 4 additions & 2 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,8 @@
from datachain.dataset import DatasetStatus, RowDict
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.func.base import Function
from datachain.progress import CombinedDownloadCallback
from datachain.lib.udf import UDFAdapter
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
from datachain.query.schema import C, UDFParamSpec, normalize_param
from datachain.query.session import Session
from datachain.sql.functions.random import rand
@@ -357,7 +358,8 @@ def get_download_callback() -> Callback:
"unit_scale": True,
"unit_divisor": 1024,
"leave": False,
}
**kwargs,
},
)


2 changes: 1 addition & 1 deletion tests/func/test_datachain.py
Original file line number Diff line number Diff line change
@@ -222,7 +222,7 @@ def test_map_file(cloud_test_catalog, use_cache, prefetch):
ctc = cloud_test_catalog

def new_signal(file: File) -> str:
assert bool(file.get_local_path()) is (use_cache and prefetch > 0)
assert bool(file.get_local_path()) is (prefetch > 0)
with file.open() as f:
return file.name + " -> " + f.read().decode("utf-8")