diff --git a/docs/using-modules.rst b/docs/using-modules.rst index 7d63eb6617..4400a8dfa8 100644 --- a/docs/using-modules.rst +++ b/docs/using-modules.rst @@ -62,6 +62,21 @@ The ``_times_two()`` helper function in the above module can be immediately used The other functions cannot be used yet, because they touch the ``ownable`` module's state. There are two ways to declare a module so that its state can be used. +Using a module as an interface +============================== + +A module can be used as an interface with the ``__at__`` syntax. + +.. code-block:: vyper + + import ownable + + an_ownable: ownable.__interface__ + + def call_ownable(addr: address): + self.an_ownable = ownable.__at__(addr) + self.an_ownable.transfer_ownership(...) + Initializing a module ===================== diff --git a/tests/functional/codegen/modules/test_exports.py b/tests/functional/codegen/modules/test_exports.py index 93f4fe6c2f..3cc21d61a9 100644 --- a/tests/functional/codegen/modules/test_exports.py +++ b/tests/functional/codegen/modules/test_exports.py @@ -440,3 +440,26 @@ def __init__(): # call `c.__default__()` env.message_call(c.address) assert c.counter() == 6 + + +def test_inline_interface_export(make_input_bundle, get_contract): + lib1 = """ +interface IAsset: + def asset() -> address: view + +implements: IAsset + +@external +@view +def asset() -> address: + return self + """ + main = """ +import lib1 + +exports: lib1.IAsset + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.asset() == c.address diff --git a/tests/functional/codegen/modules/test_interface_imports.py b/tests/functional/codegen/modules/test_interface_imports.py index 3f0f8cb010..af9f9b5e68 100644 --- a/tests/functional/codegen/modules/test_interface_imports.py +++ b/tests/functional/codegen/modules/test_interface_imports.py @@ -1,3 +1,6 @@ +import pytest + + def test_import_interface_types(make_input_bundle, get_contract): ifaces = """ interface IFoo: @@ -50,9 +53,10 @@ def foo() -> bool: # check that this typechecks both directions a: lib1.IERC20 = IERC20(msg.sender) b: lib2.IERC20 = IERC20(msg.sender) + c: IERC20 = lib1.IERC20(msg.sender) # allowed in call position # return the equality so we can sanity check it - return a == b + return a == b and b == c """ input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) c = get_contract(main, input_bundle=input_bundle) @@ -60,6 +64,36 @@ def foo() -> bool: assert c.foo() is True +@pytest.mark.parametrize("interface_syntax", ["__at__", "__interface__"]) +def test_intrinsic_interface(get_contract, make_input_bundle, interface_syntax): + lib = """ +@external +@view +def foo() -> uint256: + # detect self call + if msg.sender == self: + return 4 + else: + return 5 + """ + + main = f""" +import lib + +exports: lib.__interface__ + +@external +@view +def bar() -> uint256: + return staticcall lib.{interface_syntax}(self).foo() + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.foo() == 5 + assert c.bar() == 4 + + def test_import_interface_flags(make_input_bundle, get_contract): ifaces = """ flag Foo: diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 8887bf07cb..31475a3bc0 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -774,3 +774,92 @@ def foo(s: MyStruct) -> MyStruct: assert "b: uint256" in out assert "struct Voter:" in out assert "voted: bool" in out + + +def test_intrinsic_interface_instantiation(make_input_bundle, get_contract): + lib1 = """ +@external +@view +def foo(): + pass + """ + main = """ +import lib1 + +i: lib1.__interface__ + +@external +def bar() -> lib1.__interface__: + self.i = lib1.__at__(self) + return self.i + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == c.address + + +def test_intrinsic_interface_converts(make_input_bundle, get_contract): + lib1 = """ +@external +@view +def foo(): + pass + """ + main = """ +import lib1 + +@external +def bar() -> lib1.__interface__: + return lib1.__at__(self) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + + assert c.bar() == c.address + + +def test_intrinsic_interface_kws(env, make_input_bundle, get_contract): + value = 10**5 + lib1 = f""" +@external +@payable +def foo(a: address): + send(a, {value}) + """ + main = f""" +import lib1 + +exports: lib1.__interface__ + +@external +def bar(a: address): + extcall lib1.__at__(self).foo(a, value={value}) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + env.set_balance(c.address, value) + original_balance = env.get_balance(env.deployer) + c.bar(env.deployer) + assert env.get_balance(env.deployer) == original_balance + value + + +def test_intrinsic_interface_defaults(env, make_input_bundle, get_contract): + lib1 = """ +@external +@payable +def foo(i: uint256=1) -> uint256: + return i + """ + main = """ +import lib1 + +exports: lib1.__interface__ + +@external +def bar() -> uint256: + return extcall lib1.__at__(self).foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + assert c.bar() == 1 diff --git a/tests/functional/syntax/modules/test_deploy_visibility.py b/tests/functional/syntax/modules/test_deploy_visibility.py index f51bf9575b..c908d4adae 100644 --- a/tests/functional/syntax/modules/test_deploy_visibility.py +++ b/tests/functional/syntax/modules/test_deploy_visibility.py @@ -1,7 +1,7 @@ import pytest from vyper.compiler import compile_code -from vyper.exceptions import CallViolation +from vyper.exceptions import CallViolation, UnknownAttribute def test_call_deploy_from_external(make_input_bundle): @@ -25,3 +25,35 @@ def foo(): compile_code(main, input_bundle=input_bundle) assert e.value.message == "Cannot call an @deploy function from an @external function!" + + +@pytest.mark.parametrize("interface_syntax", ["__interface__", "__at__"]) +def test_module_interface_init(make_input_bundle, tmp_path, interface_syntax): + lib1 = """ +#lib1.vy +k: uint256 + +@external +def bar(): + pass + +@deploy +def __init__(): + self.k = 10 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + code = f""" +import lib1 + +@deploy +def __init__(): + lib1.{interface_syntax}(self).__init__() + """ + + with pytest.raises(UnknownAttribute) as e: + compile_code(code, input_bundle=input_bundle) + + # as_posix() for windows tests + lib1_path = (tmp_path / "lib1.vy").as_posix() + assert e.value.message == f"interface {lib1_path} has no member '__init__'." diff --git a/tests/functional/syntax/modules/test_exports.py b/tests/functional/syntax/modules/test_exports.py index 7b00d29c98..4314c1bbf0 100644 --- a/tests/functional/syntax/modules/test_exports.py +++ b/tests/functional/syntax/modules/test_exports.py @@ -385,6 +385,28 @@ def do_xyz(): assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" +def test_no_export_unimplemented_inline_interface(make_input_bundle): + lib1 = """ +interface ifoo: + def do_xyz(): nonpayable + +# technically implements ifoo, but missing `implements: ifoo` + +@external +def do_xyz(): + pass + """ + main = """ +import lib1 + +exports: lib1.ifoo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(InterfaceViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" + + def test_export_selector_conflict(make_input_bundle): ifoo = """ @external @@ -444,3 +466,87 @@ def __init__(): with pytest.raises(InterfaceViolation) as e: compile_code(main, input_bundle=input_bundle) assert e.value._message == "requested `lib1.ifoo` but `lib1` does not implement `lib1.ifoo`!" + + +def test_export_empty_interface(make_input_bundle, tmp_path): + lib1 = """ +def an_internal_function(): + pass + """ + main = """ +import lib1 + +exports: lib1.__interface__ + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + + # as_posix() for windows + lib1_path = (tmp_path / "lib1.vy").as_posix() + assert e.value._message == f"lib1 (located at `{lib1_path}`) has no external functions!" + + +def test_invalid_export(make_input_bundle): + lib1 = """ +@external +def foo(): + pass + """ + main = """ +import lib1 +a: address + +exports: lib1.__interface__(self.a).foo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + with pytest.raises(StructureException) as e: + compile_code(main, input_bundle=input_bundle) + + assert e.value._message == "invalid export of a value" + assert e.value._hint == "exports should look like ." + + main = """ +interface Foo: + def foo(): nonpayable + +exports: Foo + """ + with pytest.raises(StructureException) as e: + compile_code(main) + + assert e.value._message == "invalid export" + assert e.value._hint == "exports should look like ." + + +@pytest.mark.parametrize("exports_item", ["__at__", "__at__(self)", "__at__(self).__interface__"]) +def test_invalid_at_exports(get_contract, make_input_bundle, exports_item): + lib = """ +@external +@view +def foo() -> uint256: + return 5 + """ + + main = f""" +import lib + +exports: lib.{exports_item} + +@external +@view +def bar() -> uint256: + return staticcall lib.__at__(self).foo() + """ + input_bundle = make_input_bundle({"lib.vy": lib}) + + with pytest.raises(Exception) as e: + compile_code(main, input_bundle=input_bundle) + + if exports_item == "__at__": + assert "not a function or interface" in str(e.value) + if exports_item == "__at__(self)": + assert "invalid exports" in str(e.value) + if exports_item == "__at__(self).__interface__": + assert "has no member '__interface__'" in str(e.value) diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 86ea4bcfd0..baf0c73c30 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -571,3 +571,53 @@ def bar(): compiler.compile_code(code, input_bundle=input_bundle) assert e.value.message == "Contract does not implement all interface functions: bar(), foobar()" + + +def test_intrinsic_interfaces_different_types(make_input_bundle, get_contract): + lib1 = """ +@external +@view +def foo(): + pass + """ + lib2 = """ +@external +@view +def foo(): + pass + """ + main = """ +import lib1 +import lib2 + +@external +def bar(): + assert lib1.__at__(self) == lib2.__at__(self) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + + with pytest.raises(TypeMismatch): + compiler.compile_code(main, input_bundle=input_bundle) + + +@pytest.mark.xfail +def test_intrinsic_interfaces_default_function(make_input_bundle, get_contract): + lib1 = """ +@external +@payable +def __default__(): + pass + """ + main = """ +import lib1 + +@external +def bar(): + extcall lib1.__at__(self).__default__() + + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + # TODO make the exception more precise once fixed + with pytest.raises(Exception): # noqa: B017 + compiler.compile_code(main, input_bundle=input_bundle) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index cd51966710..3a09bbe6c0 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -51,6 +51,7 @@ FlagT, HashMapT, InterfaceT, + ModuleT, SArrayT, StringT, StructT, @@ -680,7 +681,8 @@ def parse_Call(self): # TODO fix cyclic import from vyper.builtins._signatures import BuiltinFunctionT - func_t = self.expr.func._metadata["type"] + func = self.expr.func + func_t = func._metadata["type"] if isinstance(func_t, BuiltinFunctionT): return func_t.build_IR(self.expr, self.context) @@ -691,8 +693,14 @@ def parse_Call(self): return self.handle_struct_literal() # Interface constructor. Bar(
). - if is_type_t(func_t, InterfaceT): + if is_type_t(func_t, InterfaceT) or func.get("attr") == "__at__": assert not self.is_stmt # sanity check typechecker + + # magic: do sanity checks for module.__at__ + if func.get("attr") == "__at__": + assert isinstance(func_t, MemberFunctionT) + assert isinstance(func.value._metadata["type"], ModuleT) + (arg0,) = self.expr.args arg_ir = Expr(arg0, self.context).ir_node @@ -702,16 +710,16 @@ def parse_Call(self): return arg_ir if isinstance(func_t, MemberFunctionT): - darray = Expr(self.expr.func.value, self.context).ir_node + # TODO consider moving these to builtins or a dedicated file + darray = Expr(func.value, self.context).ir_node assert isinstance(darray.typ, DArrayT) args = [Expr(x, self.context).ir_node for x in self.expr.args] - if self.expr.func.attr == "pop": - # TODO consider moving this to builtins - darray = Expr(self.expr.func.value, self.context).ir_node + if func.attr == "pop": + darray = Expr(func.value, self.context).ir_node assert len(self.expr.args) == 0 return_item = not self.is_stmt return pop_dyn_array(darray, return_popped_item=return_item) - elif self.expr.func.attr == "append": + elif func.attr == "append": (arg,) = args check_assign( dummy_node_for_type(darray.typ.value_type), dummy_node_for_type(arg.typ) @@ -726,6 +734,8 @@ def parse_Call(self): ret.append(append_dyn_array(darray, arg)) return IRnode.from_list(ret) + raise CompilerPanic("unreachable!") # pragma: nocover + assert isinstance(func_t, ContractFunctionT) assert func_t.is_internal or func_t.is_constructor diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index f5f99a0bc3..e0eea293bc 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -268,6 +268,9 @@ def build_abi_output(compiler_data: CompilerData) -> list: _ = compiler_data.ir_runtime # ensure _ir_info is generated abi = module_t.interface.to_toplevel_abi_dict() + if module_t.init_function: + abi += module_t.init_function.to_toplevel_abi_dict() + if compiler_data.show_gas_estimates: # Add gas estimates for each function to ABI gas_estimates = build_gas_estimates(compiler_data.function_signatures) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index e275930fa0..adfc7540a0 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -96,6 +96,7 @@ class AnalysisResult: class ModuleInfo(AnalysisResult): module_t: "ModuleT" alias: str + # import_node: vy_ast._ImportStmt # maybe could be useful ownership: ModuleOwnership = ModuleOwnership.NO_OWNERSHIP ownership_decl: Optional[vy_ast.VyperNode] = None diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 8a2beb61e6..737f675b7c 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -40,7 +40,7 @@ ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace -from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT +from vyper.semantics.types import TYPE_T, EventT, FlagT, InterfaceT, StructT, is_type_t from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation @@ -499,9 +499,19 @@ def visit_ExportsDecl(self, node): raise StructureException("not a public variable!", decl, item) funcs = [decl._expanded_getter._metadata["func_type"]] elif isinstance(info.typ, ContractFunctionT): + # e.g. lib1.__interface__(self._addr).foo + if not isinstance(get_expr_info(item.value).typ, (ModuleT, TYPE_T)): + raise StructureException( + "invalid export of a value", + item.value, + hint="exports should look like .", + ) + # regular function funcs = [info.typ] - elif isinstance(info.typ, InterfaceT): + elif is_type_t(info.typ, InterfaceT): + interface_t = info.typ.typedef + if not isinstance(item, vy_ast.Attribute): raise StructureException( "invalid export", @@ -512,7 +522,7 @@ def visit_ExportsDecl(self, node): if module_info is None: raise StructureException("not a valid module!", item.value) - if info.typ not in module_info.typ.implemented_interfaces: + if interface_t not in module_info.typ.implemented_interfaces: iface_str = item.node_source_code module_str = item.value.node_source_code msg = f"requested `{iface_str}` but `{module_str}`" @@ -523,9 +533,15 @@ def visit_ExportsDecl(self, node): # find the specific implementation of the function in the module funcs = [ module_exposed_fns[fn.name] - for fn in info.typ.functions.values() + for fn in interface_t.functions.values() if fn.is_external ] + + if len(funcs) == 0: + path = module_info.module_node.path + msg = f"{module_info.alias} (located at `{path}`) has no external functions!" + raise StructureException(msg, item) + else: raise StructureException( f"not a function or interface: `{info.typ}`", info.typ.decl_node, item diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 9734087fc3..a31ce7acc1 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -199,7 +199,7 @@ def _raise_invalid_reference(name, node): try: s = t.get_member(name, node) - if isinstance(s, (VyperType, TYPE_T)): + if isinstance(s, VyperType): # ex. foo.bar(). bar() is a ContractFunctionT return [s] diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 59a20dd99f..b881f52b2b 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -1,8 +1,8 @@ from . import primitives, subscriptable, user from .base import TYPE_T, VOID_TYPE, KwargSettings, VyperType, is_type_t, map_void from .bytestrings import BytesT, StringT, _BytestringT -from .function import MemberFunctionT -from .module import InterfaceT +from .function import ContractFunctionT, MemberFunctionT +from .module import InterfaceT, ModuleT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EventT, FlagT, StructT diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 128ede0d5b..aca37b33a3 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -114,8 +114,13 @@ def __eq__(self, other): ) def __lt__(self, other): + # CMC 2024-10-20 what is this for? return self.abi_type.selector_name() < other.abi_type.selector_name() + def __repr__(self): + # TODO: add `pretty()` to the VyperType API? + return self._id + # return a dict suitable for serializing in the AST def to_dict(self): ret = {"name": self._id} @@ -362,10 +367,7 @@ def get_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": raise StructureException(f"{self} instance does not have members", node) hint = get_levenshtein_error_suggestions(key, self.members, 0.3) - raise UnknownAttribute(f"{self} has no member '{key}'.", node, hint=hint) - - def __repr__(self): - return self._id + raise UnknownAttribute(f"{repr(self)} has no member '{key}'.", node, hint=hint) class KwargSettings: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 7a56b01281..ffeb5b7299 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -874,7 +874,7 @@ def _id(self): return self.name def __repr__(self): - return f"{self.underlying_type._id} member function '{self.name}'" + return f"{self.underlying_type} member function '{self.name}'" def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: validate_call_args(node, len(self.arg_types)) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index dabeaf21b6..498757b94e 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -19,7 +19,7 @@ ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import TYPE_T, VyperType, is_type_t -from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, FlagT, StructT, _UserType from vyper.utils import OrderedSet @@ -240,9 +240,6 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": for fn_t in module_t.exposed_functions: funcs.append((fn_t.name, fn_t)) - if (fn_t := module_t.init_function) is not None: - funcs.append((fn_t.name, fn_t)) - event_set: OrderedSet[EventT] = OrderedSet() event_set.update([node._metadata["event_type"] for node in module_t.event_defs]) event_set.update(module_t.used_events) @@ -273,6 +270,19 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": return cls._from_lists(node.name, node, functions) +def _module_at(module_t): + return MemberFunctionT( + # set underlying_type to a TYPE_T as a bit of a kludge, since it's + # kind of like a class method (but we don't have classmethod + # abstraction) + underlying_type=TYPE_T(module_t), + name="__at__", + arg_types=[AddressT()], + return_type=module_t.interface, + is_modifying=False, + ) + + # Datatype to store all module information. class ModuleT(VyperType): typeclass = "module" @@ -330,16 +340,28 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for i in self.import_stmts: import_info = i._metadata["import_info"] - self.add_member(import_info.alias, import_info.typ) if hasattr(import_info.typ, "module_t"): - self._helper.add_member(import_info.alias, TYPE_T(import_info.typ)) + module_info = import_info.typ + # get_expr_info uses ModuleInfo + self.add_member(import_info.alias, module_info) + # type_from_annotation uses TYPE_T + self._helper.add_member(import_info.alias, TYPE_T(module_info.module_t)) + else: # interfaces + assert isinstance(import_info.typ, InterfaceT) + self.add_member(import_info.alias, TYPE_T(import_info.typ)) for name, interface_t in self.interfaces.items(): # can access interfaces in type position self._helper.add_member(name, TYPE_T(interface_t)) - self.add_member("__interface__", self.interface) + # module.__at__(addr) + self.add_member("__at__", _module_at(self)) + + # allow `module.__interface__` (in exports declarations) + self.add_member("__interface__", TYPE_T(self.interface)) + # allow `module.__interface__` (in type position) + self._helper.add_member("__interface__", TYPE_T(self.interface)) # __eq__ is very strict on ModuleT - object equality! this is because we # don't want to reason about where a module came from (i.e. input bundle,