Skip to content

Commit

Permalink
Fix remaining issues on first coursework
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Jan 27, 2025
1 parent ce9bd48 commit adc2916
Showing 1 changed file with 65 additions and 33 deletions.
98 changes: 65 additions & 33 deletions docs/marimo/mlir_winter_school/4_ir_gen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "sympy",
# "xdsl",
# "marimo",
# "sympy==1.13.3",
# "xdsl==0.27.0",
# ]
# ///

Expand Down Expand Up @@ -30,8 +31,8 @@ def _():
from xdsl.irdl import irdl_op_definition, traits_def, IRDLOperation, irdl_attr_definition, operand_def, result_def
from xdsl.dialects.builtin import ModuleOp, Float64Type, FloatAttr, IntegerType, IntegerAttr
from xdsl.dialects.func import FuncOp, ReturnOp
from xdsl.dialects.arith import AddfOp, SubfOp, MulfOp, ConstantOp, AddiOp, MuliOp, SIToFPOp, FloatingPointLikeBinaryOperation, DivfOp
from xdsl.dialects.scf import ForOp, YieldOp
from xdsl.dialects.arith import AddfOp, SubfOp, MulfOp, ConstantOp, AddiOp, MuliOp, SIToFPOp, FloatingPointLikeBinaryOperation, DivfOp, SelectOp, CmpfOp
from xdsl.dialects.scf import ForOp, YieldOp, IfOp
from xdsl.dialects.math import PowFOp, SqrtOp
from xdsl.builder import Builder, InsertPoint
return (
Expand All @@ -41,6 +42,7 @@ def _():
AddiOp,
Block,
Builder,
CmpfOp,
ConstantOp,
DivfOp,
E,
Expand All @@ -55,6 +57,7 @@ def _():
GreedyRewritePatternApplier,
I,
IRDLOperation,
IfOp,
InsertPoint,
Integer,
IntegerAttr,
Expand All @@ -77,6 +80,7 @@ def _():
S,
SIToFPOp,
SSAValue,
SelectOp,
SqrtOp,
SubfOp,
Sum,
Expand Down Expand Up @@ -165,6 +169,12 @@ def _(mo):
return


@app.cell
def _(mo):
mo.md("""You will have to use the following operations:""")
return


@app.cell(hide_code=True)
def _(mo):
mo.md(
Expand Down Expand Up @@ -223,14 +233,11 @@ def print_ir(expr: Expr):
print(expr)

# Converts the SymPy expression to an MLIR `builtin.module` operation
try:
op = emit_ir(expr)
op = emit_ir(expr)

# Check that the operation verifies, and prints the operation
op.verify()
print(op)
except Exception as e:
print("Error while converting expression: ", e)
# Check that the operation verifies, and prints the operation
op.verify()
print(op)

# Print a separator
print("\n\n")
Expand All @@ -239,21 +246,33 @@ def print_ir(expr: Expr):

@app.cell(hide_code=True)
def _(mo):
mo.md(r"""This function takes a SymPy expression, creates a module and a function, and calls the main recursive function to convert SymPy AST.""")
mo.md(
"""
The following function returns the MLIR type of the expression result. As we are only handling integer and real types, we only returns either the `i64` or `f64` types, which correspond to 64bits integers and floating points.
For instance, the type of `x + y` is `f64`, as `x` and `y` are floating points. The type of `a * b` is `i64`, as both `a` and `b` are integers.
"""
)
return


@app.cell
def _(Attribute, Expr, Float64Type, IntegerType):
# Get the MLIR type for a SymPy expression
def get_type(expr: Expr) -> Attribute:
def get_mlir_type(expr: Expr) -> Attribute:
if expr.is_integer:
return IntegerType(64)
elif expr.is_extended_real:
return Float64Type()
else:
raise Exception(f"Unknown MLIR type for expression {expr}. Please make sure there cannot be a division by zero, or a power of a negative value.")
return (get_type,)
return (get_mlir_type,)


@app.cell(hide_code=True)
def _(mo):
mo.md(r"""The following function (`emit_ir`) should be called to emit MLIR IR from a SymPy expression. It takes a SymPy expression, creates a module and a function, and starts the recursion on the SymPy AST to emit MLIR IR.""")
return


@app.cell
Expand All @@ -265,18 +284,18 @@ def _(
ModuleOp,
ReturnOp,
emit_op,
get_type,
get_mlir_type,
):
def emit_ir(expr: Expr) -> ModuleOp:
# Create a module, and create a builder at the beginning of its only block
module = ModuleOp([])
builder = Builder(InsertPoint.at_end(module.body.block))

# Create the MLIR types for each symbol.
arg_types = [get_type(arg) for arg in expr.free_symbols]
arg_types = [get_mlir_type(arg) for arg in expr.free_symbols]

# Create a new function and inserts it inside the module.
func = FuncOp("main", (arg_types, [get_type(expr)]))
func = FuncOp("main", (arg_types, [get_mlir_type(expr)]))
builder.insert(func)

# Associate each symbol with its MLIR name.
Expand All @@ -299,7 +318,7 @@ def emit_ir(expr: Expr) -> ModuleOp:
return (emit_ir,)


@app.cell
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""Finally, here are the functions that you should complete. `emit_op` is fully complete, and emits the necessary IR for a SymPy expression. `emit_integer_op` and `emit_float_op` emits the operations for integer and float operations, and are only partially implemented.""")
return
Expand All @@ -316,14 +335,14 @@ def _(
SIToFPOp,
SSAValue,
Symbol,
get_type,
get_mlir_type,
):
def emit_op(
expr: Expr,
builder: Builder,
args: dict[Symbol, SSAValue],
):
type = get_type(expr)
type = get_mlir_type(expr)
if isinstance(type, IntegerType):
return emit_integer_op(expr, builder, args)
elif isinstance(type, Float64Type):
Expand All @@ -342,8 +361,10 @@ def emit_integer_op(

# Handle constants
if isinstance(expr, Integer):
# int(expr) returns the value of the `expr` constant
raise NotImplementedError("Constants are not implemented")
# Hint: int(expr) returns the value of the `expr` constant
raise NotImplementedError("Integer constants are not implemented")

# Hint: Implement here support for Add and Mul

raise NotImplementedError(f"No IR emitter for integer function {expr.func}")

Expand All @@ -356,23 +377,25 @@ def emit_real_op(
# back to a float expression.
if expr.is_integer:
res = emit_integer_op(expr, builder, args)
op = builder.insert(SIToFPOp(res))
op = builder.insert(SIToFPOp(res, Float64Type()))
return op.result

# Handle constants
if isinstance(expr, Float):
# float(expr) returns the value of the `expr` constant
raise NotImplementedError("Constants are not implemented")
# Hint: float(expr) returns the value of the `expr` constant
raise NotImplementedError("Float constants are not implemented")

# Handle symbolic values
if isinstance(expr, Symbol):
return args[expr]

# Hint: Implement here support for Add, Mul, and Pow (and later Abs and Sum)

raise NotImplementedError(f"No IR emitter for float function {expr.func}")
return emit_integer_op, emit_op, emit_real_op


@app.cell
@app.cell(hide_code=True)
def _(mo):
mo.md("""Here are a few simple examples that you should support first. For each test, the expression is printed, then either the MLIR code, or an error. Each of the operators used in these tests should only be converted to a single MLIR operation.""")
return
Expand All @@ -383,20 +406,29 @@ def _(Float, Integer, a, b, print_ir, x, y):
print_ir(Float(2))
print_ir(Integer(2))

# Adds two integers
print_ir(a + b)

# Adds two integers with a real
print_ir(a + b + x)
print_ir(x + 2)

# Multiplies two reals
print_ir(x * y)

# Multiplies three reals
print_ir(a * b + x)

# Add two reals
print_ir(x + x)
print_ir(x / y)
print_ir(x - y)

# Square a real
print_ir(x ** 4)
return


@app.cell
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""The following expression requires to handle the AST node `Abs`. Instead of converting it to `math.absf` operation, we taks you to write it using the formula `x < 0 ? -x : x` using only `arith` operations.""")
mo.md(r"""The following expression requires to handle the AST node `Abs`. Instead of converting it to `math.absf` operation, we taks you to write it using the formula `x < 0 ? -x : x` using only `arith` operations. Hint, you should use `arith.select` for expressing the conditional.""")
return


Expand All @@ -407,7 +439,7 @@ def _(Abs, print_ir, x, y):
return


@app.cell
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
Expand Down Expand Up @@ -439,7 +471,7 @@ def _(Abs, Sum, UnevaluatedExpr, a, b, c, print_ir, x, y):
# for c in range(0, b):
# result += 1
# We use an UnevaluatedExpr so that SymPy doesn't combine both sums
print_ir(Sum(UnevaluatedExpr(Sum(1, (c, 0, b))), (b, 0, a)))
print_ir(Sum(UnevaluatedExpr(Sum(x, (c, 0, b))), (b, 0, a)))
return


Expand Down

0 comments on commit adc2916

Please sign in to comment.