Skip to content

Commit

Permalink
Merge pull request #66 from OpShin/feat/polymorphic_int_constr
Browse files Browse the repository at this point in the history
Polymorphic integer constructor
  • Loading branch information
nielstron authored Jul 2, 2023
2 parents b8bf70f + f9015f3 commit aad7984
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 103 deletions.
47 changes: 36 additions & 11 deletions opshin/tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,17 @@ def validator(x: int) -> str:
ret = uplc_eval(f).value.decode("utf8")
self.assertEqual(ret, hex(x), "hex returned wrong value")

@unittest.skip("Integer stripping is currently broken")
@given(xs=st.one_of(st.builds(lambda x: str(x), st.integers()), st.text()))
@given(
xs=st.one_of(
st.builds(lambda x: str(x), st.integers()),
st.from_regex(r"\A(?!\s).*(?<!\s)\Z"),
)
)
@example("")
@example("10_00")
@example("_")
@example("_1")
@example("0\n")
# @example("0\n") # stripping is broken
def test_int_string(self, xs: str):
# this tests that errors that are caused by assignments are actually triggered at the time of assigning
source_code = """
Expand All @@ -163,15 +167,13 @@ def validator(x: str) -> int:
ret = uplc_eval(f).value
except:
ret = None
self.assertEqual(ret, exp, "str (integer) returned wrong value")
self.assertEqual(ret, exp, "int (str) returned wrong value")

@parameterized.parameterized.expand(
["10_00", "00", "_", "_1", "-10238", "19293812983721837981", "jakjsdh"]
)
def test_int_string(self, xs: str):
@given(xs=st.booleans())
def test_int_bool(self, xs: bool):
# this tests that errors that are caused by assignments are actually triggered at the time of assigning
source_code = """
def validator(x: str) -> int:
def validator(x: bool) -> int:
return int(x)
"""
ast = compiler.parse(source_code)
Expand All @@ -183,12 +185,35 @@ def validator(x: str) -> int:
except ValueError:
exp = None
try:
for d in [uplc.PlutusByteString(xs.encode("utf8"))]:
for d in [uplc.PlutusInteger(int(xs))]:
f = uplc.Apply(f, d)
ret = uplc_eval(f).value
except:
ret = None
self.assertEqual(ret, exp, "int (bool) returned wrong value")

@given(xs=st.integers())
def test_int_int(self, xs: int):
# this tests that errors that are caused by assignments are actually triggered at the time of assigning
source_code = """
def validator(x: int) -> int:
return int(x)
"""
ast = compiler.parse(source_code)
code = compiler.compile(ast)
code = code.compile()
f = code.term
try:
exp = int(xs)
except ValueError:
exp = None
try:
for d in [uplc.PlutusInteger(int(xs))]:
f = uplc.Apply(f, d)
ret = uplc_eval(f).value
except:
ret = None
self.assertEqual(ret, exp, "str (integer) returned wrong value")
self.assertEqual(ret, exp, "int (int) returned wrong value")

@given(i=st.binary())
def test_len_bytestring(self, i):
Expand Down
210 changes: 118 additions & 92 deletions opshin/typed_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,95 +954,7 @@ def stringify(self, recursive: bool = False) -> plt.AST:
@dataclass(frozen=True, unsafe_hash=True)
class IntegerType(AtomicType):
def constr_type(self) -> InstanceType:
return InstanceType(FunctionType([StringInstanceType], InstanceType(self)))

def constr(self) -> plt.AST:
# TODO we need to strip the string implicitely before parsing it
return plt.Lambda(
["x", "_"],
plt.Let(
[
("e", plt.EncodeUtf8(plt.Var("x"))),
("first_int", plt.IndexByteString(plt.Var("e"), plt.Integer(0))),
("len", plt.LengthOfByteString(plt.Var("e"))),
(
"fold_start",
plt.Lambda(
["start"],
plt.FoldList(
plt.Range(plt.Var("len"), plt.Var("start")),
plt.Lambda(
["s", "i"],
plt.Let(
[
(
"b",
plt.IndexByteString(
plt.Var("e"), plt.Var("i")
),
)
],
plt.Ite(
plt.EqualsInteger(
plt.Var("b"), plt.Integer(ord("_"))
),
plt.Var("s"),
plt.Ite(
plt.Or(
plt.LessThanInteger(
plt.Var("b"),
plt.Integer(ord("0")),
),
plt.LessThanInteger(
plt.Integer(ord("9")),
plt.Var("b"),
),
),
plt.TraceError(
"ValueError: invalid literal for int() with base 10"
),
plt.AddInteger(
plt.SubtractInteger(
plt.Var("b"),
plt.Integer(ord("0")),
),
plt.MultiplyInteger(
plt.Var("s"), plt.Integer(10)
),
),
),
),
),
),
plt.Integer(0),
),
),
),
],
plt.Ite(
plt.Or(
plt.EqualsInteger(plt.Var("len"), plt.Integer(0)),
plt.EqualsInteger(
plt.Var("first_int"),
plt.Integer(ord("_")),
),
),
plt.TraceError(
"ValueError: invalid literal for int() with base 10"
),
plt.Ite(
plt.EqualsInteger(
plt.Var("first_int"),
plt.Integer(ord("-")),
),
plt.Negate(
plt.Apply(plt.Var("fold_start"), plt.Integer(1)),
),
plt.Apply(plt.Var("fold_start"), plt.Integer(0)),
),
),
),
)
return InstanceType(PolymorphicFunctionType(IntImpl()))

def cmp(self, op: cmpop, o: "Type") -> plt.AST:
"""The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
Expand Down Expand Up @@ -1189,9 +1101,6 @@ class StringType(AtomicType):
def constr_type(self) -> InstanceType:
return InstanceType(PolymorphicFunctionType(StrImpl()))

def constr(self) -> plt.AST:
return InstanceType(PolymorphicFunctionType(StrImpl()))

def attribute_type(self, attr) -> Type:
if attr == "encode":
return InstanceType(FunctionType([], ByteStringInstanceType))
Expand Down Expand Up @@ -1670,6 +1579,123 @@ def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
return arg.typ.stringify()


class IntImpl(PolymorphicFunction):
def type_from_args(self, args: typing.List[Type]) -> FunctionType:
assert (
len(args) == 1
), f"'int' takes only one argument, but {len(args)} were given"
typ = args[0]
assert isinstance(typ, InstanceType), "Can only create ints from instances"
assert any(
isinstance(typ.typ, t) for t in (IntegerType, StringType, BoolType)
), "Can only create integers from int, str or bool"
return FunctionType(args, IntegerInstanceType)

def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
arg = args[0]
assert isinstance(arg, InstanceType), "Can only create ints from instances"
if isinstance(arg.typ, IntegerType):
return plt.Lambda(["x", "_"], plt.Var("x"))
elif isinstance(arg.typ, BoolType):
return plt.Lambda(
["x", "_"], plt.IfThenElse(plt.Var("x"), plt.Integer(1), plt.Integer(0))
)
elif isinstance(arg.typ, StringType):
return plt.Lambda(
["x", "_"],
plt.Let(
[
("e", plt.EncodeUtf8(plt.Var("x"))),
(
"first_int",
plt.IndexByteString(plt.Var("e"), plt.Integer(0)),
),
("len", plt.LengthOfByteString(plt.Var("e"))),
(
"fold_start",
plt.Lambda(
["start"],
plt.FoldList(
plt.Range(plt.Var("len"), plt.Var("start")),
plt.Lambda(
["s", "i"],
plt.Let(
[
(
"b",
plt.IndexByteString(
plt.Var("e"), plt.Var("i")
),
)
],
plt.Ite(
plt.EqualsInteger(
plt.Var("b"), plt.Integer(ord("_"))
),
plt.Var("s"),
plt.Ite(
plt.Or(
plt.LessThanInteger(
plt.Var("b"),
plt.Integer(ord("0")),
),
plt.LessThanInteger(
plt.Integer(ord("9")),
plt.Var("b"),
),
),
plt.TraceError(
"ValueError: invalid literal for int() with base 10"
),
plt.AddInteger(
plt.SubtractInteger(
plt.Var("b"),
plt.Integer(ord("0")),
),
plt.MultiplyInteger(
plt.Var("s"),
plt.Integer(10),
),
),
),
),
),
),
plt.Integer(0),
),
),
),
],
plt.Ite(
plt.Or(
plt.EqualsInteger(plt.Var("len"), plt.Integer(0)),
plt.EqualsInteger(
plt.Var("first_int"),
plt.Integer(ord("_")),
),
),
plt.TraceError(
"ValueError: invalid literal for int() with base 10"
),
plt.Ite(
plt.EqualsInteger(
plt.Var("first_int"),
plt.Integer(ord("-")),
),
plt.Negate(
plt.Apply(plt.Var("fold_start"), plt.Integer(1)),
),
plt.Apply(plt.Var("fold_start"), plt.Integer(0)),
),
),
),
)
else:
raise NotImplementedError(
f"Can not derive integer from type {arg.typ.__name__}"
)


@dataclass(frozen=True, unsafe_hash=True)
class PolymorphicFunctionType(ClassType):
"""A special type of builtin that may act differently on different parameters"""
Expand Down

0 comments on commit aad7984

Please sign in to comment.