Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 13, 2025
1 parent 4713734 commit 98c0a85
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 83 deletions.
20 changes: 8 additions & 12 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def min(
)
if dim is not NO_DEFAULT and return_indices:
# Split the tensordict
from .return_types import min
from torch.return_types import min

values_dict = {}
indices_dict = {}
Expand All @@ -882,8 +882,7 @@ def min(
else:
indices_dict[key] = key[:-1]
return min(
*result.split_keys(values_dict, indices_dict),
batch_size=result.batch_size,
result.split_keys(values_dict, indices_dict)[:2],
)
return result

Expand Down Expand Up @@ -1006,7 +1005,7 @@ def max(
)
if dim is not NO_DEFAULT and return_indices:
# Split the tensordict
from .return_types import max
from torch.return_types import max

values_dict = {}
indices_dict = {}
Expand All @@ -1016,8 +1015,7 @@ def max(
else:
indices_dict[key] = key[:-1]
return max(
*result.split_keys(values_dict, indices_dict),
batch_size=result.batch_size,
result.split_keys(values_dict, indices_dict)[:2],
)
return result

Expand Down Expand Up @@ -1110,7 +1108,7 @@ def cummin(
return result
if dim is not NO_DEFAULT and return_indices:
# Split the tensordict
from .return_types import cummin
from torch.return_types import cummin

values_dict = {}
indices_dict = {}
Expand All @@ -1120,8 +1118,7 @@ def cummin(
else:
indices_dict[key] = key[:-1]
return cummin(
*result.split_keys(values_dict, indices_dict),
batch_size=result.batch_size,
result.split_keys(values_dict, indices_dict)[:2],
)
return result

Expand Down Expand Up @@ -1214,7 +1211,7 @@ def cummax(
return result
if dim is not NO_DEFAULT and return_indices:
# Split the tensordict
from .return_types import cummax
from torch.return_types import cummax

values_dict = {}
indices_dict = {}
Expand All @@ -1224,8 +1221,7 @@ def cummax(
else:
indices_dict[key] = key[:-1]
return cummax(
*result.split_keys(values_dict, indices_dict),
batch_size=result.batch_size,
result.split_keys(values_dict, indices_dict)[:2],
)
return result

Expand Down
73 changes: 33 additions & 40 deletions tensordict/return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,67 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings

from tensordict.tensorclass import tensorclass
from tensordict.tensordict import TensorDict


@tensorclass(shadow=True)
@tensorclass
class min:
"""A `min` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.min` operations."""

values: TensorDict
vals: TensorDict
indices: TensorDict

def __getitem__(self, item):
try:
return (self.values, self.indices)[item]
except IndexError:
raise IndexError(
f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s "
f"__getitem__ method API."
)
def __post_init__(self):
warnings.warn(
f"{type(self)}.min is deprecated and will be removed in v0.9. "
f"Use torch.return_types.min instead.",
category=DeprecationWarning,
)


@tensorclass(shadow=True)
@tensorclass
class max:
"""A `max` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.max` operations."""

values: TensorDict
vals: TensorDict
indices: TensorDict

def __getitem__(self, item):
try:
return (self.values, self.indices)[item]
except IndexError:
raise IndexError(
f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s "
f"__getitem__ method API."
)
def __post_init__(self):
warnings.warn(
f"{type(self)}.max is deprecated and will be removed in v0.9. "
f"Use torch.return_types.max instead.",
category=DeprecationWarning,
)


@tensorclass(shadow=True)
@tensorclass
class cummin:
"""A `cummin` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummin` operations."""

values: TensorDict
vals: TensorDict
indices: TensorDict

def __getitem__(self, item):
try:
return (self.values, self.indices)[item]
except IndexError:
raise IndexError(
f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s "
f"__getitem__ method API."
)
def __post_init__(self):
warnings.warn(
f"{type(self)}.cummin is deprecated and will be removed in v0.9. "
f"Use torch.return_types.cummin instead.",
category=DeprecationWarning,
)


@tensorclass(shadow=True)
@tensorclass
class cummax:
"""A `cummax` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummax` operations."""

values: TensorDict
vals: TensorDict
indices: TensorDict

def __getitem__(self, item):
try:
return (self.values, self.indices)[item]
except IndexError:
raise IndexError(
f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s "
f"__getitem__ method API."
)
def __post_init__(self):
warnings.warn(
f"{type(self)}.cummax is deprecated and will be removed in v0.9. "
f"Use torch.return_types.cummax instead.",
category=DeprecationWarning,
)
57 changes: 26 additions & 31 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5303,7 +5303,10 @@ def test_memmap_threads(self, td_name, device, use_dir, tmpdir, num_threads):
],
)
def test_min_max_cummin_cummax(self, td_name, device, dim, keepdim, return_indices):
import tensordict.return_types as return_types
def _get_td(v):
if not is_tensor_collection(v):
return v.values
return v

td = getattr(self, td_name)(device)
# min
Expand All @@ -5315,71 +5318,63 @@ def test_min_max_cummin_cummax(self, td_name, device, dim, keepdim, return_indic
if not return_indices and dim is not None:
assert_allclose_td(r, td.amin(dim=dim, keepdim=keepdim))
if return_indices:
assert is_tensorclass(r)
assert isinstance(r, return_types.min)
assert not r.vals.is_empty()
# assert is_tensorclass(r)
assert isinstance(r, torch.return_types.min)
assert not r.values.is_empty()
assert not r.indices.is_empty()
else:
assert not is_tensorclass(r)
if dim is None:
assert r.batch_size == ()
assert _get_td(r).batch_size == ()
elif keepdim:
s = list(td.batch_size)
s[dim] = 1
assert r.batch_size == tuple(s)
assert _get_td(r).batch_size == tuple(s)
else:
s = list(td.batch_size)
s.pop(dim)
assert r.batch_size == tuple(s)
assert _get_td(r).batch_size == tuple(s)

r = td.max(**kwargs)
if not return_indices and dim is not None:
assert_allclose_td(r, td.amax(dim=dim, keepdim=keepdim))
if return_indices:
assert is_tensorclass(r)
assert isinstance(r, return_types.max)
assert not r.vals.is_empty()
# assert is_tensorclass(r)
assert isinstance(r, torch.return_types.max)
assert not r.values.is_empty()
assert not r.indices.is_empty()
else:
assert not is_tensorclass(r)
if dim is None:
assert r.batch_size == ()
assert _get_td(r).batch_size == ()
elif keepdim:
s = list(td.batch_size)
s[dim] = 1
assert r.batch_size == tuple(s)
assert _get_td(r).batch_size == tuple(s)
else:
s = list(td.batch_size)
s.pop(dim)
assert r.batch_size == tuple(s)
assert _get_td(r).batch_size == tuple(s)
if dim is None:
return
kwargs.pop("keepdim")
r = td.cummin(**kwargs)
if return_indices:
assert is_tensorclass(r)
assert isinstance(r, return_types.cummin)
assert not r.vals.is_empty()
# assert is_tensorclass(r)
assert isinstance(r, torch.return_types.cummin)
assert not r.values.is_empty()
assert not r.indices.is_empty()
else:
assert not is_tensorclass(r)
if dim is None:
assert r.batch_size == ()
assert _get_td(r).batch_size == ()
else:
assert r.batch_size == td.batch_size
assert _get_td(r).batch_size == td.batch_size

r = td.cummax(**kwargs)
if return_indices:
assert is_tensorclass(r)
assert isinstance(r, return_types.cummax)
assert not r.vals.is_empty()
# assert is_tensorclass(r)
assert isinstance(r, torch.return_types.cummax)
assert not r.values.is_empty()
assert not r.indices.is_empty()
else:
assert not is_tensorclass(r)
if dim is None:
assert r.batch_size == ()
assert _get_td(r).batch_size == ()
else:
assert r.batch_size == td.batch_size
assert _get_td(r).batch_size == td.batch_size

@pytest.mark.parametrize("inplace", [False, True])
def test_named_apply(self, td_name, device, inplace):
Expand Down

0 comments on commit 98c0a85

Please sign in to comment.