-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
to_pytorch
: enable prefetching
#664
Conversation
# 653 did not really enable prefetching. Prefetch was only implemented for map(), so the example gave me a false impression that the prefetching was working, but it was not. Now, to_pytorch uses AsyncMapper to prefetch the data. The number of workers is set to 2 by default, but it can be changed by setting the `prefetch` in the settings. For me, this dropped the time to load the data by 90%, from ~300s to now ~35s.
Deploying datachain-documentation with Cloudflare Pages
|
@@ -31,6 +32,8 @@ def label_to_int(value: str, classes: list) -> int: | |||
|
|||
|
|||
class PytorchDataset(IterableDataset): | |||
prefetch: int = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping it same as the default.
datachain/src/datachain/lib/udf.py
Line 293 in 911c22f
prefetch: int = 2 |
total_rank, total_workers = self.get_rank_and_workers() | ||
def _rows_iter(self, total_rank: int, total_workers: int): | ||
catalog = self._get_catalog() | ||
session = Session("PyTorch", catalog=catalog) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to create a new session and a catalog here, since the AsyncMapper
runs this on a separate thread.
to_pytorch
: enable prefetching
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #664 +/- ##
==========================================
- Coverage 87.32% 87.32% -0.01%
==========================================
Files 113 113
Lines 10717 10727 +10
Branches 1469 1469
==========================================
+ Hits 9359 9367 +8
- Misses 985 986 +1
- Partials 373 374 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! 👍
#653 did not really enable prefetching. Prefetch was only implemented for
map()
, so the example gave me a false impression that the prefetching was working, but it was not.Now,
to_pytorch
usesAsyncMapper
to prefetch the data. The number of workers is set to 2 by default, but it can be changed by setting theprefetch
in the settings.For me, this dropped the time to download the data by 90%, from ~300s to now ~35s.
Closes #631.