From a2e9ac581e53b58118a879eea326edb11090a104 Mon Sep 17 00:00:00 2001 From: luc10921 Date: Mon, 27 May 2024 10:58:04 -0300 Subject: [PATCH] #86dtfv7aw - Error when compiling smart contract with custom class that has a method that returns an explicitly typed dict --- .../model/type/collection/icollection.py | 7 ++- .../ReturnDictWithClassAttributes.py | 53 +++++++++++++++++++ boa3_test/tests/compiler_tests/test_class.py | 30 +++++++++++ 3 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 boa3_test/test_sc/class_test/ReturnDictWithClassAttributes.py diff --git a/boa3/internal/model/type/collection/icollection.py b/boa3/internal/model/type/collection/icollection.py index dae440eb9..515753a95 100644 --- a/boa3/internal/model/type/collection/icollection.py +++ b/boa3/internal/model/type/collection/icollection.py @@ -2,6 +2,7 @@ from collections.abc import Iterable from typing import Any, Self +from boa3.internal.model.expression import IExpression from boa3.internal.model.type.annotation.uniontype import UnionType from boa3.internal.model.type.classes.pythonclass import PythonClass from boa3.internal.model.type.itype import IType @@ -73,7 +74,11 @@ def get_types(cls, value: Any) -> set[IType]: if not isinstance(value, Iterable): value = {value} - types: set[IType] = {val if isinstance(val, IType) else Type.get_type(val) for val in value} + types: set[IType] = { + val if isinstance(val, IType) + else Type.get_type(val.type if isinstance(val, IExpression) else val) + for val in value + } return cls.filter_types(types) def get_item_type(self, index: tuple): diff --git a/boa3_test/test_sc/class_test/ReturnDictWithClassAttributes.py b/boa3_test/test_sc/class_test/ReturnDictWithClassAttributes.py new file mode 100644 index 000000000..837f327b5 --- /dev/null +++ b/boa3_test/test_sc/class_test/ReturnDictWithClassAttributes.py @@ -0,0 +1,53 @@ +from typing import Any + +from boa3.builtin.compile_time import public + + +class Example: + def __init__(self, shape: str, color: str, background: str, size: str): + self.shape = shape + self.color = color + self.background = background + self.size = size + + def test_value(self) -> dict[str, str]: + return { + 'shape': self.shape, + 'color': self.color, + 'background': self.background, + 'size': self.size + } + + def test_keys(self) -> dict[str, str]: + return { + self.shape: 'shape', + self.color: 'color', + self.background: 'background', + self.size: 'size' + } + + def test_pair(self) -> dict[str, str]: + return { + self.shape: self.shape, + self.color: self.color, + self.background: self.background, + self.size: self.size + } + + +@public +def test_only_values() -> Any: + example = Example('Rectangle', 'Blue', 'Black', 'Small') + return example.test_value() + + +@public +def test_only_keys() -> Any: + example = Example('Rectangle', 'Blue', 'Black', 'Small') + return example.test_keys() + + +@public +def test_pair() -> Any: + example = Example('Rectangle', 'Blue', 'Black', 'Small') + return example.test_pair() diff --git a/boa3_test/tests/compiler_tests/test_class.py b/boa3_test/tests/compiler_tests/test_class.py index 5bd09ff57..34a520471 100644 --- a/boa3_test/tests/compiler_tests/test_class.py +++ b/boa3_test/tests/compiler_tests/test_class.py @@ -459,3 +459,33 @@ async def test_class_property_and_parameter_with_same_name(self): result, _ = await self.call('main', ['unit test'], return_type=str) self.assertEqual('unit test', result) + + async def test_return_dict_with_class_attributes(self): + await self.set_up_contract('ReturnDictWithClassAttributes.py') + + expected_result = { + 'shape': 'Rectangle', + 'color': 'Blue', + 'background': 'Black', + 'size': 'Small' + } + result, _ = await self.call('test_only_values', [], return_type=dict[str,str]) + self.assertEqual(expected_result, result) + + expected_result = { + 'Rectangle': 'shape', + 'Blue': 'color', + 'Black': 'background', + 'Small': 'size' + } + result, _ = await self.call('test_only_keys', [], return_type=dict[str,str]) + self.assertEqual(expected_result, result) + + expected_result = { + 'Rectangle': 'Rectangle', + 'Blue': 'Blue', + 'Black': 'Black', + 'Small': 'Small' + } + result, _ = await self.call('test_pair', [], return_type=dict[str,str]) + self.assertEqual(expected_result, result)