From adc291615c4894df58f0d22f5d9261bf1da4dbf5 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Mon, 27 Jan 2025 02:14:59 +0000 Subject: [PATCH] Fix remaining issues on first coursework --- docs/marimo/mlir_winter_school/4_ir_gen.py | 98 ++++++++++++++-------- 1 file changed, 65 insertions(+), 33 deletions(-) diff --git a/docs/marimo/mlir_winter_school/4_ir_gen.py b/docs/marimo/mlir_winter_school/4_ir_gen.py index bebe20cb15..00b3dba26b 100644 --- a/docs/marimo/mlir_winter_school/4_ir_gen.py +++ b/docs/marimo/mlir_winter_school/4_ir_gen.py @@ -1,8 +1,9 @@ # /// script # requires-python = ">=3.12" # dependencies = [ -# "sympy", -# "xdsl", +# "marimo", +# "sympy==1.13.3", +# "xdsl==0.27.0", # ] # /// @@ -30,8 +31,8 @@ def _(): from xdsl.irdl import irdl_op_definition, traits_def, IRDLOperation, irdl_attr_definition, operand_def, result_def from xdsl.dialects.builtin import ModuleOp, Float64Type, FloatAttr, IntegerType, IntegerAttr from xdsl.dialects.func import FuncOp, ReturnOp - from xdsl.dialects.arith import AddfOp, SubfOp, MulfOp, ConstantOp, AddiOp, MuliOp, SIToFPOp, FloatingPointLikeBinaryOperation, DivfOp - from xdsl.dialects.scf import ForOp, YieldOp + from xdsl.dialects.arith import AddfOp, SubfOp, MulfOp, ConstantOp, AddiOp, MuliOp, SIToFPOp, FloatingPointLikeBinaryOperation, DivfOp, SelectOp, CmpfOp + from xdsl.dialects.scf import ForOp, YieldOp, IfOp from xdsl.dialects.math import PowFOp, SqrtOp from xdsl.builder import Builder, InsertPoint return ( @@ -41,6 +42,7 @@ def _(): AddiOp, Block, Builder, + CmpfOp, ConstantOp, DivfOp, E, @@ -55,6 +57,7 @@ def _(): GreedyRewritePatternApplier, I, IRDLOperation, + IfOp, InsertPoint, Integer, IntegerAttr, @@ -77,6 +80,7 @@ def _(): S, SIToFPOp, SSAValue, + SelectOp, SqrtOp, SubfOp, Sum, @@ -165,6 +169,12 @@ def _(mo): return +@app.cell +def _(mo): + mo.md("""You will have to use the following operations:""") + return + + @app.cell(hide_code=True) def _(mo): mo.md( @@ -223,14 +233,11 @@ def print_ir(expr: Expr): print(expr) # Converts the SymPy expression to an MLIR `builtin.module` operation - try: - op = emit_ir(expr) + op = emit_ir(expr) - # Check that the operation verifies, and prints the operation - op.verify() - print(op) - except Exception as e: - print("Error while converting expression: ", e) + # Check that the operation verifies, and prints the operation + op.verify() + print(op) # Print a separator print("\n\n") @@ -239,21 +246,33 @@ def print_ir(expr: Expr): @app.cell(hide_code=True) def _(mo): - mo.md(r"""This function takes a SymPy expression, creates a module and a function, and calls the main recursive function to convert SymPy AST.""") + mo.md( + """ + The following function returns the MLIR type of the expression result. As we are only handling integer and real types, we only returns either the `i64` or `f64` types, which correspond to 64bits integers and floating points. + + For instance, the type of `x + y` is `f64`, as `x` and `y` are floating points. The type of `a * b` is `i64`, as both `a` and `b` are integers. + """ + ) return @app.cell def _(Attribute, Expr, Float64Type, IntegerType): # Get the MLIR type for a SymPy expression - def get_type(expr: Expr) -> Attribute: + def get_mlir_type(expr: Expr) -> Attribute: if expr.is_integer: return IntegerType(64) elif expr.is_extended_real: return Float64Type() else: raise Exception(f"Unknown MLIR type for expression {expr}. Please make sure there cannot be a division by zero, or a power of a negative value.") - return (get_type,) + return (get_mlir_type,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r"""The following function (`emit_ir`) should be called to emit MLIR IR from a SymPy expression. It takes a SymPy expression, creates a module and a function, and starts the recursion on the SymPy AST to emit MLIR IR.""") + return @app.cell @@ -265,7 +284,7 @@ def _( ModuleOp, ReturnOp, emit_op, - get_type, + get_mlir_type, ): def emit_ir(expr: Expr) -> ModuleOp: # Create a module, and create a builder at the beginning of its only block @@ -273,10 +292,10 @@ def emit_ir(expr: Expr) -> ModuleOp: builder = Builder(InsertPoint.at_end(module.body.block)) # Create the MLIR types for each symbol. - arg_types = [get_type(arg) for arg in expr.free_symbols] + arg_types = [get_mlir_type(arg) for arg in expr.free_symbols] # Create a new function and inserts it inside the module. - func = FuncOp("main", (arg_types, [get_type(expr)])) + func = FuncOp("main", (arg_types, [get_mlir_type(expr)])) builder.insert(func) # Associate each symbol with its MLIR name. @@ -299,7 +318,7 @@ def emit_ir(expr: Expr) -> ModuleOp: return (emit_ir,) -@app.cell +@app.cell(hide_code=True) def _(mo): mo.md(r"""Finally, here are the functions that you should complete. `emit_op` is fully complete, and emits the necessary IR for a SymPy expression. `emit_integer_op` and `emit_float_op` emits the operations for integer and float operations, and are only partially implemented.""") return @@ -316,14 +335,14 @@ def _( SIToFPOp, SSAValue, Symbol, - get_type, + get_mlir_type, ): def emit_op( expr: Expr, builder: Builder, args: dict[Symbol, SSAValue], ): - type = get_type(expr) + type = get_mlir_type(expr) if isinstance(type, IntegerType): return emit_integer_op(expr, builder, args) elif isinstance(type, Float64Type): @@ -342,8 +361,10 @@ def emit_integer_op( # Handle constants if isinstance(expr, Integer): - # int(expr) returns the value of the `expr` constant - raise NotImplementedError("Constants are not implemented") + # Hint: int(expr) returns the value of the `expr` constant + raise NotImplementedError("Integer constants are not implemented") + + # Hint: Implement here support for Add and Mul raise NotImplementedError(f"No IR emitter for integer function {expr.func}") @@ -356,23 +377,25 @@ def emit_real_op( # back to a float expression. if expr.is_integer: res = emit_integer_op(expr, builder, args) - op = builder.insert(SIToFPOp(res)) + op = builder.insert(SIToFPOp(res, Float64Type())) return op.result # Handle constants if isinstance(expr, Float): - # float(expr) returns the value of the `expr` constant - raise NotImplementedError("Constants are not implemented") + # Hint: float(expr) returns the value of the `expr` constant + raise NotImplementedError("Float constants are not implemented") # Handle symbolic values if isinstance(expr, Symbol): return args[expr] + # Hint: Implement here support for Add, Mul, and Pow (and later Abs and Sum) + raise NotImplementedError(f"No IR emitter for float function {expr.func}") return emit_integer_op, emit_op, emit_real_op -@app.cell +@app.cell(hide_code=True) def _(mo): mo.md("""Here are a few simple examples that you should support first. For each test, the expression is printed, then either the MLIR code, or an error. Each of the operators used in these tests should only be converted to a single MLIR operation.""") return @@ -383,20 +406,29 @@ def _(Float, Integer, a, b, print_ir, x, y): print_ir(Float(2)) print_ir(Integer(2)) + # Adds two integers print_ir(a + b) + + # Adds two integers with a real print_ir(a + b + x) - print_ir(x + 2) + + # Multiplies two reals print_ir(x * y) + + # Multiplies three reals print_ir(a * b + x) + + # Add two reals print_ir(x + x) - print_ir(x / y) - print_ir(x - y) + + # Square a real + print_ir(x ** 4) return -@app.cell +@app.cell(hide_code=True) def _(mo): - mo.md(r"""The following expression requires to handle the AST node `Abs`. Instead of converting it to `math.absf` operation, we taks you to write it using the formula `x < 0 ? -x : x` using only `arith` operations.""") + mo.md(r"""The following expression requires to handle the AST node `Abs`. Instead of converting it to `math.absf` operation, we taks you to write it using the formula `x < 0 ? -x : x` using only `arith` operations. Hint, you should use `arith.select` for expressing the conditional.""") return @@ -407,7 +439,7 @@ def _(Abs, print_ir, x, y): return -@app.cell +@app.cell(hide_code=True) def _(mo): mo.md( r""" @@ -439,7 +471,7 @@ def _(Abs, Sum, UnevaluatedExpr, a, b, c, print_ir, x, y): # for c in range(0, b): # result += 1 # We use an UnevaluatedExpr so that SymPy doesn't combine both sums - print_ir(Sum(UnevaluatedExpr(Sum(1, (c, 0, b))), (b, 0, a))) + print_ir(Sum(UnevaluatedExpr(Sum(x, (c, 0, b))), (b, 0, a))) return