Skip to content

Commit

Permalink
[Refactor] Add missing functions in tensorclass register
Browse files Browse the repository at this point in the history
ghstack-source-id: 48311d7a98a9895b10e5552e5b4a4f13764607e0
Pull Request resolved: #1153
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent efb89a6 commit 05881f3
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,41 @@ def __subclasscheck__(self, subclass):
_FALLBACK_METHOD_FROM_TD = [
"__abs__",
"__add__",
"__bool__",
"__eq__",
"__ge__",
"__gt__",
"__iadd__",
"__imul__",
"__ipow__",
"__isub__",
"__itruediv__",
"__mul__",
"__ne__",
"__or__",
"__pow__",
"__sub__",
"__truediv__",
"__xor__",
"_add_batch_dim",
"_apply_nest",
"_clone",
"_clone_recurse",
"_data",
"_erase_names", # TODO: must be specialized
"_exclude", # TODO: must be specialized
"_fast_apply",
"_flatten_keys_inplace",
"_flatten_keys_outplace",
"_get_sub_tensordict",
"_grad",
"_map",
"_maybe_remove_batch_dim",
"_memmap_",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild",
"_permute",
"_remove_batch_dim",
"_repeat",
"_select", # TODO: must be specialized
Expand Down Expand Up @@ -217,6 +235,8 @@ def __subclasscheck__(self, subclass):
"clamp_max_",
"clamp_min",
"clamp_min_",
"clear",
"clear_device_",
"consolidate",
"contiguous",
"copy_",
Expand Down Expand Up @@ -251,10 +271,8 @@ def __subclasscheck__(self, subclass):
"frac_",
"from_any",
"from_dataclass",
"to_namedtuple",
"from_namedtuple",
"from_pytree",
"to_pytree",
"gather",
"isfinite",
"isnan",
Expand All @@ -275,6 +293,8 @@ def __subclasscheck__(self, subclass):
"log_",
"map",
"map_iter",
"to_namedtuple",
"to_pytree",
"masked_fill",
"masked_fill_",
"max",
Expand Down

0 comments on commit 05881f3

Please sign in to comment.