diff --git a/lib/steep/ast/types/factory.rb b/lib/steep/ast/types/factory.rb index 9b133c02a..20aeccca7 100644 --- a/lib/steep/ast/types/factory.rb +++ b/lib/steep/ast/types/factory.rb @@ -36,6 +36,29 @@ def type_1_opt(type) end end + def normalize_args(type_name, args) + case + when type_name.class? + if entry = env.normalized_module_class_entry(type_name) + type_params = entry.type_params + end + when type_name.interface? + if entry = env.interface_decls.fetch(type_name, nil) + type_params = entry.decl.type_params + end + when type_name.alias? + if entry = env.type_alias_decls.fetch(type_name, nil) + type_params = entry.decl.type_params + end + end + + if type_params && !type_params.empty? + RBS::AST::TypeParam.normalize_args(type_params, args) + else + args + end + end + def type(type) if ty = type_cache[type] return ty @@ -68,15 +91,15 @@ def type(type) Name::Singleton.new(name: type_name) when RBS::Types::ClassInstance type_name = type.name - args = type.args.map {|arg| type(arg) } + args = normalize_args(type_name, type.args).map {|arg| type(arg) } Name::Instance.new(name: type_name, args: args) when RBS::Types::Interface type_name = type.name - args = type.args.map {|arg| type(arg) } + args = normalize_args(type_name, type.args).map {|arg| type(arg) } Name::Interface.new(name: type_name, args: args) when RBS::Types::Alias type_name = type.name - args = type.args.map {|arg| type(arg) } + args = normalize_args(type_name, type.args).map {|arg| type(arg) } Name::Alias.new(name: type_name, args: args) when RBS::Types::Union Union.build(types: type.types.map {|ty| type(ty) }) @@ -247,7 +270,8 @@ def type_param(type_param) name: type_param.name, upper_bound: type_opt(type_param.upper_bound_type), variance: type_param.variance, - unchecked: type_param.unchecked? + unchecked: type_param.unchecked?, + default_type: type_opt(type_param.default_type) ) end diff --git a/lib/steep/interface/builder.rb b/lib/steep/interface/builder.rb index 82e6088c7..486c284dc 100644 --- a/lib/steep/interface/builder.rb +++ b/lib/steep/interface/builder.rb @@ -50,7 +50,7 @@ def validate_fvs(name, type) end def upper_bound(a) - variable_bounds.fetch(a, AST::Builtin::Object.instance_type) + variable_bounds.fetch(a, Interface::TypeParam::IMPLICIT_UPPER_BOUND) end end @@ -494,7 +494,7 @@ def tuple_shape(tuple) block: nil ), MethodType.new( - type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)], + type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)], type: Function.new( params: Function::Params.build( required: [ @@ -508,7 +508,7 @@ def tuple_shape(tuple) block: nil ), MethodType.new( - type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)], + type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)], type: Function.new( params: Function::Params.build(required: [AST::Types::Literal.new(value: index)]), return_type: AST::Types::Union.build(types: [elem_type, AST::Types::Var.new(name: :T)]), @@ -658,7 +658,7 @@ def record_shape(record) block: nil ), MethodType.new( - type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)], + type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)], type: Function.new( params: Function::Params.build(required: [key_type, AST::Types::Var.new(name: :T)]), return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]), @@ -667,7 +667,7 @@ def record_shape(record) block: nil ), MethodType.new( - type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)], + type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)], type: Function.new( params: Function::Params.build(required: [key_type]), return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]), diff --git a/lib/steep/interface/type_param.rb b/lib/steep/interface/type_param.rb index e848f01a5..486f2377f 100644 --- a/lib/steep/interface/type_param.rb +++ b/lib/steep/interface/type_param.rb @@ -8,13 +8,15 @@ class TypeParam attr_reader :variance attr_reader :unchecked attr_reader :location + attr_reader :default_type - def initialize(name:, upper_bound:, variance:, unchecked:, location: nil) + def initialize(name:, upper_bound:, variance:, unchecked:, location: nil, default_type:) @name = name @upper_bound = upper_bound @variance = variance @unchecked = unchecked @location = location + @default_type = default_type end def ==(other) @@ -22,13 +24,14 @@ def ==(other) other.name == name && other.upper_bound == upper_bound && other.variance == variance && - other.unchecked == unchecked + other.unchecked == unchecked && + other.default_type == default_type end alias eql? == def hash - name.hash ^ upper_bound.hash ^ variance.hash ^ unchecked.hash + name.hash ^ upper_bound.hash ^ variance.hash ^ unchecked.hash ^ default_type.hash end def self.rename(params, conflicting_names = params.map(&:name), new_names = conflicting_names.map {|n| AST::Types::Var.fresh_name(n) }) @@ -46,7 +49,8 @@ def self.rename(params, conflicting_names = params.map(&:name), new_names = conf upper_bound: param.upper_bound&.subst(subst), variance: param.variance, unchecked: param.unchecked, - location: param.location + location: param.location, + default_type: param.default_type&.subst(subst) ) else param @@ -82,19 +86,28 @@ def to_s buf end - def update(name: self.name, upper_bound: self.upper_bound, variance: self.variance, unchecked: self.unchecked, location: self.location) + def update(name: self.name, upper_bound: self.upper_bound, variance: self.variance, unchecked: self.unchecked, location: self.location, default_type: self.default_type) TypeParam.new( name: name, upper_bound: upper_bound, variance: variance, unchecked: unchecked, - location: location + location: location, + default_type: default_type ) end def subst(s) if u = upper_bound - update(upper_bound: u.subst(s)) + ub = u.subst(s) + end + + if d = default_type + dt = d.subst(s) + end + + if ub || dt + update(upper_bound: ub, default_type: dt) else self end diff --git a/lib/steep/signature/validator.rb b/lib/steep/signature/validator.rb index a0a8c19fd..3c47166d4 100644 --- a/lib/steep/signature/validator.rb +++ b/lib/steep/signature/validator.rb @@ -101,7 +101,8 @@ def validate_type_application_constraints(type_name, type_params, type_args, loc name: param.name, upper_bound: upper_bound_type, variance: param.variance, - unchecked: param.unchecked? + unchecked: param.unchecked?, + default_type: factory.type_opt(param.default_type) ), location: location ) diff --git a/lib/steep/type_construction.rb b/lib/steep/type_construction.rb index 725bbde45..b3ee5cadf 100644 --- a/lib/steep/type_construction.rb +++ b/lib/steep/type_construction.rb @@ -407,7 +407,8 @@ def for_module(node, new_module_name) name: param.name, upper_bound: checker.factory.type_opt(param.upper_bound_type), variance: param.variance, - unchecked: param.unchecked? + unchecked: param.unchecked?, + default_type: checker.factory.type_opt(param.default_type) ) end variable_context = TypeInference::Context::TypeVariableContext.new(type_params) @@ -497,7 +498,8 @@ def for_class(node, new_class_name, super_class_name) upper_bound: type_param.upper_bound_type&.yield_self {|t| checker.factory.type(t) }, variance: type_param.variance, unchecked: type_param.unchecked?, - location: type_param.location + location: type_param.location, + default_type: checker.factory.type_opt(type_param.default_type) ) end variable_context = TypeInference::Context::TypeVariableContext.new(type_params) diff --git a/sig/steep/ast/types/factory.rbs b/sig/steep/ast/types/factory.rbs index 42f9f6d8b..f085e9407 100644 --- a/sig/steep/ast/types/factory.rbs +++ b/sig/steep/ast/types/factory.rbs @@ -20,6 +20,8 @@ module Steep def type_name_resolver: () -> TypeNameResolver + def normalize_args: (RBS::TypeName type_name, Array[RBS::Types::t]) -> Array[RBS::Types::t] + def type: (RBS::Types::t `type`) -> t def type_opt: (RBS::Types::t? `type`) -> t? diff --git a/sig/steep/interface/type_param.rbs b/sig/steep/interface/type_param.rbs index 06fa1f666..b08edac40 100644 --- a/sig/steep/interface/type_param.rbs +++ b/sig/steep/interface/type_param.rbs @@ -21,7 +21,9 @@ module Steep attr_reader location: loc? - def initialize: (name: Symbol, upper_bound: AST::Types::t?, variance: variance, unchecked: bool, ?location: loc?) -> void + attr_reader default_type: AST::Types::t? + + def initialize: (name: Symbol, upper_bound: AST::Types::t?, variance: variance, unchecked: bool, ?location: loc?, default_type: AST::Types::t?) -> void def ==: (untyped other) -> bool @@ -41,7 +43,7 @@ module Steep def to_s: () -> String - def update: (?name: Symbol, ?upper_bound: AST::Types::t?, ?variance: variance, ?unchecked: bool, ?location: loc?) -> TypeParam + def update: (?name: Symbol, ?upper_bound: AST::Types::t?, ?variance: variance, ?unchecked: bool, ?location: loc?, ?default_type: AST::Types::t?) -> TypeParam def subst: (Substitution s) -> TypeParam end diff --git a/test/subtyping_test.rb b/test/subtyping_test.rb index 739be661e..0e5541148 100644 --- a/test/subtyping_test.rb +++ b/test/subtyping_test.rb @@ -1009,7 +1009,8 @@ def type_params(checker, **params) name: name, upper_bound: upper_bound, variance: :invariant, - unchecked: false + unchecked: false, + default_type: nil ) end end diff --git a/test/type_check_test.rb b/test/type_check_test.rb index 795737a9d..676702bc0 100644 --- a/test/type_check_test.rb +++ b/test/type_check_test.rb @@ -1433,6 +1433,8 @@ def foo: () -> void def test_type_assertion__type_error run_type_check_test( signatures: { + "a.rbs" => <<~RBS + RBS }, code: { "a.rb" => <<~RUBY @@ -2216,4 +2218,39 @@ def foo(x) YAML ) end + + def test_generics_upperbound_default + run_type_check_test( + signatures: { + "a.rbs" => <<~RBS + class Foo[X = Integer] + def foo: (X) -> X + end + RBS + }, + code: { + "a.rb" => <<~RUBY + class Foo + def foo(x) + x + end + end + + x = Foo.new + x.foo(1) + 1 + + y = Foo.new #: Foo[String] + y.foo("foo") + "" + + z = Foo.new #: Foo + z.foo(1) + 1 + RUBY + }, + expectations: <<~YAML + --- + - file: a.rb + diagnostics: [] + YAML + ) + end end