Skip to content

Commit

Permalink
Merge pull request #295 from OpShin/feat/slice_list
Browse files Browse the repository at this point in the history
Add support for builtin list slicing
  • Loading branch information
nielstron authored Dec 15, 2023
2 parents fc5b477 + d47fe75 commit 8193fd6
Show file tree
Hide file tree
Showing 5 changed files with 379 additions and 99 deletions.
126 changes: 103 additions & 23 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,32 +629,112 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
),
)
if isinstance(node.value.typ.typ, ListType):
assert (
node.slice.typ == IntegerInstanceType
), "Only single element list index access supported"
return plt.Lambda(
[STATEMONAD],
plt.Let(
[
("l", plt.Apply(self.visit(node.value), plt.Var(STATEMONAD))),
(
"raw_i",
plt.Apply(self.visit(node.slice), plt.Var(STATEMONAD)),
),
(
"i",
plt.Ite(
plt.LessThanInteger(plt.Var("raw_i"), plt.Integer(0)),
plt.AddInteger(
plt.Var("raw_i"), plt.LengthList(plt.Var("l"))
if not isinstance(node.slice, Slice):
assert (
node.slice.typ == IntegerInstanceType
), "Only single element list index access supported"
return plt.Lambda(
[STATEMONAD],
plt.Let(
[
(
"l",
plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
),
(
"raw_i",
plt.Apply(self.visit(node.slice), plt.Var(STATEMONAD)),
),
(
"i",
plt.Ite(
plt.LessThanInteger(
plt.Var("raw_i"), plt.Integer(0)
),
plt.AddInteger(
plt.Var("raw_i"), plt.LengthList(plt.Var("l"))
),
plt.Var("raw_i"),
),
),
],
plt.IndexAccessList(plt.Var("l"), plt.Var("i")),
),
)
else:
return plt.Lambda(
[STATEMONAD],
plt.Let(
[
(
"xs",
plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
),
(
"raw_i",
plt.Apply(
self.visit(node.slice.lower), plt.Var(STATEMONAD)
),
plt.Var("raw_i"),
),
(
"i",
plt.Ite(
plt.LessThanInteger(
plt.Var("raw_i"), plt.Integer(0)
),
plt.AddInteger(
plt.Var("raw_i"),
plt.LengthList(plt.Var("xs")),
),
plt.Var("raw_i"),
),
),
(
"raw_j",
plt.Apply(
self.visit(node.slice.upper), plt.Var(STATEMONAD)
),
),
(
"j",
plt.Ite(
plt.LessThanInteger(
plt.Var("raw_j"), plt.Integer(0)
),
plt.AddInteger(
plt.Var("raw_j"),
plt.LengthList(plt.Var("xs")),
),
plt.Var("raw_j"),
),
),
(
"drop",
plt.Ite(
plt.LessThanEqualsInteger(
plt.Var("i"), plt.Integer(0)
),
plt.Integer(0),
plt.Var("i"),
),
),
(
"take",
plt.SubtractInteger(plt.Var("j"), plt.Var("drop")),
),
],
plt.Ite(
plt.LessThanEqualsInteger(plt.Var("j"), plt.Var("i")),
empty_list(node.value.typ.typ.typ),
plt.SliceList(
plt.Var("drop"),
plt.Var("take"),
plt.Var("xs"),
empty_list(node.value.typ.typ.typ),
),
),
],
plt.IndexAccessList(plt.Var("l"), plt.Var("i")),
),
)
),
)
elif isinstance(node.value.typ.typ, DictType):
dict_typ = node.value.typ.typ
if not isinstance(node.slice, Slice):
Expand Down
170 changes: 170 additions & 0 deletions opshin/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,72 @@ def validator(x: bytes, y: int, z: int) -> bytes:
ret = None
self.assertEqual(ret, exp, "byte slice returned wrong value")

@given(x=st.binary(), y=st.integers())
@example(b"\x00", -2)
@example(b"1234", 1)
@example(b"1234", 2)
@example(b"1234", 2)
@example(b"1234", 3)
@example(b"1234", 3)
def test_slice_bytes_lower(self, x, y):
source_code = """
def validator(x: bytes, y: int) -> bytes:
return x[y:]
"""
try:
exp = x[y:]
except IndexError:
exp = None
try:
ret = eval_uplc_value(source_code, x, y)
except:
ret = None
self.assertEqual(ret, exp, "byte slice returned wrong value")

@given(x=st.binary(), y=st.integers())
@example(b"\x00", 0)
@example(b"1234", 2)
@example(b"1234", 4)
@example(b"1234", 2)
@example(b"1234", 3)
@example(b"1234", 1)
def test_slice_bytes_upper(self, x, y):
source_code = """
def validator(x: bytes, y: int) -> bytes:
return x[:y]
"""
try:
exp = x[:y]
except IndexError:
exp = None
try:
ret = eval_uplc_value(source_code, x, y)
except:
ret = None
self.assertEqual(ret, exp, "byte slice returned wrong value")

@given(x=st.binary())
@example(b"\x00")
@example(b"1234")
@example(b"1234")
@example(b"1234")
@example(b"1234")
@example(b"1234")
def test_slice_bytes_full(self, x):
source_code = """
def validator(x: bytes) -> bytes:
return x[:]
"""
try:
exp = x[:]
except IndexError:
exp = None
try:
ret = eval_uplc_value(source_code, x)
except:
ret = None
self.assertEqual(ret, exp, "byte slice returned wrong value")

@given(x=st.binary(), y=st.integers())
@example(b"1234", 0)
@example(b"1234", 1)
Expand Down Expand Up @@ -260,6 +326,110 @@ def validator(x: List[int], y: int) -> int:
ret = None
self.assertEqual(ret, exp, "list index returned wrong value")

@given(x=st.lists(st.integers(), max_size=20), y=st.integers(), z=st.integers())
@example([0], -2, 0)
@example([1, 2, 3, 4], 1, 2)
@example([1, 2, 3, 4], 2, 4)
@example([1, 2, 3, 4], 2, 2)
@example([1, 2, 3, 4], 3, 3)
@example([1, 2, 3, 4], 3, 1)
def test_slice_list(self, x, y, z):
source_code = """
def validator(x: List[int], y: int, z: int) -> List[int]:
return x[y:z]
"""
try:
exp = x[y:z]
except IndexError:
exp = None
try:
ret = eval_uplc_value(source_code, x, y, z)
except:
ret = None
self.assertEqual(
ret,
[PlutusInteger(x) for x in exp] if exp is not None else exp,
"list slice returned wrong value",
)

@given(x=st.lists(st.integers(), max_size=20), y=st.integers())
@example([0], -2)
@example([1, 2, 3, 4], 1)
@example([1, 2, 3, 4], 2)
@example([1, 2, 3, 4], 2)
@example([1, 2, 3, 4], 3)
@example([1, 2, 3, 4], 3)
def test_slice_list_lower(self, x, y):
source_code = """
def validator(x: List[int], y: int) -> List[int]:
return x[y:]
"""
try:
exp = x[y:]
except IndexError:
exp = None
try:
ret = eval_uplc_value(source_code, x, y)
except:
ret = None
self.assertEqual(
ret,
[PlutusInteger(x) for x in exp] if exp is not None else exp,
"list slice returned wrong value",
)

@given(x=st.lists(st.integers(), max_size=20), y=st.integers())
@example([0], 0)
@example([1, 2, 3, 4], 2)
@example([1, 2, 3, 4], 4)
@example([1, 2, 3, 4], 2)
@example([1, 2, 3, 4], 3)
@example([1, 2, 3, 4], 1)
def test_slice_list_upper(self, x, y):
source_code = """
def validator(x: List[int], y: int) -> List[int]:
return x[:y]
"""
try:
exp = x[:y]
except IndexError:
exp = None
try:
ret = eval_uplc_value(source_code, x, y)
except:
ret = None
self.assertEqual(
ret,
[PlutusInteger(x) for x in exp] if exp is not None else exp,
"list slice returned wrong value",
)

@given(x=st.lists(st.integers(), max_size=20))
@example([0])
@example([1, 2, 3, 4])
@example([1, 2, 3, 4])
@example([1, 2, 3, 4])
@example([1, 2, 3, 4])
@example([1, 2, 3, 4])
def test_slice_list_full(self, x):
source_code = """
def validator(x: List[int]) -> List[int]:
return x[:]
"""
try:
exp = x[:]
except IndexError:
exp = None
try:
ret = eval_uplc_value(source_code, x)
except:
ret = None
self.assertEqual(
ret,
[PlutusInteger(x) for x in exp] if exp is not None else exp,
"list slice returned wrong value",
)

@given(xs=st.lists(st.integers()), y=st.integers())
@example(xs=[0, 1], y=-1)
@example(xs=[0, 1], y=0)
Expand Down
32 changes: 23 additions & 9 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,17 +674,36 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript:
f"Could not infer type of subscript of typ {ts.value.typ.typ.__class__}"
)
elif isinstance(ts.value.typ.typ, ListType):
ts.typ = ts.value.typ.typ.typ
ts.slice = self.visit(node.slice)
assert ts.slice.typ == IntegerInstanceType, "List indices must be integers"
if not isinstance(ts.slice, Slice):
ts.typ = ts.value.typ.typ.typ
ts.slice = self.visit(node.slice)
assert (
ts.slice.typ == IntegerInstanceType
), "List indices must be integers"
else:
ts.typ = ts.value.typ
if ts.slice.lower is None:
ts.slice.lower = Constant(0)
ts.slice.lower = self.visit(node.slice.lower)
assert (
ts.slice.lower.typ == IntegerInstanceType
), "lower slice indices for lists must be integers"
if ts.slice.upper is None:
ts.slice.upper = Call(
func=Name(id="len", ctx=Load()), args=[ts.value], keywords=[]
)
ts.slice.upper = self.visit(node.slice.upper)
assert (
ts.slice.upper.typ == IntegerInstanceType
), "upper slice indices for lists must be integers"
elif isinstance(ts.value.typ.typ, ByteStringType):
if not isinstance(ts.slice, Slice):
ts.typ = IntegerInstanceType
ts.slice = self.visit(node.slice)
assert (
ts.slice.typ == IntegerInstanceType
), "bytes indices must be integers"
elif isinstance(ts.slice, Slice):
else:
ts.typ = ByteStringInstanceType
if ts.slice.lower is None:
ts.slice.lower = Constant(0)
Expand All @@ -700,12 +719,7 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript:
assert (
ts.slice.upper.typ == IntegerInstanceType
), "upper slice indices for bytes must be integers"
else:
raise TypeInferenceError(
f"Could not infer type of subscript of typ {ts.value.typ.__class__}"
)
elif isinstance(ts.value.typ.typ, DictType):
# TODO could be implemented with potentially just erroring. It might be desired to avoid this though.
if not isinstance(ts.slice, Slice):
ts.slice = self.visit(node.slice)
assert (
Expand Down
Loading

0 comments on commit 8193fd6

Please sign in to comment.