Skip to content

Commit

Permalink
prefetching: remove prefetched item after use in udf (#818)
Browse files Browse the repository at this point in the history
* prefetching: remove prefetched item after use in udf

This PR removes the prefetched item after use in the UDF.
This is enabled by default on `prefetch>0`, unless `cache=True` is set in the UDF, in
which case the prefetched item is not removed.

For pytorch dataloader, this is not enabled by default, but can be enabled by setting
`remove_prefetched=True` in the `PytorchDataset` class.
This is done so because the dataset can be used in multiple epochs, and removing the
prefetched item after use can cause it to redownload again in the next epoch.

The exposed `remove_prefetched=True|False` setting could be renamed to
some better option. Feedbacks are welcome.

* close iterable properly
  • Loading branch information
skshetry authored Jan 16, 2025
1 parent aad99e2 commit 2e89875
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 38 deletions.
9 changes: 8 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,12 @@ def collect(self, *cols: str) -> Iterator[Union[DataValue, tuple[DataValue, ...]
yield ret[0] if len(cols) == 1 else tuple(ret)

def to_pytorch(
self, transform=None, tokenizer=None, tokenizer_kwargs=None, num_samples=0
self,
transform=None,
tokenizer=None,
tokenizer_kwargs=None,
num_samples=0,
remove_prefetched: bool = False,
):
"""Convert to pytorch dataset format.
Expand All @@ -1287,6 +1292,7 @@ def to_pytorch(
tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
num_samples (int): Number of random samples to draw for each epoch.
This argument is ignored if `num_samples=0` (the default).
remove_prefetched (bool): Whether to remove prefetched files after reading.
Example:
```py
Expand All @@ -1313,6 +1319,7 @@ def to_pytorch(
tokenizer_kwargs=tokenizer_kwargs,
num_samples=num_samples,
dc_settings=chain._settings,
remove_prefetched=remove_prefetched,
)

def remove_file_signals(self) -> "Self": # noqa: D102
Expand Down
4 changes: 3 additions & 1 deletion src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
tokenizer_kwargs: Optional[dict[str, Any]] = None,
num_samples: int = 0,
dc_settings: Optional[Settings] = None,
remove_prefetched: bool = False,
):
"""
Pytorch IterableDataset that streams DataChain datasets.
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(

self._cache = catalog.cache
self._prefetch_cache: Optional[Cache] = None
self._remove_prefetched = remove_prefetched
if prefetch and not self.cache:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
Expand Down Expand Up @@ -147,7 +149,7 @@ def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
rows,
self.prefetch,
download_cb=download_cb,
after_prefetch=download_cb.increment_file_count,
remove_prefetched=self._remove_prefetched,
)

with download_cb, closing(rows):
Expand Down
70 changes: 53 additions & 17 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datachain.lib.data_model import DataValue
from datachain.lib.file import File
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.progress import CombinedDownloadCallback
from datachain.query.batch import (
Batch,
BatchingStrategy,
Expand Down Expand Up @@ -301,20 +302,42 @@ async def _prefetch_input(
return row


def _remove_prefetched(row: T) -> None:
for obj in row:
if isinstance(obj, File):
catalog = obj._catalog
assert catalog is not None
try:
catalog.cache.remove(obj)
except Exception as e: # noqa: BLE001
print(f"Failed to remove prefetched item {obj.name!r}: {e!s}")


def _prefetch_inputs(
prepared_inputs: "Iterable[T]",
prefetch: int = 0,
download_cb: Optional["Callback"] = None,
after_prefetch: "Callable[[], None]" = noop,
after_prefetch: Optional[Callable[[], None]] = None,
remove_prefetched: bool = False,
) -> "abc.Generator[T, None, None]":
if prefetch > 0:
f = partial(
_prefetch_input,
download_cb=download_cb,
after_prefetch=after_prefetch,
)
prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment]
yield from prepared_inputs
if not prefetch:
yield from prepared_inputs
return

if after_prefetch is None:
after_prefetch = noop
if isinstance(download_cb, CombinedDownloadCallback):
after_prefetch = download_cb.increment_file_count

f = partial(_prefetch_input, download_cb=download_cb, after_prefetch=after_prefetch)
mapper = AsyncMapper(f, prepared_inputs, workers=prefetch)
with closing(mapper.iterate()) as row_iter:
for row in row_iter:
try:
yield row # type: ignore[misc]
finally:
if remove_prefetched:
_remove_prefetched(row)


def _get_cache(
Expand Down Expand Up @@ -351,7 +374,13 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
)

prepared_inputs = _prepare_rows(udf_inputs)
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
prepared_inputs = _prefetch_inputs(
prepared_inputs,
self.prefetch,
download_cb=download_cb,
remove_prefetched=bool(self.prefetch) and not cache,
)

with closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
Expand Down Expand Up @@ -429,15 +458,22 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
row, udf_fields, catalog, cache, download_cb
)

def _process_row(row):
with safe_closing(self.process_safe(row)) as result_objs:
for result_obj in result_objs:
udf_output = self._flatten_row(result_obj)
yield dict(zip(self.signal_names, udf_output))

prepared_inputs = _prepare_rows(udf_inputs)
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
prepared_inputs = _prefetch_inputs(
prepared_inputs,
self.prefetch,
download_cb=download_cb,
remove_prefetched=bool(self.prefetch) and not cache,
)
with closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output
for row in processed_cb.wrap(prepared_inputs):
yield _process_row(row)

self.teardown()

Expand Down
19 changes: 10 additions & 9 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,16 @@ def process_udf_outputs(
for udf_output in udf_results:
if not udf_output:
continue
for row in udf_output:
cb.relative_update()
rows.append(adjust_outputs(warehouse, row, udf_col_types))
if len(rows) >= batch_size or (
len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
):
for row_chunk in batched(rows, batch_size):
warehouse.insert_rows(udf_table, row_chunk)
rows.clear()
with safe_closing(udf_output):
for row in udf_output:
cb.relative_update()
rows.append(adjust_outputs(warehouse, row, udf_col_types))
if len(rows) >= batch_size or (
len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80
):
for row_chunk in batched(rows, batch_size):
warehouse.insert_rows(udf_table, row_chunk)
rows.clear()

if rows:
for row_chunk in batched(rows, batch_size):
Expand Down
5 changes: 5 additions & 0 deletions tests/benchmarks/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ def test_datachain(tmp_dir, test_session, datasets, benchmark):
def run_script(uri, **kwargs):
DataChain.from_storage(uri, session=test_session, **kwargs).gen(
emd=process_laion_meta
).settings(
# Disable `prefetch` for `map()` because `process_laion_meta` repeatedly
# returns the dataset file. This causes `prefetch` to download and
# remove the file multiple times unnecessarily, slowing down the process.
prefetch=0,
).map(
stem=lambda file: file.get_file_stem(),
params=["emd.file"],
Expand Down
39 changes: 31 additions & 8 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import math
import os
import pickle
Expand Down Expand Up @@ -234,17 +235,28 @@ def verify_cache_used(file):
assert head == catalog.cache.tmp_dir
assert tail.startswith("prefetch-")

def new_signal(file: File) -> str:
assert is_prefetched(file) == (prefetch > 0)
verify_cache_used(file)
def with_checks(func, seen=[]): # noqa: B006
@functools.wraps(func)
def wrapped(file, *args, **kwargs):
# previously prefetched files should be removed if `cache` is disabled.
for f in seen:
assert f._catalog.cache.contains(f) == use_cache
seen.append(file)

assert is_prefetched(file) == (prefetch > 0)
verify_cache_used(file)
return func(file, *args, **kwargs)

return wrapped

def new_signal(file: File) -> str:
with file.open() as f:
return file.name + " -> " + f.read().decode("utf-8")

dc = (
DataChain.from_storage(ctc.src_uri, session=ctc.session)
.settings(cache=use_cache, prefetch=prefetch)
.map(signal=new_signal)
.map(signal=with_checks(new_signal))
.save()
)

Expand Down Expand Up @@ -1307,17 +1319,28 @@ def verify_cache_used(file):
assert head == catalog.cache.tmp_dir
assert tail.startswith("prefetch-")

def new_signal(file: File) -> list[str]:
assert is_prefetched(file) == (prefetch > 0)
verify_cache_used(file)
def with_checks(func, seen=[]): # noqa: B006
@functools.wraps(func)
def wrapped(file, *args, **kwargs):
# previously prefetched files should be removed if `cache` is disabled.
for f in seen:
assert f._catalog.cache.contains(f) == use_cache
seen.append(file)

assert is_prefetched(file) == (prefetch > 0)
verify_cache_used(file)
return func(file, *args, **kwargs)

return wrapped

def new_signal(file: File) -> list[str]:
with file.open("rb") as f:
return [file.name, f.read().decode("utf-8")]

dc = (
DataChain.from_storage(ctc.src_uri, session=ctc.session)
.settings(cache=use_cache, prefetch=prefetch)
.gen(signal=new_signal, output=str)
.gen(signal=with_checks(new_signal), output=str)
.save()
)
expected = {
Expand Down
17 changes: 15 additions & 2 deletions tests/func/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,16 @@ def test_to_pytorch(fake_dataset):

@pytest.mark.parametrize("use_cache", (True, False))
@pytest.mark.parametrize("prefetch", (0, 2))
def test_prefetch(mocker, catalog, fake_dataset, use_cache, prefetch):
@pytest.mark.parametrize("remove_prefetched", (True, False))
def test_prefetch(
mocker, catalog, fake_dataset, use_cache, prefetch, remove_prefetched
):
catalog.cache.clear()

dataset = fake_dataset.limit(10)
ds = dataset.settings(cache=use_cache, prefetch=prefetch).to_pytorch()
ds = dataset.settings(cache=use_cache, prefetch=prefetch).to_pytorch(
remove_prefetched=remove_prefetched
)

iter_with_prefetch = ds._iter_with_prefetch
cache = ds._cache
Expand All @@ -101,8 +106,16 @@ def is_prefetched(file: File):
return cache.contains(file)

def check_prefetched():
seen = []
for row in iter_with_prefetch():
# prefetched files should persist if `cache` is enabled, or
# `remove_prefetched` is not set. Otherwise, the old prefetched
# files should be removed.
for f in seen:
assert is_prefetched(f) == (cache and not remove_prefetched)

files = [f for f in row if isinstance(f, File)]
seen.extend(files)
assert files
files_not_in_cache = [f for f in files if not is_prefetched(f)]
if prefetch:
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from datachain.cache import DataChainCache
from datachain.lib.file import File
from datachain.lib.pytorch import PytorchDataset
from datachain.lib.settings import Settings

Expand Down Expand Up @@ -42,6 +43,35 @@ def test_close(mocker, catalog, cache):
spy.assert_called_once()


def test_prefetched_files_are_removed_after_yield(tmp_dir, mocker, catalog, cache):
files = []
for name in "abc":
(tmp_dir / name).write_text(name, encoding="utf-8")
file = File(path=tmp_dir / name)
file._set_stream(catalog)
files.append((file,))

ds = PytorchDataset(
"fake",
1,
catalog,
dc_settings=Settings(prefetch=10),
remove_prefetched=True,
)
mocker.patch.object(ds, "_row_iter", return_value=iter(files))

seen = []
for (file,) in ds._iter_with_prefetch():
# previously prefetched files should have been removed by now
for f in seen:
assert not f._catalog.cache.contains(f)
assert not f.get_local_path()
seen.append(file)

assert file._catalog.cache.contains(file)
assert file.get_local_path()


@pytest.mark.parametrize("cache", [True, False])
def test_prefetch_cache_gets_destroyed_on_gc(mocker, catalog, cache):
spy = mocker.patch.object(DataChainCache, "destroy")
Expand Down

0 comments on commit 2e89875

Please sign in to comment.