Skip to content

Commit

Permalink
docs: (mlir-winter-school) Add descriptions for func and scf (#3791)
Browse files Browse the repository at this point in the history
Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
Moxinilian and superlopuh authored Jan 27, 2025
1 parent adc2916 commit 4f0209d
Showing 1 changed file with 269 additions and 16 deletions.
285 changes: 269 additions & 16 deletions docs/marimo/mlir_winter_school/1_mlir_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def _(mo):

@app.cell(hide_code=True)
def _(mo, triangle_text):
mo.md(fr"""
mo.md(
rf"""
MLIR and xDSL use an encoding of the IR as a textual format for debugging, testing, and storing intermediate representations of programs.
It can be very useful to take a program at some stage of compilation, and inspect it.
The textual format makes this easy to do.
Expand Down Expand Up @@ -64,7 +65,8 @@ def _(mo):

@app.cell(hide_code=True)
def _(builtin, mo):
mo.md(fr"""
mo.md(
rf"""
A module is a unit of code in xDSL and MLIR.
It is an operation in the [`builtin` dialect](https://mlir.llvm.org/docs/Dialects/Builtin/), and holds a single _region_.
Expand All @@ -87,7 +89,8 @@ def _(mo):

@app.cell(hide_code=True)
def _(builtin, mo, print_generic):
mo.md(fr"""
mo.md(
rf"""
The IRs above are in what's called the _custom format_, a format that allows functions to specify a pretty and concise representation.
The _generic format_ is a more uniform and verbose representation that unambiguously shows the structure of an operation.
Here is the above minimal module in generic format:
Expand Down Expand Up @@ -138,7 +141,8 @@ def _(mo):

@app.cell(hide_code=True)
def _(builtin, mo):
mo.md(fr"""
mo.md(
rf"""
Attributes hold compile-time data, such as constants, types, and other information.
The IR above contains four attributes: `@triangle`, `0`, `1` and `index`.
`index` is the type of integer values that fit in a register on the target.
Expand Down Expand Up @@ -177,7 +181,8 @@ def _(builtin, mo):
_module_op = builtin.ModuleOp([])
_module_op.attributes = {"my_key": builtin.StringAttr("my_value")}

mo.md(fr"""
mo.md(
rf"""
Operations can be supplemented with arbitrary information via their attribute dictionary.
Here's a module with some extra information:
Expand All @@ -197,9 +202,78 @@ def _(mo):
return


@app.cell
def _():
return
@app.cell(hide_code=True)
def _(func, mo):
from xdsl.ir import Region, Block
from xdsl.dialects.builtin import i32

_func_op = func.FuncOp("hello", ((), ()), Region([Block([func.ReturnOp()])]))
_func_op_with_args = func.FuncOp(
"hello", ((i32,), ()), Region([Block([func.ReturnOp()], arg_types=[i32])])
)
_func_op_with_args.body.block.args[0].name_hint = "x"
_func_op_with_return_value = func.FuncOp(
"swap", ((i32, i32), (i32, i32)), Region([Block([], arg_types=[i32, i32])])
)
_func_op_with_return_value.body.block.add_op(
func.ReturnOp(
_func_op_with_return_value.body.block.args[1],
_func_op_with_return_value.body.block.args[0],
)
)
_func_op_with_return_value.body.block.args[0].name_hint = "a"
_func_op_with_return_value.body.block.args[1].name_hint = "b"

_func_op_calling_example = func.FuncOp(
"uses_swap",
((i32, i32), (i32, i32)),
Region([Block([], arg_types=[i32, i32])]),
)
_call_op = func.CallOp(
"swap",
[
_func_op_calling_example.body.block.args[0],
_func_op_calling_example.body.block.args[1],
],
[i32, i32],
)
_func_op_calling_example.body.block.add_op(_call_op)
_func_op_calling_example.body.block.add_op(
func.ReturnOp(
_call_op.results[0],
_call_op.results[1],
)
)
_func_op_calling_example.body.block.args[0].name_hint = "a"
_func_op_calling_example.body.block.args[1].name_hint = "b"
_call_op.results[0].name_hint = "res0"
_call_op.results[1].name_hint = "res1"

mo.md(
rf"""
The `func` dialect contains building blocks to model functions and function calls. It contains the following important operations:
- **`func.func`**: This operation is used to model function definition. They contain the symbolic name of the function to be defined, along with an inner region representing the body of the function.
```
{str(_func_op).replace("\n", "\n ")}
```
In order to model function parameters, the entry block of the body region has **block arguments** corresponding to each function argument. In the context of `func.func`, these arguments represent values that will be filled by the caller. For readability, the custom format of `func.func` prints them next to the function name.
```
{str(_func_op_with_args).replace("\n", "\n ")}
```
- **`func.return`**: This operation represents a return statement, taking as parameters the values that should be returned. `func.return` is a terminator, meaning that it must be the last operation in its block.
```
{str(_func_op_with_return_value).replace("\n", "\n ")}
```
- **`func.call`**: This operation allows calling a function by its symbol name. `func.call` takes as operands the values of the function parameters, and its results are the return values of the function. Like all operations in MLIR, the operand and result types must be locally inferable from syntax, and thus the call operation makes the function argument and result types explicit.
```
{str(_func_op_calling_example).replace("\n", "\n ")}
```
"""
)
return Block, Region, i32


@app.cell(hide_code=True)
Expand Down Expand Up @@ -275,9 +349,7 @@ def _(Parser, ctx, first_text_area, run_func):

@app.cell(hide_code=True)
def _(first_info_text, first_text_area, mo):
mo.vstack(
(first_text_area, mo.md(first_info_text))
)
mo.vstack((first_text_area, mo.md(first_info_text)))
return


Expand All @@ -288,9 +360,92 @@ def _(mo):


@app.cell(hide_code=True)
def _(mo):
mo.md(r"""## Block Arguments""")
return
def _(Block, Region, arith, func, i32, mo, scf):
from xdsl.dialects import test
from xdsl.dialects.builtin import i1

_dummies = test.TestOp(result_types=[i32, i32, i1, i32, i32, i32, i32])
_a = _dummies.results[0]
_b = _dummies.results[1]
_c = _dummies.results[2]
_lb = _dummies.results[3]
_ub = _dummies.results[4]
_step = _dummies.results[5]
_zero = _dummies.results[6]
_a.name_hint = "a"
_b.name_hint = "b"
_c.name_hint = "c"
_lb.name_hint = "start"
_ub.name_hint = "end"
_step.name_hint = "step"
_zero.name_hint = "zero"


_if_op = scf.IfOp(_c, [], Region([Block([func.CallOp("foo", [], [])])]))
_if_op_with_else = scf.IfOp(
_c,
[],
Region([Block([func.CallOp("foo", [], [])])]),
Region([Block([func.CallOp("bar", [], [])])]),
)

_if_op_with_yield = scf.IfOp(
_c,
[i32],
Region([Block([scf.YieldOp(_a)])]),
Region([Block([scf.YieldOp(_b)])]),
)
_if_op_with_yield.results[0].name_hint = "res"
_yield_op_multiple = scf.YieldOp(_a, _b)

_for_op = scf.ForOp(_lb, _ub, _step, [], Region([Block([], arg_types=[i32])]))
_for_op.body.block.add_op(func.CallOp("foo", [_for_op.body.block.args[0]], []))
_for_op.body.block.args[0].name_hint = "i"

_for_op_with_yield = scf.ForOp(_lb, _ub, _step, [_zero], Region([Block([], arg_types=[i32, i32])]))
_add_op = arith.AddiOp(_for_op_with_yield.body.block.args[1], _for_op_with_yield.body.block.args[0])
_for_op_with_yield.body.block.add_op(_add_op)
_for_op_with_yield.body.block.add_op(scf.YieldOp(_add_op.results[0]))
_add_op.results[0].name_hint = "acc_next"
_for_op_with_yield.body.block.args[0].name_hint = "i"
_for_op_with_yield.body.block.args[1].name_hint = "acc"
_for_op_with_yield.results[0].name_hint = "sum"

mo.md(
rf"""
The `scf` dialect contains building blocks to model Structured Control-Flow (SCF). In contrast to LLVM-like Control-Flow Graphs (CFG), Structured Control-Flow is a model of control-flow based on regions. This model of control-flow is similar in many ways to the ones in imperative languages. It contains the following important operations:
- **`scf.if`**: This operation represents a an if-statement. It takes in a boolean value, and if that value is true, it steps inside its inner region (the "then" region), skipping it otherwise.
```
{str(_if_op).replace("\n", "\n ")}
```
An additional region can be added (the "else" region) that is stepped inside only when the boolean value is false.
```
{str(_if_op_with_else).replace("\n", "\n ")}
```
If two regions are specified, `scf.if` can have result values of which the value is defined differently in each region. This feature will be presented with the next operation.
- **`scf.yield`**: This operation is a terminator allowing to yield values from SCF constructs. For example, in the context of an `scf.if`, one may want to declare a single value `%res` that has different content depending on which branch of the `scf.if` is taken. In order to do this, one can add `%res` as a result value to the `scf.if`. Then, `scf.yield` is used in each of the regions to define the content of `%res`.
```
{str(_if_op_with_yield).replace("\n", "\n ")}
```
A single `scf.yield` can yield multiple values.
```
{str(_yield_op_multiple).replace("\n", "\n ")}
```
In this example, `%res` will have the content of `%a` if `%c` is true, and of `%b` is `%c` is false.
- **`scf.for`**: This operation models a for loop over a range of integers. It takes in a start value, an end value, and a step value for the iteration variable, and declares a value as a block argument containing the iteration variable. For readability, the declaration site of the iteration value is printed in the `scf.for` operation itself.
```
{str(_for_op).replace("\n", "\n ")}
```
Aditionally, `scf.for` can expose more iteration variables and return them similarly to `scf.if`. Instead of being incremented automatically, these additional iteration variables are initialized to a certain value, updated at the end of the loop body via `scf.yield`, and passed outside of the loop as result values of `scf.for`. In the summation example below, the state of the sum is accumulated in an additional iteration variable `%acc` initialized with `%zero` before being returned as `%sum`.
```
{str(_for_op_with_yield).replace("\n", "\n ")}
```
"""
)
return i1, test


@app.cell(hide_code=True)
Expand All @@ -302,7 +457,9 @@ def _(mo):
@app.cell(hide_code=True)
def _(mo, triangle_text):
second_input_text = mo.ui.text("5")
second_text_area = mo.ui.code_editor(triangle_text.replace("triangle", "second"), language="javascript")
second_text_area = mo.ui.code_editor(
triangle_text.replace("triangle", "second"), language="javascript"
)
return second_input_text, second_text_area


Expand Down Expand Up @@ -360,6 +517,102 @@ def _(mo, second_info_text, second_input_text, second_text_area):
return


@app.cell(hide_code=True)
def _(mo):
mo.md(r"""## A note on control-flow terminators""")
return


@app.cell(hide_code=True)
def _(mo):
mo.md(r"""
It may be tempting to nest control-flow terminators like `func.return` or `scf.yield` within other operations. For example, a typical use case would be to have an early return from within the regions of an `scf.if`, like in the following:
```
// THIS EXAMPLE DOES NOT COMPILE
func.func @swap_or_not(%c : i1, %a : i32, %b : i32) -> (i32, i32) {
scf.if %c {
func.return %b, %a : i32, i32
} else {
func.return %a, %b : i32, i32
}
}
```
However, **this is not legal in MLIR** with the standard control-flow semantics. Upstream transformations in MLIR assume that when a block is entered, all operations will be executed to the end of the block. If early-return was possible, this expectation would be violated.
Instead, you must formulate your IR such that those terminators are immediate children of the operation they have an effect on. If you accidentally break this constraint, the IR will not validate and you will receive a compile-time error.
""")
return


@app.cell(hide_code=True)
def _(mo):
mo.md("""## Exercise 3: Swap or not""")
return


@app.cell(hide_code=True)
def _(mo):
mo.md(r"""Fix the example above to create a `@swap_or_not` function that takes in a condition of type `i1`, and two values of type `i32`, and returns the two `i32` values, swapping them if the condition is true.""")
return


@app.cell
def _(Parser, ctx, run_func, swap_or_not_text_area):
swap_or_not_error_text = ""
swap_or_not_results_12_text_swap = ""
swap_or_not_results_34_text_swap = ""
swap_or_not_results_12_text_noswap = ""
swap_or_not_results_34_text_noswap = ""
try:
swap_or_not_module = Parser(ctx, swap_or_not_text_area.value).parse_module()
swap_or_not_results_12 = run_func(swap_or_not_module, "swap_or_not", (1, 2))
swap_or_not_results_34 = run_func(swap_or_not_module, "swap_or_not", (3, 4))
swap_or_not_results_12_noswap = run_func(swap_or_not_module, "swap_or_not", (0, 1, 2))
swap_or_not_results_12_text_noswap = f"swap_or_not(false, 1, 2) = {swap_or_not_results_12_noswap}"
swap_or_not_results_12_swap = run_func(swap_or_not_module, "swap_or_not", (1, 1, 2))
swap_or_not_results_12_text_swap = f"swap_or_not(true, 1, 2) = {swap_or_not_results_12_swap}"
swap_or_not_results_34_text_noswap = f"swap_or_not(false, 3, 4) = {run_func(swap_or_not_module, "swap_or_not", (0, 3, 4))}"
swap_or_not_results_34_text_swap = f"swap_or_not(true, 3, 4) = {run_func(swap_or_not_module, "swap_or_not", (1, 3, 4))}"
except Exception as e:
error_text = str(e)
if swap_or_not_error_text:
swap_or_not_info_text = f"""
Error:
```
{swap_or_not_error_text}
```
"""
else:
swap_or_not_info_text = f"""\
Here are a few tests:
```
{swap_or_not_results_12_text_swap}
{swap_or_not_results_34_text_swap}
{swap_or_not_results_12_text_noswap}
{swap_or_not_results_34_text_noswap}
```
"""
return (
error_text,
swap_or_not_error_text,
swap_or_not_info_text,
swap_or_not_module,
swap_or_not_results_12,
swap_or_not_results_12_noswap,
swap_or_not_results_12_swap,
swap_or_not_results_12_text_noswap,
swap_or_not_results_12_text_swap,
swap_or_not_results_34,
swap_or_not_results_34_text_noswap,
swap_or_not_results_34_text_swap,
)


@app.cell(hide_code=True)
def _():
import marimo as mo
Expand Down Expand Up @@ -433,7 +686,7 @@ def run_func(module: ModuleOp, name: str, args: tuple[Any, ...]):
from xdsl.interpreter import Interpreter
from xdsl.interpreters import scf, arith, func

interpreter = Interpreter(module)
interpreter = Interpreter(module)
interpreter.register_implementations(scf.ScfFunctions)
interpreter.register_implementations(arith.ArithFunctions)
interpreter.register_implementations(func.FuncFunctions)
Expand Down

0 comments on commit 4f0209d

Please sign in to comment.