diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 4fa5daa91..beeb35883 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -23,6 +23,7 @@ - [Project Structure](./project-structure.md) - [Dialects Overview](./dialects-overview.md) - [Adding an op](./adding-an-op.md) +- [Decomposing an op in TTIR](./decomposing-an-op-in-ttir.md) - [Doxygen](./doxygen.md) - [Specifications](./specs/specs.md) - [Runtime Stitching](./specs/runtime-stitching.md) diff --git a/docs/src/decomposing-an-op-in-ttir.md b/docs/src/decomposing-an-op-in-ttir.md new file mode 100644 index 000000000..ba74d94ba --- /dev/null +++ b/docs/src/decomposing-an-op-in-ttir.md @@ -0,0 +1,53 @@ +# Decomposing an Op in TTIR + +This guide explains how to add and decompose a new operation in the TTIR dialect. We’ll focus on adding an `Index` operation, which will be decomposed into the `Slice` operation. The decomposition is implemented as a conversion pass in MLIR since it allows us to mark operations or dialects as legal or illegal, type conversion... + +This guide will cover the following steps: +- [Decomposing an Op in TTIR](#decomposing-an-op-in-ttir) + - [1. Define the Op in the TTIR frontend dialect](#1-define-the-op-in-the-ttir-frontend-dialect) + - [2. Create a conversion pattern](#2-create-a-conversion-pattern) + - [`C++ conversion pattern`](#c-conversion-pattern) + - [`Tablegen conversion pattern`](#tablegen-conversion-pattern) + - [3. Register the created conversion pattern](#3-register-the-created-conversion-pattern) + +## 1. Define the Op in the TTIR frontend dialect + +The more information regarding this step can be found here: [Define the Op in the TTIR frontend dialect](./adding-an-op.md#1-define-the-op-in-the-ttir-frontend-dialect) + +I updated the `TTIROps.td` as following: + +```td +{{#include ../../../include/ttmlir/Dialect/TTIR/IR/TTIROps.td:adding_an_op_index_ttir}} +``` + +The verification function has been added as well: + +```cpp +{{#include ../../../lib/Dialect/TTIR/IR/TTIROps.cpp:adding_an_op_index_ttir}} +``` + +## 2. Create a conversion pattern + +A conversion pattern defines how MLIR should rewrite the Op. It can be implemented in either C++ or TableGen. Currently, we only have the C++ implementation; TableGen format will be added in the future. + +#### `C++ conversion pattern` + +For the `Index` operation, we use the C++ conversion pattern because it involves changing the Op’s input types from integers to arrays, which TableGen lacks flexibility for. + +``` +{{#include ../../../include/ttmlir/Dialect/TTNN/IR/TTNNOps.td:adding_an_op_index_ttir}} +``` + +The `matchAndRewrite` method from `OpConversionPattern` is implemented to replace the matched Op with the newly created Op. Since decomposition is implemented as a conversion pass, `OpAdaptor` is used to access the attributes of the original Op in their converted types. Finally, we instantiate the new Op and call the `replaceOp` method on `ConversionPatternRewriter` to replace the original Op. + +#### `Tablegen conversion pattern` +TODO + +## 3. Register the created conversion pattern + +To register the new pattern, go to the `populateTTIRToTTIRDecompositionPatterns` function in `TTIRToTTIRDecomposition.cpp` and add it to `RewritePatternSet` using the add method. After that is done you should mark the decomposed op as illegal in `runOnOperation` method of `TTIRToTTIRDecompositionPass` in `TTIRToTTIRDecompositionPass.cpp`. + +You should also add a silicon test like described here: [Add a silicon unit test for the Op](./adding-an-op.md##8-add-a-silicon-unit-test-for-the-op). This is how the silicon test for the `Index` operation looks like: +```mlir +{{#include ../../../test/ttmlir/Silicon/TTNN/simple_index.mlir}} +``` diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 381e3750f..d26a3f6c0 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -785,6 +785,7 @@ def TTIR_SliceOp: TTIR_DPSOp<"slice"> { let hasVerifier = 1; } +// ANCHOR: adding_an_op_index_ttir def TTIR_IndexOp: TTIR_DPSOp<"index"> { let summary = "Index op."; let description = [{ @@ -809,6 +810,7 @@ def TTIR_IndexOp: TTIR_DPSOp<"index"> { let hasVerifier = 1; } +// ANCHOR: adding_an_op_index_ttir def TTIR_SqueezeOp : TTIR_DPSOp<"squeeze"> { let summary = "Squeeze op."; diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index d361fce1f..f91664616 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -30,8 +30,11 @@ using namespace mlir::tt; namespace mlir::tt { -// Decompose IndexOp into SliceOp -// +//===----------------------------------------------------------------------===// +// IndexOp decomposition +//===----------------------------------------------------------------------===// + +// ANCHOR: adding_an_op_index_ttir // This transformation adjusts IndexOp attributes so that `begin`, `end`, and // `step` become arrays, where each array element corresponds to a dimension of // the input tensor. For dimensions other than the sliced dimension, default @@ -73,6 +76,7 @@ struct IndexToSliceConversionPattern return success(); } }; +// ANCHOR_END: adding_an_op_index_ttir //===----------------------------------------------------------------------===// // Convolution passes diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 3ae2a7bad..83a193d36 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -394,6 +394,7 @@ ::mlir::LogicalResult mlir::tt::ttir::SliceOp::verify() { // IndexOp //===----------------------------------------------------------------------===// +// ANCHOR: adding_an_op_index_ttir // IndexOp verification ::mlir::LogicalResult mlir::tt::ttir::IndexOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); @@ -499,6 +500,7 @@ ::mlir::LogicalResult mlir::tt::ttir::IndexOp::verify() { return success(); } +// ANCHOR: adding_an_op_index_ttir //===----------------------------------------------------------------------===// // SqueezeOp