Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ops.getslice for complex indexing by int,slice,None,Ellipsis #555

Merged
merged 5 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _(fn):

@affine_inputs.register(Unary)
def _(fn):
if fn.op in (ops.neg, ops.sum) or isinstance(fn.op, ops.ReshapeOp):
if fn.op in (ops.neg, ops.sum) or isinstance(
fn.op, (ops.ReshapeOp, ops.GetsliceOp)
):
return affine_inputs(fn.arg)
return frozenset()

Expand Down
53 changes: 53 additions & 0 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from weakref import WeakValueDictionary

import funsor.ops as ops
from funsor.ops.builtin import parse_ellipsis, parse_slice
from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote

Domain = type
Expand Down Expand Up @@ -331,6 +332,58 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain):
)


@find_domain.register(ops.GetsliceOp)
def _find_domain_getslice(op, domain):
index = op.defaults["index"]
if isinstance(domain, ArrayType):
dtype = domain.dtype
shape = list(domain.shape)
left, right = parse_ellipsis(index)

i = 0
for part in left:
if part is None:
shape.insert(i, 1)
i += 1
elif isinstance(part, int):
del shape[i]
elif isinstance(part, slice):
start, stop, step = parse_slice(part, shape[i])
shape[i] = max(0, (stop - start + step - 1) // step)
i += 1
else:
raise ValueError(part)

i = -1
for part in reversed(right):
if part is None:
shape.insert(len(shape) + i + 1, 1)
i -= 1
elif isinstance(part, int):
del shape[i]
elif isinstance(part, slice):
start, stop, step = parse_slice(part, shape[i])
shape[i] = max(0, (stop - start + step - 1) // step)
i -= 1
else:
raise ValueError(part)

return Array[dtype, tuple(shape)]

if isinstance(domain, ProductDomain):
if isinstance(index, tuple):
assert len(index) == 1
index = index[0]
if isinstance(index, int):
return domain.__args__[index]
elif isinstance(index, slice):
return Product[domain.__args__[index]]
else:
raise ValueError(index)

raise NotImplementedError("TODO")


@find_domain.register(ops.BinaryOp)
def _find_domain_pointwise_binary_generic(op, lhs, rhs):
if (
Expand Down
101 changes: 101 additions & 0 deletions funsor/ops/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UNITS,
BinaryOp,
Op,
OpMeta,
TransformOp,
UnaryOp,
declare_op_types,
Expand Down Expand Up @@ -43,6 +44,105 @@ def getitem(lhs, rhs, offset=0):
return lhs[(slice(None),) * offset + (rhs,)]


class GetsliceMeta(OpMeta):
"""
Works around slice objects not being hashable.
"""

def hash_args_kwargs(cls, args, kwargs):
index = args[0] if args else kwargs["index"]
if not isinstance(index, tuple):
index = (index,)
key = tuple(
(x.start, x.stop, x.step) if isinstance(x, slice) else x for x in index
)
return key


@UnaryOp.make(metaclass=GetsliceMeta)
def getslice(x, index=Ellipsis):
return x[index]


getslice.supported_types = (type(None), type(Ellipsis), int, slice)


def parse_ellipsis(index):
"""
Helper to split a slice into parts left and right of Ellipses.

:param index: A tuple, or other object (None, int, slice, Funsor).
:returns: a pair of tuples ``left, right``.
:rtype: tuple
"""
if not isinstance(index, tuple):
index = (index,)
left = []
i = 0
for part in index:
i += 1
if part is Ellipsis:
break
left.append(part)
right = []
for part in reversed(index[i:]):
if part is Ellipsis:
break
right.append(part)
right.reverse()
return tuple(left), tuple(right)


def normalize_ellipsis(index, size):
"""
Expand Ellipses in an index to fill the given number of dimensions.

This should satisfy the equation::

x[i] == x[normalize_ellipsis(i, len(x.shape))]
"""
left, right = parse_ellipsis(index)
if len(left) + len(right) > size:
raise ValueError(f"Index is too wide: {index}")
middle = (slice(None),) * (size - len(left) - len(right))
return left + middle + right


def parse_slice(s, size):
"""
Helper to determine nonnegative integers (start, stop, step) of a slice.

:param slice s: A slice.
:param int size: The size of the array being indexed into.
:returns: A tuple of nonnegative integers ``start, stop, step``.
:rtype: tuple
"""
start = s.start
if start is None:
start = 0
assert isinstance(start, int)
if start >= 0:
start = min(size, start)
else:
start = max(0, size + start)

stop = s.stop
if stop is None:
stop = size
assert isinstance(stop, int)
if stop >= 0:
stop = min(size, stop)
else:
stop = max(0, size + stop)

step = s.step
if step is None:
step = 1
assert isinstance(step, int)

return start, stop, step


abs = UnaryOp.make(_builtin_abs)
eq = BinaryOp.make(operator.eq)
ge = BinaryOp.make(operator.ge)
Expand Down Expand Up @@ -194,6 +294,7 @@ def sigmoid_log_abs_det_jacobian(x, y):
"floordiv",
"ge",
"getitem",
"getslice",
"gt",
"invert",
"le",
Expand Down
10 changes: 10 additions & 0 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,16 @@ def eager_getitem_tensor_tensor(op, lhs, rhs):
return Tensor(data, inputs, lhs.dtype)


@eager.register(Unary, ops.GetsliceOp, Tensor)
def eager_getslice_tensor(op, x):
index = op.defaults["index"]
if not isinstance(index, tuple):
index = (index,)
index = (slice(None),) * len(x.inputs) + index
data = x.data[index]
return Tensor(data, x.inputs, x.dtype)


@eager.register(
Finitary, ops.StackOp, typing.Tuple[typing.Union[(Number, Tensor)], ...]
)
Expand Down
55 changes: 44 additions & 11 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from funsor.interpreter import PatternMissingError, interpret
from funsor.ops import AssociativeOp, GetitemOp, Op
from funsor.ops.builtin import normalize_ellipsis, parse_ellipsis
from funsor.syntax import INFIX_OPERATORS, PREFIX_OPERATORS
from funsor.typing import GenericTypeMeta, Variadic, deep_type, get_args, get_origin
from funsor.util import getargspec, lazy_property, pretty, quote
Expand Down Expand Up @@ -730,23 +731,24 @@ def __ge__(self, other):
return Binary(ops.ge, self, to_funsor(other))

def __getitem__(self, other):
"""
Helper to desugar into either ops.getitem (for advanced indexing
involving Funsors as indices) or ops.getslice (for simple indexing
involving only integers, slices, None, and Ellipsis).
"""
if type(other) is not tuple:
if isinstance(other, ops.getslice.supported_types):
return ops.getslice(self, other)
other = to_funsor(other, Bint[self.output.shape[0]])
return Binary(ops.getitem, self, other)

# Handle complex slicing operations involving no funsors.
if all(isinstance(part, ops.getslice.supported_types) for part in other):
return ops.getslice(self, other)

# Handle Ellipsis slicing.
if any(part is Ellipsis for part in other):
left = []
for part in other:
if part is Ellipsis:
break
left.append(part)
right = []
for part in reversed(other):
if part is Ellipsis:
break
right.append(part)
right.reverse()
left, right = parse_ellipsis(other)
missing = len(self.output.shape) - len(left) - len(right)
assert missing >= 0
middle = [slice(None)] * missing
Expand All @@ -756,6 +758,8 @@ def __getitem__(self, other):
result = self
offset = 0
for part in other:
if part is None:
raise NotImplementedError("TODO")
if isinstance(part, slice):
if part != slice(None):
raise NotImplementedError("TODO support nontrivial slicing")
Expand Down Expand Up @@ -1754,6 +1758,20 @@ def eager_getitem_lambda(op, lhs, rhs):
return Lambda(lhs.var, expr)


@eager.register(Unary, ops.GetsliceOp, Lambda)
def eager_getslice_lambda(op, x):
index = normalize_ellipsis(op.defaults["index"], len(x.shape))
head, tail = index[0], index[1:]
expr = x.expr
if head != slice(None):
expr = expr(**{x.var.name: head})
if x.var.name not in expr.inputs:
return expr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if tail is not empty here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, fixed

if tail:
expr = ops.getslice(expr, tail)
return Lambda(x.var, expr)


class Independent(Funsor):
"""
Creates an independent diagonal distribution.
Expand Down Expand Up @@ -1885,6 +1903,21 @@ def eager_getitem_tuple(op, lhs, rhs):
return op(lhs.args, rhs.data)


@lazy.register(Unary, ops.GetsliceOp, Tuple)
@eager.register(Unary, ops.GetsliceOp, Tuple)
def eager_getslice_tuple(op, x):
index = op.defaults["index"]
if isinstance(index, tuple):
assert len(index) == 1
index = index[0]
if isinstance(index, int):
return op(x.args)
elif isinstance(index, slice):
return Tuple(op(x.args))
else:
raise ValueError(index)


def _symbolic(inputs, output, fn):
args, vargs, kwargs, defaults = getargspec(fn)
assert not vargs
Expand Down
17 changes: 17 additions & 0 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,20 @@ def iter_subsets(iterable, *, min_size=None, max_size=None):
max_size = len(iterable)
for size in range(min_size, max_size + 1):
yield from itertools.combinations(iterable, size)


class DesugarGetitem:
"""
Helper to desugar ``.__getitem__()`` syntax.

Example::

>>> desugar_getitem[1:3, ..., None]
(slice(1, 3), Ellipsis, None)
"""

def __getitem__(self, index):
return index


desugar_getitem = DesugarGetitem()
38 changes: 38 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from funsor import ops
from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND
from funsor.ops.builtin import parse_ellipsis, parse_slice
from funsor.testing import desugar_getitem
from funsor.util import get_backend


Expand Down Expand Up @@ -36,3 +38,39 @@ def test_transform_op_gc(dist):
assert len(op_set) == 1
del op
assert len(op_set) == 0


@pytest.mark.parametrize(
"index, left, right",
[
(desugar_getitem[()], (), ()),
(desugar_getitem[0], (0,), ()),
(desugar_getitem[...], (), ()),
(desugar_getitem[..., ...], (), ()),
(desugar_getitem[1, ...], (1,), ()),
(desugar_getitem[..., 1], (), (1,)),
(desugar_getitem[:, None, ..., 1, 1:2], (slice(None), None), (1, slice(1, 2))),
],
ids=str,
)
def test_parse_ellipsis(index, left, right):
assert parse_ellipsis(index) == (left, right)


@pytest.mark.parametrize(
"s, size, start, stop, step",
[
(desugar_getitem[:], 5, 0, 5, 1),
(desugar_getitem[:3], 5, 0, 3, 1),
(desugar_getitem[-9:3], 5, 0, 3, 1),
(desugar_getitem[:-2], 5, 0, 3, 1),
(desugar_getitem[2:], 5, 2, 5, 1),
(desugar_getitem[2:9], 5, 2, 5, 1),
(desugar_getitem[-3:], 5, 2, 5, 1),
(desugar_getitem[-3:-2], 5, 2, 3, 1),
],
ids=str,
)
def test_parse_slice(s, size, start, stop, step):
actual = parse_slice(s, size)
assert actual == (start, stop, step)
Loading