diff --git a/tensordict/persistent.py b/tensordict/persistent.py index cf75d3ecf..c3103c524 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -124,9 +124,7 @@ class PersistentTensorDict(TensorDictBase): """ - def __new__(cls, *args, **kwargs): - cls._td_dim_names = None - return super().__new__(cls, *args, **kwargs) + _td_dim_names = None def __init__( self, diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index b1a459e42..3a89e2ea0 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -354,17 +354,14 @@ class TensorDictBase(MutableMapping): ) KEY_ERROR = 'key "{}" not found in {} with ' "keys {}" - def __new__(cls, *args: Any, **kwargs: Any) -> T: - self = super().__new__(cls) - self._safe = kwargs.get("_safe", False) - self._lazy = kwargs.get("_lazy", False) - self._inplace_set = kwargs.get("_inplace_set", False) - self.is_meta = kwargs.get("is_meta", False) - self._is_locked = kwargs.get("_is_locked", False) - self._cache = None - self._last_op = None - self.__last_op_queue = None - return self + _safe = False + _lazy = False + _inplace_set = False + is_meta = False + _is_locked = False + _cache = None + _last_op = None + __last_op_queue = None def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() @@ -4077,26 +4074,11 @@ class TensorDict(TensorDictBase): """ - __slots__ = ( - "_tensordict", - "_batch_size", - "_is_shared", - "_is_memmap", - "_device", - "_is_locked", - "_td_dim_names", - "_lock_id", - "_locked_tensordicts", - "_cache", - "_last_op", - "__last_op_queue", - ) - - def __new__(cls, *args: Any, **kwargs: Any) -> TensorDict: - cls._is_shared = False - cls._is_memmap = False - cls._td_dim_names = None - return super().__new__(cls, *args, _safe=True, _lazy=False, **kwargs) + _is_shared = False + _is_memmap = False + _td_dim_names = None + _safe = True + _lazy = False def __init__( self, @@ -5001,11 +4983,12 @@ def _nested_keys( ) def __getstate__(self): - return { - slot: getattr(self, slot) - for slot in self.__slots__ - if slot not in ("_last_op", "_cache", "__last_op_queue") + result = { + key: val + for key, val in self.__dict__.items() + if key not in ("_last_op", "_cache", "__last_op_queue") } + return result def __setstate__(self, state): for slot, value in state.items(): @@ -5790,10 +5773,11 @@ class SubTensorDict(TensorDictBase): """ - def __new__(cls, *args: Any, **kwargs: Any) -> SubTensorDict: - cls._is_shared = False - cls._is_memmap = False - return super().__new__(cls, _safe=False, _lazy=True, _inplace_set=True) + _is_shared = False + _is_memmap = False + _safe = False + _lazy = True + _inplace_set = True def __init__( self, @@ -6414,9 +6398,9 @@ def __torch_function__( else: return super().__torch_function__(func, types, args, kwargs) - def __new__(cls, *args: Any, **kwargs: Any) -> LazyStackedTensorDict: - cls._td_dim_name = None - return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs) + _td_dim_name = None + _safe = False + _lazy = True def __init__( self, @@ -8162,8 +8146,8 @@ def _repr_exclusive_fields(self): class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" - def __new__(cls, *args: Any, **kwargs: Any) -> _CustomOpTensorDict: - return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs) + _safe = False + _lazy = True def __init__( self,