Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve generics #1216

Merged
merged 11 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ group :development, optional: true do
gem "majo"
end

gem "rbs"
# gem "rbs", path: "../rbs"
5 changes: 2 additions & 3 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ PATH
logger (>= 1.3.0)
parser (>= 3.1)
rainbow (>= 2.2.2, < 4.0)
rbs (>= 3.5.0.pre)
rbs (~> 3.6.0.dev)
securerandom (>= 0.1)
strscan (>= 1.0.0)
terminal-table (>= 2, < 4)
Expand Down Expand Up @@ -73,7 +73,7 @@ GEM
rb-fsevent (0.11.2)
rb-inotify (0.11.1)
ffi (~> 1.0)
rbs (3.5.2)
rbs (3.6.0.dev.1)
logger
rdoc (6.7.0)
psych (>= 4.0.0)
Expand Down Expand Up @@ -101,7 +101,6 @@ DEPENDENCIES
minitest-hooks
minitest-slow_test
rake
rbs
stackprof
steep!
vernier (~> 1.0)
Expand Down
5 changes: 5 additions & 0 deletions lib/steep.rb
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def self.ui_logger
def self.new_logger(output, prev_level)
logger = Logger.new(output)
logger.formatter = proc do |severity, datetime, progname, msg|
# @type var severity: String
# @type var datetime: Time
# @type var progname: untyped
# @type var msg: untyped
# @type block: String
"#{datetime.strftime('%Y-%m-%d %H:%M:%S.%L')}: #{severity}: #{msg}\n"
end
ActiveSupport::TaggedLogging.new(logger).tap do |logger|
Expand Down
34 changes: 29 additions & 5 deletions lib/steep/ast/types/factory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ def type_1_opt(type)
end
end

def normalize_args(type_name, args)
case
when type_name.class?
if entry = env.normalized_module_class_entry(type_name)
type_params = entry.type_params
end
when type_name.interface?
if entry = env.interface_decls.fetch(type_name, nil)
type_params = entry.decl.type_params
end
when type_name.alias?
if entry = env.type_alias_decls.fetch(type_name, nil)
type_params = entry.decl.type_params
end
end

if type_params && !type_params.empty?
RBS::AST::TypeParam.normalize_args(type_params, args)
else
args
end
end

def type(type)
if ty = type_cache[type]
return ty
Expand Down Expand Up @@ -68,15 +91,15 @@ def type(type)
Name::Singleton.new(name: type_name)
when RBS::Types::ClassInstance
type_name = type.name
args = type.args.map {|arg| type(arg) }
args = normalize_args(type_name, type.args).map {|arg| type(arg) }
Name::Instance.new(name: type_name, args: args)
when RBS::Types::Interface
type_name = type.name
args = type.args.map {|arg| type(arg) }
args = normalize_args(type_name, type.args).map {|arg| type(arg) }
Name::Interface.new(name: type_name, args: args)
when RBS::Types::Alias
type_name = type.name
args = type.args.map {|arg| type(arg) }
args = normalize_args(type_name, type.args).map {|arg| type(arg) }
Name::Alias.new(name: type_name, args: args)
when RBS::Types::Union
Union.build(types: type.types.map {|ty| type(ty) })
Expand Down Expand Up @@ -245,9 +268,10 @@ def params(type)
def type_param(type_param)
Interface::TypeParam.new(
name: type_param.name,
upper_bound: type_opt(type_param.upper_bound),
upper_bound: type_opt(type_param.upper_bound_type),
variance: type_param.variance,
unchecked: type_param.unchecked?
unchecked: type_param.unchecked?,
default_type: type_opt(type_param.default_type)
)
end

Expand Down
10 changes: 5 additions & 5 deletions lib/steep/interface/builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def validate_fvs(name, type)
end

def upper_bound(a)
variable_bounds.fetch(a, nil)
variable_bounds.fetch(a, Interface::TypeParam::IMPLICIT_UPPER_BOUND)
end
end

Expand Down Expand Up @@ -494,7 +494,7 @@ def tuple_shape(tuple)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(
required: [
Expand All @@ -508,7 +508,7 @@ def tuple_shape(tuple)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(required: [AST::Types::Literal.new(value: index)]),
return_type: AST::Types::Union.build(types: [elem_type, AST::Types::Var.new(name: :T)]),
Expand Down Expand Up @@ -658,7 +658,7 @@ def record_shape(record)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(required: [key_type, AST::Types::Var.new(name: :T)]),
return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]),
Expand All @@ -667,7 +667,7 @@ def record_shape(record)
block: nil
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false, default_type: nil)],
type: Function.new(
params: Function::Params.build(required: [key_type]),
return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]),
Expand Down
29 changes: 22 additions & 7 deletions lib/steep/interface/type_param.rb
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
module Steep
module Interface
class TypeParam
IMPLICIT_UPPER_BOUND = AST::Builtin.optional(AST::Builtin::Object.instance_type)

attr_reader :name
attr_reader :upper_bound
attr_reader :variance
attr_reader :unchecked
attr_reader :location
attr_reader :default_type

def initialize(name:, upper_bound:, variance:, unchecked:, location: nil)
def initialize(name:, upper_bound:, variance:, unchecked:, location: nil, default_type:)
@name = name
@upper_bound = upper_bound
@variance = variance
@unchecked = unchecked
@location = location
@default_type = default_type
end

def ==(other)
other.is_a?(TypeParam) &&
other.name == name &&
other.upper_bound == upper_bound &&
other.variance == variance &&
other.unchecked == unchecked
other.unchecked == unchecked &&
other.default_type == default_type
end

alias eql? ==

def hash
name.hash ^ upper_bound.hash ^ variance.hash ^ unchecked.hash
name.hash ^ upper_bound.hash ^ variance.hash ^ unchecked.hash ^ default_type.hash
end

def self.rename(params, conflicting_names = params.map(&:name), new_names = conflicting_names.map {|n| AST::Types::Var.fresh_name(n) })
Expand All @@ -44,7 +49,8 @@ def self.rename(params, conflicting_names = params.map(&:name), new_names = conf
upper_bound: param.upper_bound&.subst(subst),
variance: param.variance,
unchecked: param.unchecked,
location: param.location
location: param.location,
default_type: param.default_type&.subst(subst)
)
else
param
Expand Down Expand Up @@ -80,19 +86,28 @@ def to_s
buf
end

def update(name: self.name, upper_bound: self.upper_bound, variance: self.variance, unchecked: self.unchecked, location: self.location)
def update(name: self.name, upper_bound: self.upper_bound, variance: self.variance, unchecked: self.unchecked, location: self.location, default_type: self.default_type)
TypeParam.new(
name: name,
upper_bound: upper_bound,
variance: variance,
unchecked: unchecked,
location: location
location: location,
default_type: default_type
)
end

def subst(s)
if u = upper_bound
update(upper_bound: u.subst(s))
ub = u.subst(s)
end

if d = default_type
dt = d.subst(s)
end

if ub || dt
update(upper_bound: ub, default_type: dt)
else
self
end
Expand Down
4 changes: 2 additions & 2 deletions lib/steep/server/lsp_formatter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def name_and_params(name, params)
end
s << param.name.to_s

if param.upper_bound
s << " < #{param.upper_bound.to_s}"
if param.upper_bound_type
s << " < #{param.upper_bound_type.to_s}"
end

s
Expand Down
15 changes: 8 additions & 7 deletions lib/steep/signature/validator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def validate_type_application_constraints(type_name, type_params, type_args, loc
type_params.zip(type_args).each do |param, arg|
arg or raise

if param.upper_bound
upper_bound_type = factory.type(param.upper_bound).subst(subst)
if param.upper_bound_type
upper_bound_type = factory.type(param.upper_bound_type).subst(subst)
arg_type = factory.type(arg)

constraints = Subtyping::Constraints.empty
Expand All @@ -101,7 +101,8 @@ def validate_type_application_constraints(type_name, type_params, type_args, loc
name: param.name,
upper_bound: upper_bound_type,
variance: param.variance,
unchecked: param.unchecked?
unchecked: param.unchecked?,
default_type: factory.type_opt(param.default_type)
),
location: location
)
Expand Down Expand Up @@ -236,7 +237,7 @@ def each_variable_type(definition)
def validate_definition_type(definition)
each_method_type(definition) do |method_type|
upper_bounds = method_type.type_params.each.with_object({}) do |param, hash|
hash[param.name] = factory.type_opt(param.upper_bound)
hash[param.name] = factory.type_opt(param.upper_bound_type)
end

checker.push_variable_bounds(upper_bounds) do
Expand Down Expand Up @@ -264,7 +265,7 @@ def validate_one_class_decl(name, entry)
Steep.logger.tagged "#{name}" do
builder.build_instance(name).tap do |definition|
upper_bounds = definition.type_params_decl.each.with_object({}) do |param, bounds|
bounds[param.name] = factory.type_opt(param.upper_bound)
bounds[param.name] = factory.type_opt(param.upper_bound_type)
end

self_type = AST::Types::Name::Instance.new(
Expand Down Expand Up @@ -480,7 +481,7 @@ def validate_one_interface(name)
definition = builder.build_interface(name)

upper_bounds = definition.type_params_decl.each.with_object({}) do |param, bounds|
bounds[param.name] = factory.type_opt(param.upper_bound)
bounds[param.name] = factory.type_opt(param.upper_bound_type)
end

self_type = AST::Types::Name::Interface.new(
Expand Down Expand Up @@ -575,7 +576,7 @@ def validate_one_alias(name, entry = env.type_alias_decls[name])
end

upper_bounds = entry.decl.type_params.each.with_object({}) do |param, bounds|
bounds[param.name] = factory.type_opt(param.upper_bound)
bounds[param.name] = factory.type_opt(param.upper_bound_type)
end

validator.validate_type_alias(entry: entry) do |type|
Expand Down
Loading