Skip to content

Commit

Permalink
create prefetch cache one level higher
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Jan 7, 2025
1 parent 4bd3b4b commit a99007a
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 148 deletions.
5 changes: 1 addition & 4 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,7 @@ def forward(self, x):
# Train the model
for epoch in range(NUM_EPOCHS):
with tqdm(
train_loader,
desc=f"epoch {epoch + 1}/{NUM_EPOCHS}",
unit="batch",
leave=False,
train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch"
) as loader:
for data in loader:
inputs, labels = data
Expand Down
2 changes: 2 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,15 @@ def from_storage(
return dc

if update or not list_ds_exists:
# disable prefetch for listing, as it pre-downloads all files
(
cls.from_records(
DataChain.DEFAULT_FILE_RECORD,
session=session,
settings=settings,
in_memory=in_memory,
)
.settings(prefetch=0)
.gen(
list_bucket(list_uri, cache, client_config=client_config),
output={f"{object_name}": File},
Expand Down
4 changes: 3 additions & 1 deletion src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool:
return False

Check warning on line 280 in src/datachain/lib/file.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/file.py#L280

Added line #L280 was not covered by tests

await client._download(self, callback=download_cb or self._download_cb)
self._download_cb = DEFAULT_CALLBACK
self._set_stream(
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK
)
return True

def get_local_path(self) -> Optional[str]:
Expand Down
21 changes: 13 additions & 8 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
if TYPE_CHECKING:
from torchvision.transforms.v2 import Transform

from datachain.cache import DataChainCache as Cache


logger = logging.getLogger("datachain")

Expand Down Expand Up @@ -80,17 +82,18 @@ def __init__(
if (prefetch := dc_settings.prefetch) is not None:
self.prefetch = prefetch

if self.cache or not self.prefetch:
self._cache = catalog.cache
else:
self._cache = catalog.cache
self._prefetch_cache: Optional[Cache] = None
if prefetch and not self.cache:
tmp_dir = catalog.cache.tmp_dir
assert tmp_dir
self._cache = get_temp_cache(tmp_dir, prefix="prefetch-")
weakref.finalize(self, self._cache.destroy)
self._prefetch_cache = get_temp_cache(tmp_dir, prefix="prefetch-")
self._cache = self._prefetch_cache
weakref.finalize(self, self._prefetch_cache.destroy)

def close(self) -> None:
if not self.cache:
self._cache.destroy()
if self._prefetch_cache:
self._prefetch_cache.destroy()

def _init_catalog(self, catalog: "Catalog"):
# For compatibility with multiprocessing,
Expand Down Expand Up @@ -134,7 +137,9 @@ def _iter_with_prefetch(self) -> Generator[tuple[Any], None, None]:
download_cb = CombinedDownloadCallback()
if os.getenv("DATACHAIN_SHOW_PREFETCH_PROGRESS"):
download_cb = get_download_callback(

Check warning on line 139 in src/datachain/lib/pytorch.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/pytorch.py#L139

Added line #L139 was not covered by tests
f"{total_rank}/{total_workers}", position=total_rank
f"{total_rank}/{total_workers}",
position=total_rank,
leave=True,
)

rows = self._row_iter(total_rank, total_workers)
Expand Down
50 changes: 23 additions & 27 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from datachain.asyn import AsyncMapper
from datachain.cache import temporary_cache
from datachain.catalog.catalog import clone_catalog_with_cache
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
Expand Down Expand Up @@ -106,6 +105,10 @@ def run(
processed_cb,
)

@property
def prefetch(self) -> int:
return self.inner.prefetch


class UDFBase(AbstractUDF):
"""Base class for stateful user-defined functions.
Expand Down Expand Up @@ -156,6 +159,7 @@ def process(self, file) -> list[float]:
"""

is_output_batched = False
prefetch: int = 0

def __init__(self):
self.params: Optional[SignalSchema] = None
Expand Down Expand Up @@ -346,20 +350,15 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
row, udf_fields, catalog, cache, download_cb
)

with _get_cache(catalog.cache, self.prefetch, use_cache=cache) as _cache:
catalog = clone_catalog_with_cache(catalog, _cache)

prepared_inputs = _prepare_rows(udf_inputs)
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
with closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [
{"sys__id": id_} | dict(zip(self.signal_names, udf_output))
]
processed_cb.relative_update(1)
yield output
prepared_inputs = _prepare_rows(udf_inputs)
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
with closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down Expand Up @@ -430,18 +429,15 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
row, udf_fields, catalog, cache, download_cb
)

with _get_cache(catalog.cache, self.prefetch, use_cache=cache) as _cache:
catalog = clone_catalog_with_cache(catalog, _cache)

prepared_inputs = _prepare_rows(udf_inputs)
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
with closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output
prepared_inputs = _prepare_rows(udf_inputs)
prepared_inputs = _prefetch_inputs(prepared_inputs, self.prefetch)
with closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down
188 changes: 100 additions & 88 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from sqlalchemy.sql.selectable import Select

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog.catalog import clone_catalog_with_cache
from datachain.data_storage.schema import (
PARTITION_COLUMN_ID,
partition_col_names,
Expand All @@ -43,7 +44,7 @@
from datachain.dataset import DatasetStatus, RowDict
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.func.base import Function
from datachain.lib.udf import UDFAdapter
from datachain.lib.udf import UDFAdapter, _get_cache
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
from datachain.query.schema import C, UDFParamSpec, normalize_param
from datachain.query.session import Session
Expand Down Expand Up @@ -420,97 +421,108 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:

udf_fields = [str(c.name) for c in query.selected_columns]

try:
if workers:
if self.catalog.in_memory:
raise RuntimeError(
"In-memory databases cannot be used with "
"distributed processing."
)
prefetch = getattr(self.udf, "prefetch", 0)
with _get_cache(self.catalog.cache, prefetch, use_cache=self.cache) as _cache:
catalog = clone_catalog_with_cache(self.catalog, _cache)
try:
if workers:
if catalog.in_memory:
raise RuntimeError(

Check warning on line 430 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L430

Added line #L430 was not covered by tests
"In-memory databases cannot be used with "
"distributed processing."
)

from datachain.catalog.loader import get_distributed_class

distributor = get_distributed_class(min_task_size=self.min_task_size)
distributor(
self.udf,
self.catalog,
udf_table,
query,
workers,
processes,
udf_fields=udf_fields,
is_generator=self.is_generator,
use_partitioning=use_partitioning,
cache=self.cache,
)
elif processes:
# Parallel processing (faster for more CPU-heavy UDFs)
if self.catalog.in_memory:
raise RuntimeError(
"In-memory databases cannot be used with parallel processing."
)
udf_info: UdfInfo = {
"udf_data": filtered_cloudpickle_dumps(self.udf),
"catalog_init": self.catalog.get_init_params(),
"metastore_clone_params": self.catalog.metastore.clone_params(),
"warehouse_clone_params": self.catalog.warehouse.clone_params(),
"table": udf_table,
"query": query,
"udf_fields": udf_fields,
"batching": batching,
"processes": processes,
"is_generator": self.is_generator,
"cache": self.cache,
}

# Run the UDFDispatcher in another process to avoid needing
# if __name__ == '__main__': in user scripts
exec_cmd = get_datachain_executable()
cmd = [*exec_cmd, "internal-run-udf"]
envs = dict(os.environ)
envs.update({"PYTHONPATH": os.getcwd()})
process_data = filtered_cloudpickle_dumps(udf_info)

with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603
process.communicate(process_data)
if retval := process.poll():
raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}")
else:
# Otherwise process single-threaded (faster for smaller UDFs)
warehouse = self.catalog.warehouse

udf_inputs = batching(warehouse.dataset_select_paginated, query)
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.cache,
download_cb,
processed_cb,
from datachain.catalog.loader import get_distributed_class

Check warning on line 435 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L435

Added line #L435 was not covered by tests

distributor = get_distributed_class(

Check warning on line 437 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L437

Added line #L437 was not covered by tests
min_task_size=self.min_task_size
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
distributor(

Check warning on line 440 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L440

Added line #L440 was not covered by tests
self.udf,
cb=generated_cb,
catalog,
udf_table,
query,
workers,
processes,
udf_fields=udf_fields,
is_generator=self.is_generator,
use_partitioning=use_partitioning,
cache=self.cache,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()

except QueryScriptCancelError:
self.catalog.warehouse.close()
sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
except (Exception, KeyboardInterrupt):
# Close any open database connections if an error is encountered
self.catalog.warehouse.close()
raise
elif processes:
# Parallel processing (faster for more CPU-heavy UDFs)
if catalog.in_memory:
raise RuntimeError(
"In-memory databases cannot be used "
"with parallel processing."
)
udf_info: UdfInfo = {
"udf_data": filtered_cloudpickle_dumps(self.udf),
"catalog_init": catalog.get_init_params(),
"metastore_clone_params": catalog.metastore.clone_params(),
"warehouse_clone_params": catalog.warehouse.clone_params(),
"table": udf_table,
"query": query,
"udf_fields": udf_fields,
"batching": batching,
"processes": processes,
"is_generator": self.is_generator,
"cache": self.cache,
}

# Run the UDFDispatcher in another process to avoid needing
# if __name__ == '__main__': in user scripts
exec_cmd = get_datachain_executable()
cmd = [*exec_cmd, "internal-run-udf"]
envs = dict(os.environ)
envs.update({"PYTHONPATH": os.getcwd()})
process_data = filtered_cloudpickle_dumps(udf_info)

with subprocess.Popen( # noqa: S603
cmd, env=envs, stdin=subprocess.PIPE
) as process:
process.communicate(process_data)
if retval := process.poll():
raise RuntimeError(
f"UDF Execution Failed! Exit code: {retval}"
)
else:
# Otherwise process single-threaded (faster for smaller UDFs)
warehouse = catalog.warehouse

udf_inputs = batching(warehouse.dataset_select_paginated, query)
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)

try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
catalog,
self.cache,
download_cb,
processed_cb,
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
self.udf,
cb=generated_cb,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()

except QueryScriptCancelError:
self.catalog.warehouse.close()
sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)

Check warning on line 521 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L520-L521

Added lines #L520 - L521 were not covered by tests
except (Exception, KeyboardInterrupt):
# Close any open database connections if an error is encountered
self.catalog.warehouse.close()
raise

def create_partitions_table(self, query: Select) -> "Table":
"""
Expand Down
Loading

0 comments on commit a99007a

Please sign in to comment.