From f495cd4527fbf1c51df39472b1accba3dc2553c3 Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Wed, 14 Aug 2024 02:27:31 +0800 Subject: [PATCH 1/5] Support `next` and `break` in macro methods --- src/compiler/crystal/macros/interpreter.cr | 64 +++++++ src/compiler/crystal/macros/methods.cr | 213 +++++++++++++-------- 2 files changed, 195 insertions(+), 82 deletions(-) diff --git a/src/compiler/crystal/macros/interpreter.cr b/src/compiler/crystal/macros/interpreter.cr index 8db46bd118cf..89a4d948755b 100644 --- a/src/compiler/crystal/macros/interpreter.cr +++ b/src/compiler/crystal/macros/interpreter.cr @@ -74,23 +74,57 @@ module Crystal record MacroVarKey, name : String, exps : Array(ASTNode)? + class BlockScope + enum State + Continue + FoundNext + FoundBreak + end + + property state = State::Continue + end + def initialize(@program : Program, @scope : Type, @path_lookup : Type, @location : Location?, @vars = {} of String => ASTNode, @block : Block? = nil, @def : Def? = nil, @in_macro = false, @call : Call? = nil) @str = IO::Memory.new(512) # Can't be String::Builder because of `{{debug}}` @last = Nop.new + @block_scopes = [] of BlockScope end def define_var(name : String, value : ASTNode) : Nil @vars[name] = value end + def new_block_scope(& : BlockScope ->) + block_scope = BlockScope.new + @block_scopes << block_scope + begin + yield block_scope + ensure + @block_scopes.pop + end + end + def accept(node) node.accept self @last end + def visit_any(node) + # upon encountering a `next` or `break` expression, this pauses evaluation + # until `interpret_block` in `./methods.cr` drops a pending `next`, or + # `new_block_scope` drops the innermost block scope by returning; while + # evaluation is paused, `@last` holds the argument to the most recent + # `next` or `break` expression + if block_scope = @block_scopes.last? + return false unless block_scope.state.continue? + end + + true + end + def visit(node : Expressions) node.expressions.each &.accept self false @@ -408,6 +442,36 @@ module Crystal false end + def visit(node : Next) + unless block_scope = @block_scopes.last? + node.raise "invalid next" + end + + if exp = node.exp + exp.accept self + else + @last = Nop.new + end + + block_scope.state = :found_next + false + end + + def visit(node : Break) + unless block_scope = @block_scopes.last? + node.raise "invalid next" + end + + if exp = node.exp + exp.accept self + else + @last = Nop.new + end + + block_scope.state = :found_break + false + end + def visit(node : Path) @last = resolve(node) false diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index d3a1a1cc15a6..ba8910478d4b 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -554,7 +554,8 @@ module Crystal end def interpret_compare(other : NumberLiteral) - to_number <=> other.to_number + # it should not be possible to obtain a NaN number literal + (to_number <=> other.to_number).not_nil! end def bool_bin_op(method, args, named_args, block, &) @@ -910,23 +911,25 @@ module Crystal block_arg_key = block.args[0]? block_arg_value = block.args[1]? - entries.each do |entry| - interpreter.define_var(block_arg_key.name, entry.key) if block_arg_key - interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + entries.each do |entry| + interpreter.define_var(block_arg_key.name, entry.key) if block_arg_key + interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value + interpret_block(block_scope) + end || NilLiteral.new end - - NilLiteral.new end when "map" interpret_check_args(uses_block: true) do block_arg_key = block.args[0]? block_arg_value = block.args[1]? - ArrayLiteral.map(entries) do |entry| - interpreter.define_var(block_arg_key.name, entry.key) if block_arg_key - interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + ArrayLiteral.map(entries) do |entry| + interpreter.define_var(block_arg_key.name, entry.key) if block_arg_key + interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value + interpret_block(block_scope) + end end end when "double_splat" @@ -1004,23 +1007,25 @@ module Crystal block_arg_key = block.args[0]? block_arg_value = block.args[1]? - entries.each do |entry| - interpreter.define_var(block_arg_key.name, MacroId.new(entry.key)) if block_arg_key - interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + entries.each do |entry| + interpreter.define_var(block_arg_key.name, MacroId.new(entry.key)) if block_arg_key + interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value + interpret_block(block_scope) + end || NilLiteral.new end - - NilLiteral.new end when "map" interpret_check_args(uses_block: true) do block_arg_key = block.args[0]? block_arg_value = block.args[1]? - ArrayLiteral.map(entries) do |entry| - interpreter.define_var(block_arg_key.name, MacroId.new(entry.key)) if block_arg_key - interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + ArrayLiteral.map(entries) do |entry| + interpreter.define_var(block_arg_key.name, MacroId.new(entry.key)) if block_arg_key + interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value + interpret_block(block_scope) + end end end when "double_splat" @@ -1109,20 +1114,22 @@ module Crystal interpret_check_args(uses_block: true) do block_arg = block.args.first? - interpret_to_range(interpreter).each do |num| - interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + interpret_to_range(interpreter).each do |num| + interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg + interpret_block(block_scope) + end || NilLiteral.new end - - NilLiteral.new end when "map" interpret_check_args(uses_block: true) do block_arg = block.args.first? - interpret_map(interpreter) do |num| - interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + interpret_map(interpreter) do |num| + interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg + interpret_block(block_scope) + end end end when "to_a" @@ -2809,19 +2816,25 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - Crystal::BoolLiteral.new(object.elements.any? do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept(block.body).truthy? - end) + interpreter.new_block_scope do |block_scope| + result = object.elements.any? do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope).truthy? + end + result.is_a?(Bool) ? Crystal::BoolLiteral.new(result) : result + end end when "all?" interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - Crystal::BoolLiteral.new(object.elements.all? do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept(block.body).truthy? - end) + interpreter.new_block_scope do |block_scope| + result = object.elements.all? do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope).truthy? + end + result.is_a?(Bool) ? Crystal::BoolLiteral.new(result) : result + end end when "splat" interpret_check_args(node: object, min_count: 0) do |arg| @@ -2845,11 +2858,13 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - found = object.elements.find do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept(block.body).truthy? + interpreter.new_block_scope do |block_scope| + found = object.elements.find do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope).truthy? + end + found ? found : Crystal::NilLiteral.new end - found ? found : Crystal::NilLiteral.new end when "first" interpret_check_args(node: object) { object.elements.first? || Crystal::NilLiteral.new } @@ -2869,33 +2884,35 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - object.elements.each do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + object.elements.each do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope) + end || Crystal::NilLiteral.new end - - Crystal::NilLiteral.new end when "each_with_index" interpret_check_args(node: object, uses_block: true) do block_arg = block.args[0]? index_arg = block.args[1]? - object.elements.each_with_index do |elem, idx| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.define_var(index_arg.name, Crystal::NumberLiteral.new idx) if index_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + object.elements.each_with_index do |elem, idx| + interpreter.define_var(block_arg.name, elem) if block_arg + interpreter.define_var(index_arg.name, Crystal::NumberLiteral.new idx) if index_arg + interpret_block(block_scope) + end || Crystal::NilLiteral.new end - - Crystal::NilLiteral.new end when "map" interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - klass.map(object.elements) do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + klass.map(object.elements) do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope) + end end end when "map_with_index" @@ -2903,10 +2920,12 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a block_arg = block.args[0]? index_arg = block.args[1]? - klass.map_with_index(object.elements) do |elem, idx| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.define_var(index_arg.name, Crystal::NumberLiteral.new idx) if index_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + klass.map_with_index(object.elements) do |elem, idx| + interpreter.define_var(block_arg.name, elem) if block_arg + interpreter.define_var(index_arg.name, Crystal::NumberLiteral.new idx) if index_arg + interpret_block(block_scope) + end end end when "select" @@ -2922,17 +2941,19 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a accumulate_arg = block.args.first? value_arg = block.args[1]? - if memo - object.elements.reduce(memo) do |accumulate, elem| - interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg - interpreter.define_var(value_arg.name, elem) if value_arg - interpreter.accept block.body - end - else - object.elements.reduce do |accumulate, elem| - interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg - interpreter.define_var(value_arg.name, elem) if value_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + if memo + object.elements.reduce(memo) do |accumulate, elem| + interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg + interpreter.define_var(value_arg.name, elem) if value_arg + interpret_block(block_scope) + end + else + object.elements.reduce do |accumulate, elem| + interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg + interpreter.define_var(value_arg.name, elem) if value_arg + interpret_block(block_scope) + end end end end @@ -3110,6 +3131,24 @@ private macro interpret_check_args_toplevel(*, min_count = nil, uses_block = fal interpret_check_args(node: node, min_count: {{ min_count }}, uses_block: {{ uses_block }}, top_level: true) {{ block }} end +# Returns the result of evaluating the current macro method call's block, except +# that, upon encountering a `break` expression, literally triggers a `break` in +# the block enclosing this `interpret_block` call, with the AST node +# corresponding to the result from the encountered expression as the argument. +# +# When used together with `new_block_scope` and methods from `Enumerable`, this +# allows macro methods to be implemented more or less like their non-macro +# equivalents. +# +# Accesses the `interpreter` and `block` variables in the current scope. +private macro interpret_block(block_scope) + %block_scope = {{ block_scope }} + %block_scope.state = :continue + %result = interpreter.accept block.body + break %result if %block_scope.state.found_break? + %result +end + private def full_macro_name(node, method, top_level) if top_level "macro '::#{method}'" @@ -3186,11 +3225,14 @@ end private def filter(object, klass, block, interpreter, keep = true) block_arg = block.args.first? - klass.new(object.elements.select do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - block_result = interpreter.accept(block.body).truthy? - keep ? block_result : !block_result - end) + interpreter.new_block_scope do |block_scope| + result = object.elements.select do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + block_result = interpret_block(block_scope).truthy? + keep ? block_result : !block_result + end + result.is_a?(Crystal::ASTNode) ? result : klass.new(result) + end end private def fetch_annotation(node, method, args, named_args, block, &) @@ -3232,12 +3274,19 @@ end private def sort_by(object, klass, block, interpreter) block_arg = block.args.first? - klass.new(object.elements.sort { |x, y| - block_arg.try { |arg| interpreter.define_var(arg.name, x) } - x_result = interpreter.accept(block.body) - block_arg.try { |arg| interpreter.define_var(arg.name, y) } - y_result = interpreter.accept(block.body) + interpreter.new_block_scope do |block_scope| + result = object.elements.sort_by do |elem| + block_arg.try { |arg| interpreter.define_var(arg.name, elem) } + InterpretCompareWrapper.new(interpret_block(block_scope)) + end + result.is_a?(Crystal::ASTNode) ? result : klass.new(result) + end +end + +private record InterpretCompareWrapper, node : Crystal::ASTNode do + include Comparable(self) - x_result.interpret_compare(y_result) - }) + def <=>(other : self) : Int + node.interpret_compare(other.node) + end end From 88bbb87fa0e162403985cc5a9f0db0aecff33712 Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Wed, 14 Aug 2024 02:42:44 +0800 Subject: [PATCH 2/5] fixup --- src/compiler/crystal/macros/methods.cr | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index ba8910478d4b..b261a854bfce 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -554,8 +554,7 @@ module Crystal end def interpret_compare(other : NumberLiteral) - # it should not be possible to obtain a NaN number literal - (to_number <=> other.to_number).not_nil! + to_number <=> other.to_number end def bool_bin_op(method, args, named_args, block, &) @@ -3286,7 +3285,7 @@ end private record InterpretCompareWrapper, node : Crystal::ASTNode do include Comparable(self) - def <=>(other : self) : Int + def <=>(other : self) node.interpret_compare(other.node) end end From 9d3f6c1b26fa2b68136d0958acca57f2506613be Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Wed, 14 Aug 2024 02:55:10 +0800 Subject: [PATCH 3/5] fixup autocast --- src/compiler/crystal/macros/interpreter.cr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/crystal/macros/interpreter.cr b/src/compiler/crystal/macros/interpreter.cr index 89a4d948755b..7dc50d9ad4d1 100644 --- a/src/compiler/crystal/macros/interpreter.cr +++ b/src/compiler/crystal/macros/interpreter.cr @@ -81,7 +81,7 @@ module Crystal FoundBreak end - property state = State::Continue + property state : State = State::Continue end def initialize(@program : Program, From 7a8feca8f03a0c9648ba4634daef31516b99262b Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Wed, 14 Aug 2024 02:56:01 +0800 Subject: [PATCH 4/5] fixup error message --- src/compiler/crystal/macros/interpreter.cr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/crystal/macros/interpreter.cr b/src/compiler/crystal/macros/interpreter.cr index 7dc50d9ad4d1..267d97881c1a 100644 --- a/src/compiler/crystal/macros/interpreter.cr +++ b/src/compiler/crystal/macros/interpreter.cr @@ -459,7 +459,7 @@ module Crystal def visit(node : Break) unless block_scope = @block_scopes.last? - node.raise "invalid next" + node.raise "invalid break" end if exp = node.exp From 117ac0cefcb9c26d541e57db3385c7dce8dc669b Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Wed, 14 Aug 2024 20:06:04 +0800 Subject: [PATCH 5/5] a --- src/compiler/crystal/macros/methods.cr | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index b261a854bfce..b0f3bf9eff39 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -3132,8 +3132,8 @@ end # Returns the result of evaluating the current macro method call's block, except # that, upon encountering a `break` expression, literally triggers a `break` in -# the block enclosing this `interpret_block` call, with the AST node -# corresponding to the result from the encountered expression as the argument. +# the block enclosing this `interpret_block` call, forwarding the evaluated +# argument to that encountered expression. # # When used together with `new_block_scope` and methods from `Enumerable`, this # allows macro methods to be implemented more or less like their non-macro