From 734d9f0b6c5320a5731ec75920e840a57aca49fa Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 19 Jul 2023 16:30:17 -0400 Subject: [PATCH 1/6] init --- tensordict/csrc/pybind.cpp | 3 + tensordict/csrc/utils.h | 30 +++ tensordict/tenosrstack.py | 385 +++++++++++++++++++++++++++++++++++++ test/test_tensorstack.py | 65 +++++++ 4 files changed, 483 insertions(+) create mode 100644 tensordict/tenosrstack.py create mode 100644 test/test_tensorstack.py diff --git a/tensordict/csrc/pybind.cpp b/tensordict/csrc/pybind.cpp index 4e31b629c..bcd09068f 100644 --- a/tensordict/csrc/pybind.cpp +++ b/tensordict/csrc/pybind.cpp @@ -5,6 +5,8 @@ #include #include +#include +#include #include @@ -18,4 +20,5 @@ PYBIND11_MODULE(_tensordict, m) { m.def("_unravel_key_to_tuple", &_unravel_key_to_tuple, py::arg("key")); m.def("unravel_key_list", py::overload_cast(&unravel_key_list), py::arg("keys")); m.def("unravel_key_list", py::overload_cast(&unravel_key_list), py::arg("keys")); + m.def("_populate_index", &_populate_index, "populate index function"); } diff --git a/tensordict/csrc/utils.h b/tensordict/csrc/utils.h index c8cbfd861..5fc85c6c3 100644 --- a/tensordict/csrc/utils.h +++ b/tensordict/csrc/utils.h @@ -5,6 +5,7 @@ #include #include +#include namespace py = pybind11; @@ -76,3 +77,32 @@ py::list unravel_key_list(const py::list& keys) { py::list unravel_key_list(const py::tuple& keys) { return unravel_key_list(py::list(keys)); } + +torch::Tensor _populate_index(torch::Tensor offsets, torch::Tensor offsets_cs, torch::Tensor index) { + int64_t total = 0; + for (int _cur = 0; _cur < index.numel(); ++_cur) { + int64_t loc = index[_cur].item(); + int64_t incr = offsets[loc].item(); + total += incr; + } + torch::Tensor out = torch::empty({total}, torch::dtype(torch::kLong)); + + int64_t* out_data = out.data_ptr(); + int64_t cur_offset; + int64_t count = -1; + int64_t maxcount = -1; + int64_t cur = -1; + int64_t n = offsets.numel(); + for (int i = 0; i < total; ++i) { + if (cur < n && count == maxcount) { + cur++; + count = -1; + int64_t cur_idx = index[cur].item(); + maxcount = offsets[cur_idx].item() - 1; + cur_offset = offsets_cs[cur_idx].item(); + } + count++; + out_data[i] = cur_offset + count; + } + return out; +} diff --git a/tensordict/tenosrstack.py b/tensordict/tenosrstack.py new file mode 100644 index 000000000..34eb09ed9 --- /dev/null +++ b/tensordict/tenosrstack.py @@ -0,0 +1,385 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from tensordict._tensordict import _populate_index + + +class _NestedSize(tuple): + def numel(self): + out = 0 + for elt in self: + if isinstance(elt, (tuple, torch.Size)): + out += elt.numel() + else: + out = max(out * elt, elt) + return out + + def __getitem__(self, index): + out = super().__getitem__(index) + if _NestedSize.is_nested(out): + return out + if isinstance(out, tuple): + return torch.Size(out) + return out + + @classmethod + def is_nested(cls, obj): + return isinstance(obj, tuple) and isinstance(obj[0], tuple) + + @classmethod + def broadcast_shape(cls, other_shape, shapes): + if not cls.is_nested(shapes): + new_shape = _NestedSize(torch.broadcast_shapes(other_shape, shapes)) + else: + new_shape = _NestedSize( + cls.broadcast_shape(other_shape, shape) for shape in shapes + ) + return new_shape + + @classmethod + def from_list(cls, nested_list): + if not isinstance(nested_list, list): + return nested_list + return cls([_NestedSize.from_list(elt) for elt in nested_list]) + + @classmethod + def refine_shapes(cls, first_neg_right_index, shapes, new_shape): + shapes = torch.tensor(shapes) + shapes = shapes.view(new_shape[:first_neg_right_index]) + shapes = shapes.tolist() + return _NestedSize.from_list(shapes) + + +def get_parent_class(f): + return f.__globals__.get(f.__qualname__.split(".")[0], None) + + +def _lazy_init(func): + name = "_" + func.__name__ + + def new_func(self): + if not hasattr(self, name): + self._init() + return getattr(self, name) + + def setter(self, value): + setattr(self, name, value) + + return property(new_func, setter) + + +def _copy_shapes(func): + def new_func(self, *args, **kwargs): + out = getattr(torch.Tensor, func.__name__)(self, *args, **kwargs) + _shapes = self._shapes + out._shapes = _shapes + return out + + return new_func + + +def _broadcast(func): + def new_func(self, other): + out = getattr(torch.Tensor, func.__name__)( + self.view(-1, *self._trailing_dims), other + )._flat + other_shape = getattr(other, "shape", torch.Size([])) + out._shapes = _NestedSize.broadcast_shape(other_shape, self._shapes) + return out + + return new_func + + +class TensorStack(torch.Tensor): + def __new__(cls, tensor, *, shapes=None): + return super().__new__(cls, tensor) + + def __init__(self, tensor, *, shapes=None): + super(TensorStack, self).__init__() + self._shapes = shapes + self._init() + + def _init(self): + self._unique_shape, self._common_shape = self._get_common_shape(self._shapes) + self._get_offsets_() + + @property + def _trailing_dims(self): + dims = [] + for i in reversed(self._common_shape): + if i >= 0: + dims.append(i) + else: + break + return torch.Size(reversed(dims)) + + def _get_offsets_(self): + offsets = tuple(shape.numel() for shape in self._shapes) + n = self._trailing_dims.numel() + offsets = torch.tensor([offset // n for offset in offsets]) + self._offsets = offsets + self._offsets_cs = torch.cumsum(torch.nn.functional.pad(offsets, [1, 0]), 0) + + def _get_common_shape(self, shapes): + common_shape = shapes[0] + if _NestedSize.is_nested(common_shape): + new_shapes = [] + for shape in shapes: + _unique_shape, _common_shape = self._get_common_shape(shape) + new_shapes.append(_common_shape) + # unique_shape = all(new_unique_shape) + return self._get_common_shape(new_shapes) + + for _shape in shapes[1:]: + if len(_shape) != len(common_shape): + raise RuntimeError + if _shape != common_shape: + unique_shape = False + common_shape = torch.Size( + [ + s if s == s_other else -1 + for s, s_other in zip(common_shape, _shape) + ] + ) + else: + unique_shape = all(s >= 0 for s in common_shape) + return unique_shape, torch.Size([len(shapes), *common_shape]) + + @classmethod + def from_tensors(cls, tensors): + if not len(tensors): + raise RuntimeError + shapes = _NestedSize( + _NestedSize(tensor.shape) + if not isinstance(tensor, TensorStack) + else tensor._shapes + for tensor in tensors + ) + return TensorStack( + torch.cat([tensor.view(-1) for tensor in tensors]), shapes=shapes + ) + + @property + def shape(self): + return self._common_shape + + def unstack(self): + return tuple( + super(TensorStack, self).__getitem__(slice(idx0, idx1)).view(shape) + for idx0, idx1, shape in zip( + self._offsets_cs[:-1], self._offsets_cs[1:], self._shapes + ) + ) + + @property + def _flat(self): + # represents the tensor as a flat one + return super().view(-1) + + @property + def _compact(self): + # represents the tensor with a compact structure (rightmost consistent dims-wise) + return super().view(-1, *self._trailing_dims) + + def view(self, *shape): + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + shape = shape[0] + if isinstance(shape, _NestedSize): + out = TensorStack(self.data, shapes=shape) + return out + n_trailing = len(self._trailing_dims) + if ( + self._trailing_dims + and shape[-n_trailing:] == self._trailing_dims + and shape[-n_trailing - 1] == -1 + ): + # eg, (4, 2, -1, 3, 5) -> (2, 2, 2, -1, 3, 5) + for i in range(-1, -self.ndim - 1, -1): + if self.shape[i] == -1: + break + first_neg_right_index = i + shapes = _NestedSize.refine_shapes( + first_neg_right_index, self._shapes, shape + ) + out = TensorStack(self.data, shapes=shapes) + return out + else: + return super().view(shape) + + @property + def ndim(self): + return len(self.shape) + + def ndimension(self): + return self.ndim + + def __getitem__(self, index): + if isinstance(index, (int,)): + shape = self._shapes[index] + idx0 = self._offsets_cs[index] + idx1 = self._offsets_cs[index + 1] + out = super(TensorStack, self._compact).__getitem__(slice(idx0, idx1))._flat + if isinstance(shape, torch.Size): + return torch.Tensor(out).view(shape) + else: + out._shapes = shape + return out + if isinstance(index, (slice,)): + index = range(*index.indices(self.shape[0])) + if isinstance( + index, + ( + range, + list, + ), + ): + index = torch.tensor(index) + if isinstance(index, torch.Tensor): + index_view = index.view(-1) + shape = _NestedSize([self._shapes[idx] for idx in index_view]) + index_view = _populate_index(self._offsets, self._offsets_cs, index_view) + out = super(TensorStack, self._compact).__getitem__(index_view)._flat + out._shapes = shape + if index.ndim > 1: + out = out.view(*index.shape, *out.shape[1:]) + return out + else: + raise NotImplementedError + + def __repr__(self): + return f"{self.__class__.__name__}(shape={self.shape}, dtype={self.dtype}, device={self.device})" + + @_copy_shapes + def to(self, *args, **kwargs): + ... + + @_copy_shapes + def cpu(self): + ... + + @_copy_shapes + def bool(self): + ... + + @_copy_shapes + def float(self): + ... + + @_copy_shapes + def double(self): + ... + + @_copy_shapes + def int(self): + ... + + @_copy_shapes + def cuda(self): + ... + + @_copy_shapes + def __neg__(self): + ... + + @_copy_shapes + def __abs__(self): + ... + + @_copy_shapes + def __inv__(self): + ... + + @_copy_shapes + def __invert__(self): + ... + + @_broadcast + def add(self, other): + ... + + @_broadcast + def div(self, other): + ... + + @_broadcast + def rdiv(self, other): + ... + + @_broadcast + def __add__(self, other): + ... + + @_broadcast + def __mod__(self, other): + ... + + @_broadcast + def __pow__(self, other): + ... + + @_broadcast + def __sub__(self, other): + ... + + @_broadcast + def __truediv__(self, other): + ... + + @_broadcast + def __eq__(self, other): + ... + + @_broadcast + def __ne__(self, other): + ... + + @_broadcast + def __div__(self, other): + ... + + @_broadcast + def __floordiv__(self, other): + ... + + @_broadcast + def __lt__(self, other): + ... + + @_broadcast + def __le__(self, other): + ... + + @_broadcast + def __ge__(self, other): + ... + + @_broadcast + def __gt__(self, other): + ... + + @_broadcast + def __rdiv__(self, other): + ... + + @_broadcast + def __mul__(self, other): + ... + + @_lazy_init + def _offsets(self): + ... + + @_lazy_init + def _offsets_cs(self): + ... + + @_lazy_init + def _unique_shape(self): + ... + + @_lazy_init + def _common_shape(self): + ... diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py new file mode 100644 index 000000000..1569d766c --- /dev/null +++ b/test/test_tensorstack.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch + +from tensordict.tenosrstack import TensorStack + + +@pytest.fixture +def _tensorstack(): + torch.manual_seed(0) + x = torch.randint(10, (3, 1, 5)) + y = torch.randint(10, (3, 2, 5)) + z = torch.randint(10, (3, 3, 5)) + t = TensorStack.from_tensors([x, y, z]) + return t, (x, y, z) + + +class TestTensorStack: + def test_indexing_int(self, _tensorstack): + t, (x, y, z) = _tensorstack + assert (t[0] == x).all() + assert (t[1] == y).all() + assert (t[2] == z).all() + + def test_indexing_slice(self, _tensorstack): + t, (x, y, z) = _tensorstack + + assert (t[:3][0] == x).all() + assert (t[:3][1] == y).all() + assert (t[:3][2] == z).all() + assert (t[-3:][0] == x).all() + assert (t[-3:][1] == y).all() + assert (t[-3:][2] == z).all() + assert (t[::-1][0] == z).all() + assert (t[::-1][1] == y).all() + assert (t[::-1][2] == x).all() + + def test_indexing_range(self, _tensorstack): + t, (x, y, z) = _tensorstack + assert (t[range(3)][0] == x).all() + assert (t[range(3)][1] == y).all() + assert (t[range(3)][2] == z).all() + assert (t[range(1, 3)][0] == y).all() + assert (t[range(1, 3)][1] == z).all() + + def test_indexing_tensor(self, _tensorstack): + t, (x, y, z) = _tensorstack + assert (t[torch.tensor([0, 2])][0] == x).all() + assert (t[torch.tensor([0, 2])][1] == z).all() + assert (t[torch.tensor([0, 2, 0, 2])][2] == x).all() + assert (t[torch.tensor([0, 2, 0, 2])][3] == z).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][0][0] == x).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][0][1] == z).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][1][0] == x).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][1][1] == z).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From ba6c424cc21ad6e89935fa665824d8c0fbe9a868 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 20 Jul 2023 16:30:22 -0400 Subject: [PATCH 2/6] amend --- tensordict/csrc/pybind.cpp | 1 + tensordict/csrc/utils.h | 29 ++- tensordict/tenosrstack.py | 385 ------------------------------ tensordict/tensorstack.py | 470 +++++++++++++++++++++++++++++++++++++ test/test_tensorstack.py | 63 ++++- 5 files changed, 545 insertions(+), 403 deletions(-) delete mode 100644 tensordict/tenosrstack.py create mode 100644 tensordict/tensorstack.py diff --git a/tensordict/csrc/pybind.cpp b/tensordict/csrc/pybind.cpp index bcd09068f..0a0fd6107 100644 --- a/tensordict/csrc/pybind.cpp +++ b/tensordict/csrc/pybind.cpp @@ -21,4 +21,5 @@ PYBIND11_MODULE(_tensordict, m) { m.def("unravel_key_list", py::overload_cast(&unravel_key_list), py::arg("keys")); m.def("unravel_key_list", py::overload_cast(&unravel_key_list), py::arg("keys")); m.def("_populate_index", &_populate_index, "populate index function"); + m.def("_as_shape", &_as_shape, "Converts a het shape to a shape with -1 for het dims."); } diff --git a/tensordict/csrc/utils.h b/tensordict/csrc/utils.h index 5fc85c6c3..7716ed2ee 100644 --- a/tensordict/csrc/utils.h +++ b/tensordict/csrc/utils.h @@ -78,13 +78,8 @@ py::list unravel_key_list(const py::tuple& keys) { return unravel_key_list(py::list(keys)); } -torch::Tensor _populate_index(torch::Tensor offsets, torch::Tensor offsets_cs, torch::Tensor index) { - int64_t total = 0; - for (int _cur = 0; _cur < index.numel(); ++_cur) { - int64_t loc = index[_cur].item(); - int64_t incr = offsets[loc].item(); - total += incr; - } +torch::Tensor _populate_index(torch::Tensor offsets, torch::Tensor offsets_cs) { + int64_t total = offsets.sum().item(); torch::Tensor out = torch::empty({total}, torch::dtype(torch::kLong)); int64_t* out_data = out.data_ptr(); @@ -97,12 +92,26 @@ torch::Tensor _populate_index(torch::Tensor offsets, torch::Tensor offsets_cs, t if (cur < n && count == maxcount) { cur++; count = -1; - int64_t cur_idx = index[cur].item(); - maxcount = offsets[cur_idx].item() - 1; - cur_offset = offsets_cs[cur_idx].item(); + maxcount = offsets[cur].item() - 1; + cur_offset = offsets_cs[cur].item(); } count++; out_data[i] = cur_offset + count; } return out; } +py::list _as_shape(torch::Tensor shape_tensor) { + torch::Tensor shape_tensor_view = shape_tensor.reshape({-1, shape_tensor.size(-1)}); + torch::Tensor out = shape_tensor_view[0].clone(); + torch::Tensor unique = (shape_tensor_view == out).all(0); + out.masked_fill_(torch::logical_not(unique), -1); + std::vector shape_vector(shape_tensor.sizes().begin(), shape_tensor.sizes().end() - 1); + // Extend 'shape_vector' with the values from 'out'. + auto out_accessor = out.accessor(); + for (int64_t i = 0; i < out_accessor.size(0); ++i) { + shape_vector.push_back(out_accessor[i]); + } + + py::list shape = py::cast(shape_vector); + return shape; +} diff --git a/tensordict/tenosrstack.py b/tensordict/tenosrstack.py deleted file mode 100644 index 34eb09ed9..000000000 --- a/tensordict/tenosrstack.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from tensordict._tensordict import _populate_index - - -class _NestedSize(tuple): - def numel(self): - out = 0 - for elt in self: - if isinstance(elt, (tuple, torch.Size)): - out += elt.numel() - else: - out = max(out * elt, elt) - return out - - def __getitem__(self, index): - out = super().__getitem__(index) - if _NestedSize.is_nested(out): - return out - if isinstance(out, tuple): - return torch.Size(out) - return out - - @classmethod - def is_nested(cls, obj): - return isinstance(obj, tuple) and isinstance(obj[0], tuple) - - @classmethod - def broadcast_shape(cls, other_shape, shapes): - if not cls.is_nested(shapes): - new_shape = _NestedSize(torch.broadcast_shapes(other_shape, shapes)) - else: - new_shape = _NestedSize( - cls.broadcast_shape(other_shape, shape) for shape in shapes - ) - return new_shape - - @classmethod - def from_list(cls, nested_list): - if not isinstance(nested_list, list): - return nested_list - return cls([_NestedSize.from_list(elt) for elt in nested_list]) - - @classmethod - def refine_shapes(cls, first_neg_right_index, shapes, new_shape): - shapes = torch.tensor(shapes) - shapes = shapes.view(new_shape[:first_neg_right_index]) - shapes = shapes.tolist() - return _NestedSize.from_list(shapes) - - -def get_parent_class(f): - return f.__globals__.get(f.__qualname__.split(".")[0], None) - - -def _lazy_init(func): - name = "_" + func.__name__ - - def new_func(self): - if not hasattr(self, name): - self._init() - return getattr(self, name) - - def setter(self, value): - setattr(self, name, value) - - return property(new_func, setter) - - -def _copy_shapes(func): - def new_func(self, *args, **kwargs): - out = getattr(torch.Tensor, func.__name__)(self, *args, **kwargs) - _shapes = self._shapes - out._shapes = _shapes - return out - - return new_func - - -def _broadcast(func): - def new_func(self, other): - out = getattr(torch.Tensor, func.__name__)( - self.view(-1, *self._trailing_dims), other - )._flat - other_shape = getattr(other, "shape", torch.Size([])) - out._shapes = _NestedSize.broadcast_shape(other_shape, self._shapes) - return out - - return new_func - - -class TensorStack(torch.Tensor): - def __new__(cls, tensor, *, shapes=None): - return super().__new__(cls, tensor) - - def __init__(self, tensor, *, shapes=None): - super(TensorStack, self).__init__() - self._shapes = shapes - self._init() - - def _init(self): - self._unique_shape, self._common_shape = self._get_common_shape(self._shapes) - self._get_offsets_() - - @property - def _trailing_dims(self): - dims = [] - for i in reversed(self._common_shape): - if i >= 0: - dims.append(i) - else: - break - return torch.Size(reversed(dims)) - - def _get_offsets_(self): - offsets = tuple(shape.numel() for shape in self._shapes) - n = self._trailing_dims.numel() - offsets = torch.tensor([offset // n for offset in offsets]) - self._offsets = offsets - self._offsets_cs = torch.cumsum(torch.nn.functional.pad(offsets, [1, 0]), 0) - - def _get_common_shape(self, shapes): - common_shape = shapes[0] - if _NestedSize.is_nested(common_shape): - new_shapes = [] - for shape in shapes: - _unique_shape, _common_shape = self._get_common_shape(shape) - new_shapes.append(_common_shape) - # unique_shape = all(new_unique_shape) - return self._get_common_shape(new_shapes) - - for _shape in shapes[1:]: - if len(_shape) != len(common_shape): - raise RuntimeError - if _shape != common_shape: - unique_shape = False - common_shape = torch.Size( - [ - s if s == s_other else -1 - for s, s_other in zip(common_shape, _shape) - ] - ) - else: - unique_shape = all(s >= 0 for s in common_shape) - return unique_shape, torch.Size([len(shapes), *common_shape]) - - @classmethod - def from_tensors(cls, tensors): - if not len(tensors): - raise RuntimeError - shapes = _NestedSize( - _NestedSize(tensor.shape) - if not isinstance(tensor, TensorStack) - else tensor._shapes - for tensor in tensors - ) - return TensorStack( - torch.cat([tensor.view(-1) for tensor in tensors]), shapes=shapes - ) - - @property - def shape(self): - return self._common_shape - - def unstack(self): - return tuple( - super(TensorStack, self).__getitem__(slice(idx0, idx1)).view(shape) - for idx0, idx1, shape in zip( - self._offsets_cs[:-1], self._offsets_cs[1:], self._shapes - ) - ) - - @property - def _flat(self): - # represents the tensor as a flat one - return super().view(-1) - - @property - def _compact(self): - # represents the tensor with a compact structure (rightmost consistent dims-wise) - return super().view(-1, *self._trailing_dims) - - def view(self, *shape): - if len(shape) == 1 and isinstance(shape[0], (tuple, list)): - shape = shape[0] - if isinstance(shape, _NestedSize): - out = TensorStack(self.data, shapes=shape) - return out - n_trailing = len(self._trailing_dims) - if ( - self._trailing_dims - and shape[-n_trailing:] == self._trailing_dims - and shape[-n_trailing - 1] == -1 - ): - # eg, (4, 2, -1, 3, 5) -> (2, 2, 2, -1, 3, 5) - for i in range(-1, -self.ndim - 1, -1): - if self.shape[i] == -1: - break - first_neg_right_index = i - shapes = _NestedSize.refine_shapes( - first_neg_right_index, self._shapes, shape - ) - out = TensorStack(self.data, shapes=shapes) - return out - else: - return super().view(shape) - - @property - def ndim(self): - return len(self.shape) - - def ndimension(self): - return self.ndim - - def __getitem__(self, index): - if isinstance(index, (int,)): - shape = self._shapes[index] - idx0 = self._offsets_cs[index] - idx1 = self._offsets_cs[index + 1] - out = super(TensorStack, self._compact).__getitem__(slice(idx0, idx1))._flat - if isinstance(shape, torch.Size): - return torch.Tensor(out).view(shape) - else: - out._shapes = shape - return out - if isinstance(index, (slice,)): - index = range(*index.indices(self.shape[0])) - if isinstance( - index, - ( - range, - list, - ), - ): - index = torch.tensor(index) - if isinstance(index, torch.Tensor): - index_view = index.view(-1) - shape = _NestedSize([self._shapes[idx] for idx in index_view]) - index_view = _populate_index(self._offsets, self._offsets_cs, index_view) - out = super(TensorStack, self._compact).__getitem__(index_view)._flat - out._shapes = shape - if index.ndim > 1: - out = out.view(*index.shape, *out.shape[1:]) - return out - else: - raise NotImplementedError - - def __repr__(self): - return f"{self.__class__.__name__}(shape={self.shape}, dtype={self.dtype}, device={self.device})" - - @_copy_shapes - def to(self, *args, **kwargs): - ... - - @_copy_shapes - def cpu(self): - ... - - @_copy_shapes - def bool(self): - ... - - @_copy_shapes - def float(self): - ... - - @_copy_shapes - def double(self): - ... - - @_copy_shapes - def int(self): - ... - - @_copy_shapes - def cuda(self): - ... - - @_copy_shapes - def __neg__(self): - ... - - @_copy_shapes - def __abs__(self): - ... - - @_copy_shapes - def __inv__(self): - ... - - @_copy_shapes - def __invert__(self): - ... - - @_broadcast - def add(self, other): - ... - - @_broadcast - def div(self, other): - ... - - @_broadcast - def rdiv(self, other): - ... - - @_broadcast - def __add__(self, other): - ... - - @_broadcast - def __mod__(self, other): - ... - - @_broadcast - def __pow__(self, other): - ... - - @_broadcast - def __sub__(self, other): - ... - - @_broadcast - def __truediv__(self, other): - ... - - @_broadcast - def __eq__(self, other): - ... - - @_broadcast - def __ne__(self, other): - ... - - @_broadcast - def __div__(self, other): - ... - - @_broadcast - def __floordiv__(self, other): - ... - - @_broadcast - def __lt__(self, other): - ... - - @_broadcast - def __le__(self, other): - ... - - @_broadcast - def __ge__(self, other): - ... - - @_broadcast - def __gt__(self, other): - ... - - @_broadcast - def __rdiv__(self, other): - ... - - @_broadcast - def __mul__(self, other): - ... - - @_lazy_init - def _offsets(self): - ... - - @_lazy_init - def _offsets_cs(self): - ... - - @_lazy_init - def _unique_shape(self): - ... - - @_lazy_init - def _common_shape(self): - ... diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py new file mode 100644 index 000000000..6ddfcf6cd --- /dev/null +++ b/tensordict/tensorstack.py @@ -0,0 +1,470 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict._tensordict import _as_shape, _populate_index +from torch import Tensor +from torch.utils._pytree import tree_flatten, tree_map + + +def _lazy_init(func): + """A caching helper.""" + name = "_" + func.__name__ + + def setter(self, value): + setattr(self, name, value) + + def new_func(self): + if not hasattr(self, name): + r = func(self) + setter(self, r) + return r + return getattr(self, name) + + return property(new_func, setter) + + +def _broadcast_shapes(*shapes): + """A modified version of torch.broadcast_shapes that accepts -1.""" + max_len = 0 + for shape in shapes: + if isinstance(shape, int): + if max_len < 1: + max_len = 1 + elif isinstance(shape, (tuple, list)): + s = len(shape) + if max_len < s: + max_len = s + result = [1] * max_len + for shape in shapes: + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, (tuple, list)): + for i in range(-1, -1 - len(shape), -1): + cur_shape = shape[i] + if cur_shape == -1: + cur_shape = None # in double we use None as placeholder, which equals nothing + if cur_shape == 1 or cur_shape == result[i]: + continue + if result[i] == -1: + # in this case, we consider this as het dim + continue + if result[i] != 1: + raise RuntimeError( + "Shape mismatch: objects cannot be broadcast to a single shape" + ) + result[i] = shape[i] + else: + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", + shape, + ) + return torch.Size(result) + + +class _NestedShape: + def __new__(cls, shapes): + # TODO: if a tensor with nume() == tensor.shape[-1], then return a regular tensor + if isinstance(shapes, _NestedShape): + return shapes + return super().__new__(cls) + + def __init__(self, shapes): + if not isinstance(shapes, torch.Tensor): + shapes = torch.tensor(shapes) + self._shapes = shapes + + @_lazy_init + def _offsets(self): + common_shape = self.common_shape + shapes = self._shapes + if common_shape: + shapes = shapes[..., : -len(common_shape)] + return shapes.prod(-1) + + @_lazy_init + def _offsets_cs(self): + common_shape = self.common_shape + shapes = self._shapes + if common_shape: + shapes = shapes[..., : -len(common_shape)] + cs = shapes.prod(-1).view(-1).cumsum(0) + cs_pad = torch.nn.functional.pad(cs[:-1], [1, 0]) + return torch.stack( + [ + cs_pad.view(shapes.shape[:-1]), + cs.view(shapes.shape[:-1]), + ] + ) + + def unfold(self): + """Converts the shape to the maximum-indexable format. + + Examples: + >>> ns = _NestedShape(([11, 2, 3], [11, 5, 3])) + >>> print(ns.batch_dim) + torch.Size([2]) + >>> print(ns.unfold().batch_dim) + torch.Size([2, 11]) + """ + out = _NestedShape(self._shapes.clone()) + is_unique, val = out.is_unique(out._shapes.ndim - 1) + while is_unique: + out._shapes = ( + out._shapes[..., 1:] + .unsqueeze(-2) + .expand(*out._shapes.shape[:-1], val, -1) + ) + is_unique, val = out.is_unique(out._shapes.ndim - 1) + return out + + @_lazy_init + def ndim(self): + return self._shapes.ndim - 1 + self._shapes.shape[-1] + + def is_unique(self, dim): + if dim < 0: + dim = self.ndim + dim + if dim < 0 or dim >= self.ndim: + raise RuntimeError + if dim < self._shapes.ndim - 1: + return (True, self._shapes.shape[dim]) + v = self.as_shape[dim - self._shapes.ndim + 1] + return v != -1, v + + @_lazy_init + def het_dims(self): + return [dim for dim in range(self.ndim) if self.as_shape[dim] == -1] + + def numel(self): + return ( + self._offsets_cs[(1,) + (-1,) * (self._shapes.ndim - 1)] + * self.common_shape.numel() + ) + + @property + def batch_dim(self): + return self._shapes.shape[:-1] + + @_lazy_init + def common_shape(self): + shape = [] + + for v in reversed(self.as_shape): + if v != -1: + shape.append(v) + else: + break + return torch.Size(reversed(shape)) + + @classmethod + def broadcast_shape(cls, shape: torch.Size, nested_shape: _NestedShape): + broadcast_shape = _broadcast_shapes(shape, nested_shape.as_shape) + return nested_shape.expand(broadcast_shape) + + def expand(self, *broadcast_shape): + if len(broadcast_shape) == 1 and isinstance(broadcast_shape[0], (tuple, list)): + broadcast_shape = broadcast_shape[0] + as_shape = self.as_shape + if len(broadcast_shape) == len(as_shape) and all( + s1 == s2 or s1 == -1 or s2 == -1 + for (s1, s2) in zip(broadcast_shape, as_shape) + ): + return self + + # trailing dims, ie dims that are registered + broadcast_shape_trailing = broadcast_shape[self._shapes.shape[-1] :] + broadcast_shape_trailing = _broadcast_shapes( + broadcast_shape_trailing, as_shape[len(self.batch_dim) :] + ) + # replace trailing dims + shapes = self._shapes.clone() + for i in range(-1, len(broadcast_shape_trailing) - 1, -1): + if as_shape[i] != -1: + shapes[..., i] = broadcast_shape_trailing[i] + + # leading dims, ie dims that are not explicitely registered + broadcast_shape_leading = broadcast_shape[: -self.ndim] + + # find first -1 in broadcast_shape + if not len(broadcast_shape_leading): + return _NestedShape(shapes) + + return _NestedShape( + shapes.expand(*broadcast_shape_leading, *self._shapes.shape) + ) + + @_lazy_init + def is_plain(self): + return not self.het_dims + + def __getitem__(self, item): + try: + return _NestedShape(self._shapes[item]) + except IndexError as err: + if "too many indices" in str(err): + raise IndexError( + "Cannot index along dimensions on the right of the heterogeneous dimension." + ) + + @_lazy_init + def as_shape(self): + shape_cpp = torch.Size(_as_shape(self._shapes)) + return shape_cpp + # first_shape = self._shapes[(0,) * (self._shapes.ndim - 1)].clone() + # unique = (self._shapes == first_shape).view(-1, self._shapes.shape[-1]).all(0) + # first_shape[~unique] = -1 + # shape = list(self._shapes.shape[:-1]) + list(first_shape) + # return torch.Size(shape) + + def __repr__(self): + return str(self.as_shape) + + def __eq__(self, other): + return (self._shapes == other).all() + + def __ne__(self, other): + return (self._shapes != other).any() + + +def get_parent_class(f): + return f.__globals__.get(f.__qualname__.split(".")[0], None) + + +def _copy_shapes(func): + def new_func(self, *args, **kwargs): + out = getattr(torch.Tensor, func.__name__)(self, *args, **kwargs) + _shapes = self._shapes + out._shapes = _shapes + return out + + return new_func + + +def _broadcast(func): + def new_func(self, other): + other_shape = getattr(other, "shape", torch.Size([])) + shapes = _NestedShape.broadcast_shape(other_shape, self._shapes) + compact = self._compact + if isinstance(other, TensorStack): + other = other._compact + if shapes != self._shapes: + raise RuntimeError("broadcast between TensorStack not implemented yet.") + # other = other.unsqueeze(-2) + elif isinstance(other, Tensor) and shapes != self._shapes: + # we need to squash + other = other.reshape( + *other.shape[: -self.ndim], -1, *other.shape[-len(compact.shape[1:]) :] + ) + out = getattr(torch.Tensor, func.__name__)(compact, other) + return TensorStack(out, shapes=shapes) + + return new_func + + +class TensorStack(torch.Tensor): + def __new__(cls, tensor, *, shapes): + if shapes.is_plain: + return tensor.reshape(shapes.as_shape) + return super().__new__(cls, tensor) + + def __init__(self, tensor, *, shapes, unfold=False): + super(TensorStack, self).__init__() + if not isinstance(shapes, _NestedShape): + raise ValueError("shapes must be a _NestedShape instance") + if unfold: + shapes = shapes.unfold() + self._shapes = shapes + + @classmethod + def from_tensors(cls, tensors): + if not len(tensors): + raise RuntimeError + shapes = _NestedShape(tree_map(lambda x: x.shape, tensors)) + return TensorStack( + torch.cat([t.view(-1) for t in tree_flatten(tensors)[0]]), shapes=shapes + ) + + def numel(self): + return self._shapes.numel() + + @property + def shape(self): + return self._shapes.as_shape + + def unstack(self): + raise NotImplementedError + + @property + def _flat(self): + # represents the tensor as a flat one + return super().view(-1) + + @property + def _compact(self): + # represents the tensor with a compact structure (rightmost consistent dims-wise) + return torch.Tensor(super().view(-1, *self._shapes.common_shape)) + + def view(self, *shape): + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + shape = shape[0] + if isinstance(shape, _NestedShape): + if shape.numel() != self.numel(): + raise ValueError + out = TensorStack(self, shapes=shape) + return out + if len(shape) == 1 and shape[0] == -1: + return self._flat + n = self.numel() + shape = torch.Size(shape) + common_shape = self._shapes.common_shape + compact_shape = torch.Size([n // common_shape.numel(), *common_shape]) + if shape in (torch.Size([-1, *common_shape]), compact_shape): + return self._compact + raise RuntimeError(shape) + + @property + def ndim(self): + return len(self.shape) + + def ndimension(self): + return self.ndim + + def __getitem__(self, index): + if isinstance(index, (int,)): + idx_beg = self._shapes._offsets_cs[0, index] + idx_end = self._shapes._offsets_cs[1, index] + shapes = self._shapes[index] + out = self._compact.__getitem__(slice(idx_beg, idx_end)) + out = TensorStack(out, shapes=shapes) + return out + shapes = self._shapes[index] + # TODO: capture wrong indexing + elts = _populate_index( + self._shapes._offsets[index].view(-1), + self._shapes._offsets_cs[0][index].view(-1), + ) + tensor = self._compact[elts] + return TensorStack(tensor, shapes=shapes) + + def __repr__(self): + return f"{self.__class__.__name__}(shape={self.shape}, dtype={self.dtype}, device={self.device})" + + @_copy_shapes + def to(self, *args, **kwargs): + ... + + @_copy_shapes + def cpu(self): + ... + + @_copy_shapes + def bool(self): + ... + + @_copy_shapes + def float(self): + ... + + @_copy_shapes + def double(self): + ... + + @_copy_shapes + def int(self): + ... + + @_copy_shapes + def cuda(self): + ... + + @_copy_shapes + def __neg__(self): + ... + + @_copy_shapes + def __abs__(self): + ... + + @_copy_shapes + def __inv__(self): + ... + + @_copy_shapes + def __invert__(self): + ... + + @_broadcast + def add(self, other): + ... + + @_broadcast + def div(self, other): + ... + + @_broadcast + def rdiv(self, other): + ... + + @_broadcast + def __add__(self, other): + ... + + @_broadcast + def __mod__(self, other): + ... + + @_broadcast + def __pow__(self, other): + ... + + @_broadcast + def __sub__(self, other): + ... + + @_broadcast + def __truediv__(self, other): + ... + + @_broadcast + def __eq__(self, other): + ... + + @_broadcast + def __ne__(self, other): + ... + + @_broadcast + def __div__(self, other): + ... + + @_broadcast + def __floordiv__(self, other): + ... + + @_broadcast + def __lt__(self, other): + ... + + @_broadcast + def __le__(self, other): + ... + + @_broadcast + def __ge__(self, other): + ... + + @_broadcast + def __gt__(self, other): + ... + + @_broadcast + def __rdiv__(self, other): + ... + + @_broadcast + def __mul__(self, other): + ... diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py index 1569d766c..d838e0183 100644 --- a/test/test_tensorstack.py +++ b/test/test_tensorstack.py @@ -7,7 +7,7 @@ import pytest import torch -from tensordict.tenosrstack import TensorStack +from tensordict.tensorstack import TensorStack @pytest.fixture @@ -36,9 +36,10 @@ def test_indexing_slice(self, _tensorstack): assert (t[-3:][0] == x).all() assert (t[-3:][1] == y).all() assert (t[-3:][2] == z).all() - assert (t[::-1][0] == z).all() - assert (t[::-1][1] == y).all() - assert (t[::-1][2] == x).all() + # this breaks because the shape backend is a tensor, which cannot be indexed with neg steps + # assert (t[::-1][0] == z).all() + # assert (t[::-1][1] == y).all() + # assert (t[::-1][2] == x).all() def test_indexing_range(self, _tensorstack): t, (x, y, z) = _tensorstack @@ -54,10 +55,56 @@ def test_indexing_tensor(self, _tensorstack): assert (t[torch.tensor([0, 2])][1] == z).all() assert (t[torch.tensor([0, 2, 0, 2])][2] == x).all() assert (t[torch.tensor([0, 2, 0, 2])][3] == z).all() - assert (t[torch.tensor([[0, 2], [0, 2]])][0][0] == x).all() - assert (t[torch.tensor([[0, 2], [0, 2]])][0][1] == z).all() - assert (t[torch.tensor([[0, 2], [0, 2]])][1][0] == x).all() - assert (t[torch.tensor([[0, 2], [0, 2]])][1][1] == z).all() + + assert (t[torch.tensor([[0, 2], [0, 2]])][0, 0] == x).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][0, 1] == z).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][1, 0] == x).all() + assert (t[torch.tensor([[0, 2], [0, 2]])][1, 1] == z).all() + + def test_indexing_composite(self, _tensorstack): + _, (x, y, z) = _tensorstack + t = TensorStack.from_tensors([[x, y, z], [x, y, z]]) + assert (t[0, 0] == x).all() + assert (t[torch.tensor([0]), torch.tensor([0])] == x).all() + assert (t[torch.tensor([0]), torch.tensor([1])] == y).all() + assert (t[torch.tensor([0]), torch.tensor([2])] == z).all() + assert (t[:, torch.tensor([0])] == x).all() + assert (t[:, torch.tensor([1])] == y).all() + assert (t[:, torch.tensor([2])] == z).all() + assert ( + t[torch.tensor([0]), torch.tensor([1, 2])] + == TensorStack.from_tensors([y, z]) + ).all() + with pytest.raises(IndexError, match="Cannot index along"): + assert ( + t[..., torch.tensor([1, 2]), :, :, :] + == TensorStack.from_tensors([y, z]) + ).all() + + @pytest.mark.parametrize( + "op", + ["__add__", "__truediv__", "__mul__", "__sub__", "__mod__", "__eq__", "__ne__"], + ) + def test_elementwise(self, _tensorstack, op): + t, (x, y, z) = _tensorstack + t2 = getattr(t, op)(2) + torch.testing.assert_close(t2[0], getattr(x, op)(2)) + torch.testing.assert_close(t2[1], getattr(y, op)(2)) + torch.testing.assert_close(t2[2], getattr(z, op)(2)) + t2 = getattr(t, op)(torch.ones(5) * 2) + torch.testing.assert_close(t2[0], getattr(x, op)(torch.ones(5) * 2)) + torch.testing.assert_close(t2[1], getattr(y, op)(torch.ones(5) * 2)) + torch.testing.assert_close(t2[2], getattr(z, op)(torch.ones(5) * 2)) + # check broadcasting + assert t2[0].shape == x.shape + v = torch.ones(2, 1, 1, 1, 5) * 2 + t2 = getattr(t, op)(v) + assert t2.shape == torch.Size([2, 3, 3, -1, 5]) + torch.testing.assert_close(t2[:, 0], getattr(x, op)(v[:, 0])) + torch.testing.assert_close(t2[:, 1], getattr(y, op)(v[:, 0])) + torch.testing.assert_close(t2[:, 2], getattr(z, op)(v[:, 0])) + # check broadcasting + assert t2[:, 0].shape == torch.Size((2, *x.shape)) if __name__ == "__main__": From 57715e4f6aff405e91065a01f4e851dfd74ead8f Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 20 Jul 2023 16:41:23 -0400 Subject: [PATCH 3/6] amend --- tensordict/csrc/utils.h | 19 +++++++++++++++++-- tensordict/tensorstack.py | 5 +++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tensordict/csrc/utils.h b/tensordict/csrc/utils.h index 7716ed2ee..b4699001e 100644 --- a/tensordict/csrc/utils.h +++ b/tensordict/csrc/utils.h @@ -103,8 +103,8 @@ torch::Tensor _populate_index(torch::Tensor offsets, torch::Tensor offsets_cs) { py::list _as_shape(torch::Tensor shape_tensor) { torch::Tensor shape_tensor_view = shape_tensor.reshape({-1, shape_tensor.size(-1)}); torch::Tensor out = shape_tensor_view[0].clone(); - torch::Tensor unique = (shape_tensor_view == out).all(0); - out.masked_fill_(torch::logical_not(unique), -1); + torch::Tensor not_unique = (shape_tensor_view != out).any(0); + out.masked_fill_(not_unique, -1); std::vector shape_vector(shape_tensor.sizes().begin(), shape_tensor.sizes().end() - 1); // Extend 'shape_vector' with the values from 'out'. auto out_accessor = out.accessor(); @@ -115,3 +115,18 @@ py::list _as_shape(torch::Tensor shape_tensor) { py::list shape = py::cast(shape_vector); return shape; } +//py::list _as_shape(torch::Tensor shape_tensor) { +// torch::Tensor shape_tensor_view = shape_tensor.reshape({-1, shape_tensor.size(-1)}); +// torch::Tensor out = shape_tensor_view[0]; +// auto not_unique = (shape_tensor_view != out).any(0); +// out.masked_fill_(not_unique, -1); +// std::vector shape_vector; +// shape_vector.reserve(shape_tensor.ndimension() + shape_tensor.size(-1) - 1); // Reserve capacity to avoid reallocations. +// shape_vector.insert(shape_vector.end(), shape_tensor.sizes().begin(), shape_tensor.sizes().end() - 1); +// auto out_accessor = out.accessor(); +// for (int64_t i = 0; i < out_accessor.size(0); ++i) { +// shape_vector.push_back(out_accessor[i]); +// } +// py::list shape = py::cast(shape_vector); +// return shape; +//} diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py index 6ddfcf6cd..0d8a8cc33 100644 --- a/tensordict/tensorstack.py +++ b/tensordict/tensorstack.py @@ -137,7 +137,8 @@ def is_unique(self, dim): @_lazy_init def het_dims(self): - return [dim for dim in range(self.ndim) if self.as_shape[dim] == -1] + as_shape = self.as_shape + return [dim for dim, s in enumerate(as_shape) if s == -1] def numel(self): return ( @@ -199,7 +200,7 @@ def expand(self, *broadcast_shape): @_lazy_init def is_plain(self): - return not self.het_dims + return (self._shapes.ndim == 1 or not self.het_dims) def __getitem__(self, item): try: From 0cfbda5afdf3f0458c906240870b7e09bc2bf5c0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 20 Jul 2023 17:36:17 -0400 Subject: [PATCH 4/6] init --- tensordict/tensorstack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py index 0d8a8cc33..24214fcb6 100644 --- a/tensordict/tensorstack.py +++ b/tensordict/tensorstack.py @@ -299,12 +299,12 @@ def shape(self): def unstack(self): raise NotImplementedError - @property + @_lazy_init def _flat(self): # represents the tensor as a flat one return super().view(-1) - @property + @_lazy_init def _compact(self): # represents the tensor with a compact structure (rightmost consistent dims-wise) return torch.Tensor(super().view(-1, *self._shapes.common_shape)) From d503825c7ba52ec9bf699d3d44ba3d31386d5bf2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 21 Jul 2023 10:32:12 -0400 Subject: [PATCH 5/6] hack3 --- tensordict/csrc/utils.h | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tensordict/csrc/utils.h b/tensordict/csrc/utils.h index b4699001e..4b9fe49df 100644 --- a/tensordict/csrc/utils.h +++ b/tensordict/csrc/utils.h @@ -8,7 +8,7 @@ #include namespace py = pybind11; - +using namespace torch::indexing; py::tuple _unravel_key_to_tuple(const py::object& key) { bool is_tuple = py::isinstance(key); @@ -101,9 +101,16 @@ torch::Tensor _populate_index(torch::Tensor offsets, torch::Tensor offsets_cs) { return out; } py::list _as_shape(torch::Tensor shape_tensor) { - torch::Tensor shape_tensor_view = shape_tensor.reshape({-1, shape_tensor.size(-1)}); - torch::Tensor out = shape_tensor_view[0].clone(); - torch::Tensor not_unique = (shape_tensor_view != out).any(0); +// torch::Tensor shape_tensor_view = shape_tensor.reshape({-1, shape_tensor.size(-1)}); + torch::Tensor out = shape_tensor; + for (int64_t i = 0; i < shape_tensor.ndimension() - 1; ++i) { + out = out[0]; + } + out = out.clone(); + torch::Tensor not_unique = shape_tensor != out; + for (int64_t i = 0; i < shape_tensor.ndimension() - 1; ++i) { + not_unique = not_unique.any(0); + } out.masked_fill_(not_unique, -1); std::vector shape_vector(shape_tensor.sizes().begin(), shape_tensor.sizes().end() - 1); // Extend 'shape_vector' with the values from 'out'. From 06a125469172fa1ae14f53b943c3e71222279d02 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 21 Jul 2023 12:52:58 -0400 Subject: [PATCH 6/6] amend --- tensordict/tensorstack.py | 46 ++++++++++++++++++++++++++++++++++----- test/test_tensorstack.py | 24 ++++++++++++++++++++ 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/tensordict/tensorstack.py b/tensordict/tensorstack.py index 24214fcb6..cc06b9876 100644 --- a/tensordict/tensorstack.py +++ b/tensordict/tensorstack.py @@ -91,7 +91,7 @@ def _offsets_cs(self): shapes = self._shapes if common_shape: shapes = shapes[..., : -len(common_shape)] - cs = shapes.prod(-1).view(-1).cumsum(0) + cs = shapes.prod(-1).reshape(-1).cumsum(0) cs_pad = torch.nn.functional.pad(cs[:-1], [1, 0]) return torch.stack( [ @@ -200,7 +200,7 @@ def expand(self, *broadcast_shape): @_lazy_init def is_plain(self): - return (self._shapes.ndim == 1 or not self.het_dims) + return self._shapes.ndim == 1 or not self.het_dims def __getitem__(self, item): try: @@ -239,7 +239,7 @@ def _copy_shapes(func): def new_func(self, *args, **kwargs): out = getattr(torch.Tensor, func.__name__)(self, *args, **kwargs) _shapes = self._shapes - out._shapes = _shapes + out = TensorStack(out, shapes=_shapes) return out return new_func @@ -337,16 +337,25 @@ def ndimension(self): def __getitem__(self, index): if isinstance(index, (int,)): idx_beg = self._shapes._offsets_cs[0, index] - idx_end = self._shapes._offsets_cs[1, index] shapes = self._shapes[index] - out = self._compact.__getitem__(slice(idx_beg, idx_end)) + if idx_beg.numel() <= 1: + idx_end = self._shapes._offsets_cs[1, index] + out = self._compact.__getitem__(slice(idx_beg, idx_end)) + else: + elts = _populate_index( + self._shapes._offsets[index].view(-1), + idx_beg.view(-1), + ) + out = self._compact[elts] out = TensorStack(out, shapes=shapes) return out shapes = self._shapes[index] + if not isinstance(index, tuple): + index = (index,) # TODO: capture wrong indexing elts = _populate_index( self._shapes._offsets[index].view(-1), - self._shapes._offsets_cs[0][index].view(-1), + self._shapes._offsets_cs[(0, *index)].view(-1), ) tensor = self._compact[elts] return TensorStack(tensor, shapes=shapes) @@ -354,6 +363,31 @@ def __getitem__(self, index): def __repr__(self): return f"{self.__class__.__name__}(shape={self.shape}, dtype={self.dtype}, device={self.device})" + def permute(self, *dims): + if len(dims) == 1 and isinstance(dims[0], (list, tuple)): + dims = tuple(dims[0]) + n_batch_dms = len(self._shapes.batch_dim) + last_dims = [d - n_batch_dms for d in dims[n_batch_dms:]] + if last_dims != list(range(self.ndim - n_batch_dms)): + raise RuntimeError + dims = dims[:n_batch_dms] + out = TensorStack(self, shapes=_NestedShape(self._shapes._shapes)) + out._shapes._shapes = self._shapes._shapes.permute(*dims, n_batch_dms)[ + ..., last_dims + ] + out._shapes._offsets_cs = self._shapes._offsets_cs.permute( + 0, *[dim + 1 for dim in dims] + ) + out._shapes._offsets = self._shapes._offsets.permute(*dims) + return out + + def transpose(self, dim0, dim1): + out = TensorStack(self, shapes=_NestedShape(self._shapes._shapes)) + out._shapes._shapes = self._shapes._shapes.transpose(dim0, dim1) + out._shapes._offsets_cs = self._shapes._offsets_cs.transpose(dim0 + 1, dim1 + 1) + out._shapes._offsets = self._shapes._offsets.transpose(dim0, dim1) + return out + @_copy_shapes def to(self, *args, **kwargs): ... diff --git a/test/test_tensorstack.py b/test/test_tensorstack.py index d838e0183..a96324858 100644 --- a/test/test_tensorstack.py +++ b/test/test_tensorstack.py @@ -106,6 +106,30 @@ def test_elementwise(self, _tensorstack, op): # check broadcasting assert t2[:, 0].shape == torch.Size((2, *x.shape)) + def test_permute(self): + w = torch.randint(10, (3, 5, 5)) + x = torch.randint(10, (3, 4, 5)) + y = torch.randint(10, (3, 5, 5)) + z = torch.randint(10, (3, 4, 5)) + ts = TensorStack.from_tensors([[w, x], [y, z]]) + tst = ts.permute(1, 0, 2, 3, 4) + assert (tst[0, 1] == ts[1, 0]).all() + assert (tst[1, 0] == ts[0, 1]).all() + assert (tst[1, 1] == ts[1, 1]).all() + assert (tst[0, 0] == ts[0, 0]).all() + + def test_transpose(self): + w = torch.randint(10, (3, 5, 5)) + x = torch.randint(10, (3, 4, 5)) + y = torch.randint(10, (3, 5, 5)) + z = torch.randint(10, (3, 4, 5)) + ts = TensorStack.from_tensors([[w, x], [y, z]]) + tst = ts.transpose(1, 0) + assert (tst[0, 1] == ts[1, 0]).all() + assert (tst[1, 0] == ts[0, 1]).all() + assert (tst[1, 1] == ts[1, 1]).all() + assert (tst[0, 0] == ts[0, 0]).all() + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()