diff --git a/crates/erg_compiler/codegen.rs b/crates/erg_compiler/codegen.rs index 4ac8368f8..b4ec235d2 100644 --- a/crates/erg_compiler/codegen.rs +++ b/crates/erg_compiler/codegen.rs @@ -39,9 +39,9 @@ use crate::hir::DefaultParamSignature; use crate::hir::GlobSignature; use crate::hir::ListWithLength; use crate::hir::{ - Accessor, Args, BinOp, Block, Call, ClassDef, Def, DefBody, Expr, GuardClause, Identifier, - Lambda, List, Literal, NonDefaultParamSignature, Params, PatchDef, PosArg, ReDef, Record, - Signature, SubrSignature, Tuple, UnaryOp, VarSignature, HIR, + Accessor, Args, BinOp, Block, Call, ClassDef, Def, DefBody, Dict, Expr, GuardClause, + Identifier, Lambda, List, Literal, NonDefaultParamSignature, Params, PatchDef, PosArg, ReDef, + Record, Set, Signature, SubrSignature, Tuple, UnaryOp, VarSignature, HIR, }; use crate::ty::codeobj::{CodeObj, CodeObjFlags, MakeFunctionFlags}; use crate::ty::value::{GenTypeObj, ValueObj}; @@ -864,6 +864,51 @@ impl PyCodeGenerator { self.emit_args_311(args, AccessKind::Name); return; } + "list_iterator" => { + let list = Expr::Literal(Literal::new(ValueObj::List(vec![].into()), Token::DUMMY)); + let iter = Identifier::static_public("iter"); + let iter_call = iter.call(Args::single(PosArg::new(list))); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(iter_call.into()))); + self.emit_call(typ_call); + return; + } + "set_iterator" => { + let set = Expr::Set(Set::empty()); + let iter = Identifier::static_public("iter"); + let iter_call = iter.call(Args::single(PosArg::new(set))); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(iter_call.into()))); + self.emit_call(typ_call); + return; + } + "dict_items" => { + let dict = Expr::Dict(Dict::empty()); + let items = Identifier::static_public("iter"); + let items_call = items.call(Args::single(PosArg::new(dict))); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(items_call.into()))); + self.emit_call(typ_call); + return; + } + "dict_keys" => { + let dict = Expr::Dict(Dict::empty()); + let keys = Identifier::static_public("keys"); + let keys_call = dict.method_call(keys, Args::empty()); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(keys_call.into()))); + self.emit_call(typ_call); + return; + } + "dict_values" => { + let dict = Expr::Dict(Dict::empty()); + let values = Identifier::static_public("values"); + let values_call = dict.method_call(values, Args::empty()); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(values_call.into()))); + self.emit_call(typ_call); + return; + } _ => {} } let name = self @@ -2754,6 +2799,46 @@ impl PyCodeGenerator { self.emit_load_name_instr(Identifier::private("#sum")); self.emit_args_311(args, Name); } + "ListIterator" => { + let list = Expr::Literal(Literal::new(ValueObj::List(vec![].into()), Token::DUMMY)); + let iter = Identifier::static_public("iter"); + let iter_call = iter.call(Args::single(PosArg::new(list))); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(iter_call.into()))); + self.emit_call(typ_call); + } + "SetIterator" => { + let set = Expr::Set(Set::empty()); + let iter = Identifier::static_public("iter"); + let iter_call = iter.call(Args::single(PosArg::new(set))); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(iter_call.into()))); + self.emit_call(typ_call); + } + "DictItems" => { + let dict = Expr::Dict(Dict::empty()); + let iter = Identifier::static_public("iter"); + let items_call = iter.call(Args::single(PosArg::new(dict))); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(items_call.into()))); + self.emit_call(typ_call); + } + "DictKeys" => { + let dict = Expr::Dict(Dict::empty()); + let keys = Identifier::static_public("keys"); + let keys_call = dict.method_call(keys, Args::empty()); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(keys_call.into()))); + self.emit_call(typ_call); + } + "DictValues" => { + let dict = Expr::Dict(Dict::empty()); + let values = Identifier::static_public("values"); + let values_call = dict.method_call(values, Args::empty()); + let typ = Identifier::static_public("type"); + let typ_call = typ.call(Args::single(PosArg::new(values_call.into()))); + self.emit_call(typ_call); + } other if local.ref_t().is_poly_meta_type() && other != "classof" => { if self.py_version.minor <= Some(9) { self.load_fake_generic(); diff --git a/crates/erg_compiler/context/initialize/traits.rs b/crates/erg_compiler/context/initialize/traits.rs index 6275f3592..2e1bce0b8 100644 --- a/crates/erg_compiler/context/initialize/traits.rs +++ b/crates/erg_compiler/context/initialize/traits.rs @@ -31,7 +31,8 @@ impl Context { let mut named = Self::builtin_mono_trait(NAMED, 2); named.register_builtin_erg_decl(FUNC_NAME, Str, Visibility::BUILTIN_PUBLIC); let mut sized = Self::builtin_mono_trait(SIZED, 2); - let t = fn0_met(mono(SIZED), Nat).quantify(); + let ret_t = if PYTHON_MODE { Int } else { Nat }; + let t = fn0_met(mono(SIZED), ret_t).quantify(); sized.register_builtin_erg_decl(FUNDAMENTAL_LEN, t, Visibility::BUILTIN_PUBLIC); let mut copy = Self::builtin_mono_trait(COPY, 2); let Slf = mono_q(SELF, subtypeof(mono(COPY))); @@ -227,15 +228,24 @@ impl Context { /* Iterable */ let mut iterable = Self::builtin_poly_trait(ITERABLE, vec![PS::t_nd(TY_T)], 2); iterable.register_superclass(poly(OUTPUT, vec![ty_tp(T.clone())]), &output); - let Slf = mono_q(SELF, subtypeof(poly(ITERABLE, vec![ty_tp(T.clone())]))); - let t = fn0_met(Slf.clone(), proj(Slf, ITER)).quantify(); - iterable.register_builtin_decl( - FUNC_ITER, - t, - Visibility::BUILTIN_PUBLIC, - Some(FUNDAMENTAL_ITER), - ); - iterable.register_builtin_erg_decl(ITER, Type, Visibility::BUILTIN_PUBLIC); + if PYTHON_MODE { + let t = fn0_met( + poly(ITERABLE, vec![ty_tp(T.clone())]), + poly(ITERATOR, vec![ty_tp(T.clone())]), + ) + .quantify(); + iterable.register_builtin_erg_decl(FUNDAMENTAL_ITER, t, Visibility::BUILTIN_PUBLIC); + } else { + let Slf = mono_q(SELF, subtypeof(poly(ITERABLE, vec![ty_tp(T.clone())]))); + let t = fn0_met(Slf.clone(), proj(Slf, ITER)).quantify(); + iterable.register_builtin_decl( + FUNC_ITER, + t, + Visibility::BUILTIN_PUBLIC, + Some(FUNDAMENTAL_ITER), + ); + iterable.register_builtin_erg_decl(ITER, Type, Visibility::BUILTIN_PUBLIC); + } let Slf = poly(ITERABLE, vec![ty_tp(T.clone())]); let U = type_q(TY_U); let t_map = fn1_met( @@ -244,9 +254,10 @@ impl Context { poly(MAP, vec![ty_tp(U.clone())]), ) .quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_MAP, t_map, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_map"), ); @@ -268,9 +279,10 @@ impl Context { ) .quantify(); let t_filter = t_filter.with_default_intersec_index(1); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_FILTER, t_filter, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_filter"), ); @@ -279,9 +291,10 @@ impl Context { vec![TyParam::List(vec![ty_tp(Nat), ty_tp(T.clone())])], ); let t_enumerate = fn0_met(Slf.clone(), poly(ITERATOR, vec![ty_tp(ret_t)])).quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_ENUMERATE, t_enumerate, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::enumerate"), ); @@ -291,9 +304,10 @@ impl Context { poly(ZIP, vec![ty_tp(T.clone()), ty_tp(U.clone())]), ) .quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_ZIP, t_zip, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::zip"), ); @@ -304,59 +318,67 @@ impl Context { T.clone(), ) .quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_REDUCE, t_reduce, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_reduce"), ); let t_nth = fn1_met(Slf.clone(), Nat, T.clone()).quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_NTH, t_nth, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_nth"), ); let t_skip = fn1_met(Slf.clone(), Nat, poly(ITERATOR, vec![ty_tp(T.clone())])).quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_SKIP, t_skip, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_skip"), ); let t_all = fn1_met(Slf.clone(), func1(T.clone(), Bool), Bool).quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_ALL, t_all, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_all"), ); let t_any = fn1_met(Slf.clone(), func1(T.clone(), Bool), Bool).quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_ANY, t_any, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_any"), ); let t_reversed = fn0_met(Slf.clone(), poly(ITERATOR, vec![ty_tp(T.clone())])).quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_REVERSED, t_reversed, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::reversed"), ); let t_position = fn1_met(Slf.clone(), func1(T.clone(), Bool), or(Nat, NoneType)).quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_POSITION, t_position, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_position"), ); let t_find = fn1_met(Slf.clone(), func1(T.clone(), Bool), or(T.clone(), NoneType)).quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_FIND, t_find, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_find"), ); @@ -369,16 +391,18 @@ impl Context { poly(ITERATOR, vec![ty_tp(T.clone())]), ) .quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_CHAIN, t_chain, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::iterable_chain"), ); let t_to_list = fn0_met(Slf.clone(), unknown_len_list_t(T.clone())).quantify(); - iterable.register_builtin_decl( + iterable.register_builtin_py_impl( FUNC_TO_LIST, t_to_list, + Mutability::Immutable, Visibility::BUILTIN_PUBLIC, Some("Function::list"), ); @@ -396,7 +420,7 @@ impl Context { ); /* Container */ let mut container = Self::builtin_poly_trait(CONTAINER, vec![PS::t_nd(TY_T)], 2); - let op_t = fn1_met(mono(CONTAINER), T.clone(), Bool).quantify(); + let op_t = fn1_met(poly(CONTAINER, vec![ty_tp(T.clone())]), T.clone(), Bool).quantify(); container.register_superclass(poly(OUTPUT, vec![ty_tp(T.clone())]), &output); container.register_builtin_erg_decl(FUNDAMENTAL_CONTAINS, op_t, Visibility::BUILTIN_PUBLIC); /* Collection */ diff --git a/crates/erg_compiler/hir.rs b/crates/erg_compiler/hir.rs index 7b115e976..0bd2d8ddb 100644 --- a/crates/erg_compiler/hir.rs +++ b/crates/erg_compiler/hir.rs @@ -554,6 +554,10 @@ impl Identifier { Call::new(Expr::Accessor(Accessor::Ident(self)), None, args) } + pub fn method_call(self, attr_name: Identifier, args: Args) -> Call { + Call::new(Expr::Accessor(Accessor::Ident(self)), Some(attr_name), args) + } + pub fn is_py_api(&self) -> bool { self.vi.py_name.is_some() } @@ -1095,6 +1099,14 @@ impl_display_for_enum!(Dict; Normal, Comprehension); impl_locational_for_enum!(Dict; Normal, Comprehension); impl_t_for_enum!(Dict; Normal, Comprehension); +impl Dict { + pub fn empty() -> Self { + let l_brace = Token::from_str(TokenKind::LBrace, "{"); + let r_brace = Token::from_str(TokenKind::RBrace, "}"); + Self::Normal(NormalDict::new(l_brace, r_brace, HashMap::new(), vec![])) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct NormalSet { pub l_brace: Token, @@ -1194,6 +1206,19 @@ impl_display_for_enum!(Set; Normal, WithLength); impl_locational_for_enum!(Set; Normal, WithLength); impl_t_for_enum!(Set; Normal, WithLength); +impl Set { + pub fn empty() -> Self { + let l_brace = Token::from_str(TokenKind::LBrace, "{"); + let r_brace = Token::from_str(TokenKind::RBrace, "}"); + Self::Normal(NormalSet::new( + l_brace, + r_brace, + Type::Uninited, + Args::empty(), + )) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct RecordAttrs(Vec); @@ -3123,6 +3148,14 @@ impl Expr { )) } + pub fn method_call(self, attr_name: Identifier, args: Args) -> Call { + Call::new(self, Some(attr_name), args) + } + + pub fn method_call_expr(self, attr_name: Identifier, args: Args) -> Self { + Self::Call(self.method_call(attr_name, args)) + } + pub fn attr(self, ident: Identifier) -> Accessor { Accessor::attr(self, ident) } diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index aa1a9aa9a..f410d1246 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -2699,6 +2699,7 @@ impl GenericASTLowerer { } }; let mut hir_methods_list = vec![]; + let mut implemented = set! {}; for methods in class_def.methods_list.into_iter() { let mut hir_methods = hir::Block::empty(); let (class, impl_trait) = @@ -2832,7 +2833,7 @@ impl GenericASTLowerer { } else { self.check_override(&class, None); } - if let Err(err) = self.check_trait_impl(impl_trait.clone(), &class) { + if let Err(err) = self.check_trait_impl(impl_trait.clone(), &class, &mut implemented) { errors.push(err); } let impl_trait = impl_trait.map(|(t, _)| t); @@ -3152,6 +3153,7 @@ impl GenericASTLowerer { &mut self, //: methods context impl_trait: Option<(Type, &TypeSpecWithOp)>, class: &Type, + implemented: &mut Set, ) -> SingleLowerResult<()> { if let Some((impl_trait, t_spec)) = impl_trait { if let Some(mut sups) = self.module.context.get_super_traits(&impl_trait) { @@ -3177,7 +3179,7 @@ impl GenericASTLowerer { .unwrap() .get_nominal_type_ctx(&impl_trait) { - self.check_methods_compatibility(&impl_trait, class, typ_ctx, t_spec) + self.check_methods_compatibility(&impl_trait, class, typ_ctx, t_spec, implemented) } else { return Err(LowerError::no_type_error( self.cfg.input.clone(), @@ -3190,17 +3192,19 @@ impl GenericASTLowerer { .get_similar_name(&impl_trait.local_name()), )); }; - for unverified in unverified_names { - errors.push(LowerError::not_in_trait_error( - self.cfg.input.clone(), - line!() as usize, - self.module.context.caused_by(), - unverified.inspect(), - &impl_trait, - class, - None, - unverified.loc(), - )); + if !PYTHON_MODE { + for unverified in unverified_names { + errors.push(LowerError::not_in_trait_error( + self.cfg.input.clone(), + line!() as usize, + self.module.context.caused_by(), + unverified.inspect(), + &impl_trait, + class, + None, + unverified.loc(), + )); + } } self.errs.extend(errors); } @@ -3216,48 +3220,68 @@ impl GenericASTLowerer { ctx: trait_ctx, }: &TypeContext, t_spec: &TypeSpecWithOp, + implemented: &mut Set, ) -> (Set<&VarName>, CompileErrors) { let mut errors = CompileErrors::empty(); let mut unverified_names = self.module.context.locals.keys().collect::>(); - for (decl_name, decl_vi) in trait_ctx.decls.iter() { - if let Some((name, vi)) = self.module.context.get_var_kv(decl_name.inspect()) { - let def_t = &vi.t; - let replaced_decl_t = decl_vi - .t - .clone() - .replace(trait_type, impl_trait) - .replace(impl_trait, class); - unverified_names.remove(name); - if !self.module.context.supertype_of(&replaced_decl_t, def_t) { - let hint = self - .module - .context - .get_simple_type_mismatch_hint(&replaced_decl_t, def_t); - errors.push(LowerError::trait_member_type_error( + let tys_decls = if let Some(sups) = self.module.context.get_super_types(trait_type) { + sups.map(|sup| { + if implemented.linear_contains(&sup) { + return (sup, Dict::new()); + } + let decls = self + .module + .context + .get_nominal_type_ctx(&sup) + .map_or(Dict::new(), |ctx| ctx.decls.clone()); + (sup, decls) + }) + .collect::>() + } else { + vec![(impl_trait.clone(), trait_ctx.decls.clone())] + }; + for (impl_trait, decls) in tys_decls { + for (decl_name, decl_vi) in decls { + if let Some((name, vi)) = self.module.context.get_var_kv(decl_name.inspect()) { + let def_t = &vi.t; + let replaced_decl_t = decl_vi + .t + .clone() + .replace(trait_type, &impl_trait) + .replace(&impl_trait, class); + unverified_names.remove(name); + if !self.module.context.supertype_of(&replaced_decl_t, def_t) { + let hint = self + .module + .context + .get_simple_type_mismatch_hint(&replaced_decl_t, def_t); + errors.push(LowerError::trait_member_type_error( + self.cfg.input.clone(), + line!() as usize, + name.loc(), + self.module.context.caused_by(), + name.inspect(), + &impl_trait, + &decl_vi.t, + &vi.t, + hint, + )); + } + } else { + errors.push(LowerError::trait_member_not_defined_error( self.cfg.input.clone(), line!() as usize, - name.loc(), self.module.context.caused_by(), - name.inspect(), - impl_trait, - &decl_vi.t, - &vi.t, - hint, + decl_name.inspect(), + &impl_trait, + class, + None, + t_spec.loc(), )); } - } else { - errors.push(LowerError::trait_member_not_defined_error( - self.cfg.input.clone(), - line!() as usize, - self.module.context.caused_by(), - decl_name.inspect(), - impl_trait, - class, - None, - t_spec.loc(), - )); } } + implemented.insert(trait_type.clone()); (unverified_names, errors) } diff --git a/examples/impl.er b/examples/impl.er index 15b8fd2ea..782dd1299 100644 --- a/examples/impl.er +++ b/examples/impl.er @@ -24,3 +24,16 @@ s: Int = p * q assert s == 11 assert r == Point.new 4, 6 assert r.norm() == 52 + +MyList = Class { + .list = List(Obj) +} +MyList|<: Iterable(Obj)|. + Iter = ListIterator(Obj) + iter self = self.list.iter() +MyList|<: Sized|. + __len__ self = len self.list +MyList|<: Container(Obj)|. + __contains__ self, x: Obj = x in self.list +MyList|<: Sequence(Obj)|. + __getitem__ self, idx = self.list[idx]