Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TensorDict.expand() allocates no new memory making indexing along expanded dimensions point to the original object #1008

Closed
3 tasks done
pancan21 opened this issue Sep 23, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@pancan21
Copy link

pancan21 commented Sep 23, 2024

Describe the bug

If a tensor value is set by index after TensorDict.expand() was called, the operation will change all values along the expanded dimension instead of only applying to the value at the given index. The memory location of all entries along the expanded dimension is the same making it impossible to change values individually. This results in unintentional side effects.

To Reproduce

import torch
from tensordict import TensorDict

# create tensor dict with zeros
zeros = TensorDict({"a": torch.zeros((5,))})
assert (zeros["a"] == torch.zeros((5,))).all()

# expand the tensor dict
zeros_expanded = zeros.expand(10)
assert (zeros_expanded["a"] == torch.zeros((10,5,))).all()

# try to set a value with any method below
zeros_expanded["a"][0] = torch.ones((5,))
# zeros_expanded.set_at_("a", torch.ones((5,)), 1)

print(zeros_expanded["a"][2]) # Expected tensor([0., 0., 0., 0., 0.]) but got tensor([1., 1., 1., 1., 1.])

assert (zeros_expanded["a"][0] == torch.ones((5,))).all() # Passes
assert (zeros_expanded["a"][2:] == torch.zeros((5,))).all() # Fails
assert (hex(id(zeros_expanded["a"][0])) != hex(id(zeros_expanded["a"][1]))) # Fails

Expected behavior

I expect that changes are only happening at the index that is used to set the value and not all other indices

System info

Describe the characteristic of your environment:

  • tensordict==0.5.0
    via rye
  • torch==2.4.1
    via rye
    via tensordict
  • numpy==2.1.1
    via rye
    via tensordict
  • Python version 3.12.4
import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
# 0.5.0 2.1.1 3.12.4 (main, Jul 13 2024, 23:45:08) [Clang 17.0.6 ] linux 2.4.1+cu121

Reason and Possible fixes

As the memory location of the indexed elements seem to be same it is probably an issue of expand that does not allocate the memory properly and instead uses the same pointer for all elements.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@pancan21 pancan21 added the bug Something isn't working label Sep 23, 2024
@vmoens
Copy link
Contributor

vmoens commented Sep 23, 2024

Isn't that the expected behaviour of expand?
If you want a new memory allocation with expand, you should call clone.
It's the same with a regular tensor:

import torch
x = torch.randn(3)
y = x.expand(10, 3)
x.data_ptr()
Out[7]: 5628813248
y.data_ptr()
Out[8]: 5628813248

@pancan21
Copy link
Author

Thanks I didn't know that this is the default behavior from the tensordict documentation.
Maybe one can add the information from the torch documentation or link to it to make clear that its a shallow copy:
E.g. Returns a new view of the self tensor with singleton dimensions expanded to a larger size
...
Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0. Any dimension of size 1 can be expanded to an arbitrary value without allocating new memory.
from https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants