Skip to content

Commit

Permalink
Merge pull request #1216 from soutaro/any_type_upperbound
Browse files Browse the repository at this point in the history
Improve generics
  • Loading branch information
soutaro authored Sep 10, 2024
2 parents 95da23f + 9ec57bb commit 115a5d0
Show file tree
Hide file tree
Showing 20 changed files with 284 additions and 104 deletions.
2 changes: 1 addition & 1 deletion Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ group :development, optional: true do
gem "majo"
end

gem "rbs"
# gem "rbs", path: "../rbs"
5 changes: 2 additions & 3 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ PATH
logger (>= 1.3.0)
parser (>= 3.1)
rainbow (>= 2.2.2, < 4.0)
rbs (>= 3.5.0.pre)
rbs (~> 3.6.0.dev)
securerandom (>= 0.1)
strscan (>= 1.0.0)
terminal-table (>= 2, < 4)
Expand Down Expand Up @@ -73,7 +73,7 @@ GEM
rb-fsevent (0.11.2)
rb-inotify (0.11.1)
ffi (~> 1.0)
rbs (3.5.2)
rbs (3.6.0.dev.1)
logger
rdoc (6.7.0)
psych (>= 4.0.0)
Expand Down Expand Up @@ -101,7 +101,6 @@ DEPENDENCIES
minitest-hooks
minitest-slow_test
rake
rbs
stackprof
steep!
vernier (~> 1.0)
Expand Down
5 changes: 5 additions & 0 deletions lib/steep.rb
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def self.ui_logger
def self.new_logger(output, prev_level)
logger = Logger.new(output)
logger.formatter = proc do |severity, datetime, progname, msg|
# @type var severity: String
# @type var datetime: Time
# @type var progname: untyped
# @type var msg: untyped
# @type block: String
"#{datetime.strftime('%Y-%m-%d %H:%M:%S.%L')}: #{severity}: #{msg}\n"
end
ActiveSupport::TaggedLogging.new(logger).tap do |logger|
Expand Down
34 changes: 29 additions & 5 deletions lib/steep/ast/types/factory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ def type_1_opt(type)
end
end

def normalize_args(type_name, args)
case
when type_name.class?
if entry = env.normalized_module_class_entry(type_name)
type_params = entry.type_params
end
when type_name.interface?
if entry = env.interface_decls.fetch(type_name, nil)
type_params = entry.decl.type_params
end
when type_name.alias?
if entry = env.type_alias_decls.fetch(type_name, nil)
type_params = entry.decl.type_params
end
end

if type_params && !type_params.empty?
RBS::AST::TypeParam.normalize_args(type_params, args)
else
args
end
end

def type(type)
if ty = type_cache[type]
return ty
Expand Down Expand Up @@ -68,15 +91,15 @@ def type(type)
Name::Singleton.new(name: type_name)
when RBS::Types::ClassInstance
type_name = type.name
args = type.args.map {|arg| type(arg) }
args = normalize_args(type_name, type.args).map {|arg| type(arg) }
Name::Instance.new(name: type_name, args: args)
when RBS::Types::Interface
type_name = type.name
args = type.args.map {|arg| type(arg) }
args = normalize_args(type_name, type.args).map {|arg| type(arg) }
Name::Interface.new(name: type_name, args: args)
when RBS::Types::Alias
type_name = type.name
args = type.args.map {|arg| type(arg) }
args = normalize_args(type_name, type.args).map {|arg| type(arg) }
Name::Alias.new(name: type_name, args: args)
when RBS::Types::Union
Union.build(types: type.types.map {|ty| type(ty) })
Expand Down Expand Up @@ -245,9 +268,10 @@ def params(type)
def type_param(type_param)
Interface::TypeParam.new(
name: type_param.name,
upper_bound: type_opt(type_param.upper_bound),
upper_bound: type_opt(type_param.upper_bound_type),
variance: type_param.variance,
unchecked: type_param.unchecked?
unchecked: type_param.unchecked?,
default_type: type_opt(type_param.default_type)
)
end

Expand Down
10 changes: 5 additions & 5 deletions lib/steep/interface/builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def validate_fvs(name, type)
end

def upper_bound(a)
variable_bounds.fetch(a, nil)
variable_bounds.fetch(a, Interface::TypeParam::IMPLICIT_UPPER_BOUND)
end
end

Expand Down Expand Up @@ -494,7 +494,7 @@ def tuple_shape(tuple)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(
required: [
Expand All @@ -508,7 +508,7 @@ def tuple_shape(tuple)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(required: [AST::Types::Literal.new(value: index)]),
return_type: AST::Types::Union.build(types: [elem_type, AST::Types::Var.new(name: :T)]),
Expand Down Expand Up @@ -658,7 +658,7 @@ def record_shape(record)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(required: [key_type, AST::Types::Var.new(name: :T)]),
return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]),
Expand All @@ -667,7 +667,7 @@ def record_shape(record)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(required: [key_type]),
return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]),
Expand Down
29 changes: 22 additions & 7 deletions lib/steep/interface/type_param.rb
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
module Steep
module Interface
class TypeParam
IMPLICIT_UPPER_BOUND = AST::Builtin.optional(AST::Builtin::Object.instance_type)

attr_reader :name
attr_reader :upper_bound
attr_reader :variance
attr_reader :unchecked
attr_reader :location
attr_reader :default_type

def initialize(name:, upper_bound:, variance:, unchecked:, location: nil)
def initialize(name:, upper_bound:, variance:, unchecked:, location: nil, default_type:)
@name = name
@upper_bound = upper_bound
@variance = variance
@unchecked = unchecked
@location = location
@default_type = default_type
end

def ==(other)
other.is_a?(TypeParam) &&
other.name == name &&
other.upper_bound == upper_bound &&
other.variance == variance &&
other.unchecked == unchecked
other.unchecked == unchecked &&
other.default_type == default_type
end

alias eql? ==

def hash
name.hash ^ upper_bound.hash ^ variance.hash ^ unchecked.hash
name.hash ^ upper_bound.hash ^ variance.hash ^ unchecked.hash ^ default_type.hash
end

def self.rename(params, conflicting_names = params.map(&:name), new_names = conflicting_names.map {|n| AST::Types::Var.fresh_name(n) })
Expand All @@ -44,7 +49,8 @@ def self.rename(params, conflicting_names = params.map(&:name), new_names = conf
upper_bound: param.upper_bound&.subst(subst),
variance: param.variance,
unchecked: param.unchecked,
location: param.location
location: param.location,
default_type: param.default_type&.subst(subst)
)
else
param
Expand Down Expand Up @@ -80,19 +86,28 @@ def to_s
buf
end

def update(name: self.name, upper_bound: self.upper_bound, variance: self.variance, unchecked: self.unchecked, location: self.location)
def update(name: self.name, upper_bound: self.upper_bound, variance: self.variance, unchecked: self.unchecked, location: self.location, default_type: self.default_type)
TypeParam.new(
name: name,
upper_bound: upper_bound,
variance: variance,
unchecked: unchecked,
location: location
location: location,
default_type: default_type
)
end

def subst(s)
if u = upper_bound
update(upper_bound: u.subst(s))
ub = u.subst(s)
end

if d = default_type
dt = d.subst(s)
end

if ub || dt
update(upper_bound: ub, default_type: dt)
else
self
end
Expand Down
4 changes: 2 additions & 2 deletions lib/steep/server/lsp_formatter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def name_and_params(name, params)
end
s << param.name.to_s

if param.upper_bound
s << " < #{param.upper_bound.to_s}"
if param.upper_bound_type
s << " < #{param.upper_bound_type.to_s}"
end

s
Expand Down
15 changes: 8 additions & 7 deletions lib/steep/signature/validator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def validate_type_application_constraints(type_name, type_params, type_args, loc
type_params.zip(type_args).each do |param, arg|
arg or raise

if param.upper_bound
upper_bound_type = factory.type(param.upper_bound).subst(subst)
if param.upper_bound_type
upper_bound_type = factory.type(param.upper_bound_type).subst(subst)
arg_type = factory.type(arg)

constraints = Subtyping::Constraints.empty
Expand All @@ -101,7 +101,8 @@ def validate_type_application_constraints(type_name, type_params, type_args, loc
name: param.name,
upper_bound: upper_bound_type,
variance: param.variance,
unchecked: param.unchecked?
unchecked: param.unchecked?,
default_type: factory.type_opt(param.default_type)
),
location: location
)
Expand Down Expand Up @@ -236,7 +237,7 @@ def each_variable_type(definition)
def validate_definition_type(definition)
each_method_type(definition) do |method_type|
upper_bounds = method_type.type_params.each.with_object({}) do |param, hash|
hash[param.name] = factory.type_opt(param.upper_bound)
hash[param.name] = factory.type_opt(param.upper_bound_type)
end

checker.push_variable_bounds(upper_bounds) do
Expand Down Expand Up @@ -264,7 +265,7 @@ def validate_one_class_decl(name, entry)
Steep.logger.tagged "#{name}" do
builder.build_instance(name).tap do |definition|
upper_bounds = definition.type_params_decl.each.with_object({}) do |param, bounds|
bounds[param.name] = factory.type_opt(param.upper_bound)
bounds[param.name] = factory.type_opt(param.upper_bound_type)
end

self_type = AST::Types::Name::Instance.new(
Expand Down Expand Up @@ -480,7 +481,7 @@ def validate_one_interface(name)
definition = builder.build_interface(name)

upper_bounds = definition.type_params_decl.each.with_object({}) do |param, bounds|
bounds[param.name] = factory.type_opt(param.upper_bound)
bounds[param.name] = factory.type_opt(param.upper_bound_type)
end

self_type = AST::Types::Name::Interface.new(
Expand Down Expand Up @@ -575,7 +576,7 @@ def validate_one_alias(name, entry = env.type_alias_decls[name])
end

upper_bounds = entry.decl.type_params.each.with_object({}) do |param, bounds|
bounds[param.name] = factory.type_opt(param.upper_bound)
bounds[param.name] = factory.type_opt(param.upper_bound_type)
end

validator.validate_type_alias(entry: entry) do |type|
Expand Down
Loading

0 comments on commit 115a5d0

Please sign in to comment.