-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] tensorclass select breaks non-selected keys #935
Comments
That's expected behaviour: |
Hmm so is |
Oooh ok the problem is that it does it in place, gotcha |
@vmoens In [1]: import torch
...: from tensordict import TensorDict, tensorclass
...:
...:
...: @tensorclass
...: class TensorClass:
...: a: torch.Tensor
...: b: torch.Tensor
...:
...:
...: tc = TensorClass(a=torch.ones(10), b=torch.ones(10))
...: td = TensorDict(a=torch.ones(10), b=torch.ones(10))
In [2]: tc
Out[2]:
TensorClass(
a=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
batch_size=torch.Size([]),
device=None,
is_shared=False)
In [3]: td
Out[3]:
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
In [4]: tc.select('a')
Out[4]:
TensorClass(
a=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b=None,
batch_size=torch.Size([]),
device=None,
is_shared=False)
In [5]: td.select('a')
Out[5]:
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
In [6]: tc.select('a').get('b')
In [7]: td.select('a').get('b')
---------------------------------------------------------------------------
[...]
KeyError: 'key "b" not found in TensorDict with keys [\'a\']' Personally I find the |
Yep here we need to figure out on which side of the fence this should fall. @dataclass
class A:
a: int
b: float
A(1)
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevconsole.py", line 364, in runcode
coro = func()
^^^^^^
File "<input>", line 1, in <module>
TypeError: A.__init__() missing 1 required positional argument: 'b' As you can see, if we were to put "a" and "b" and then delete "b" the behaviour would be poorly defined. So we assume that the other value is |
Ah I see now, I guess the semantics are supposed to be slightly different for |
Tbh I agree with you, there's no select in data class so the question kind of is "what if there was?" Which is very open ended... |
hmm good point not exactly sure why but I thought of |
Hey @vmoens,
tensordict
has been instrumental for our work and I'm a big fan! Just hit an issue upon upgrading tov0.5.0
.Describe the bug
Calling
.select(...)
on atensorclass
seems to cause weird behavior with non-selected keys gettingNone
'd.To Reproduce
output:
System info
Checklist
The text was updated successfully, but these errors were encountered: