diff --git a/.github/workflows/ci-mlir.yml b/.github/workflows/ci-mlir.yml index 2d93edcf8c..9fe1df4eab 100644 --- a/.github/workflows/ci-mlir.yml +++ b/.github/workflows/ci-mlir.yml @@ -112,6 +112,13 @@ jobs: export PATH=$PATH:${GITHUB_WORKSPACE}/llvm-project/build/bin/ pytest --nbval docs/mlir_interoperation.ipynb --maxfail 1 -vv + - name: Test MLIR dependent marimo notebooks + run: | + cd xdsl + # Add mlir-opt to the path + export PATH=$PATH:${GITHUB_WORKSPACE}/llvm-project/build/bin/ + make tests-marimo-mlir + - name: Combine coverage data run: | cd xdsl diff --git a/Makefile b/Makefile index d368a94ea8..7f6979b578 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ TESTS_COVERAGE_FILE = ${COVERAGE_FILE}.tests .ONESHELL: # these targets don't produce files: -.PHONY: ${VENV_DIR}/ venv clean filecheck pytest pytest-nb tests-toy tests rerun-notebooks precommit-install precommit black pyright tests-marimo +.PHONY: ${VENV_DIR}/ venv clean filecheck pytest pytest-nb tests-toy tests rerun-notebooks precommit-install precommit black pyright tests-marimo tests-marimo-mlir .PHONY: coverage coverage-tests coverage-filecheck-tests coverage-report-html coverage-report-md # set up the venv with all dependencies for development @@ -58,8 +58,15 @@ tests-marimo: done @echo "All marimo tests passed successfully." +tests-marimo-mlir: + @for file in docs/marimo/mlir/*.py; do \ + echo "Running $$file"; \ + python3 "$$file" || exit 1; \ + done + @echo "All marimo mlir tests passed successfully." + # run all tests -tests: pytest tests-toy filecheck pytest-nb tests-marimo pyright +tests: pytest tests-toy filecheck pytest-nb tests-marimo tests-marimo-mlir pyright @echo All tests done. # re-generate the output from all jupyter notebooks in the docs directory diff --git a/docs/marimo/mlir/README.md b/docs/marimo/mlir/README.md new file mode 100644 index 0000000000..d6e0d85842 --- /dev/null +++ b/docs/marimo/mlir/README.md @@ -0,0 +1,4 @@ +# Marimo notebooks that depend on mlir-opt + +For these notebooks to run as intended, `mlir-opt` needs to be in the path. +Please see the [MLIR Interoperation](../../mlir_interoperation.md) document for more information. diff --git a/docs/marimo/mlir/onnx-demo.py b/docs/marimo/mlir/onnx-demo.py new file mode 100644 index 0000000000..d2253a39e2 --- /dev/null +++ b/docs/marimo/mlir/onnx-demo.py @@ -0,0 +1,421 @@ +import marimo + +__generated_with = "0.6.10" +app = marimo.App() + + +@app.cell +def __(): + import marimo as mo + + mo.md( + """ + # ONNX to Snitch + + This notebook uses Marimo, a Jupyter-like notebook with interactive UI elements and reactive state. + """ + ) + return mo, + + +@app.cell +def __(mo): + rank = mo.ui.slider(1, 4, value=2, label="Rank") + + mo.md( + f""" + For example, here is a slider, which can take on values from 1 to 4. + + {rank} + """ + ) + return rank, + + +@app.cell +def __(mo, rank): + shape = list(range(2, 2 + rank.value)) + + mo.md( + f""" + We use the slider to determine the shape of our inputs and outputs: + + {shape} + """ + ) + return shape, + + +@app.cell +def __(mo, shape): + import onnx + from onnx import AttributeProto, GraphProto, TensorProto, ValueInfoProto, helper + + # Create one input (ValueInfoProto) + X1 = helper.make_tensor_value_info("X1", TensorProto.DOUBLE, shape) + X2 = helper.make_tensor_value_info("X2", TensorProto.DOUBLE, shape) + + # Create one output (ValueInfoProto) + Y = helper.make_tensor_value_info("Y", TensorProto.DOUBLE, shape) + + # Create a node (NodeProto) - This is based on Pad-11 + node_def = helper.make_node( + "Sub", # node name + ["X1", "X2"], # inputs + ["Y"], # outputs + ) + + # Create the graph (GraphProto) + graph_def = helper.make_graph( + [node_def], + "main_graph", + [X1, X2], + [Y], + ) + + # Set opset version to 18 + opset_import = [helper.make_operatorsetid("", 18)] + + # Create the model (ModelProto) without using helper.make_model + model_def = helper.make_model( + graph_def, producer_name="onnx-example", opset_imports=opset_import + ) + + print(f"The model is:\n{model_def}") + onnx.checker.check_model(model_def) + # onnx.save(model_def, "add.onnx") + print("The model is checked!") + + mo.md( + f""" + ### The ONNX model + + We use the ONNX API to build a simple function, one that returns the elementwise sum of two arrays of shape {shape} + """ + ) + return ( + AttributeProto, + GraphProto, + TensorProto, + ValueInfoProto, + X1, + X2, + Y, + graph_def, + helper, + model_def, + node_def, + onnx, + opset_import, + ) + + +@app.cell +def __(mo): + from xdsl.ir import Attribute, SSAValue + + mo.md( + """ + We then convert the ONNX Graph to the xDSL representation, in the onnx dialect. + """ + ) + return Attribute, SSAValue + + +@app.cell +def __(mo, model_def): + from xdsl.frontend.onnx.ir_builder import build_module + + init_module = build_module(model_def.graph).clone() + + print(init_module) + + mo.md( + """ + Here is the same function, it takes two `tensor` values of our chosen shape, passes them as operands to the `onnx.Add` operation, and returns it. + """ + ) + return build_module, init_module + + +@app.cell +def __(init_module, mo): + from xdsl.context import MLContext + from xdsl.tools.command_line_tool import get_all_dialects + from xdsl.transforms.convert_onnx_to_linalg import ConvertOnnxToLinalgPass + + ctx = MLContext() + + for dialect_name, dialect_factory in get_all_dialects().items(): + ctx.register_dialect(dialect_name, dialect_factory) + + linalg_module = init_module.clone() + + ConvertOnnxToLinalgPass().apply(ctx, linalg_module) + + print(linalg_module) + + mo.md( + """ + We can use a pass implemented in xDSL to convert the ONNX operations to builtin operations, here we can use the `tensor.empty` op to create our output buffer, and `linalg.add` to represent the addition in destination-passing style. + """ + ) + return ( + ConvertOnnxToLinalgPass, + MLContext, + ctx, + dialect_factory, + dialect_name, + get_all_dialects, + linalg_module, + ) + + +@app.cell +def __(ctx, linalg_module, mo): + from xdsl.transforms.mlir_opt import MLIROptPass + + generalized_module = linalg_module.clone() + + MLIROptPass(generic=False, arguments=["--linalg-generalize-named-ops"]).apply( + ctx, generalized_module + ) + + print(generalized_module) + + mo.md( + """ + We can also call into MLIR, here to convert `linalg.add` to `linalg.generic`, a representation of Einstein summation. + """ + ) + return MLIROptPass, generalized_module + + +@app.cell +def __(MLIROptPass, ctx, generalized_module, mo): + bufferized_module = generalized_module.clone() + + MLIROptPass( + arguments=[ + "--empty-tensor-to-alloc-tensor", + "--one-shot-bufferize=bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map", + ] + ).apply(ctx, bufferized_module) + + print(bufferized_module) + + mo.md( + """ + We then use MLIR to bufferize our function. + """ + ) + return bufferized_module, + + +@app.cell +def __(MLIROptPass, bufferized_module, ctx): + scf_module = bufferized_module.clone() + + MLIROptPass( + arguments=["--convert-linalg-to-loops", "--lower-affine", "--canonicalize"] + ).apply(ctx, scf_module) + + print(scf_module) + return scf_module, + + +@app.cell +def __(ctx, scf_module): + from xdsl.backend.riscv.lowering import ( + convert_arith_to_riscv, + convert_func_to_riscv_func, + convert_memref_to_riscv, + convert_scf_to_riscv_scf, + ) + from xdsl.passes import PipelinePass + from xdsl.transforms import reconcile_unrealized_casts + + riscv_module = scf_module.clone() + + lower_to_riscv = PipelinePass( + [ + convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(), + convert_memref_to_riscv.ConvertMemrefToRiscvPass(), + convert_arith_to_riscv.ConvertArithToRiscvPass(), + convert_scf_to_riscv_scf.ConvertScfToRiscvPass(), + reconcile_unrealized_casts.ReconcileUnrealizedCastsPass(), + ] + ).apply(ctx, riscv_module) + + print(riscv_module) + return ( + PipelinePass, + convert_arith_to_riscv, + convert_func_to_riscv_func, + convert_memref_to_riscv, + convert_scf_to_riscv_scf, + lower_to_riscv, + reconcile_unrealized_casts, + riscv_module, + ) + + +@app.cell +def __(PipelinePass, ctx, riscv_module): + from xdsl.backend.riscv.lowering.convert_snitch_stream_to_snitch import ( + ConvertSnitchStreamToSnitch, + ) + from xdsl.transforms.canonicalize import CanonicalizePass + from xdsl.transforms.lower_snitch import LowerSnitchPass + from xdsl.transforms.riscv_register_allocation import RISCVRegisterAllocation + from xdsl.transforms.riscv_scf_loop_range_folding import ( + RiscvScfLoopRangeFoldingPass, + ) + from xdsl.transforms.snitch_register_allocation import SnitchRegisterAllocation + + regalloc_module = riscv_module.clone() + + PipelinePass( + [ + RISCVRegisterAllocation(), + CanonicalizePass(), + ] + ).apply(ctx, regalloc_module) + + print(regalloc_module) + return ( + CanonicalizePass, + ConvertSnitchStreamToSnitch, + LowerSnitchPass, + RISCVRegisterAllocation, + RiscvScfLoopRangeFoldingPass, + SnitchRegisterAllocation, + regalloc_module, + ) + + +@app.cell +def __(CanonicalizePass, ctx, regalloc_module): + from xdsl.backend.riscv.lowering.convert_riscv_scf_to_riscv_cf import ( + ConvertRiscvScfToRiscvCfPass, + ) + from xdsl.dialects.riscv import riscv_code + + assembly_module = regalloc_module.clone() + + ConvertRiscvScfToRiscvCfPass().apply(ctx, assembly_module) + CanonicalizePass().apply(ctx, assembly_module) + + print(assembly_module) + return ConvertRiscvScfToRiscvCfPass, assembly_module, riscv_code + + +@app.cell +def __(assembly_module, mo, riscv_code): + assembly = riscv_code(assembly_module) + + print(assembly) + + mo.md( + """ + This representation of the program in xDSL corresponds ~1:1 to RISC-V assembly, and we can use a helper function to print that out. + """ + ) + return assembly, + + +@app.cell +def __( + CanonicalizePass, + PipelinePass, + bufferized_module, + convert_arith_to_riscv, + convert_func_to_riscv_func, + convert_memref_to_riscv, + convert_scf_to_riscv_scf, + ctx, + mo, + reconcile_unrealized_casts, +): + from xdsl.transforms import ( + arith_add_fastmath, + convert_linalg_to_memref_stream, + convert_memref_stream_to_loops, + convert_memref_stream_to_snitch_stream, + convert_riscv_scf_for_to_frep, + dead_code_elimination, + loop_hoist_memref, + lower_affine, + memref_streamify, + ) + + snitch_stream_module = bufferized_module.clone() + + pass_pipeline = PipelinePass( + [ + convert_linalg_to_memref_stream.ConvertLinalgToMemrefStreamPass(), + memref_streamify.MemrefStreamifyPass(), + convert_memref_stream_to_loops.ConvertMemrefStreamToLoopsPass(), + convert_memref_stream_to_snitch_stream.ConvertMemrefStreamToSnitch(), + arith_add_fastmath.AddArithFastMathFlagsPass(), + loop_hoist_memref.LoopHoistMemrefPass(), + lower_affine.LowerAffinePass(), + convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(), + convert_memref_to_riscv.ConvertMemrefToRiscvPass(), + convert_arith_to_riscv.ConvertArithToRiscvPass(), + CanonicalizePass(), + convert_scf_to_riscv_scf.ConvertScfToRiscvPass(), + dead_code_elimination.DeadCodeElimination(), + reconcile_unrealized_casts.ReconcileUnrealizedCastsPass(), + convert_riscv_scf_for_to_frep.ConvertRiscvScfForToFrepPass(), + ] + ) + + pass_pipeline.apply(ctx, snitch_stream_module) + + print(snitch_stream_module) + + mo.md( + """ + ### Compiling to Snitch + + xDSL is also capable of targeting Snitch, and making use of its streaming registers and fixed-repetition loop. We use a different lowering flow from the linalg.generic representation to represent a high-level, structured, but Snitch-specific representation of the code: + """ + ) + return ( + arith_add_fastmath, + convert_linalg_to_memref_stream, + convert_memref_stream_to_loops, + convert_memref_stream_to_snitch_stream, + convert_riscv_scf_for_to_frep, + dead_code_elimination, + loop_hoist_memref, + lower_affine, + memref_streamify, + pass_pipeline, + snitch_stream_module, + ) + + +@app.cell +def __(ctx, mo, riscv_code, snitch_stream_module): + from xdsl.transforms import test_lower_snitch_stream_to_asm + + snitch_asm_module = snitch_stream_module.clone() + + test_lower_snitch_stream_to_asm.TestLowerSnitchStreamToAsm().apply( + ctx, snitch_asm_module + ) + + print(riscv_code(snitch_asm_module)) + + mo.md( + """ + We can then lower this to assembly that includes assembly instructions from the Snitch-extended ISA: + """ + ) + return snitch_asm_module, test_lower_snitch_stream_to_asm + + +if __name__ == "__main__": + app.run() diff --git a/pyproject.toml b/pyproject.toml index fb94f4b883..db07105e2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,6 +128,7 @@ target-version = "py310" "tests/test_declarative_assembly_format.py" = ["F811"] "versioneer.py" = ["ALL"] "_version.py" = ["ALL"] +"**/{docs/marimo}/*" = ["E501"] [tool.ruff.mccabe] max-complexity = 10