diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index ebc727e32..842ee9d1a 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1277,7 +1277,12 @@ def collect(self, *cols: str) -> Iterator[Union[DataValue, tuple[DataValue, ...] yield ret[0] if len(cols) == 1 else tuple(ret) def to_pytorch( - self, transform=None, tokenizer=None, tokenizer_kwargs=None, num_samples=0 + self, + transform=None, + tokenizer=None, + tokenizer_kwargs=None, + num_samples=0, + remove_prefetched: bool = False, ): """Convert to pytorch dataset format. @@ -1287,6 +1292,7 @@ def to_pytorch( tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer. num_samples (int): Number of random samples to draw for each epoch. This argument is ignored if `num_samples=0` (the default). + remove_prefetched (bool): Whether to remove prefetched files after reading. Example: ```py @@ -1313,6 +1319,7 @@ def to_pytorch( tokenizer_kwargs=tokenizer_kwargs, num_samples=num_samples, dc_settings=chain._settings, + remove_prefetched=remove_prefetched, ) def remove_file_signals(self) -> "Self": # noqa: D102 diff --git a/src/datachain/lib/pytorch.py b/src/datachain/lib/pytorch.py index 985af387c..f3b504d75 100644 --- a/src/datachain/lib/pytorch.py +++ b/src/datachain/lib/pytorch.py @@ -50,6 +50,7 @@ def __init__( tokenizer_kwargs: Optional[dict[str, Any]] = None, num_samples: int = 0, dc_settings: Optional[Settings] = None, + remove_prefetched: bool = False, ): """ Pytorch IterableDataset that streams DataChain datasets. @@ -84,6 +85,7 @@ def __init__( self._cache = catalog.cache self._prefetch_cache: Optional[Cache] = None + self._remove_prefetched = remove_prefetched if prefetch and not self.cache: tmp_dir = catalog.cache.tmp_dir assert tmp_dir @@ -147,7 +149,7 @@ def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]: rows, self.prefetch, download_cb=download_cb, - after_prefetch=download_cb.increment_file_count, + remove_prefetched=self._remove_prefetched, ) with download_cb, closing(rows): diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index 717d940b9..750e58dc8 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -16,6 +16,7 @@ from datachain.lib.data_model import DataValue from datachain.lib.file import File from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError +from datachain.progress import CombinedDownloadCallback from datachain.query.batch import ( Batch, BatchingStrategy, @@ -301,20 +302,42 @@ async def _prefetch_input( return row +def _remove_prefetched(row: T) -> None: + for obj in row: + if isinstance(obj, File): + catalog = obj._catalog + assert catalog is not None + try: + catalog.cache.remove(obj) + except Exception as e: # noqa: BLE001 + print(f"Failed to remove prefetched item {obj.name!r}: {e!s}") + + def _prefetch_inputs( prepared_inputs: "Iterable[T]", prefetch: int = 0, download_cb: Optional["Callback"] = None, - after_prefetch: "Callable[[], None]" = noop, + after_prefetch: Optional[Callable[[], None]] = None, + remove_prefetched: bool = False, ) -> "abc.Generator[T, None, None]": - if prefetch > 0: - f = partial( - _prefetch_input, - download_cb=download_cb, - after_prefetch=after_prefetch, - ) - prepared_inputs = AsyncMapper(f, prepared_inputs, workers=prefetch).iterate() # type: ignore[assignment] - yield from prepared_inputs + if not prefetch: + yield from prepared_inputs + return + + if after_prefetch is None: + after_prefetch = noop + if isinstance(download_cb, CombinedDownloadCallback): + after_prefetch = download_cb.increment_file_count + + f = partial(_prefetch_input, download_cb=download_cb, after_prefetch=after_prefetch) + mapper = AsyncMapper(f, prepared_inputs, workers=prefetch) + with closing(mapper.iterate()) as row_iter: + for row in row_iter: + try: + yield row # type: ignore[misc] + finally: + if remove_prefetched: + _remove_prefetched(row) def _get_cache( @@ -351,7 +374,13 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]": ) prepared_inputs = _prepare_rows(udf_inputs) - prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch) + prepared_inputs = _prefetch_inputs( + prepared_inputs, + self.prefetch, + download_cb=download_cb, + remove_prefetched=bool(self.prefetch) and not cache, + ) + with closing(prepared_inputs): for id_, *udf_args in prepared_inputs: result_objs = self.process_safe(udf_args) @@ -429,15 +458,22 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]": row, udf_fields, catalog, cache, download_cb ) + def _process_row(row): + with safe_closing(self.process_safe(row)) as result_objs: + for result_obj in result_objs: + udf_output = self._flatten_row(result_obj) + yield dict(zip(self.signal_names, udf_output)) + prepared_inputs = _prepare_rows(udf_inputs) - prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch) + prepared_inputs = _prefetch_inputs( + prepared_inputs, + self.prefetch, + download_cb=download_cb, + remove_prefetched=bool(self.prefetch) and not cache, + ) with closing(prepared_inputs): - for row in prepared_inputs: - result_objs = self.process_safe(row) - udf_outputs = (self._flatten_row(row) for row in result_objs) - output = (dict(zip(self.signal_names, row)) for row in udf_outputs) - processed_cb.relative_update(1) - yield output + for row in processed_cb.wrap(prepared_inputs): + yield _process_row(row) self.teardown() diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index b510b604d..1da0eaaad 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -336,15 +336,16 @@ def process_udf_outputs( for udf_output in udf_results: if not udf_output: continue - for row in udf_output: - cb.relative_update() - rows.append(adjust_outputs(warehouse, row, udf_col_types)) - if len(rows) >= batch_size or ( - len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80 - ): - for row_chunk in batched(rows, batch_size): - warehouse.insert_rows(udf_table, row_chunk) - rows.clear() + with safe_closing(udf_output): + for row in udf_output: + cb.relative_update() + rows.append(adjust_outputs(warehouse, row, udf_col_types)) + if len(rows) >= batch_size or ( + len(rows) % 10 == 0 and psutil.virtual_memory().percent > 80 + ): + for row_chunk in batched(rows, batch_size): + warehouse.insert_rows(udf_table, row_chunk) + rows.clear() if rows: for row_chunk in batched(rows, batch_size): diff --git a/tests/benchmarks/test_datachain.py b/tests/benchmarks/test_datachain.py index 3dc5b4e92..88139885d 100644 --- a/tests/benchmarks/test_datachain.py +++ b/tests/benchmarks/test_datachain.py @@ -6,6 +6,11 @@ def test_datachain(tmp_dir, test_session, datasets, benchmark): def run_script(uri, **kwargs): DataChain.from_storage(uri, session=test_session, **kwargs).gen( emd=process_laion_meta + ).settings( + # Disable `prefetch` for `map()` because `process_laion_meta` repeatedly + # returns the dataset file. This causes `prefetch` to download and + # remove the file multiple times unnecessarily, slowing down the process. + prefetch=0, ).map( stem=lambda file: file.get_file_stem(), params=["emd.file"], diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 979a624e5..a486d2c28 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -1,3 +1,4 @@ +import functools import math import os import pickle @@ -234,17 +235,28 @@ def verify_cache_used(file): assert head == catalog.cache.tmp_dir assert tail.startswith("prefetch-") - def new_signal(file: File) -> str: - assert is_prefetched(file) == (prefetch > 0) - verify_cache_used(file) + def with_checks(func, seen=[]): # noqa: B006 + @functools.wraps(func) + def wrapped(file, *args, **kwargs): + # previously prefetched files should be removed if `cache` is disabled. + for f in seen: + assert f._catalog.cache.contains(f) == use_cache + seen.append(file) + + assert is_prefetched(file) == (prefetch > 0) + verify_cache_used(file) + return func(file, *args, **kwargs) + return wrapped + + def new_signal(file: File) -> str: with file.open() as f: return file.name + " -> " + f.read().decode("utf-8") dc = ( DataChain.from_storage(ctc.src_uri, session=ctc.session) .settings(cache=use_cache, prefetch=prefetch) - .map(signal=new_signal) + .map(signal=with_checks(new_signal)) .save() ) @@ -1307,17 +1319,28 @@ def verify_cache_used(file): assert head == catalog.cache.tmp_dir assert tail.startswith("prefetch-") - def new_signal(file: File) -> list[str]: - assert is_prefetched(file) == (prefetch > 0) - verify_cache_used(file) + def with_checks(func, seen=[]): # noqa: B006 + @functools.wraps(func) + def wrapped(file, *args, **kwargs): + # previously prefetched files should be removed if `cache` is disabled. + for f in seen: + assert f._catalog.cache.contains(f) == use_cache + seen.append(file) + + assert is_prefetched(file) == (prefetch > 0) + verify_cache_used(file) + return func(file, *args, **kwargs) + return wrapped + + def new_signal(file: File) -> list[str]: with file.open("rb") as f: return [file.name, f.read().decode("utf-8")] dc = ( DataChain.from_storage(ctc.src_uri, session=ctc.session) .settings(cache=use_cache, prefetch=prefetch) - .gen(signal=new_signal, output=str) + .gen(signal=with_checks(new_signal), output=str) .save() ) expected = { diff --git a/tests/func/test_pytorch.py b/tests/func/test_pytorch.py index 2ebcbc9ab..a39df26b7 100644 --- a/tests/func/test_pytorch.py +++ b/tests/func/test_pytorch.py @@ -86,11 +86,16 @@ def test_to_pytorch(fake_dataset): @pytest.mark.parametrize("use_cache", (True, False)) @pytest.mark.parametrize("prefetch", (0, 2)) -def test_prefetch(mocker, catalog, fake_dataset, use_cache, prefetch): +@pytest.mark.parametrize("remove_prefetched", (True, False)) +def test_prefetch( + mocker, catalog, fake_dataset, use_cache, prefetch, remove_prefetched +): catalog.cache.clear() dataset = fake_dataset.limit(10) - ds = dataset.settings(cache=use_cache, prefetch=prefetch).to_pytorch() + ds = dataset.settings(cache=use_cache, prefetch=prefetch).to_pytorch( + remove_prefetched=remove_prefetched + ) iter_with_prefetch = ds._iter_with_prefetch cache = ds._cache @@ -101,8 +106,16 @@ def is_prefetched(file: File): return cache.contains(file) def check_prefetched(): + seen = [] for row in iter_with_prefetch(): + # prefetched files should persist if `cache` is enabled, or + # `remove_prefetched` is not set. Otherwise, the old prefetched + # files should be removed. + for f in seen: + assert is_prefetched(f) == (cache and not remove_prefetched) + files = [f for f in row if isinstance(f, File)] + seen.extend(files) assert files files_not_in_cache = [f for f in files if not is_prefetched(f)] if prefetch: diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 2018837aa..6e30c34f2 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -4,6 +4,7 @@ import pytest from datachain.cache import DataChainCache +from datachain.lib.file import File from datachain.lib.pytorch import PytorchDataset from datachain.lib.settings import Settings @@ -42,6 +43,35 @@ def test_close(mocker, catalog, cache): spy.assert_called_once() +def test_prefetched_files_are_removed_after_yield(tmp_dir, mocker, catalog, cache): + files = [] + for name in "abc": + (tmp_dir / name).write_text(name, encoding="utf-8") + file = File(path=tmp_dir / name) + file._set_stream(catalog) + files.append((file,)) + + ds = PytorchDataset( + "fake", + 1, + catalog, + dc_settings=Settings(prefetch=10), + remove_prefetched=True, + ) + mocker.patch.object(ds, "_row_iter", return_value=iter(files)) + + seen = [] + for (file,) in ds._iter_with_prefetch(): + # previously prefetched files should have been removed by now + for f in seen: + assert not f._catalog.cache.contains(f) + assert not f.get_local_path() + seen.append(file) + + assert file._catalog.cache.contains(file) + assert file.get_local_path() + + @pytest.mark.parametrize("cache", [True, False]) def test_prefetch_cache_gets_destroyed_on_gc(mocker, catalog, cache): spy = mocker.patch.object(DataChainCache, "destroy")