From 4f0209ddda1fa33befd542105ad3154a761040e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= <30992420+Moxinilian@users.noreply.github.com> Date: Mon, 27 Jan 2025 06:13:24 +0100 Subject: [PATCH] docs: (mlir-winter-school) Add descriptions for func and scf (#3791) Co-authored-by: Sasha Lopoukhine --- docs/marimo/mlir_winter_school/1_mlir_ir.py | 285 ++++++++++++++++++-- 1 file changed, 269 insertions(+), 16 deletions(-) diff --git a/docs/marimo/mlir_winter_school/1_mlir_ir.py b/docs/marimo/mlir_winter_school/1_mlir_ir.py index c18633dc42..db194ae820 100644 --- a/docs/marimo/mlir_winter_school/1_mlir_ir.py +++ b/docs/marimo/mlir_winter_school/1_mlir_ir.py @@ -25,7 +25,8 @@ def _(mo): @app.cell(hide_code=True) def _(mo, triangle_text): - mo.md(fr""" + mo.md( + rf""" MLIR and xDSL use an encoding of the IR as a textual format for debugging, testing, and storing intermediate representations of programs. It can be very useful to take a program at some stage of compilation, and inspect it. The textual format makes this easy to do. @@ -64,7 +65,8 @@ def _(mo): @app.cell(hide_code=True) def _(builtin, mo): - mo.md(fr""" + mo.md( + rf""" A module is a unit of code in xDSL and MLIR. It is an operation in the [`builtin` dialect](https://mlir.llvm.org/docs/Dialects/Builtin/), and holds a single _region_. @@ -87,7 +89,8 @@ def _(mo): @app.cell(hide_code=True) def _(builtin, mo, print_generic): - mo.md(fr""" + mo.md( + rf""" The IRs above are in what's called the _custom format_, a format that allows functions to specify a pretty and concise representation. The _generic format_ is a more uniform and verbose representation that unambiguously shows the structure of an operation. Here is the above minimal module in generic format: @@ -138,7 +141,8 @@ def _(mo): @app.cell(hide_code=True) def _(builtin, mo): - mo.md(fr""" + mo.md( + rf""" Attributes hold compile-time data, such as constants, types, and other information. The IR above contains four attributes: `@triangle`, `0`, `1` and `index`. `index` is the type of integer values that fit in a register on the target. @@ -177,7 +181,8 @@ def _(builtin, mo): _module_op = builtin.ModuleOp([]) _module_op.attributes = {"my_key": builtin.StringAttr("my_value")} - mo.md(fr""" + mo.md( + rf""" Operations can be supplemented with arbitrary information via their attribute dictionary. Here's a module with some extra information: @@ -197,9 +202,78 @@ def _(mo): return -@app.cell -def _(): - return +@app.cell(hide_code=True) +def _(func, mo): + from xdsl.ir import Region, Block + from xdsl.dialects.builtin import i32 + + _func_op = func.FuncOp("hello", ((), ()), Region([Block([func.ReturnOp()])])) + _func_op_with_args = func.FuncOp( + "hello", ((i32,), ()), Region([Block([func.ReturnOp()], arg_types=[i32])]) + ) + _func_op_with_args.body.block.args[0].name_hint = "x" + _func_op_with_return_value = func.FuncOp( + "swap", ((i32, i32), (i32, i32)), Region([Block([], arg_types=[i32, i32])]) + ) + _func_op_with_return_value.body.block.add_op( + func.ReturnOp( + _func_op_with_return_value.body.block.args[1], + _func_op_with_return_value.body.block.args[0], + ) + ) + _func_op_with_return_value.body.block.args[0].name_hint = "a" + _func_op_with_return_value.body.block.args[1].name_hint = "b" + + _func_op_calling_example = func.FuncOp( + "uses_swap", + ((i32, i32), (i32, i32)), + Region([Block([], arg_types=[i32, i32])]), + ) + _call_op = func.CallOp( + "swap", + [ + _func_op_calling_example.body.block.args[0], + _func_op_calling_example.body.block.args[1], + ], + [i32, i32], + ) + _func_op_calling_example.body.block.add_op(_call_op) + _func_op_calling_example.body.block.add_op( + func.ReturnOp( + _call_op.results[0], + _call_op.results[1], + ) + ) + _func_op_calling_example.body.block.args[0].name_hint = "a" + _func_op_calling_example.body.block.args[1].name_hint = "b" + _call_op.results[0].name_hint = "res0" + _call_op.results[1].name_hint = "res1" + + mo.md( + rf""" + The `func` dialect contains building blocks to model functions and function calls. It contains the following important operations: + + - **`func.func`**: This operation is used to model function definition. They contain the symbolic name of the function to be defined, along with an inner region representing the body of the function. + ``` + {str(_func_op).replace("\n", "\n ")} + ``` + In order to model function parameters, the entry block of the body region has **block arguments** corresponding to each function argument. In the context of `func.func`, these arguments represent values that will be filled by the caller. For readability, the custom format of `func.func` prints them next to the function name. + ``` + {str(_func_op_with_args).replace("\n", "\n ")} + ``` + + - **`func.return`**: This operation represents a return statement, taking as parameters the values that should be returned. `func.return` is a terminator, meaning that it must be the last operation in its block. + ``` + {str(_func_op_with_return_value).replace("\n", "\n ")} + ``` + + - **`func.call`**: This operation allows calling a function by its symbol name. `func.call` takes as operands the values of the function parameters, and its results are the return values of the function. Like all operations in MLIR, the operand and result types must be locally inferable from syntax, and thus the call operation makes the function argument and result types explicit. + ``` + {str(_func_op_calling_example).replace("\n", "\n ")} + ``` + """ + ) + return Block, Region, i32 @app.cell(hide_code=True) @@ -275,9 +349,7 @@ def _(Parser, ctx, first_text_area, run_func): @app.cell(hide_code=True) def _(first_info_text, first_text_area, mo): - mo.vstack( - (first_text_area, mo.md(first_info_text)) - ) + mo.vstack((first_text_area, mo.md(first_info_text))) return @@ -288,9 +360,92 @@ def _(mo): @app.cell(hide_code=True) -def _(mo): - mo.md(r"""## Block Arguments""") - return +def _(Block, Region, arith, func, i32, mo, scf): + from xdsl.dialects import test + from xdsl.dialects.builtin import i1 + + _dummies = test.TestOp(result_types=[i32, i32, i1, i32, i32, i32, i32]) + _a = _dummies.results[0] + _b = _dummies.results[1] + _c = _dummies.results[2] + _lb = _dummies.results[3] + _ub = _dummies.results[4] + _step = _dummies.results[5] + _zero = _dummies.results[6] + _a.name_hint = "a" + _b.name_hint = "b" + _c.name_hint = "c" + _lb.name_hint = "start" + _ub.name_hint = "end" + _step.name_hint = "step" + _zero.name_hint = "zero" + + + _if_op = scf.IfOp(_c, [], Region([Block([func.CallOp("foo", [], [])])])) + _if_op_with_else = scf.IfOp( + _c, + [], + Region([Block([func.CallOp("foo", [], [])])]), + Region([Block([func.CallOp("bar", [], [])])]), + ) + + _if_op_with_yield = scf.IfOp( + _c, + [i32], + Region([Block([scf.YieldOp(_a)])]), + Region([Block([scf.YieldOp(_b)])]), + ) + _if_op_with_yield.results[0].name_hint = "res" + _yield_op_multiple = scf.YieldOp(_a, _b) + + _for_op = scf.ForOp(_lb, _ub, _step, [], Region([Block([], arg_types=[i32])])) + _for_op.body.block.add_op(func.CallOp("foo", [_for_op.body.block.args[0]], [])) + _for_op.body.block.args[0].name_hint = "i" + + _for_op_with_yield = scf.ForOp(_lb, _ub, _step, [_zero], Region([Block([], arg_types=[i32, i32])])) + _add_op = arith.AddiOp(_for_op_with_yield.body.block.args[1], _for_op_with_yield.body.block.args[0]) + _for_op_with_yield.body.block.add_op(_add_op) + _for_op_with_yield.body.block.add_op(scf.YieldOp(_add_op.results[0])) + _add_op.results[0].name_hint = "acc_next" + _for_op_with_yield.body.block.args[0].name_hint = "i" + _for_op_with_yield.body.block.args[1].name_hint = "acc" + _for_op_with_yield.results[0].name_hint = "sum" + + mo.md( + rf""" + The `scf` dialect contains building blocks to model Structured Control-Flow (SCF). In contrast to LLVM-like Control-Flow Graphs (CFG), Structured Control-Flow is a model of control-flow based on regions. This model of control-flow is similar in many ways to the ones in imperative languages. It contains the following important operations: + + - **`scf.if`**: This operation represents a an if-statement. It takes in a boolean value, and if that value is true, it steps inside its inner region (the "then" region), skipping it otherwise. + ``` + {str(_if_op).replace("\n", "\n ")} + ``` + An additional region can be added (the "else" region) that is stepped inside only when the boolean value is false. + ``` + {str(_if_op_with_else).replace("\n", "\n ")} + ``` + If two regions are specified, `scf.if` can have result values of which the value is defined differently in each region. This feature will be presented with the next operation. + + - **`scf.yield`**: This operation is a terminator allowing to yield values from SCF constructs. For example, in the context of an `scf.if`, one may want to declare a single value `%res` that has different content depending on which branch of the `scf.if` is taken. In order to do this, one can add `%res` as a result value to the `scf.if`. Then, `scf.yield` is used in each of the regions to define the content of `%res`. + ``` + {str(_if_op_with_yield).replace("\n", "\n ")} + ``` + A single `scf.yield` can yield multiple values. + ``` + {str(_yield_op_multiple).replace("\n", "\n ")} + ``` + In this example, `%res` will have the content of `%a` if `%c` is true, and of `%b` is `%c` is false. + + - **`scf.for`**: This operation models a for loop over a range of integers. It takes in a start value, an end value, and a step value for the iteration variable, and declares a value as a block argument containing the iteration variable. For readability, the declaration site of the iteration value is printed in the `scf.for` operation itself. + ``` + {str(_for_op).replace("\n", "\n ")} + ``` + Aditionally, `scf.for` can expose more iteration variables and return them similarly to `scf.if`. Instead of being incremented automatically, these additional iteration variables are initialized to a certain value, updated at the end of the loop body via `scf.yield`, and passed outside of the loop as result values of `scf.for`. In the summation example below, the state of the sum is accumulated in an additional iteration variable `%acc` initialized with `%zero` before being returned as `%sum`. + ``` + {str(_for_op_with_yield).replace("\n", "\n ")} + ``` + """ + ) + return i1, test @app.cell(hide_code=True) @@ -302,7 +457,9 @@ def _(mo): @app.cell(hide_code=True) def _(mo, triangle_text): second_input_text = mo.ui.text("5") - second_text_area = mo.ui.code_editor(triangle_text.replace("triangle", "second"), language="javascript") + second_text_area = mo.ui.code_editor( + triangle_text.replace("triangle", "second"), language="javascript" + ) return second_input_text, second_text_area @@ -360,6 +517,102 @@ def _(mo, second_info_text, second_input_text, second_text_area): return +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""## A note on control-flow terminators""") + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r""" + It may be tempting to nest control-flow terminators like `func.return` or `scf.yield` within other operations. For example, a typical use case would be to have an early return from within the regions of an `scf.if`, like in the following: + + ``` + // THIS EXAMPLE DOES NOT COMPILE + func.func @swap_or_not(%c : i1, %a : i32, %b : i32) -> (i32, i32) { + scf.if %c { + func.return %b, %a : i32, i32 + } else { + func.return %a, %b : i32, i32 + } + } + ``` + + However, **this is not legal in MLIR** with the standard control-flow semantics. Upstream transformations in MLIR assume that when a block is entered, all operations will be executed to the end of the block. If early-return was possible, this expectation would be violated. + + Instead, you must formulate your IR such that those terminators are immediate children of the operation they have an effect on. If you accidentally break this constraint, the IR will not validate and you will receive a compile-time error. + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md("""## Exercise 3: Swap or not""") + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""Fix the example above to create a `@swap_or_not` function that takes in a condition of type `i1`, and two values of type `i32`, and returns the two `i32` values, swapping them if the condition is true.""") + return + + +@app.cell +def _(Parser, ctx, run_func, swap_or_not_text_area): + swap_or_not_error_text = "" + swap_or_not_results_12_text_swap = "" + swap_or_not_results_34_text_swap = "" + swap_or_not_results_12_text_noswap = "" + swap_or_not_results_34_text_noswap = "" + try: + swap_or_not_module = Parser(ctx, swap_or_not_text_area.value).parse_module() + swap_or_not_results_12 = run_func(swap_or_not_module, "swap_or_not", (1, 2)) + swap_or_not_results_34 = run_func(swap_or_not_module, "swap_or_not", (3, 4)) + swap_or_not_results_12_noswap = run_func(swap_or_not_module, "swap_or_not", (0, 1, 2)) + swap_or_not_results_12_text_noswap = f"swap_or_not(false, 1, 2) = {swap_or_not_results_12_noswap}" + swap_or_not_results_12_swap = run_func(swap_or_not_module, "swap_or_not", (1, 1, 2)) + swap_or_not_results_12_text_swap = f"swap_or_not(true, 1, 2) = {swap_or_not_results_12_swap}" + swap_or_not_results_34_text_noswap = f"swap_or_not(false, 3, 4) = {run_func(swap_or_not_module, "swap_or_not", (0, 3, 4))}" + swap_or_not_results_34_text_swap = f"swap_or_not(true, 3, 4) = {run_func(swap_or_not_module, "swap_or_not", (1, 3, 4))}" + except Exception as e: + error_text = str(e) + if swap_or_not_error_text: + swap_or_not_info_text = f""" + Error: + + ``` + {swap_or_not_error_text} + ``` + """ + else: + swap_or_not_info_text = f"""\ + Here are a few tests: + + ``` + {swap_or_not_results_12_text_swap} + {swap_or_not_results_34_text_swap} + + {swap_or_not_results_12_text_noswap} + {swap_or_not_results_34_text_noswap} + ``` + """ + return ( + error_text, + swap_or_not_error_text, + swap_or_not_info_text, + swap_or_not_module, + swap_or_not_results_12, + swap_or_not_results_12_noswap, + swap_or_not_results_12_swap, + swap_or_not_results_12_text_noswap, + swap_or_not_results_12_text_swap, + swap_or_not_results_34, + swap_or_not_results_34_text_noswap, + swap_or_not_results_34_text_swap, + ) + + @app.cell(hide_code=True) def _(): import marimo as mo @@ -433,7 +686,7 @@ def run_func(module: ModuleOp, name: str, args: tuple[Any, ...]): from xdsl.interpreter import Interpreter from xdsl.interpreters import scf, arith, func - interpreter = Interpreter(module) + interpreter = Interpreter(module) interpreter.register_implementations(scf.ScfFunctions) interpreter.register_implementations(arith.ArithFunctions) interpreter.register_implementations(func.FuncFunctions)