From 2a735168f14d4afe05e838bca48da3a4e61582cc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Apr 2024 10:21:13 +0100 Subject: [PATCH] [Feature] Type casting for tensorclass (#735) --- tensordict/tensorclass.py | 91 +++++++++++++++++++++++++++++++++++---- tensordict/utils.py | 19 ++++++++ test/test_tensorclass.py | 65 +++++++++++++++++++++++++--- 3 files changed, 160 insertions(+), 15 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index bb28c9af5..e0b1e1454 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -17,6 +17,8 @@ import os import pickle import shutil + +import sys import warnings from copy import copy, deepcopy from dataclasses import dataclass @@ -47,12 +49,23 @@ IndexType, is_non_tensor, is_tensorclass, + KeyDependentDefaultDict, NestedKey, ) from torch import multiprocessing as mp, Tensor from torch.multiprocessing import Manager T = TypeVar("T", bound=TensorDictBase) +# We use an abstract AnyType instead of Any because Any isn't recognised as a type for python < 3.10 +major, minor = sys.version_info[:2] +if (major, minor) < (3, 11): + + class _AnyType: + def __subclasscheck__(self, subclass): + return False + +else: + _AnyType = Any # methods where non_tensordict data should be cleared in the return value _CLEAR_METADATA = {"all", "any"} @@ -77,7 +90,7 @@ } -def tensorclass(cls: T) -> T: +class tensorclass: """A decorator to create :obj:`tensorclass` classes. :obj:`tensorclass` classes are specialized :obj:`dataclass` instances that @@ -134,6 +147,24 @@ def tensorclass(cls: T) -> T: """ + def __new__(cls, autocast: bool = False): + if not isinstance(autocast, bool): + clz = autocast + self = super().__new__(cls) + self.__init__(autocast=False) + return self.__call__(clz) + return super().__new__(cls) + + def __init__(self, autocast: bool): + self.autocast = autocast + + def __call__(self, clz): + clz = _tensorclass(clz) + clz.autocast = self.autocast + return clz + + +def _tensorclass(cls: T) -> T: def __torch_function__( cls, func: Callable, @@ -368,6 +399,11 @@ def wrapper( return wrapper +_cast_funcs = KeyDependentDefaultDict(lambda cls: cls) +_cast_funcs[torch.Tensor] = torch.as_tensor +_cast_funcs[np.ndarray] = np.asarray + + def _get_type_hints(cls, with_locals=False): ####### # Set proper type annotations for autocasting to tensordict/tensorclass @@ -423,6 +459,10 @@ def get_parent_locals(cls, localns=localns): localns=localns, # globalns=globals(), ) + cls._type_hints = { + key: val if isinstance(val, type) else _AnyType + for key, val in cls._type_hints.items() + } except NameError: if not with_locals: return _get_type_hints(cls, with_locals=True) @@ -1057,19 +1097,22 @@ def _set( """ if isinstance(key, str): + cls = type(self) __dict__ = self.__dict__ if __dict__["_tensordict"].is_locked: raise RuntimeError(_LOCK_ERROR) if key in ("batch_size", "names", "device"): # handled by setattr return - expected_keys = self.__dataclass_fields__ + expected_keys = cls.__dataclass_fields__ if key not in expected_keys: raise AttributeError( f"Cannot set the attribute '{key}', expected attributes are {expected_keys}." ) - if isinstance(value, tuple(tensordict_lib.base._ACCEPTED_CLASSES)): + def set_tensor( + key=key, value=value, inplace=inplace, non_blocking=non_blocking + ): # Avoiding key clash, honoring the user input to assign tensor type data to the key if key in self._non_tensordict.keys(): if inplace: @@ -1079,18 +1122,48 @@ def _set( del self._non_tensordict[key] self._tensordict.set(key, value, inplace=inplace, non_blocking=non_blocking) return self - if isinstance(value, dict): - type_hints = self._type_hints + + def _is_castable(datatype): + return issubclass(datatype, (int, float, np.ndarray)) + + if cls.autocast: + type_hints = cls._type_hints if type_hints is not None: - target_cls = type_hints.get(key, None) - if isinstance(target_cls, type) and _is_tensor_collection(target_cls): + target_cls = type_hints.get(key, _AnyType) + else: + warnings.warn("type_hints are none, cannot perform auto-casting") + target_cls = _AnyType + + if isinstance(value, dict): + if _is_tensor_collection(target_cls): value = target_cls.from_dict(value) self._tensordict.set( key, value, inplace=inplace, non_blocking=non_blocking ) return self - else: - warnings.warn(self._set_dict_warn_msg) + elif type_hints is None: + warnings.warn(type(self)._set_dict_warn_msg) + elif ( + value is not None + and target_cls in tensordict_lib.base._ACCEPTED_CLASSES + ): + try: + if not issubclass(type(value), target_cls): + cast_val = _cast_funcs[target_cls](value) + else: + cast_val = value + except TypeError: + raise TypeError( + f"Failed to cast the value {key} to the type annotation {target_cls}." + ) + return set_tensor(value=cast_val) + elif value is not None and target_cls is not _AnyType: + value = _cast_funcs[target_cls](value) + elif target_cls is _AnyType and _is_castable(type(value)): + return set_tensor() + else: + if isinstance(value, tuple(tensordict_lib.base._ACCEPTED_CLASSES)): + return set_tensor() # Avoiding key clash, honoring the user input to assign non-tensor data to the key if key in self._tensordict.keys(): diff --git a/tensordict/utils.py b/tensordict/utils.py index bac125377..cb2afba19 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2226,3 +2226,22 @@ class Tensor: def dims(self, *args, **kwargs): raise ImportError("functorch.dim not found") + + +class KeyDependentDefaultDict(collections.defaultdict): + """A key-dependent default dict. + + Examples: + >>> my_dict = KeyDependentDefaultDict(lambda key: "foo_" + key) + >>> print(my_dict["bar"]) + foo_bar + """ + + def __init__(self, fun): + self.fun = fun + super().__init__() + + def __missing__(self, key): + value = self.fun(key) + self[key] = value + return value diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 7ec0844b2..505a8ffd6 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -17,6 +17,7 @@ from tempfile import TemporaryDirectory from typing import Any, Optional, Tuple, Union +import numpy as np import pytest import torch @@ -1019,7 +1020,7 @@ class MyDataParent: data.set("k", torch.zeros(3, 4, 5)) def test_set_dict(self): - @tensorclass + @tensorclass(autocast=True) class MyClass: x: torch.Tensor y: MyClass = None @@ -1881,7 +1882,7 @@ class TestClass: assert not (test_class == TestClass.from_dict(test_class3.to_dict())).all() -@tensorclass +@tensorclass(autocast=True) class AutoCast: tensor: torch.Tensor non_tensor: str @@ -1889,7 +1890,7 @@ class AutoCast: tc: AutoCast -@tensorclass +@tensorclass(autocast=True) class AutoCastOr: tensor: torch.Tensor non_tensor: str @@ -1897,7 +1898,7 @@ class AutoCastOr: tc: AutoCast | None = None -@tensorclass +@tensorclass(autocast=True) class AutoCastOptional: tensor: torch.Tensor non_tensor: str @@ -1906,8 +1907,18 @@ class AutoCastOptional: tc: Optional[AutoCast] = None +@tensorclass(autocast=True) +class AutoCastTensor: + tensor: torch.Tensor + integer: int + string: str + floating: float + numpy: np.ndarray + anything: Any + + class TestAutoCasting: - @tensorclass + @tensorclass(autocast=True) class ClsAutoCast: tensor: torch.Tensor non_tensor: str @@ -1915,6 +1926,48 @@ class ClsAutoCast: tc: "ClsAutoCast" # noqa: F821 tc_global: AutoCast + def test_autocast_attr(self): + @tensorclass(autocast=False) + class T: + X: torch.Tensor + + assert not T.autocast + + @tensorclass + class T: + X: torch.Tensor + + assert not T.autocast + + @tensorclass(autocast=True) + class T: + X: torch.Tensor + + assert T.autocast + + def test_autocast_simple(self): + obj = AutoCastTensor( + tensor=1, + integer=1, + string=1, + floating=1, + numpy=1, + anything=1, + ) + assert isinstance(obj.tensor, torch.Tensor) + assert isinstance(obj.integer, int) + assert isinstance(obj.string, str) + assert isinstance(obj.floating, float) + assert isinstance(obj.numpy, np.ndarray) + assert isinstance(obj.anything, torch.Tensor) + obj.tensor = 1.0 + assert isinstance(obj.tensor, torch.Tensor) + with pytest.raises(TypeError): + obj.tensor = "str" + obj.anything = 1.0 + assert isinstance(obj.anything, torch.Tensor) + obj.anything = "str" + def test_autocast(self): # Autocasting is implemented only for tensordict / tensorclasses. # Since some type annotations are not supported such as `Tensor | None`, @@ -2028,7 +2081,7 @@ def test_autocast_optional(self): assert obj.tc["tc"] is None def test_autocast_func(self): - @tensorclass + @tensorclass(autocast=True) class FuncAutoCast: tensor: torch.Tensor non_tensor: str