Skip to content

Commit

Permalink
[Feature] Type casting for tensorclass (#735)
Browse files Browse the repository at this point in the history
vmoens authored Apr 24, 2024

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
1 parent 357a981 commit 2a73516
Showing 3 changed files with 160 additions and 15 deletions.
91 changes: 82 additions & 9 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@
import os
import pickle
import shutil

import sys
import warnings
from copy import copy, deepcopy
from dataclasses import dataclass
@@ -47,12 +49,23 @@
IndexType,
is_non_tensor,
is_tensorclass,
KeyDependentDefaultDict,
NestedKey,
)
from torch import multiprocessing as mp, Tensor
from torch.multiprocessing import Manager

T = TypeVar("T", bound=TensorDictBase)
# We use an abstract AnyType instead of Any because Any isn't recognised as a type for python < 3.10
major, minor = sys.version_info[:2]
if (major, minor) < (3, 11):

class _AnyType:
def __subclasscheck__(self, subclass):
return False

else:
_AnyType = Any

# methods where non_tensordict data should be cleared in the return value
_CLEAR_METADATA = {"all", "any"}
@@ -77,7 +90,7 @@
}


def tensorclass(cls: T) -> T:
class tensorclass:
"""A decorator to create :obj:`tensorclass` classes.
:obj:`tensorclass` classes are specialized :obj:`dataclass` instances that
@@ -134,6 +147,24 @@ def tensorclass(cls: T) -> T:
"""

def __new__(cls, autocast: bool = False):
if not isinstance(autocast, bool):
clz = autocast
self = super().__new__(cls)
self.__init__(autocast=False)
return self.__call__(clz)
return super().__new__(cls)

def __init__(self, autocast: bool):
self.autocast = autocast

def __call__(self, clz):
clz = _tensorclass(clz)
clz.autocast = self.autocast
return clz


def _tensorclass(cls: T) -> T:
def __torch_function__(
cls,
func: Callable,
@@ -368,6 +399,11 @@ def wrapper(
return wrapper


_cast_funcs = KeyDependentDefaultDict(lambda cls: cls)
_cast_funcs[torch.Tensor] = torch.as_tensor
_cast_funcs[np.ndarray] = np.asarray


def _get_type_hints(cls, with_locals=False):
#######
# Set proper type annotations for autocasting to tensordict/tensorclass
@@ -423,6 +459,10 @@ def get_parent_locals(cls, localns=localns):
localns=localns,
# globalns=globals(),
)
cls._type_hints = {
key: val if isinstance(val, type) else _AnyType
for key, val in cls._type_hints.items()
}
except NameError:
if not with_locals:
return _get_type_hints(cls, with_locals=True)
@@ -1057,19 +1097,22 @@ def _set(
"""
if isinstance(key, str):
cls = type(self)
__dict__ = self.__dict__
if __dict__["_tensordict"].is_locked:
raise RuntimeError(_LOCK_ERROR)
if key in ("batch_size", "names", "device"):
# handled by setattr
return
expected_keys = self.__dataclass_fields__
expected_keys = cls.__dataclass_fields__
if key not in expected_keys:
raise AttributeError(
f"Cannot set the attribute '{key}', expected attributes are {expected_keys}."
)

if isinstance(value, tuple(tensordict_lib.base._ACCEPTED_CLASSES)):
def set_tensor(
key=key, value=value, inplace=inplace, non_blocking=non_blocking
):
# Avoiding key clash, honoring the user input to assign tensor type data to the key
if key in self._non_tensordict.keys():
if inplace:
@@ -1079,18 +1122,48 @@ def _set(
del self._non_tensordict[key]
self._tensordict.set(key, value, inplace=inplace, non_blocking=non_blocking)
return self
if isinstance(value, dict):
type_hints = self._type_hints

def _is_castable(datatype):
return issubclass(datatype, (int, float, np.ndarray))

if cls.autocast:
type_hints = cls._type_hints
if type_hints is not None:
target_cls = type_hints.get(key, None)
if isinstance(target_cls, type) and _is_tensor_collection(target_cls):
target_cls = type_hints.get(key, _AnyType)
else:
warnings.warn("type_hints are none, cannot perform auto-casting")
target_cls = _AnyType

if isinstance(value, dict):
if _is_tensor_collection(target_cls):
value = target_cls.from_dict(value)
self._tensordict.set(
key, value, inplace=inplace, non_blocking=non_blocking
)
return self
else:
warnings.warn(self._set_dict_warn_msg)
elif type_hints is None:
warnings.warn(type(self)._set_dict_warn_msg)
elif (
value is not None
and target_cls in tensordict_lib.base._ACCEPTED_CLASSES
):
try:
if not issubclass(type(value), target_cls):
cast_val = _cast_funcs[target_cls](value)
else:
cast_val = value
except TypeError:
raise TypeError(
f"Failed to cast the value {key} to the type annotation {target_cls}."
)
return set_tensor(value=cast_val)
elif value is not None and target_cls is not _AnyType:
value = _cast_funcs[target_cls](value)
elif target_cls is _AnyType and _is_castable(type(value)):
return set_tensor()
else:
if isinstance(value, tuple(tensordict_lib.base._ACCEPTED_CLASSES)):
return set_tensor()

# Avoiding key clash, honoring the user input to assign non-tensor data to the key
if key in self._tensordict.keys():
19 changes: 19 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
@@ -2226,3 +2226,22 @@ class Tensor:

def dims(self, *args, **kwargs):
raise ImportError("functorch.dim not found")


class KeyDependentDefaultDict(collections.defaultdict):
"""A key-dependent default dict.
Examples:
>>> my_dict = KeyDependentDefaultDict(lambda key: "foo_" + key)
>>> print(my_dict["bar"])
foo_bar
"""

def __init__(self, fun):
self.fun = fun
super().__init__()

def __missing__(self, key):
value = self.fun(key)
self[key] = value
return value
65 changes: 59 additions & 6 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
from tempfile import TemporaryDirectory
from typing import Any, Optional, Tuple, Union

import numpy as np
import pytest
import torch

@@ -1019,7 +1020,7 @@ class MyDataParent:
data.set("k", torch.zeros(3, 4, 5))

def test_set_dict(self):
@tensorclass
@tensorclass(autocast=True)
class MyClass:
x: torch.Tensor
y: MyClass = None
@@ -1881,23 +1882,23 @@ class TestClass:
assert not (test_class == TestClass.from_dict(test_class3.to_dict())).all()


@tensorclass
@tensorclass(autocast=True)
class AutoCast:
tensor: torch.Tensor
non_tensor: str
td: TensorDict
tc: AutoCast


@tensorclass
@tensorclass(autocast=True)
class AutoCastOr:
tensor: torch.Tensor
non_tensor: str
td: TensorDict
tc: AutoCast | None = None


@tensorclass
@tensorclass(autocast=True)
class AutoCastOptional:
tensor: torch.Tensor
non_tensor: str
@@ -1906,15 +1907,67 @@ class AutoCastOptional:
tc: Optional[AutoCast] = None


@tensorclass(autocast=True)
class AutoCastTensor:
tensor: torch.Tensor
integer: int
string: str
floating: float
numpy: np.ndarray
anything: Any


class TestAutoCasting:
@tensorclass
@tensorclass(autocast=True)
class ClsAutoCast:
tensor: torch.Tensor
non_tensor: str
td: TensorDict
tc: "ClsAutoCast" # noqa: F821
tc_global: AutoCast

def test_autocast_attr(self):
@tensorclass(autocast=False)
class T:
X: torch.Tensor

assert not T.autocast

@tensorclass
class T:
X: torch.Tensor

assert not T.autocast

@tensorclass(autocast=True)
class T:
X: torch.Tensor

assert T.autocast

def test_autocast_simple(self):
obj = AutoCastTensor(
tensor=1,
integer=1,
string=1,
floating=1,
numpy=1,
anything=1,
)
assert isinstance(obj.tensor, torch.Tensor)
assert isinstance(obj.integer, int)
assert isinstance(obj.string, str)
assert isinstance(obj.floating, float)
assert isinstance(obj.numpy, np.ndarray)
assert isinstance(obj.anything, torch.Tensor)
obj.tensor = 1.0
assert isinstance(obj.tensor, torch.Tensor)
with pytest.raises(TypeError):
obj.tensor = "str"
obj.anything = 1.0
assert isinstance(obj.anything, torch.Tensor)
obj.anything = "str"

def test_autocast(self):
# Autocasting is implemented only for tensordict / tensorclasses.
# Since some type annotations are not supported such as `Tensor | None`,
@@ -2028,7 +2081,7 @@ def test_autocast_optional(self):
assert obj.tc["tc"] is None

def test_autocast_func(self):
@tensorclass
@tensorclass(autocast=True)
class FuncAutoCast:
tensor: torch.Tensor
non_tensor: str

0 comments on commit 2a73516

Please sign in to comment.