Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 3, 2023
1 parent 9865dec commit a3595f4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 47 deletions.
4 changes: 1 addition & 3 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ class PersistentTensorDict(TensorDictBase):
"""

def __new__(cls, *args, **kwargs):
cls._td_dim_names = None
return super().__new__(cls, *args, **kwargs)
_td_dim_names = None

def __init__(
self,
Expand Down
72 changes: 28 additions & 44 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,17 +354,14 @@ class TensorDictBase(MutableMapping):
)
KEY_ERROR = 'key "{}" not found in {} with ' "keys {}"

def __new__(cls, *args: Any, **kwargs: Any) -> T:
self = super().__new__(cls)
self._safe = kwargs.get("_safe", False)
self._lazy = kwargs.get("_lazy", False)
self._inplace_set = kwargs.get("_inplace_set", False)
self.is_meta = kwargs.get("is_meta", False)
self._is_locked = kwargs.get("_is_locked", False)
self._cache = None
self._last_op = None
self.__last_op_queue = None
return self
_safe = False
_lazy = False
_inplace_set = False
is_meta = False
_is_locked = False
_cache = None
_last_op = None
__last_op_queue = None

def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
Expand Down Expand Up @@ -4077,26 +4074,11 @@ class TensorDict(TensorDictBase):
"""

__slots__ = (
"_tensordict",
"_batch_size",
"_is_shared",
"_is_memmap",
"_device",
"_is_locked",
"_td_dim_names",
"_lock_id",
"_locked_tensordicts",
"_cache",
"_last_op",
"__last_op_queue",
)

def __new__(cls, *args: Any, **kwargs: Any) -> TensorDict:
cls._is_shared = False
cls._is_memmap = False
cls._td_dim_names = None
return super().__new__(cls, *args, _safe=True, _lazy=False, **kwargs)
_is_shared = False
_is_memmap = False
_td_dim_names = None
_safe = True
_lazy = False

def __init__(
self,
Expand Down Expand Up @@ -5001,11 +4983,12 @@ def _nested_keys(
)

def __getstate__(self):
return {
slot: getattr(self, slot)
for slot in self.__slots__
if slot not in ("_last_op", "_cache", "__last_op_queue")
result = {
key: val
for key, val in self.__dict__.items()
if key not in ("_last_op", "_cache", "__last_op_queue")
}
return result

def __setstate__(self, state):
for slot, value in state.items():
Expand Down Expand Up @@ -5790,10 +5773,11 @@ class SubTensorDict(TensorDictBase):
"""

def __new__(cls, *args: Any, **kwargs: Any) -> SubTensorDict:
cls._is_shared = False
cls._is_memmap = False
return super().__new__(cls, _safe=False, _lazy=True, _inplace_set=True)
_is_shared = False
_is_memmap = False
_safe = False
_lazy = True
_inplace_set = True

def __init__(
self,
Expand Down Expand Up @@ -6414,9 +6398,9 @@ def __torch_function__(
else:
return super().__torch_function__(func, types, args, kwargs)

def __new__(cls, *args: Any, **kwargs: Any) -> LazyStackedTensorDict:
cls._td_dim_name = None
return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs)
_td_dim_name = None
_safe = False
_lazy = True

def __init__(
self,
Expand Down Expand Up @@ -8162,8 +8146,8 @@ def _repr_exclusive_fields(self):
class _CustomOpTensorDict(TensorDictBase):
"""Encodes lazy operations on tensors contained in a TensorDict."""

def __new__(cls, *args: Any, **kwargs: Any) -> _CustomOpTensorDict:
return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs)
_safe = False
_lazy = True

def __init__(
self,
Expand Down

1 comment on commit a3595f4

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: a3595f4 Previous: 9865dec Ratio
benchmarks/common/common_ops_test.py::test_plain_set_stack_nested 3419.7137298629254 iter/sec (stddev: 0.0003991740607016415) 7102.635672018859 iter/sec (stddev: 0.000003490104581926313) 2.08
benchmarks/common/common_ops_test.py::test_items_stack_nested_locked 656.343150461759 iter/sec (stddev: 0.0003111802618047426) 1331.741397657034 iter/sec (stddev: 0.000030312527898759866) 2.03
benchmarks/common/common_ops_test.py::test_keys_stack_nested_locked 788.6443206913364 iter/sec (stddev: 0.0003061982749919703) 1626.0407392109466 iter/sec (stddev: 0.000012775485021631076) 2.06
benchmarks/common/common_ops_test.py::test_values_stack_nested_locked 955.063108522064 iter/sec (stddev: 0.0004743697892677344) 2064.2310412001884 iter/sec (stddev: 0.000012551582401382587) 2.16
benchmarks/common/common_ops_test.py::test_setitem_dim[int] 17219.6549339233 iter/sec (stddev: 0.000035878887277692905) 37081.704008478664 iter/sec (stddev: 0.0000019449480172614492) 2.15
benchmarks/common/common_ops_test.py::test_setitem_dim[slice_int] 8719.754485117723 iter/sec (stddev: 0.00006578131347708857) 19100.86607737998 iter/sec (stddev: 0.000007366987788507623) 2.19
benchmarks/common/common_ops_test.py::test_setitem_dim[tuple] 11436.732419311986 iter/sec (stddev: 0.0000488122632059882) 24645.2337086472 iter/sec (stddev: 0.000003389822100687616) 2.15
benchmarks/common/common_ops_test.py::test_unbind_speed_stack1 644461.6098652725 iter/sec (stddev: 0.00003056490032461228) 1670689.808577556 iter/sec (stddev: 2.3238120703907647e-8) 2.59
benchmarks/common/memmap_benchmarks_test.py::test_creation[device0] 1504.858172785846 iter/sec (stddev: 0.0006393396980101482) 3439.973939450101 iter/sec (stddev: 0.00009263139975944315) 2.29
benchmarks/common/memmap_benchmarks_test.py::test_creation_from_tensor 1422.4732601074566 iter/sec (stddev: 0.0003778617584150533) 3130.170212285269 iter/sec (stddev: 0.000013151245167648012) 2.20
benchmarks/common/memmap_benchmarks_test.py::test_add_one[memmap_tensor0] 14247.64631363357 iter/sec (stddev: 0.00007170766678251406) 39703.09957082932 iter/sec (stddev: 0.000005248389106894473) 2.79
benchmarks/common/memmap_benchmarks_test.py::test_contiguous[memmap_tensor0] 70766.9207073084 iter/sec (stddev: 0.00007593331906997162) 173922.50384782036 iter/sec (stddev: 5.62562376421342e-7) 2.46
benchmarks/common/memmap_benchmarks_test.py::test_stack[memmap_tensor0] 20414.372737045724 iter/sec (stddev: 0.00004487886673659763) 54588.99227440141 iter/sec (stddev: 0.0000016336264114157033) 2.67
benchmarks/common/memmap_benchmarks_test.py::test_memmaptd_index_astensor 465.69216765587333 iter/sec (stddev: 0.000585917326871396) 1090.650647915972 iter/sec (stddev: 0.000018371014635300354) 2.34
benchmarks/common/memmap_benchmarks_test.py::test_memmaptd_index_op 172.08119090954335 iter/sec (stddev: 0.0014806015309059626) 446.8523829858738 iter/sec (stddev: 0.00006864059888340528) 2.60
benchmarks/common/pytree_benchmarks_test.py::test_add_pytree 12358.153554517876 iter/sec (stddev: 0.00011198467990816379) 30554.01449889499 iter/sec (stddev: 0.000003242521013973305) 2.47
benchmarks/common/pytree_benchmarks_test.py::test_add_td 8460.561779317395 iter/sec (stddev: 0.0001071044659562639) 17428.648656598485 iter/sec (stddev: 0.0000033848856519431052) 2.06
benchmarks/nn/functional_benchmarks_test.py::test_exec_functorch 3086.258152160973 iter/sec (stddev: 0.00031168353309776514) 6717.08248817304 iter/sec (stddev: 0.000005958601452842365) 2.18
benchmarks/nn/functional_benchmarks_test.py::test_exec_td 2983.778321124548 iter/sec (stddev: 0.00011444980913672227) 6904.0807872751775 iter/sec (stddev: 0.000006460462486537992) 2.31
benchmarks/nn/functional_benchmarks_test.py::test_vmap_mlp_speed[True-True] 509.5662314569196 iter/sec (stddev: 0.0007016200346563893) 1109.127969972082 iter/sec (stddev: 0.000030517185825485492) 2.18
benchmarks/nn/functional_benchmarks_test.py::test_vmap_mlp_speed[True-False] 905.4351794999594 iter/sec (stddev: 0.0006324626564469274) 2118.7321959016936 iter/sec (stddev: 0.000019325758496973623) 2.34
benchmarks/nn/functional_benchmarks_test.py::test_vmap_mlp_speed[False-True] 521.1566809138526 iter/sec (stddev: 0.0011074708019783673) 1287.1211815540744 iter/sec (stddev: 0.000024838200972154112) 2.47
benchmarks/nn/functional_benchmarks_test.py::test_vmap_mlp_speed[False-False] 1132.0401142067333 iter/sec (stddev: 0.0004904316698756558) 2629.0395966063325 iter/sec (stddev: 0.000020510632639301753) 2.32

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.