Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prefetch: use a separate temporary cache for prefetching #730

Merged
merged 13 commits into from
Jan 7, 2025
9 changes: 3 additions & 6 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(self, x):
if __name__ == "__main__":
ds = (
DataChain.from_storage(STORAGE, type="image")
.settings(cache=True, prefetch=25)
.settings(prefetch=25)
.filter(C("file.path").glob("*.jpg"))
.map(
label=lambda path: label_to_int(basename(path)[:3], CLASSES),
Expand All @@ -68,7 +68,7 @@ def forward(self, x):
train_loader = DataLoader(
ds.to_pytorch(transform=transform),
batch_size=25,
num_workers=max(4, os.cpu_count() or 2),
num_workers=min(4, os.cpu_count() or 2),
persistent_workers=True,
multiprocessing_context=multiprocessing.get_context("spawn"),
)
Expand All @@ -80,10 +80,7 @@ def forward(self, x):
# Train the model
for epoch in range(NUM_EPOCHS):
with tqdm(
train_loader,
desc=f"epoch {epoch + 1}/{NUM_EPOCHS}",
unit="batch",
leave=False,
train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch"
) as loader:
for data in loader:
inputs, labels = data
Expand Down
22 changes: 16 additions & 6 deletions src/datachain/asyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
Iterable,
Iterator,
)
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, wait
from heapq import heappop, heappush
from typing import Any, Callable, Generic, Optional, TypeVar

from fsspec.asyn import get_loop

from datachain.utils import safe_closing

ASYNC_WORKERS = 20

InputT = TypeVar("InputT", contravariant=True) # noqa: PLC0105
Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(
self.pool = ThreadPoolExecutor(workers)
self._tasks: set[asyncio.Task] = set()
self._shutdown_producer = threading.Event()
self._producer_is_shutdown = threading.Event()

def start_task(self, coro: Coroutine) -> asyncio.Task:
task = self.loop.create_task(coro)
Expand All @@ -64,11 +67,16 @@ def start_task(self, coro: Coroutine) -> asyncio.Task:
return task

def _produce(self) -> None:
for item in self.iterable:
if self._shutdown_producer.is_set():
return
fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop)
fut.result() # wait until the item is in the queue
try:
with safe_closing(self.iterable):
for item in self.iterable:
if self._shutdown_producer.is_set():
return
coro = self.work_queue.put(item)
fut = asyncio.run_coroutine_threadsafe(coro, self.loop)
fut.result() # wait until the item is in the queue
finally:
self._producer_is_shutdown.set()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This _produce() is run via loop.run_in_executor. Since this is running in a different thread, and is a synchronous function, .cancel() does not really cancel the task/function, and it'll keep running in the background.

We do ask it to shut down, by setting self._shutdown_producer event, but this function is running in the background.
We want to wait for the self.iterable to be closed, but that function may not run immediately when we .close() the mapper.iterate() iterator.

Eg:

mapper = AsyncMapper(func, iterable) 
with closing(mapper.iterate()) as it:
    next(it)

In this case, the iterable.close() may be called much later than mapper.iterate.close(). We want to ensure the iterable gets closed when mapper.iterate() gets closed.

So, for this, we are setting another event that gets set after iterable.close() gets called.


async def produce(self) -> None:
await self.to_thread(self._produce)
Expand Down Expand Up @@ -179,6 +187,8 @@ def iterate(self, timeout=None) -> Generator[ResultT, None, None]:
self.shutdown_producer()
if not async_run.done():
async_run.cancel()
wait([async_run])
Copy link
Member Author

@skshetry skshetry Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.cancel() does not immediately cancel the underlying asyncio task.

We could add a .result() to wait for the future, but that does not seem to work either for the cancelled future from run_coroutine_threadsafe(). See python/cpython#105836.

So, I have added wait(...) as it seems to wait the cancelled future, and wait for underlying asyncio task.

Alternatively, we could add an asyncio.Event and wait for it.

self._producer_is_shutdown.wait()

def __iter__(self):
return self.iterate()
Expand Down
40 changes: 31 additions & 9 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
from collections.abc import Iterator
from contextlib import contextmanager
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Optional

from dvc_data.hashfile.db.local import LocalHashFileDB
from dvc_objects.fs.local import LocalFileSystem
from dvc_objects.fs.utils import remove
from fsspec.callbacks import Callback, TqdmCallback

from .progress import Tqdm
Expand All @@ -20,6 +24,23 @@ def try_scandir(path):
pass


def get_temp_cache(tmp_dir: str, prefix: Optional[str] = None) -> "DataChainCache":
cache_dir = mkdtemp(prefix=prefix, dir=tmp_dir)
return DataChainCache(cache_dir, tmp_dir=tmp_dir)


@contextmanager
def temporary_cache(
tmp_dir: str, prefix: Optional[str] = None, delete: bool = True
) -> Iterator["DataChainCache"]:
cache = get_temp_cache(tmp_dir, prefix=prefix)
try:
yield cache
finally:
if delete:
cache.destroy()


class DataChainCache:
def __init__(self, cache_dir: str, tmp_dir: str):
self.odb = LocalHashFileDB(
Expand All @@ -28,6 +49,9 @@ def __init__(self, cache_dir: str, tmp_dir: str):
tmp_dir=tmp_dir,
)

def __eq__(self, other) -> bool:
return self.odb == other.odb

@property
def cache_dir(self):
return self.odb.path
Expand Down Expand Up @@ -82,20 +106,18 @@ async def download(
os.unlink(tmp_info)

def store_data(self, file: "File", contents: bytes) -> None:
checksum = file.get_hash()
dst = self.path_from_checksum(checksum)
if not os.path.exists(dst):
# Create the file only if it's not already in cache
os.makedirs(os.path.dirname(dst), exist_ok=True)
with open(dst, mode="wb") as f:
f.write(contents)

def clear(self):
self.odb.add_bytes(file.get_hash(), contents)

def clear(self) -> None:
"""
Completely clear the cache.
"""
self.odb.clear()

def destroy(self) -> None:
# `clear` leaves the prefix directory structure intact.
remove(self.cache_dir)

def get_total_size(self) -> int:
total = 0
for subdir in try_scandir(self.odb.path):
Expand Down
6 changes: 6 additions & 0 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,12 @@ def find_column_to_str( # noqa: PLR0911
return ""


def clone_catalog_with_cache(catalog: "Catalog", cache: "DataChainCache") -> "Catalog":
clone = catalog.copy()
clone.cache = cache
return clone


class Catalog:
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,15 @@ def from_storage(
return dc

if update or not list_ds_exists:
# disable prefetch for listing, as it pre-downloads all files
(
cls.from_records(
DataChain.DEFAULT_FILE_RECORD,
session=session,
settings=settings,
in_memory=in_memory,
)
.settings(prefetch=0)
.gen(
list_bucket(list_uri, cache, client_config=client_config),
output={f"{object_name}": File},
Expand Down
19 changes: 15 additions & 4 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,21 @@
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(self) -> None:
if self._caching_enabled:
client = self._catalog.get_client(self.source)
await client._download(self, callback=self._download_cb)
async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
from datachain.client.hf import HfClient

if self._catalog is None:
raise RuntimeError("cannot prefetch file because catalog is not setup")

Check warning on line 276 in src/datachain/lib/file.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/file.py#L276

Added line #L276 was not covered by tests

client = self._catalog.get_client(self.source)
if client.protocol == HfClient.protocol:
return False

Check warning on line 280 in src/datachain/lib/file.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/file.py#L280

Added line #L280 was not covered by tests
Comment on lines +279 to +280
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


await client._download(self, callback=download_cb or self._download_cb)
self._set_stream(
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
)
return True

def get_local_path(self) -> Optional[str]:
"""Return path to a file in a local cache.
Expand Down
70 changes: 57 additions & 13 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
from collections.abc import Iterator
import os
import weakref
from collections.abc import Generator, Iterable, Iterator
from contextlib import closing
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
Expand All @@ -9,15 +12,19 @@
from torchvision.transforms import v2

from datachain import Session
from datachain.asyn import AsyncMapper
from datachain.cache import get_temp_cache
from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
from datachain.lib.settings import Settings
from datachain.lib.text import convert_text
from datachain.progress import CombinedDownloadCallback
from datachain.query.dataset import get_download_callback

if TYPE_CHECKING:
from torchvision.transforms.v2 import Transform

from datachain.cache import DataChainCache as Cache


logger = logging.getLogger("datachain")

Expand Down Expand Up @@ -75,6 +82,19 @@
if (prefetch := dc_settings.prefetch) is not None:
self.prefetch = prefetch

self._cache = catalog.cache
self._prefetch_cache: Optional[Cache] = None
if prefetch and not self.cache:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
self._prefetch_cache = get_temp_cache(tmp_dir, prefix="prefetch-")
self._cache = self._prefetch_cache
weakref.finalize(self, self._prefetch_cache.destroy)

def close(self) -> None:
if self._prefetch_cache:
self._prefetch_cache.destroy()

def _init_catalog(self, catalog: "Catalog"):
# For compatibility with multiprocessing,
# we can only store params in __init__(), as Catalog isn't picklable
Expand All @@ -89,9 +109,15 @@
ms = ms_cls(*ms_args, **ms_kwargs)
wh_cls, wh_args, wh_kwargs = self._wh_params
wh = wh_cls(*wh_args, **wh_kwargs)
return Catalog(ms, wh, **self._catalog_params)
catalog = Catalog(ms, wh, **self._catalog_params)
catalog.cache = self._cache
return catalog

def _rows_iter(self, total_rank: int, total_workers: int):
def _row_iter(
self,
total_rank: int,
total_workers: int,
) -> Generator[tuple[Any, ...], None, None]:
catalog = self._get_catalog()
session = Session("PyTorch", catalog=catalog)
ds = DataChain.from_dataset(
Expand All @@ -104,16 +130,34 @@
ds = ds.chunk(total_rank, total_workers)
yield from ds.collect()

def __iter__(self) -> Iterator[Any]:
total_rank, total_workers = self.get_rank_and_workers()
rows = self._rows_iter(total_rank, total_workers)
if self.prefetch > 0:
from datachain.lib.udf import _prefetch_input

rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()
yield from map(self._process_row, rows)
def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
from datachain.lib.udf import _prefetch_inputs

def _process_row(self, row_features):
total_rank, total_workers = self.get_rank_and_workers()
download_cb = CombinedDownloadCallback()
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
download_cb = get_download_callback(

Check warning on line 139 in src/datachain/lib/pytorch.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/pytorch.py#L139

Added line #L139 was not covered by tests
f"{total_rank}/{total_workers}",
position=total_rank,
leave=True,
)

rows = self._row_iter(total_rank, total_workers)
rows = _prefetch_inputs(
rows,
self.prefetch,
download_cb=download_cb,
after_prefetch=download_cb.increment_file_count,
)

with download_cb, closing(rows):
yield from rows

def __iter__(self) -> Iterator[list[Any]]:
with closing(self._iter_with_prefetch()) as rows:
yield from map(self._process_row, rows)

def _process_row(self, row_features: Iterable[Any]) -> list[Any]:
row = []
for fr in row_features:
if hasattr(fr, "read"):
Expand Down
Loading
Loading