Skip to content

Commit

Permalink
[Feature] Add missing __torch_function__
Browse files Browse the repository at this point in the history
ghstack-source-id: 3dbefb4f5322a944664bbc2d29af7f862cb92342
Pull Request resolved: #1169
  • Loading branch information
vmoens committed Jan 9, 2025
1 parent b493178 commit bc6390c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ def _unbind(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]:
return td.unbind(*args, **kwargs)


@implements_for_td(torch.unflatten)
def _unflatten(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]:
return td.unflatten(*args, **kwargs)


@implements_for_td(torch.flatten)
def _flatten(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]:
return td.flatten(*args, **kwargs)


@implements_for_td(torch.transpose)
def _transpose(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]:
return td.transpose(*args, **kwargs)


@implements_for_td(torch.gather)
def _gather(
input: T,
Expand Down
2 changes: 2 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __subclasscheck__(self, subclass):
torch.cat: True,
torch.clone: True,
torch.empty_like: True,
torch.flatten: True,
torch.full_like: True,
torch.gather: True,
torch.ones_like: True,
Expand All @@ -114,6 +115,7 @@ def __subclasscheck__(self, subclass):
torch.squeeze: True,
torch.stack: True,
torch.unbind: True,
torch.unflatten: True,
torch.unsqueeze: True,
torch.zeros_like: True,
}
Expand Down

0 comments on commit bc6390c

Please sign in to comment.