diff --git a/src/compiler/crystal/macros/interpreter.cr b/src/compiler/crystal/macros/interpreter.cr index 8db46bd118cf..267d97881c1a 100644 --- a/src/compiler/crystal/macros/interpreter.cr +++ b/src/compiler/crystal/macros/interpreter.cr @@ -74,23 +74,57 @@ module Crystal record MacroVarKey, name : String, exps : Array(ASTNode)? + class BlockScope + enum State + Continue + FoundNext + FoundBreak + end + + property state : State = State::Continue + end + def initialize(@program : Program, @scope : Type, @path_lookup : Type, @location : Location?, @vars = {} of String => ASTNode, @block : Block? = nil, @def : Def? = nil, @in_macro = false, @call : Call? = nil) @str = IO::Memory.new(512) # Can't be String::Builder because of `{{debug}}` @last = Nop.new + @block_scopes = [] of BlockScope end def define_var(name : String, value : ASTNode) : Nil @vars[name] = value end + def new_block_scope(& : BlockScope ->) + block_scope = BlockScope.new + @block_scopes << block_scope + begin + yield block_scope + ensure + @block_scopes.pop + end + end + def accept(node) node.accept self @last end + def visit_any(node) + # upon encountering a `next` or `break` expression, this pauses evaluation + # until `interpret_block` in `./methods.cr` drops a pending `next`, or + # `new_block_scope` drops the innermost block scope by returning; while + # evaluation is paused, `@last` holds the argument to the most recent + # `next` or `break` expression + if block_scope = @block_scopes.last? + return false unless block_scope.state.continue? + end + + true + end + def visit(node : Expressions) node.expressions.each &.accept self false @@ -408,6 +442,36 @@ module Crystal false end + def visit(node : Next) + unless block_scope = @block_scopes.last? + node.raise "invalid next" + end + + if exp = node.exp + exp.accept self + else + @last = Nop.new + end + + block_scope.state = :found_next + false + end + + def visit(node : Break) + unless block_scope = @block_scopes.last? + node.raise "invalid break" + end + + if exp = node.exp + exp.accept self + else + @last = Nop.new + end + + block_scope.state = :found_break + false + end + def visit(node : Path) @last = resolve(node) false diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index 8a7aa569fa95..b0f3bf9eff39 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -910,23 +910,25 @@ module Crystal block_arg_key = block.args[0]? block_arg_value = block.args[1]? - entries.each do |entry| - interpreter.define_var(block_arg_key.name, entry.key) if block_arg_key - interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + entries.each do |entry| + interpreter.define_var(block_arg_key.name, entry.key) if block_arg_key + interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value + interpret_block(block_scope) + end || NilLiteral.new end - - NilLiteral.new end when "map" interpret_check_args(uses_block: true) do block_arg_key = block.args[0]? block_arg_value = block.args[1]? - ArrayLiteral.map(entries) do |entry| - interpreter.define_var(block_arg_key.name, entry.key) if block_arg_key - interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + ArrayLiteral.map(entries) do |entry| + interpreter.define_var(block_arg_key.name, entry.key) if block_arg_key + interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value + interpret_block(block_scope) + end end end when "double_splat" @@ -1004,23 +1006,25 @@ module Crystal block_arg_key = block.args[0]? block_arg_value = block.args[1]? - entries.each do |entry| - interpreter.define_var(block_arg_key.name, MacroId.new(entry.key)) if block_arg_key - interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + entries.each do |entry| + interpreter.define_var(block_arg_key.name, MacroId.new(entry.key)) if block_arg_key + interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value + interpret_block(block_scope) + end || NilLiteral.new end - - NilLiteral.new end when "map" interpret_check_args(uses_block: true) do block_arg_key = block.args[0]? block_arg_value = block.args[1]? - ArrayLiteral.map(entries) do |entry| - interpreter.define_var(block_arg_key.name, MacroId.new(entry.key)) if block_arg_key - interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + ArrayLiteral.map(entries) do |entry| + interpreter.define_var(block_arg_key.name, MacroId.new(entry.key)) if block_arg_key + interpreter.define_var(block_arg_value.name, entry.value) if block_arg_value + interpret_block(block_scope) + end end end when "double_splat" @@ -1109,20 +1113,22 @@ module Crystal interpret_check_args(uses_block: true) do block_arg = block.args.first? - interpret_to_range(interpreter).each do |num| - interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + interpret_to_range(interpreter).each do |num| + interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg + interpret_block(block_scope) + end || NilLiteral.new end - - NilLiteral.new end when "map" interpret_check_args(uses_block: true) do block_arg = block.args.first? - interpret_map(interpreter) do |num| - interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + interpret_map(interpreter) do |num| + interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg + interpret_block(block_scope) + end end end when "to_a" @@ -2809,19 +2815,25 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - Crystal::BoolLiteral.new(object.elements.any? do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept(block.body).truthy? - end) + interpreter.new_block_scope do |block_scope| + result = object.elements.any? do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope).truthy? + end + result.is_a?(Bool) ? Crystal::BoolLiteral.new(result) : result + end end when "all?" interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - Crystal::BoolLiteral.new(object.elements.all? do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept(block.body).truthy? - end) + interpreter.new_block_scope do |block_scope| + result = object.elements.all? do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope).truthy? + end + result.is_a?(Bool) ? Crystal::BoolLiteral.new(result) : result + end end when "splat" interpret_check_args(node: object, min_count: 0) do |arg| @@ -2845,11 +2857,13 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - found = object.elements.find do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept(block.body).truthy? + interpreter.new_block_scope do |block_scope| + found = object.elements.find do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope).truthy? + end + found ? found : Crystal::NilLiteral.new end - found ? found : Crystal::NilLiteral.new end when "first" interpret_check_args(node: object) { object.elements.first? || Crystal::NilLiteral.new } @@ -2869,33 +2883,35 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - object.elements.each do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + object.elements.each do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope) + end || Crystal::NilLiteral.new end - - Crystal::NilLiteral.new end when "each_with_index" interpret_check_args(node: object, uses_block: true) do block_arg = block.args[0]? index_arg = block.args[1]? - object.elements.each_with_index do |elem, idx| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.define_var(index_arg.name, Crystal::NumberLiteral.new idx) if index_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + object.elements.each_with_index do |elem, idx| + interpreter.define_var(block_arg.name, elem) if block_arg + interpreter.define_var(index_arg.name, Crystal::NumberLiteral.new idx) if index_arg + interpret_block(block_scope) + end || Crystal::NilLiteral.new end - - Crystal::NilLiteral.new end when "map" interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? - klass.map(object.elements) do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + klass.map(object.elements) do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + interpret_block(block_scope) + end end end when "map_with_index" @@ -2903,10 +2919,12 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a block_arg = block.args[0]? index_arg = block.args[1]? - klass.map_with_index(object.elements) do |elem, idx| - interpreter.define_var(block_arg.name, elem) if block_arg - interpreter.define_var(index_arg.name, Crystal::NumberLiteral.new idx) if index_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + klass.map_with_index(object.elements) do |elem, idx| + interpreter.define_var(block_arg.name, elem) if block_arg + interpreter.define_var(index_arg.name, Crystal::NumberLiteral.new idx) if index_arg + interpret_block(block_scope) + end end end when "select" @@ -2922,17 +2940,19 @@ private def interpret_array_or_tuple_method(object, klass, method, args, named_a accumulate_arg = block.args.first? value_arg = block.args[1]? - if memo - object.elements.reduce(memo) do |accumulate, elem| - interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg - interpreter.define_var(value_arg.name, elem) if value_arg - interpreter.accept block.body - end - else - object.elements.reduce do |accumulate, elem| - interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg - interpreter.define_var(value_arg.name, elem) if value_arg - interpreter.accept block.body + interpreter.new_block_scope do |block_scope| + if memo + object.elements.reduce(memo) do |accumulate, elem| + interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg + interpreter.define_var(value_arg.name, elem) if value_arg + interpret_block(block_scope) + end + else + object.elements.reduce do |accumulate, elem| + interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg + interpreter.define_var(value_arg.name, elem) if value_arg + interpret_block(block_scope) + end end end end @@ -3110,6 +3130,24 @@ private macro interpret_check_args_toplevel(*, min_count = nil, uses_block = fal interpret_check_args(node: node, min_count: {{ min_count }}, uses_block: {{ uses_block }}, top_level: true) {{ block }} end +# Returns the result of evaluating the current macro method call's block, except +# that, upon encountering a `break` expression, literally triggers a `break` in +# the block enclosing this `interpret_block` call, forwarding the evaluated +# argument to that encountered expression. +# +# When used together with `new_block_scope` and methods from `Enumerable`, this +# allows macro methods to be implemented more or less like their non-macro +# equivalents. +# +# Accesses the `interpreter` and `block` variables in the current scope. +private macro interpret_block(block_scope) + %block_scope = {{ block_scope }} + %block_scope.state = :continue + %result = interpreter.accept block.body + break %result if %block_scope.state.found_break? + %result +end + private def full_macro_name(node, method, top_level) if top_level "macro '::#{method}'" @@ -3186,11 +3224,14 @@ end private def filter(object, klass, block, interpreter, keep = true) block_arg = block.args.first? - klass.new(object.elements.select do |elem| - interpreter.define_var(block_arg.name, elem) if block_arg - block_result = interpreter.accept(block.body).truthy? - keep ? block_result : !block_result - end) + interpreter.new_block_scope do |block_scope| + result = object.elements.select do |elem| + interpreter.define_var(block_arg.name, elem) if block_arg + block_result = interpret_block(block_scope).truthy? + keep ? block_result : !block_result + end + result.is_a?(Crystal::ASTNode) ? result : klass.new(result) + end end private def fetch_annotation(node, method, args, named_args, block, &) @@ -3232,11 +3273,13 @@ end private def sort_by(object, klass, block, interpreter) block_arg = block.args.first? - klass.new(object.elements.sort_by do |elem| - block_arg.try { |arg| interpreter.define_var(arg.name, elem) } - result = interpreter.accept(block.body) - InterpretCompareWrapper.new(result) - end) + interpreter.new_block_scope do |block_scope| + result = object.elements.sort_by do |elem| + block_arg.try { |arg| interpreter.define_var(arg.name, elem) } + InterpretCompareWrapper.new(interpret_block(block_scope)) + end + result.is_a?(Crystal::ASTNode) ? result : klass.new(result) + end end private record InterpretCompareWrapper, node : Crystal::ASTNode do