Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] I think TensorDict doesn't work with pin_memory in a dataloader #679

Closed
2 tasks done
andersonbcdefg opened this issue Feb 18, 2024 · 8 comments
Closed
2 tasks done
Assignees
Labels
bug Something isn't working

Comments

@andersonbcdefg
Copy link

Describe the bug

It seems like the batch size goes missing when PyTorch attempts to pin it.

To Reproduce

Use a TensorDict as the dataset (or in my case, the TensorDict is inside a more complex IterableDataset class), and feed to a PyTorch dataloader with pin_memory=True. I think this happens because the memory-pinning function tries to create a new TensorDict and doesn't pass the batch size.

ValueError: Caught ValueError in pin memory thread for device 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/tensordict/_td.py", line 1234, in _parse_batch_size
    return torch.Size(batch_size)
TypeError: 'NoneType' object is not iterable

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/pin_memory.py", line 36, in do_one_step
    data = pin_memory(data, device)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/pin_memory.py", line 72, in pin_memory
    return type(data)([pin_memory(sample, device) for sample in data])  # type: ignore[call-arg]
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/pin_memory.py", line 72, in <listcomp>
    return type(data)([pin_memory(sample, device) for sample in data])  # type: ignore[call-arg]
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/pin_memory.py", line 62, in pin_memory
    return type(data)({k: pin_memory(sample, device) for k, sample in data.items()})  # type: ignore[call-arg]
  File "/usr/local/lib/python3.9/dist-packages/tensordict/_td.py", line 223, in __init__
    self._batch_size = self._parse_batch_size(source, batch_size)
  File "/usr/local/lib/python3.9/dist-packages/tensordict/_td.py", line 1240, in _parse_batch_size
    raise ValueError(
ValueError: batch size was not specified when creating the TensorDict instance and it could not be retrieved from source

Expected behavior

Pinning memory just works and doesn't cause an exception.

System info

Installed from pip, 0.3.0, used with NVIDIA A6000 and Torch 2.2, Python 3.9.16

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)

Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@andersonbcdefg andersonbcdefg added the bug Something isn't working label Feb 18, 2024
@vmoens
Copy link
Contributor

vmoens commented Feb 18, 2024

This is somewhat similar to huggingface/accelerate#2405
There are two things we can do here: 1. You could call tensordict.pin_memory within the collate_fn and (2) PyTorch should use PyTree within the dataloader pin_memory.
I will make a PR for (2)

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Mar 1, 2024
…memory` and `collate_fn` (#120553)

For the user-defined `Mapping` type, it may contain some metadata (e.g., pytorch/tensordict#679, #120195 (comment)). Simply use `type(mapping)({k: v for k, v in mapping.items()})` do not take this metadata into account. This PR uses `copy.copy(mapping)` to create a clone of the original collection and iteratively updates the elements in the cloned collection. This preserves the metadata in the original collection via `copy.copy(...)` rather than relying on the `__init__` method in the user-defined classes.

Reference:

- pytorch/tensordict#679
- #120195

Closes #120195

Pull Request resolved: #120553
Approved by: https://github.com/vmoens
@vmoens
Copy link
Contributor

vmoens commented Mar 6, 2024

This will now work on torch nightlies!

@vmoens vmoens closed this as completed Mar 6, 2024
@andersonbcdefg
Copy link
Author

andersonbcdefg commented Mar 6, 2024 via email

@SamGalanakis
Copy link

Hello @vmoens , having a possibly related issue. I have a custom collate function that returns a TensorDict and I have pin_memory = True on the Dataloader, I am seeing this warning:

/opt/conda/lib/python3.11/site-packages/tensordict/tensorclass.py:1108: UserWarning:

The method <bound method TensorDictBase.pin_memory of TensorDict(
    fields={
    },
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)> wasn't explicitly implemented for tensorclass. This fallback will be deprecated in future releases because it is inefficient and non-compilable. Please raise an issue in tensordict repo to support this method!

Any ideas?
I am using the latest nightly from: ghcr.io/pytorch/pytorch-nightly:2.5.0.dev20240806-cuda12.1-cudnn9-devel

@vmoens
Copy link
Contributor

vmoens commented Aug 6, 2024

I'll fix that thanks for reporting

@haithamkhedr
Copy link

I'll fix that thanks for reporting
@vmoens Is this fixed already? Thanks

@Mxbonn
Copy link
Contributor

Mxbonn commented Sep 20, 2024

I'll fix that thanks for reporting

any update on this?

@vmoens
Copy link
Contributor

vmoens commented Sep 20, 2024

I think it is! Trying to make a release asap with this and other fixes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants