Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 10, 2025
1 parent 47bddb7 commit 8cf88d1
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,8 +917,14 @@ def __torch_function__(
setattr(cls, method_name, _wrap_td_method(method_name))
for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP:
if not hasattr(cls, method_name) and method_name not in expected_keys:
is_property = isinstance(getattr(TensorDictBase, method_name, None), property)
setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True, is_property=is_property))
is_property = isinstance(
getattr(TensorDictBase, method_name, None), property
)
setattr(
cls,
method_name,
_wrap_td_method(method_name, no_wrap=True, is_property=is_property),
)

for method_name in _FALLBACK_METHOD_FROM_TD_COPY:
if not hasattr(cls, method_name):
Expand Down Expand Up @@ -1535,7 +1541,9 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417
return wrapper


def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False, is_property=False):
def _wrap_td_method(
funcname, *, copy_non_tensor=False, no_wrap=False, is_property=False
):
def deliver_result(self, result, kwargs):
if result is None:
return
Expand All @@ -1554,6 +1562,7 @@ def deliver_result(self, result, kwargs):
return result

if not is_property:

def wrapped_func(self, *args, **kwargs):
if not is_compiling():
td = super(type(self), self).__getattribute__("_tensordict")
Expand Down

0 comments on commit 8cf88d1

Please sign in to comment.