Skip to content

Commit

Permalink
Fix a critical typing bug that allows incompatible types to be passed…
Browse files Browse the repository at this point in the history
… into functions

The bug affects the union type comparison operation when comparing it to
other union types. It would allow types of unexpected type to be passed
into a function.

To remedy this situation, a new testing harness should be created that
tests the validity of the type system
  • Loading branch information
nielstron committed Feb 1, 2024
1 parent 511c086 commit 608a66e
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 5 deletions.
61 changes: 61 additions & 0 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,3 +2715,64 @@ def validator(b: Dict[int, Dict[bytes, int]]) -> Dict[bytes, int]:
"""
res = eval_uplc_value(source_code, {1: {b"": 0}}, constant_folding=True)
self.assertEqual(res, {})

def test_union_subset_call(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
@dataclass()
class B(PlutusData):
CONSTR_ID = 1
bar: int
@dataclass()
class C(PlutusData):
CONSTR_ID = 2
foobar: int
def fun(x: Union[A, B, C]) -> int:
return 0
def validator(x: Union[A, B]) -> int:
return fun(x)
"""
builder._compile(source_code)

@unittest.expectedFailure
def test_union_superset_call(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass
@dataclass()
class A(PlutusData):
CONSTR_ID = 0
foo: int
@dataclass()
class B(PlutusData):
CONSTR_ID = 1
bar: int
@dataclass()
class C(PlutusData):
CONSTR_ID = 2
foobar: int
def fun(x: Union[A, B]) -> int:
return 0
def validator(x: Union[A, B, C]) -> int:
return fun(x)
"""
builder._compile(source_code)
16 changes: 16 additions & 0 deletions opshin/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ..types import *


def test_union_type_order():
A = RecordType(Record("A", "A", 0, [("foo", IntegerInstanceType)]))
B = RecordType(Record("B", "B", 1, [("bar", IntegerInstanceType)]))
C = RecordType(Record("C", "C", 2, [("baz", IntegerInstanceType)]))
abc = UnionType([A, B, C])
ab = UnionType([A, B])
a = A

assert a >= a
assert ab >= a
assert not a >= ab
assert abc >= ab
assert not ab >= abc
9 changes: 5 additions & 4 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,10 +801,11 @@ def visit_Call(self, node: Call) -> TypedCall:
assert len(tc.args) == len(
functyp.argtyps
), f"Signature of function does not match number of arguments. Expected {len(functyp.argtyps)} arguments with these types: {functyp.argtyps} but got {len(tc.args)} arguments."
# all arguments need to be supertypes of the given type
assert all(
ap >= a.typ for a, ap in zip(tc.args, functyp.argtyps)
), f"Signature of function does not match arguments. Expected {len(functyp.argtyps)} arguments with these types: {functyp.argtyps} but got {[a.typ for a in tc.args]}."
# all arguments need to be subtypes of the parameter type
for i, (a, ap) in enumerate(zip(tc.args, functyp.argtyps)):
assert (
ap >= a.typ
), f"Signature of function does not match arguments in argument {i}. Expected this type: {ap} but got {a.typ}."
tc.typ = functyp.rettyp
return tc
raise TypeInferenceError("Could not infer type of call")
Expand Down
10 changes: 9 additions & 1 deletion opshin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ def __ge__(self, other):
@dataclass(frozen=True, unsafe_hash=True)
class ClassType(Type):
def __ge__(self, other):
"""
Returns whether other can be substituted for this type.
In other words this returns whether the interface of this type is a subset of the interface of other.
Note that this is usually <= and not >=, but this needs to be fixed later.
Produces a partial order on types.
The top element is the most generic type and can not substitute for anything.
The bottom element is the most specific type and can be substituted for anything.
"""
raise NotImplementedError("Comparison between raw classtypes impossible")

def copy_only_attributes(self) -> plt.AST:
Expand Down Expand Up @@ -730,7 +738,7 @@ def attribute(self, attr: str) -> plt.AST:

def __ge__(self, other):
if isinstance(other, UnionType):
return all(any(t >= ot for ot in other.typs) for t in self.typs)
return all(self >= ot for ot in other.typs)
return any(t >= other for t in self.typs)

def cmp(self, op: cmpop, o: "Type") -> plt.AST:
Expand Down

0 comments on commit 608a66e

Please sign in to comment.