-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] I think TensorDict doesn't work with pin_memory in a dataloader #679
Comments
This is somewhat similar to huggingface/accelerate#2405 |
…memory` and `collate_fn` (#120553) For the user-defined `Mapping` type, it may contain some metadata (e.g., pytorch/tensordict#679, #120195 (comment)). Simply use `type(mapping)({k: v for k, v in mapping.items()})` do not take this metadata into account. This PR uses `copy.copy(mapping)` to create a clone of the original collection and iteratively updates the elements in the cloned collection. This preserves the metadata in the original collection via `copy.copy(...)` rather than relying on the `__init__` method in the user-defined classes. Reference: - pytorch/tensordict#679 - #120195 Closes #120195 Pull Request resolved: #120553 Approved by: https://github.com/vmoens
This will now work on torch nightlies! |
wooo thanks vincent!
…On Wed, Mar 6, 2024 at 5:53 AM Vincent Moens ***@***.***> wrote:
This will now work on torch nightlies!
—
Reply to this email directly, view it on GitHub
<#679 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AEDJ3R4VMXDYGAJR2VVBWILYW4NVVAVCNFSM6AAAAABDN2YAWWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOBQHEZDKMZWGE>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Hello @vmoens , having a possibly related issue. I have a custom collate function that returns a TensorDict and I have pin_memory = True on the Dataloader, I am seeing this warning: /opt/conda/lib/python3.11/site-packages/tensordict/tensorclass.py:1108: UserWarning:
Any ideas? |
I'll fix that thanks for reporting |
|
any update on this? |
I think it is! Trying to make a release asap with this and other fixes |
Describe the bug
It seems like the batch size goes missing when PyTorch attempts to pin it.
To Reproduce
Use a TensorDict as the dataset (or in my case, the TensorDict is inside a more complex IterableDataset class), and feed to a PyTorch dataloader with pin_memory=True. I think this happens because the memory-pinning function tries to create a new TensorDict and doesn't pass the batch size.
Expected behavior
Pinning memory just works and doesn't cause an exception.
System info
Installed from pip, 0.3.0, used with NVIDIA A6000 and Torch 2.2, Python 3.9.16
Describe the characteristic of your environment:
Checklist
The text was updated successfully, but these errors were encountered: