From 59bbb35505fb0ac9a1986612d2d1de5d29b4352a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 12 Jan 2024 18:53:38 +0100 Subject: [PATCH] Implement list index --- opshin/tests/test_stdlib.py | 22 +++++++++++++++++++ opshin/types.py | 42 +++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/opshin/tests/test_stdlib.py b/opshin/tests/test_stdlib.py index 11954d19..08b5ccb3 100644 --- a/opshin/tests/test_stdlib.py +++ b/opshin/tests/test_stdlib.py @@ -46,6 +46,28 @@ def validator(x: Dict[int, bytes], y: int) -> bytes: exp = None self.assertEqual(ret, exp, "dict[] returned wrong value") + @given(st.data()) + def test_list_index(self, data): + source_code = """ +def validator(x: List[int], z: int) -> int: + return x.index(z) + """ + xs = data.draw(st.lists(st.integers())) + z = data.draw( + st.one_of(st.sampled_from(xs), st.integers()) + if len(xs) > 0 + else st.integers() + ) + try: + ret = eval_uplc_value(source_code, xs, z) + except RuntimeError as e: + ret = None + try: + exp = xs.index(z) + except ValueError: + exp = None + self.assertEqual(ret, exp, "list.index returned wrong value") + @given(xs=st.dictionaries(st.integers(), st.binary())) def test_dict_keys(self, xs): source_code = """ diff --git a/opshin/types.py b/opshin/types.py index 66db543a..6faec70f 100644 --- a/opshin/types.py +++ b/opshin/types.py @@ -843,6 +843,48 @@ class ListType(ClassType): def __ge__(self, other): return isinstance(other, ListType) and self.typ >= other.typ + def attribute_type(self, attr) -> "Type": + if attr == "index": + return InstanceType( + FunctionType(frozenlist([self.typ]), IntegerInstanceType) + ) + super().attribute_type(attr) + + def attribute(self, attr) -> plt.AST: + if attr == "index": + return OLambda( + ["self", "x"], + OLet( + [("x", plt.Force(OVar("x")))], + plt.Apply( + plt.RecFun( + OLambda( + ["index", "xs", "a"], + plt.IteNullList( + OVar("xs"), + plt.TraceError("Did not find element in list"), + plt.Ite( + plt.EqualsInteger( + OVar("x"), plt.HeadList(OVar("xs")) + ), + OVar("a"), + plt.Apply( + OVar("index"), + OVar("index"), + plt.TailList(OVar("xs")), + plt.AddInteger(OVar("a"), plt.Integer(1)), + ), + ), + ), + ), + ), + OVar("self"), + plt.Integer(0), + ), + ), + ) + super().attribute(attr) + def stringify(self, recursive: bool = False) -> plt.AST: return OLambda( ["self"],