Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast list access implementation #337

Draft
wants to merge 6 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,12 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
assert (
node.slice.typ == IntegerInstanceType
), "Only single element list index access supported"
if isinstance(node.slice, Constant) and node.slice.value >= 0:
index = node.slice.value
return plt.ConstantIndexAccessListFast(
self.visit(node.value),
index,
)
return OLet(
[
(
Expand All @@ -666,7 +672,7 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
),
),
],
plt.IndexAccessList(OVar("l"), OVar("i")),
plt.IndexAccessListFast(OVar("l"), OVar("i")),
)
else:
return OLet(
Expand Down
28 changes: 27 additions & 1 deletion opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import frozendict
import frozenlist2
import hypothesis
from hypothesis import given
from hypothesis import given, example
from hypothesis import strategies as st
from opshin import IndefiniteList
from parameterized import parameterized
Expand Down Expand Up @@ -2978,3 +2978,29 @@ def validator(_: None) -> int:
self.assertEqual(
B.CONSTR_ID, res, "Invalid constr id generation (does not match pycardano)"
)

@given(st.data())
def test_constant_index_list(self, data):
xs = data.draw(st.lists(st.integers()))
y = data.draw(
st.one_of(
st.integers(min_value=1 - len(xs), max_value=len(xs) - 1), st.integers()
)
if xs
else st.integers()
)
# test the optimization for list access when the index is known at compile time
source_code = f"""
from typing import Dict, List, Union
def validator(x: List[int]) -> int:
return x[{y}]
"""
try:
exp = xs[y]
except IndexError:
exp = None
try:
ret = eval_uplc_value(source_code, xs)
except Exception as e:
ret = None
self.assertEqual(ret, exp, "list index returned wrong value")
14 changes: 10 additions & 4 deletions opshin/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,16 @@ def validator(x: bytes, y: int) -> int:
ret = None
self.assertEqual(ret, exp, "byte index returned wrong value")

@given(xs=st.lists(st.integers()), y=st.integers())
@example(xs=[0], y=-1)
@example(xs=[0], y=0)
def test_index_list(self, xs, y):
@given(st.data())
def test_index_list(self, data):
xs = data.draw(st.lists(st.integers()))
y = data.draw(
st.one_of(
st.integers(min_value=1 - len(xs), max_value=len(xs) - 1), st.integers()
)
if xs
else st.integers()
)
source_code = """
from typing import Dict, List, Union
def validator(x: List[int], y: int) -> int:
Expand Down
27 changes: 16 additions & 11 deletions opshin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def attribute(self, attr: str) -> plt.AST:
return OLambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.ConstantNthField(
plt.ConstantNthFieldFast(
OVar("self"),
pos,
),
Expand Down Expand Up @@ -602,7 +602,7 @@ def stringify(self, recursive: bool = False) -> plt.AST:
plt.Apply(
field_type.stringify(recursive=True),
transform_ext_params_map(field_type)(
plt.ConstantNthField(OVar("self"), pos)
plt.ConstantNthFieldFast(OVar("self"), pos)
),
),
map_fields,
Expand All @@ -613,7 +613,7 @@ def stringify(self, recursive: bool = False) -> plt.AST:
plt.Apply(
self.record.fields[0][1].stringify(recursive=True),
transform_ext_params_map(self.record.fields[0][1])(
plt.ConstantNthField(OVar("self"), pos)
plt.ConstantNthFieldFast(OVar("self"), pos)
),
),
map_fields,
Expand Down Expand Up @@ -721,8 +721,16 @@ def attribute(self, attr: str) -> plt.AST:
if not pos_constrs:
pos_decisor = plt.TraceError("Invalid constructor")
else:
pos_decisor = plt.Integer(pos_constrs[-1][0])
pos_decisor = plt.ConstantNthFieldFast(OVar("self"), pos_constrs[-1][0])
pos_constrs = pos_constrs[:-1]
# constr is not needed when there is only one position for all constructors
if not pos_constrs:
return OLambda(
["self"],
transform_ext_params_map(attr_typ)(
pos_decisor,
),
)
for pos, constrs in pos_constrs:
assert constrs, "Found empty constructors for a position"
constr_check = plt.EqualsInteger(
Expand All @@ -735,18 +743,15 @@ def attribute(self, attr: str) -> plt.AST:
)
pos_decisor = plt.Ite(
constr_check,
plt.Integer(pos),
plt.ConstantNthFieldFast(OVar("self"), pos),
pos_decisor,
)
return OLambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.NthField(
OVar("self"),
OLet(
[("constr", plt.Constructor(OVar("self")))],
pos_decisor,
),
OLet(
[("constr", plt.Constructor(OVar("self")))],
pos_decisor,
),
),
)
Expand Down
Loading