Skip to content

Commit

Permalink
fix: various speed ups & fix data_dims attribute error
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed May 12, 2024
1 parent 23900f7 commit b55460d
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 38 deletions.
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# v0.3.4

- Fix a data_dims access issue
- Marginally improve the speed of handling FoldedTensors in standard torch operations
- Use default torch types (e.g. `torch.float32` or `torch.torch64`)

# v0.3.3

- Handle empty inputs (e.g. `as_folded_tensor([[[], []], [[]]])`) by returning an empty tensor
Expand Down
115 changes: 77 additions & 38 deletions foldedtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch.autograd import Function

from . import _C

np_to_torch_dtype = {
torch.bool: bool,
torch.uint8: np.uint8,
Expand All @@ -19,9 +21,29 @@
torch.complex128: np.complex128,
}

from . import _C
pass_through_functions = {
torch.Tensor._grad.__get__,
torch.Tensor.grad,
torch.Tensor._base.__get__,
torch.Tensor.__repr__,
torch.Tensor.__str__,
torch.Tensor.__format__,
torch.Tensor.shape.__get__,
torch.Tensor.size.__get__,
torch.Tensor.dtype.__get__,
torch.Tensor.device.__get__,
}
if hasattr(torch._C, "TensorBase"):
pass_through_functions.add(torch._C.TensorBase.size)
else:
pass_through_functions.add(torch.Tensor.size)

try:
DisableTorchFunctionSubclass = torch._C.DisableTorchFunctionSubclass
except AttributeError:
DisableTorchFunctionSubclass = torch._C.DisableTorchFunction

__version__ = "0.3.3"
__version__ = "0.3.4"


# noinspection PyMethodOverriding
Expand Down Expand Up @@ -71,7 +93,6 @@ def backward(ctx, grad_output):
ctx.lengths,
ctx.old_data_dims,
)
# new_data_flat.index_put_({new_indexer}, old_data_flat.index_select(0, old_indexer));
shape_suffix = grad_output.shape[len(ctx.new_data_dims) :]
grad_input = torch.zeros(
(*shape_prefix, *shape_suffix), dtype=grad_output.dtype, device=device
Expand All @@ -80,20 +101,13 @@ def backward(ctx, grad_output):
-1, *shape_suffix
).index_select(0, ctx.output_indexer)
return grad_input, None
# return FoldedTensor(
# data=refolded_data,
# lengths=ctx.lengths,
# data_dims=ctx.old_data_dims,
# full_names=full_names,
# indexer=indexer,
# )


type_to_dtype_dict = {
int: torch.int64,
float: torch.float64,
int: torch.tensor([0]).dtype,
float: torch.tensor([0.0]).dtype,
bool: torch.bool,
None: torch.float64,
None: torch.tensor([0.0]).dtype,
}


Expand Down Expand Up @@ -151,16 +165,18 @@ def as_folded_tensor(
)
if (data_dims[-1] + 1) != len(full_names):
raise ValueError(
"The last dimension of `data_dims` must be the last variable dimension."
"The last dimension of `data_dims` must be the last "
"variable dimension."
)
elif full_names is not None:
data_dims = tuple(range(len(full_names)))
if isinstance(data, torch.Tensor) and lengths is not None:
data_dims = data_dims or tuple(range(len(lengths)))
np_indexer, shape = _C.make_refolding_indexer(lengths, data_dims)
assert shape == list(
data.shape[: len(data_dims)]
), f"Shape inferred from lengths is not compatible with data dims: {shape}, {data.shape}, {len(data_dims)}"
assert shape == list(data.shape[: len(data_dims)]), (
f"Shape inferred from lengths is not compatible with data dims: {shape}, "
f"{data.shape}, {len(data_dims)}"
)
result = FoldedTensor(
data=data,
lengths=lengths,
Expand Down Expand Up @@ -208,6 +224,23 @@ def as_folded_tensor(
return result


def _postprocess_func_result(result, input):
if (
input is not None
and input.shape[: len(input.data_dims)] != result.shape[: len(input.data_dims)]
):
return result

return FoldedTensor(
data=result,
lengths=input.lengths,
data_dims=input.data_dims,
full_names=input.full_names,
indexer=input.indexer,
mask=input._mask,
)


# noinspection PyUnresolvedReferences,PyInitNewSignature
class FoldedTensor(torch.Tensor):
"""
Expand Down Expand Up @@ -296,46 +329,52 @@ def to(self, *args, **kwargs):

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
result = super().__torch_function__(func, types, args, kwargs)
if kwargs is None:
kwargs = {}
if func in pass_through_functions:
with DisableTorchFunctionSubclass():
return func(*args, **kwargs)
with DisableTorchFunctionSubclass():
result = func(*args, **kwargs)

if func is torch.Tensor.share_memory_:
self = args[0]
self.indexer.share_memory_()
if self._mask is not None:
self._mask.share_memory_()

if not isinstance(result, torch.Tensor):
return result
return self

ft = None
for arg in (*args, *(kwargs or {}).values()):
for arg in (*args, *kwargs.values()):
if isinstance(arg, FoldedTensor):
assert (
ft is None or ft.data_dims == arg.data_dims
), "Cannot perform operation on FoldedTensors with different structure"
), "Cannot perform operation on FoldedTensors with different structures"
ft = arg
if isinstance(arg, (list, tuple)):
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, FoldedTensor):
assert (
ft is None or ft.data_dims == item.data_dims
), "Cannot perform operation on FoldedTensors with different structure"
assert ft is None or ft.data_dims == item.data_dims, (
"Cannot perform operation on FoldedTensors with "
"different structures"
)
ft = item

if isinstance(result, torch.Tensor):
return _postprocess_func_result(result, ft)

if (
ft is not None
and ft.shape[: len(ft.data_dims)] != result.shape[: len(ft.data_dims)]
isinstance(result, (tuple, list))
and len(result)
and isinstance(result[0], torch.Tensor)
):
return result.as_subclass(torch.Tensor)
return type(result)(
_postprocess_func_result(item, ft)
if isinstance(item, FoldedTensor)
else item
for item in result
)

result = FoldedTensor(
data=result,
lengths=ft.lengths,
data_dims=ft.data_dims,
full_names=ft.full_names,
indexer=ft.indexer,
mask=ft._mask,
)
return result

def clone(self):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_folded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,16 @@ def test_imbalanced_sequence_2():
)

assert "'int' object is not iterable" in str(e.value)


def test_max():
ft = as_folded_tensor(
[
[0, 1, 2],
[3, 4],
],
dtype=torch.float,
)
values, indices = ft.max(-1)
assert (values == torch.tensor([2, 4])).all()
assert (indices == torch.tensor([2, 1])).all()

0 comments on commit b55460d

Please sign in to comment.