Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] tensorclass select breaks non-selected keys #935

Closed
3 tasks done
egorchakov opened this issue Jul 31, 2024 · 8 comments · Fixed by #936
Closed
3 tasks done

[BUG] tensorclass select breaks non-selected keys #935

egorchakov opened this issue Jul 31, 2024 · 8 comments · Fixed by #936
Assignees
Labels
bug Something isn't working

Comments

@egorchakov
Copy link

Hey @vmoens, tensordict has been instrumental for our work and I'm a big fan! Just hit an issue upon upgrading to v0.5.0.

Describe the bug

Calling .select(...) on a tensorclass seems to cause weird behavior with non-selected keys getting None'd.

To Reproduce

import torch
from tensordict import tensorclass


@tensorclass
class Data:
    a: torch.Tensor
    b: torch.Tensor


d = Data(a=torch.ones(10), b=torch.ones(10))

print(f"d:\n{d}")
print("---------")
print(f"d.select('a'):\n{d.select("a")}")
print("---------")
print(f"d.b:\n{d.b}")
print("---------")
print(f"d:\n{d}")

output:

d:
Data(
    a=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
    b=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
---------
d.select('a'):
Data(
    a=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
    b=None,
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
---------
d.b:
None
---------
d:
Data(
    a=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
    b=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
    b=None,
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

System info

import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
0.5.0 1.26.4 3.12.4 (main, Jul 31 2024, 10:50:51) [GCC 11.4.0] linux 2.4.0+cu121

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@egorchakov egorchakov added the bug Something isn't working label Jul 31, 2024
@vmoens
Copy link
Contributor

vmoens commented Jul 31, 2024

That's expected behaviour: select just filters out everything you don't select.
Would you rather expect an attribute error?

@egorchakov
Copy link
Author

Hmm so is select default inplace now? I'm confused as to why calling d.select('a') would mutate d in a way that affects subsequent access to other keys (e.g. d.b). TensorDict seems to behave as expected in this case (vs tensorclass), without the dot notation obviously.

@vmoens
Copy link
Contributor

vmoens commented Jul 31, 2024

Oooh ok the problem is that it does it in place, gotcha
That's a bug!

@vmoens vmoens linked a pull request Jul 31, 2024 that will close this issue
@egorchakov
Copy link
Author

@vmoens tensorclass vs TensorDict still seem to behave somewhat inconsistently wrt select, which could be surprising: getting a non-selected key from a tensorclass returns None, while for a TensorDict that raises a KeyError, eg:

In [1]: import torch
   ...: from tensordict import TensorDict, tensorclass
   ...:
   ...:
   ...: @tensorclass
   ...: class TensorClass:
   ...:     a: torch.Tensor
   ...:     b: torch.Tensor
   ...:
   ...:
   ...: tc = TensorClass(a=torch.ones(10), b=torch.ones(10))
   ...: td = TensorDict(a=torch.ones(10), b=torch.ones(10))
   
In [2]: tc
Out[2]:
TensorClass(
    a=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
    b=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

In [3]: td
Out[3]:
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

In [4]: tc.select('a')
Out[4]:
TensorClass(
    a=Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
    b=None,
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

In [5]: td.select('a')
Out[5]:
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

In [6]: tc.select('a').get('b')

In [7]: td.select('a').get('b')
---------------------------------------------------------------------------
[...]
KeyError: 'key "b" not found in TensorDict with keys [\'a\']'

Personally I find the TensorDict behavior (KeyError) more appropriate in this case as it indicates select being more "view-like" but 🤷‍♂️

@vmoens
Copy link
Contributor

vmoens commented Aug 1, 2024

Yep here we need to figure out on which side of the fence this should fall.
@tensorclass is supposed to behave like @dataclass:

@dataclass
class A:
    a: int
    b: float
    
A(1)
Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
           ^^^^^^
  File "<input>", line 1, in <module>
TypeError: A.__init__() missing 1 required positional argument: 'b'

As you can see, if we were to put "a" and "b" and then delete "b" the behaviour would be poorly defined. So we assume that the other value is None.

@egorchakov
Copy link
Author

Ah I see now, I guess the semantics are supposed to be slightly different for @tensorclass vs TensorDict.

@vmoens
Copy link
Contributor

vmoens commented Aug 1, 2024

Tbh I agree with you, there's no select in data class so the question kind of is "what if there was?" Which is very open ended...

@egorchakov
Copy link
Author

hmm good point

not exactly sure why but I thought of @tensorclass select/exclude as something akin to pydantic's include/exclude, and pydantic models seem dataclass-ish

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants