diff --git a/lib/steep/source.rb b/lib/steep/source.rb index a77a1da45..64a2e542a 100644 --- a/lib/steep/source.rb +++ b/lib/steep/source.rb @@ -459,12 +459,14 @@ def self.insert_type_node(node, comments) case node.type when :lvasgn, :ivasgn, :gvasgn, :cvasgn, :casgn # Skip + when :return, :break, :next + # Skip + when :def, :defs + # Skip when :masgn lhs, rhs = node.children node = node.updated(nil, [lhs, insert_type_node(rhs, comments)]) return adjust_location(node) - when :return, :break, :next - # Skip when :begin location = node.loc #: Parser::Source::Map & Parser::AST::_Collection if location.begin @@ -549,6 +551,21 @@ def self.insert_type_node(node, comments) ] ) ) + when :def + name, args, body = node.children + assertion_location = args&.location&.expression || (_ = node.location).name + no_assertion_comments = comments.except(assertion_location.last_line) + args = insert_type_node(args, no_assertion_comments) + body = insert_type_node(body, comments) if body + return adjust_location(node.updated(nil, [name, args, body])) + when :defs + object, name, args, body = node.children + assertion_location = args&.location&.expression || (_ = node.location).name + no_assertion_comments = comments.except(assertion_location.last_line) + object = insert_type_node(object, no_assertion_comments) + args = insert_type_node(args, no_assertion_comments) + body = insert_type_node(body, comments) if body + return adjust_location(node.updated(nil, [object, name, args, body])) else adjust_location( map_child_node(node, nil) {|child| insert_type_node(child, comments) } diff --git a/sig/steep/source.rbs b/sig/steep/source.rbs index 6c1ce0927..15f92be6e 100644 --- a/sig/steep/source.rbs +++ b/sig/steep/source.rbs @@ -94,6 +94,8 @@ module Steep def self.adjust_location: (Parser::AST::Node) -> Parser::AST::Node + # Returns an `:assertion` node with `TypeAssertion` + # def self.assertion_node: (Parser::AST::Node, AST::Node::TypeAssertion) -> Parser::AST::Node def self.type_application_node: (Parser::AST::Node, AST::Node::TypeApplication) -> Parser::AST::Node diff --git a/test/source_test.rb b/test/source_test.rb index 6c0d782f6..6ad5dd2fc 100644 --- a/test/source_test.rb +++ b/test/source_test.rb @@ -712,6 +712,73 @@ def test_assertion__no_skip end end + def test_assertion__def + with_factory do |factory| + source = Steep::Source.parse(<<~RUBY, path: Pathname("foo.rb"), factory: factory) + def foo(x) #: String + end + + def bar #: void + end + + def baz(x) = 123 #: Numeric + RUBY + + source.node.children[0].tap do |node| + assert_equal :def, node.type + assert_equal :args, dig(node, 1).type + assert_equal :arg, dig(node, 1, 0).type + assert_nil dig(node, 2) + end + source.node.children[1].tap do |node| + assert_equal :def, node.type + assert_equal :args, dig(node, 1).type + assert_nil dig(node, 2) + end + source.node.children[2].tap do |node| + assert_equal :def, node.type + assert_equal :args, dig(node, 1).type + assert_equal :arg, dig(node, 1, 0).type + assert_equal :assertion, dig(node, 2).type + end + end + end + + def test_assertion__defs + with_factory do |factory| + source = Steep::Source.parse(<<~RUBY, path: Pathname("foo.rb"), factory: factory) + def self.foo(x) #: String + end + + def self.bar #: void + end + + def self.baz(x) = 123 #: Numeric + RUBY + + source.node.children[0].tap do |node| + assert_equal :defs, node.type + assert_equal :self, dig(node, 0).type + assert_equal :args, dig(node, 2).type + assert_equal :arg, dig(node, 2, 0).type + assert_nil dig(node, 3) + end + source.node.children[1].tap do |node| + assert_equal :defs, node.type + assert_equal :self, dig(node, 0).type + assert_equal :args, dig(node, 2).type + assert_nil dig(node, 3) + end + source.node.children[2].tap do |node| + assert_equal :defs, node.type + assert_equal :self, dig(node, 0).type + assert_equal :args, dig(node, 2).type + assert_equal :arg, dig(node, 2, 0).type + assert_equal :assertion, dig(node, 3).type + end + end + end + def test_tapp_send with_factory({ Pathname("foo.rbs") => <<-RBS }) do |factory| class Tapp diff --git a/test/type_check_test.rb b/test/type_check_test.rb index 68783f0ad..99310de85 100644 --- a/test/type_check_test.rb +++ b/test/type_check_test.rb @@ -2077,4 +2077,30 @@ def f(&block) YAML ) end + + def test_args_annotation + run_type_check_test( + signatures: { + "a.rbs" => <<~RBS + class Hello + def foo: (String) -> Integer + end + RBS + }, + code: { + "a.rb" => <<~RUBY + class Hello + def foo(x) #: Integer + 3 + end + end + RUBY + }, + expectations: <<~YAML + --- + - file: a.rb + diagnostics: [] + YAML + ) + end end