diff --git a/known_sig/orthoses/trace/attribute.rbs b/known_sig/orthoses/trace/attribute.rbs index e96b069..781f4bc 100644 --- a/known_sig/orthoses/trace/attribute.rbs +++ b/known_sig/orthoses/trace/attribute.rbs @@ -2,6 +2,9 @@ module Orthoses class Trace class Attribute include Orthoses::Trace::Targetable + + def initialize: (Orthoses::_Call loader, patterns: Array[String], ?sort_union_types: bool?) -> void + def call: () -> Orthoses::store end end end diff --git a/known_sig/orthoses/trace/method.rbs b/known_sig/orthoses/trace/method.rbs index 9016e6e..f8a168f 100644 --- a/known_sig/orthoses/trace/method.rbs +++ b/known_sig/orthoses/trace/method.rbs @@ -2,6 +2,9 @@ module Orthoses class Trace class Method include Orthoses::Trace::Targetable + + def initialize: (Orthoses::_Call loader, patterns: Array[String], ?sort_union_types: bool?) -> void + def call: () -> Orthoses::store end end end diff --git a/lib/orthoses/trace/attribute.rb b/lib/orthoses/trace/attribute.rb index 4f1280f..56ab1b0 100644 --- a/lib/orthoses/trace/attribute.rb +++ b/lib/orthoses/trace/attribute.rb @@ -23,9 +23,10 @@ def attr_writer(*names) include Targetable - def initialize(loader, patterns:) + def initialize(loader, patterns:, sort_union_types: true) @loader = loader @patterns = patterns + @sort_union_types = sort_union_types @captured_dict = Hash.new { |h, k| h[k] = Hash.new { |hh, kk| hh[kk] = [] } } end @@ -45,6 +46,7 @@ def call @captured_dict.each do |mod_name, captures| captures.each do |(kind, prefix, name), types| + types.sort! if @sort_union_types injected = Utils::TypeList.new(types).inject store[mod_name] << "#{kind} #{prefix}#{name}: #{injected}" end diff --git a/lib/orthoses/trace/attribute_test.rb b/lib/orthoses/trace/attribute_test.rb index 50af341..c36ef07 100644 --- a/lib/orthoses/trace/attribute_test.rb +++ b/lib/orthoses/trace/attribute_test.rb @@ -12,6 +12,10 @@ class Bar attr_accessor :attr_acce_publ end + class Baz + attr_accessor :multi_types + end + attr_accessor :attr_acce_publ attr_reader :attr_read_publ attr_writer :attr_writ_publ @@ -20,6 +24,7 @@ class << self end attr_accessor :attr_acce_priv + private :attr_acce_priv def initialize @@ -52,7 +57,7 @@ def test_attribute(t) expect = <<~RBS class TraceAttributeTest::Foo attr_accessor attr_acce_priv: Integer - attr_accessor attr_acce_publ: Symbol | String + attr_accessor attr_acce_publ: String | Symbol attr_reader attr_read_publ: Symbol attr_writer attr_writ_publ: Integer attr_accessor self.self_attr_acce_publ: Integer? @@ -66,4 +71,52 @@ class TraceAttributeTest::Foo::Bar t.error("expect=\n```rbs\n#{expect}```\n, but got \n```rbs\n#{actual}```\n") end end + + def test_union_sort(t) + store1 = Orthoses::Trace::Attribute.new(->{ + LOADER_ATTRIBUTE.call + baz = Foo::Baz.new + baz.multi_types = 0 + baz.multi_types = '1' + + Orthoses::Utils.new_store + }, patterns: ['TraceAttributeTest::Foo::Baz']).call + + store2 = Orthoses::Trace::Attribute.new(->{ + LOADER_ATTRIBUTE.call + baz = Foo::Baz.new + baz.multi_types = '1' + baz.multi_types = 0 + + Orthoses::Utils.new_store + }, patterns: ['TraceAttributeTest::Foo::Baz']).call + + expect = store1.map { _2.to_rbs }.join("\n") + actual = store2.map { _2.to_rbs }.join("\n") + unless expect == actual + t.error("expect=\n```rbs\n#{expect}```\n, but got \n```rbs\n#{actual}```\n") + end + end + + def test_without_union_sort(t) + store = Orthoses::Trace::Attribute.new(->{ + LOADER_ATTRIBUTE.call + baz = Foo::Baz.new + # The order of the union types will be the following + baz.multi_types = '1' # String + baz.multi_types = 0 # Integer + + Orthoses::Utils.new_store + }, patterns: ['TraceAttributeTest::Foo::Baz'], sort_union_types: false).call + + expect = store.map { _2.to_rbs }.join("\n") + actual = <<~RBS + class TraceAttributeTest::Foo::Baz + attr_accessor multi_types: String | Integer + end + RBS + unless expect == actual + t.error("expect=\n```rbs\n#{expect}```\n, but got \n```rbs\n#{actual}```\n") + end + end end diff --git a/lib/orthoses/trace/method.rb b/lib/orthoses/trace/method.rb index 68b39c4..7a1e8de 100644 --- a/lib/orthoses/trace/method.rb +++ b/lib/orthoses/trace/method.rb @@ -7,9 +7,10 @@ class Method Info = Struct.new(:key, :op_name_types, :raised, keyword_init: true) include Targetable - def initialize(loader, patterns:) + def initialize(loader, patterns:, sort_union_types: true) @loader = loader @patterns = patterns + @sort_union_types = sort_union_types @stack = [] @args_return_map = Hash.new { |h, k| h[k] = [] } @@ -79,6 +80,7 @@ def build_method_definitions @args_return_map.map do |(mod_name, kind, visibility, method_id), type_samples| type_samples.uniq! + type_samples.sort! if @sort_union_types method_types = type_samples.map do |(op_name_types, return_type)| required_positionals = [] optional_positionals = [] diff --git a/lib/orthoses/trace/method_test.rb b/lib/orthoses/trace/method_test.rb index e8e7466..2be18c4 100644 --- a/lib/orthoses/trace/method_test.rb +++ b/lib/orthoses/trace/method_test.rb @@ -44,6 +44,15 @@ def if_raise(a) end end + def multi_types(key) + { + 0 => 0, + 1 => '1', + '2' => '2', + '3' => 3, + }[key] + end + private def priv(bool) @@ -131,4 +140,60 @@ def test_raise_first(t) Orthoses::Utils.new_store }, patterns: ['TraceMethodTest']).call end + + def test_union_sort(t) + store1 = Orthoses::Trace::Method.new(-> { + LOADER_METHOD.call + m = M.new(100) + m.multi_types 0 + m.multi_types 1 + m.multi_types '2' + m.multi_types '3' + + Orthoses::Utils.new_store + }, patterns: ['TraceMethodTest::M']).call + + store2 = Orthoses::Trace::Method.new(-> { + LOADER_METHOD.call + m = M.new(100) + m.multi_types '3' + m.multi_types '2' + m.multi_types 1 + m.multi_types 0 + + Orthoses::Utils.new_store + }, patterns: ['TraceMethodTest::M']).call + + expect = store1.map { _2.to_rbs }.join("\n") + actual = store2.map { _2.to_rbs }.join("\n") + unless expect == actual + t.error("expect=\n```rbs\n#{expect}```\n, but got \n```rbs\n#{actual}```\n") + end + end + + def test_without_union_sort(t) + store = Orthoses::Trace::Method.new(-> { + LOADER_METHOD.call + m = M.new(100) + # The order of the union types will be the following + m.multi_types '3' # (String) -> Integer + m.multi_types '2' # (String) -> String + m.multi_types 1 # (Integer) -> String + m.multi_types 0 # (Integer) -> Integer + + Orthoses::Utils.new_store + }, patterns: ['TraceMethodTest::M'], sort_union_types: false).call + + actual = store.map { _2.to_rbs }.join("\n") + expect = <<~RBS + class TraceMethodTest::M + private def initialize: (Integer a) -> void + def multi_types: (String key) -> (Integer | String) + | (Integer key) -> (String | Integer) + end + RBS + unless expect == actual + t.error("expect=\n```rbs\n#{expect}```\n, but got \n```rbs\n#{actual}```\n") + end + end end diff --git a/sig/orthoses/trace.rbs b/sig/orthoses/trace.rbs index ede5608..dad4105 100644 --- a/sig/orthoses/trace.rbs +++ b/sig/orthoses/trace.rbs @@ -10,9 +10,10 @@ end class Orthoses::Trace::Attribute @loader: untyped @patterns: untyped + @sort_union_types: untyped @captured_dict: untyped - def initialize: (untyped loader, patterns: untyped) -> void - def call: () -> untyped + def initialize: (Orthoses::_Call loader, patterns: Array[String], ?sort_union_types: bool?) -> void + def call: () -> Orthoses::store private def build_trace_hook: () -> untyped include Orthoses::Trace::Targetable end @@ -30,11 +31,12 @@ end class Orthoses::Trace::Method @loader: untyped @patterns: untyped + @sort_union_types: untyped @stack: untyped @args_return_map: untyped @alias_map: untyped - def initialize: (untyped loader, patterns: untyped) -> void - def call: () -> untyped + def initialize: (Orthoses::_Call loader, patterns: Array[String], ?sort_union_types: bool?) -> void + def call: () -> Orthoses::store private def build_trace_point: () -> untyped private def build_members: () -> untyped private def build_method_definitions: () -> untyped