-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
prefetch: use a separate temporary cache for prefetching #730
Changes from all commits
1ffd71b
a3a4322
1027708
07ba315
e085052
a4903c4
b0ea2f2
9a7510e
00966a0
84967b3
4e230de
4bd3b4b
d8f8f39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,14 @@ | |
Iterable, | ||
Iterator, | ||
) | ||
from concurrent.futures import ThreadPoolExecutor | ||
from concurrent.futures import ThreadPoolExecutor, wait | ||
from heapq import heappop, heappush | ||
from typing import Any, Callable, Generic, Optional, TypeVar | ||
|
||
from fsspec.asyn import get_loop | ||
|
||
from datachain.utils import safe_closing | ||
|
||
ASYNC_WORKERS = 20 | ||
|
||
InputT = TypeVar("InputT", contravariant=True) # noqa: PLC0105 | ||
|
@@ -56,6 +58,7 @@ def __init__( | |
self.pool = ThreadPoolExecutor(workers) | ||
self._tasks: set[asyncio.Task] = set() | ||
self._shutdown_producer = threading.Event() | ||
self._producer_is_shutdown = threading.Event() | ||
|
||
def start_task(self, coro: Coroutine) -> asyncio.Task: | ||
task = self.loop.create_task(coro) | ||
|
@@ -64,11 +67,16 @@ def start_task(self, coro: Coroutine) -> asyncio.Task: | |
return task | ||
|
||
def _produce(self) -> None: | ||
for item in self.iterable: | ||
if self._shutdown_producer.is_set(): | ||
return | ||
fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop) | ||
fut.result() # wait until the item is in the queue | ||
try: | ||
with safe_closing(self.iterable): | ||
for item in self.iterable: | ||
if self._shutdown_producer.is_set(): | ||
return | ||
coro = self.work_queue.put(item) | ||
fut = asyncio.run_coroutine_threadsafe(coro, self.loop) | ||
fut.result() # wait until the item is in the queue | ||
finally: | ||
self._producer_is_shutdown.set() | ||
|
||
async def produce(self) -> None: | ||
await self.to_thread(self._produce) | ||
|
@@ -179,6 +187,8 @@ def iterate(self, timeout=None) -> Generator[ResultT, None, None]: | |
self.shutdown_producer() | ||
if not async_run.done(): | ||
async_run.cancel() | ||
wait([async_run]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We could add a So, I have added Alternatively, we could add an |
||
self._producer_is_shutdown.wait() | ||
|
||
def __iter__(self): | ||
return self.iterate() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -269,10 +269,21 @@ | |
client = self._catalog.get_client(self.source) | ||
client.download(self, callback=self._download_cb) | ||
|
||
async def _prefetch(self) -> None: | ||
if self._caching_enabled: | ||
client = self._catalog.get_client(self.source) | ||
await client._download(self, callback=self._download_cb) | ||
async def _prefetch(self, download_cb: Optional["Callback"] = None) -> bool: | ||
from datachain.client.hf import HfClient | ||
|
||
if self._catalog is None: | ||
raise RuntimeError("cannot prefetch file because catalog is not setup") | ||
|
||
client = self._catalog.get_client(self.source) | ||
if client.protocol == HfClient.protocol: | ||
return False | ||
Comment on lines
+279
to
+280
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
await client._download(self, callback=download_cb or self._download_cb) | ||
self._set_stream( | ||
self._catalog, caching_enabled=True, download_cb=DEFAULT_CALLBACK | ||
) | ||
return True | ||
|
||
def get_local_path(self) -> Optional[str]: | ||
"""Return path to a file in a local cache. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
_produce()
is run vialoop.run_in_executor
. Since this is running in a different thread, and is a synchronous function,.cancel()
does not really cancel the task/function, and it'll keep running in the background.We do ask it to shut down, by setting
self._shutdown_producer
event, but this function is running in the background.We want to wait for the
self.iterable
to be closed, but that function may not run immediately when we.close()
themapper.iterate()
iterator.Eg:
In this case, the
iterable.close()
may be called much later thanmapper.iterate.close()
. We want to ensure the iterable gets closed whenmapper.iterate()
gets closed.So, for this, we are setting another event that gets set after
iterable.close()
gets called.