From 646683c65a9a235e3639311598ee529fe53d1442 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 7 Jan 2025 17:53:33 +0000 Subject: [PATCH] [Feature] TensorDict.clamp ghstack-source-id: 44f0937c195d969055de10709402af7c4473df32 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1165 --- tensordict/_td.py | 16 ++++----- tensordict/base.py | 75 ++++++++++++++++++++++++++++++++++++++--- test/test_tensordict.py | 12 +++++++ 3 files changed, 91 insertions(+), 12 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 7a7364d4b..ccc5ffbe3 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -611,7 +611,7 @@ def _quick_set(swap_dict, swap_td): else: return TensorDict._new_unsafe(_swap, batch_size=torch.Size(())) - def __ne__(self, other: object) -> T | bool: + def __ne__(self, other: Any) -> T | bool: if is_tensorclass(other): return other != self if isinstance(other, (dict,)): @@ -635,7 +635,7 @@ def __ne__(self, other: object) -> T | bool: ) return True - def __xor__(self, other: object) -> T | bool: + def __xor__(self, other: Any) -> T | bool: if is_tensorclass(other): return other ^ self if isinstance(other, (dict,)): @@ -659,7 +659,7 @@ def __xor__(self, other: object) -> T | bool: ) return True - def __or__(self, other: object) -> T | bool: + def __or__(self, other: Any) -> T | bool: if is_tensorclass(other): return other | self if isinstance(other, (dict,)): @@ -683,7 +683,7 @@ def __or__(self, other: object) -> T | bool: ) return False - def __eq__(self, other: object) -> T | bool: + def __eq__(self, other: Any) -> T | bool: if is_tensorclass(other): return other == self if isinstance(other, (dict,)): @@ -705,7 +705,7 @@ def __eq__(self, other: object) -> T | bool: ) return False - def __ge__(self, other: object) -> T | bool: + def __ge__(self, other: Any) -> T | bool: if is_tensorclass(other): return other <= self if isinstance(other, (dict,)): @@ -727,7 +727,7 @@ def __ge__(self, other: object) -> T | bool: ) return False - def __gt__(self, other: object) -> T | bool: + def __gt__(self, other: Any) -> T | bool: if is_tensorclass(other): return other < self if isinstance(other, (dict,)): @@ -749,7 +749,7 @@ def __gt__(self, other: object) -> T | bool: ) return False - def __le__(self, other: object) -> T | bool: + def __le__(self, other: Any) -> T | bool: if is_tensorclass(other): return other >= self if isinstance(other, (dict,)): @@ -771,7 +771,7 @@ def __le__(self, other: object) -> T | bool: ) return False - def __lt__(self, other: object) -> T | bool: + def __lt__(self, other: Any) -> T | bool: if is_tensorclass(other): return other > self if isinstance(other, (dict,)): diff --git a/tensordict/base.py b/tensordict/base.py index 4efb362d8..5c0476560 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -82,6 +82,7 @@ convert_ellipsis_to_idx, DeviceType, erase_cache, + expand_as_right, implement_for, IndexType, infer_size_impl, @@ -10513,7 +10514,14 @@ def clamp_max_(self, other: TensorDictBase | torch.Tensor) -> T: else: vals = self._values_list(True, True) other_val = other - torch._foreach_clamp_max_(vals, other_val) + try: + torch._foreach_clamp_max_(vals, other_val) + except RuntimeError as err: + if "isDifferentiableType" in str(err): + raise RuntimeError( + "Attempted to execute _foreach_clamp_max_ with a differentiable tensor. " + "Use `td.apply(lambda x: x.clamp_max_(val)` instead." + ) return self def clamp_max( @@ -10547,7 +10555,14 @@ def clamp_max( keys = new_keys else: other_val = other - vals = torch._foreach_clamp_max(vals, other_val) + try: + vals = torch._foreach_clamp_max(vals, other_val) + except RuntimeError as err: + if "isDifferentiableType" in str(err): + raise RuntimeError( + "Attempted to execute _foreach_clamp_max with a differentiable tensor. " + "Use `td.apply(lambda x: x.clamp_max(val)` instead." + ) items = dict(zip(keys, vals)) def pop(name, val): @@ -10579,7 +10594,15 @@ def clamp_min_(self, other: TensorDictBase | torch.Tensor) -> T: else: vals = self._values_list(True, True) other_val = other - torch._foreach_clamp_min_(vals, other_val) + try: + torch._foreach_clamp_min_(vals, other_val) + except RuntimeError as err: + if "isDifferentiableType" in str(err): + raise RuntimeError( + "Attempted to execute _foreach_clamp_min_ with a differentiable tensor. " + "Use `td.apply(lambda x: x.clamp_min_(val)` instead." + ) + return self def clamp_min( @@ -10612,7 +10635,15 @@ def clamp_min( keys = new_keys else: other_val = other - vals = torch._foreach_clamp_min(vals, other_val) + try: + vals = torch._foreach_clamp_min(vals, other_val) + except RuntimeError as err: + if "isDifferentiableType" in str(err): + raise RuntimeError( + "Attempted to execute _foreach_clamp_min with a differentiable tensor. " + "Use `td.apply(lambda x: x.clamp_min(val)` instead." + ) + items = dict(zip(keys, vals)) def pop(name, val): @@ -10631,6 +10662,42 @@ def pop(name, val): result.update(items) return result + def clamp(self, min=None, max=None, *, out=None): # noqa: W605 + r"""Clamps all elements in :attr:`self` into the range `[` :attr:`min`, :attr:`max` `]`. + + Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + + .. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + + If :attr:`min` is ``None``, there is no lower bound. + Or, if :attr:`max` is ``None`` there is no upper bound. + + .. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + + """ + if min is None: + if out is not None: + raise ValueError( + "clamp() with min/max=None isn't implemented with specified output." + ) + return self.clamp_max(max) + if max is None: + if out is not None: + raise ValueError( + "clamp() with min/max=None isn't implemented with specified output." + ) + return self.clamp_min(min) + if out is None: + return self._fast_apply(lambda x: x.clamp(min, max)) + result = self._fast_apply( + lambda x, y: x.clamp(min, max, out=y), out, default=None + ) + with out.unlock_() if out.is_locked else contextlib.nullcontext(): + return out.update(result) + def pow_(self, other: TensorDictBase | torch.Tensor) -> T: """In-place version of :meth:`~.pow`. diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 22582441b..97191db16 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3995,6 +3995,18 @@ def test_chunk(self, td_name, device, dim, chunks): assert sum([_td.shape[dim] for _td in td_chunks]) == td.shape[dim] assert (torch.cat(td_chunks, dim) == td).all() + def test_clamp(self, td_name, device): + td = getattr(self, td_name)(device) + tdc = td.clamp(-1, 1) + assert (tdc <= 1).all() + assert (tdc >= -1).all() + if td.requires_grad: + td = td.detach() + tdc = td.clamp(None, 1) + assert (tdc <= 1).all() + tdc = td.clamp(-1) + assert (tdc >= -1).all() + def test_clear(self, td_name, device): td = getattr(self, td_name)(device) with td.unlock_():