Skip to content

Commit

Permalink
prefetch: disable for huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Jan 7, 2025
1 parent 1ffd71b commit a3a4322
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
24 changes: 20 additions & 4 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,26 @@ def ensure_cached(self) -> None:
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(self) -> None:
if self._caching_enabled:
client = self._catalog.get_client(self.source)
await client._download(self, callback=self._download_cb)
async def _prefetch(
self,
catalog: Optional["Catalog"] = None,
download_cb: Optional["Callback"] = None,
) -> bool:
from datachain.client.hf import HfClient

catalog = catalog or self._catalog
download_cb = download_cb or self._download_cb
if catalog is None:
raise RuntimeError("cannot prefetch file because catalog is not setup")

client = catalog.get_client(self.source)
if client.protocol == HfClient.protocol:
self._set_stream(catalog, self._caching_enabled, download_cb=download_cb)
return False

await client._download(self, callback=download_cb)
self._set_stream(catalog, caching_enabled=True) # reset download callback
return True

def get_local_path(self) -> Optional[str]:
"""Return path to a file in a local cache.
Expand Down
6 changes: 3 additions & 3 deletions src/datachain/lib/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ async def _prefetch_input(row: T, catalog: "Catalog", download_cb: Callback) ->

for obj in row:
if isinstance(obj, File):
obj._set_stream(catalog, True, download_cb)
await obj._prefetch()
callback()
prefetched = await obj._prefetch(catalog, download_cb)
if prefetched:
callback()
return row


Expand Down

0 comments on commit a3a4322

Please sign in to comment.