From e405170248f8b4d20fd0a3926dccdb850cb86660 Mon Sep 17 00:00:00 2001 From: Hang Su Date: Fri, 21 Oct 2022 14:39:18 -0400 Subject: [PATCH] type cast init commit --- pyteal/ast/abi/util.py | 47 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/pyteal/ast/abi/util.py b/pyteal/ast/abi/util.py index 4cece8b49..017ff4716 100644 --- a/pyteal/ast/abi/util.py +++ b/pyteal/ast/abi/util.py @@ -9,6 +9,7 @@ get_args, get_origin, ) +from inspect import isabstract import algosdk.abi @@ -593,3 +594,49 @@ def type_spec_is_assignable_to(a: TypeSpec, b: TypeSpec) -> bool: return True return False + + +def type_cast_admissible(a: TypeSpec, b: TypeSpec) -> bool: + from pyteal.ast.abi import ( + TupleTypeSpec, + StaticArrayTypeSpec, + DynamicArrayTypeSpec, + ) + + match a, b: + case TupleTypeSpec(), TupleTypeSpec(): + a, b = cast(TupleTypeSpec, a), cast(TupleTypeSpec, b) + if a.length_static() != b.length_static(): + return False + + return all( + map( + lambda ab: type_spec_is_assignable_to(ab[0], ab[1]), + zip(a.value_type_specs(), b.value_type_specs()), + ) + ) + case DynamicArrayTypeSpec(), DynamicArrayTypeSpec(): + a, b = cast(DynamicArrayTypeSpec, a), cast(DynamicArrayTypeSpec, b) + return type_cast_admissible(a.value_spec, b.value_spec) + case StaticArrayTypeSpec(), StaticArrayTypeSpec(): + a, b = cast(StaticArrayTypeSpec, a), cast(StaticArrayTypeSpec, b) + return a.length_static() == b.length_static() and type_cast_admissible( + a.value_spec, b.value_spec + ) + case _: + return type_spec_is_assignable_to(a, b) + + +def type_cast(value: BaseType, t: type[T]) -> T: + if isabstract(t) or not issubclass(t, BaseType): + raise TealInputError( + f"type cast target class {t} cannot be abstract, and it has to be BaseType" + ) + value_ts: TypeSpec = value.type_spec() + target_ts: TypeSpec = type_spec_from_annotation(t) + + if not type_cast_admissible(value_ts, target_ts): + raise TealInputError(f"casting {value_ts} to {target_ts} is not allowed") + new_instance = target_ts.new_instance() + new_instance.stored_value.slot = value.stored_value.slot + return cast(T, new_instance)