Skip to content

Commit

Permalink
Support default types of generic params
Browse files Browse the repository at this point in the history
  • Loading branch information
soutaro committed Sep 9, 2024
1 parent 6c73671 commit 6798eb0
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 22 deletions.
32 changes: 28 additions & 4 deletions lib/steep/ast/types/factory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ def type_1_opt(type)
end
end

def normalize_args(type_name, args)
case
when type_name.class?
if entry = env.normalized_module_class_entry(type_name)
type_params = entry.type_params
end
when type_name.interface?
if entry = env.interface_decls.fetch(type_name, nil)
type_params = entry.decl.type_params
end
when type_name.alias?
if entry = env.type_alias_decls.fetch(type_name, nil)
type_params = entry.decl.type_params
end
end

if type_params && !type_params.empty?
RBS::AST::TypeParam.normalize_args(type_params, args)
else
args
end
end

def type(type)
if ty = type_cache[type]
return ty
Expand Down Expand Up @@ -68,15 +91,15 @@ def type(type)
Name::Singleton.new(name: type_name)
when RBS::Types::ClassInstance
type_name = type.name
args = type.args.map {|arg| type(arg) }
args = normalize_args(type_name, type.args).map {|arg| type(arg) }
Name::Instance.new(name: type_name, args: args)
when RBS::Types::Interface
type_name = type.name
args = type.args.map {|arg| type(arg) }
args = normalize_args(type_name, type.args).map {|arg| type(arg) }
Name::Interface.new(name: type_name, args: args)
when RBS::Types::Alias
type_name = type.name
args = type.args.map {|arg| type(arg) }
args = normalize_args(type_name, type.args).map {|arg| type(arg) }
Name::Alias.new(name: type_name, args: args)
when RBS::Types::Union
Union.build(types: type.types.map {|ty| type(ty) })
Expand Down Expand Up @@ -247,7 +270,8 @@ def type_param(type_param)
name: type_param.name,
upper_bound: type_opt(type_param.upper_bound_type),
variance: type_param.variance,
unchecked: type_param.unchecked?
unchecked: type_param.unchecked?,
default_type: type_opt(type_param.default_type)
)
end

Expand Down
10 changes: 5 additions & 5 deletions lib/steep/interface/builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def validate_fvs(name, type)
end

def upper_bound(a)
variable_bounds.fetch(a, AST::Builtin::Object.instance_type)
variable_bounds.fetch(a, Interface::TypeParam::IMPLICIT_UPPER_BOUND)
end
end

Expand Down Expand Up @@ -494,7 +494,7 @@ def tuple_shape(tuple)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(
required: [
Expand All @@ -508,7 +508,7 @@ def tuple_shape(tuple)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(required: [AST::Types::Literal.new(value: index)]),
return_type: AST::Types::Union.build(types: [elem_type, AST::Types::Var.new(name: :T)]),
Expand Down Expand Up @@ -658,7 +658,7 @@ def record_shape(record)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(required: [key_type, AST::Types::Var.new(name: :T)]),
return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]),
Expand All @@ -667,7 +667,7 @@ def record_shape(record)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(required: [key_type]),
return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]),
Expand Down
27 changes: 20 additions & 7 deletions lib/steep/interface/type_param.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,30 @@ class TypeParam
attr_reader :variance
attr_reader :unchecked
attr_reader :location
attr_reader :default_type

def initialize(name:, upper_bound:, variance:, unchecked:, location: nil)
def initialize(name:, upper_bound:, variance:, unchecked:, location: nil, default_type:)
@name = name
@upper_bound = upper_bound
@variance = variance
@unchecked = unchecked
@location = location
@default_type = default_type
end

def ==(other)
other.is_a?(TypeParam) &&
other.name == name &&
other.upper_bound == upper_bound &&
other.variance == variance &&
other.unchecked == unchecked
other.unchecked == unchecked &&
other.default_type == default_type
end

alias eql? ==

def hash
name.hash ^ upper_bound.hash ^ variance.hash ^ unchecked.hash
name.hash ^ upper_bound.hash ^ variance.hash ^ unchecked.hash ^ default_type.hash
end

def self.rename(params, conflicting_names = params.map(&:name), new_names = conflicting_names.map {|n| AST::Types::Var.fresh_name(n) })
Expand All @@ -46,7 +49,8 @@ def self.rename(params, conflicting_names = params.map(&:name), new_names = conf
upper_bound: param.upper_bound&.subst(subst),
variance: param.variance,
unchecked: param.unchecked,
location: param.location
location: param.location,
default_type: param.default_type&.subst(subst)
)
else
param
Expand Down Expand Up @@ -82,19 +86,28 @@ def to_s
buf
end

def update(name: self.name, upper_bound: self.upper_bound, variance: self.variance, unchecked: self.unchecked, location: self.location)
def update(name: self.name, upper_bound: self.upper_bound, variance: self.variance, unchecked: self.unchecked, location: self.location, default_type: self.default_type)
TypeParam.new(
name: name,
upper_bound: upper_bound,
variance: variance,
unchecked: unchecked,
location: location
location: location,
default_type: default_type
)
end

def subst(s)
if u = upper_bound
update(upper_bound: u.subst(s))
ub = u.subst(s)
end

if d = default_type
dt = d.subst(s)
end

if ub || dt
update(upper_bound: ub, default_type: dt)
else
self
end
Expand Down
3 changes: 2 additions & 1 deletion lib/steep/signature/validator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def validate_type_application_constraints(type_name, type_params, type_args, loc
name: param.name,
upper_bound: upper_bound_type,
variance: param.variance,
unchecked: param.unchecked?
unchecked: param.unchecked?,
default_type: factory.type_opt(param.default_type)
),
location: location
)
Expand Down
6 changes: 4 additions & 2 deletions lib/steep/type_construction.rb
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ def for_module(node, new_module_name)
name: param.name,
upper_bound: checker.factory.type_opt(param.upper_bound_type),
variance: param.variance,
unchecked: param.unchecked?
unchecked: param.unchecked?,
default_type: checker.factory.type_opt(param.default_type)
)
end
variable_context = TypeInference::Context::TypeVariableContext.new(type_params)
Expand Down Expand Up @@ -497,7 +498,8 @@ def for_class(node, new_class_name, super_class_name)
upper_bound: type_param.upper_bound_type&.yield_self {|t| checker.factory.type(t) },
variance: type_param.variance,
unchecked: type_param.unchecked?,
location: type_param.location
location: type_param.location,
default_type: checker.factory.type_opt(type_param.default_type)
)
end
variable_context = TypeInference::Context::TypeVariableContext.new(type_params)
Expand Down
2 changes: 2 additions & 0 deletions sig/steep/ast/types/factory.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ module Steep

def type_name_resolver: () -> TypeNameResolver

def normalize_args: (RBS::TypeName type_name, Array[RBS::Types::t]) -> Array[RBS::Types::t]

def type: (RBS::Types::t `type`) -> t

def type_opt: (RBS::Types::t? `type`) -> t?
Expand Down
6 changes: 4 additions & 2 deletions sig/steep/interface/type_param.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ module Steep

attr_reader location: loc?

def initialize: (name: Symbol, upper_bound: AST::Types::t?, variance: variance, unchecked: bool, ?location: loc?) -> void
attr_reader default_type: AST::Types::t?

def initialize: (name: Symbol, upper_bound: AST::Types::t?, variance: variance, unchecked: bool, ?location: loc?, default_type: AST::Types::t?) -> void

def ==: (untyped other) -> bool

Expand All @@ -41,7 +43,7 @@ module Steep

def to_s: () -> String

def update: (?name: Symbol, ?upper_bound: AST::Types::t?, ?variance: variance, ?unchecked: bool, ?location: loc?) -> TypeParam
def update: (?name: Symbol, ?upper_bound: AST::Types::t?, ?variance: variance, ?unchecked: bool, ?location: loc?, ?default_type: AST::Types::t?) -> TypeParam

def subst: (Substitution s) -> TypeParam
end
Expand Down
3 changes: 2 additions & 1 deletion test/subtyping_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,8 @@ def type_params(checker, **params)
name: name,
upper_bound: upper_bound,
variance: :invariant,
unchecked: false
unchecked: false,
default_type: nil
)
end
end
Expand Down
37 changes: 37 additions & 0 deletions test/type_check_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,8 @@ def foo: () -> void
def test_type_assertion__type_error
run_type_check_test(
signatures: {
"a.rbs" => <<~RBS
RBS
},
code: {
"a.rb" => <<~RUBY
Expand Down Expand Up @@ -2216,4 +2218,39 @@ def foo(x)
YAML
)
end

def test_generics_upperbound_default
run_type_check_test(
signatures: {
"a.rbs" => <<~RBS
class Foo[X = Integer]
def foo: (X) -> X
end
RBS
},
code: {
"a.rb" => <<~RUBY
class Foo
def foo(x)
x
end
end
x = Foo.new
x.foo(1) + 1
y = Foo.new #: Foo[String]
y.foo("foo") + ""
z = Foo.new #: Foo
z.foo(1) + 1
RUBY
},
expectations: <<~YAML
---
- file: a.rb
diagnostics: []
YAML
)
end
end

0 comments on commit 6798eb0

Please sign in to comment.