From 392c0b01c30963f47fb3cae377c2f748b9a72fae Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Thu, 27 Feb 2025 01:02:18 +0000 Subject: [PATCH] Pass to wrap stablehlo ops in composite --- BUILD.bazel | 1 + docs/generated/stablehlo_passes.md | 53 +++++ .../stablehlo_wrap_in_composite.mlir | 93 ++++++++ stablehlo/transforms/CMakeLists.txt | 1 + stablehlo/transforms/Passes.h | 9 + stablehlo/transforms/Passes.td | 56 +++++ .../transforms/StablehloWrapInComposite.cpp | 216 ++++++++++++++++++ 7 files changed, 429 insertions(+) create mode 100644 stablehlo/tests/transforms/stablehlo_wrap_in_composite.mlir create mode 100644 stablehlo/transforms/StablehloWrapInComposite.cpp diff --git a/BUILD.bazel b/BUILD.bazel index 6eafc665da..33d6d0b774 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1239,6 +1239,7 @@ cc_library( "stablehlo/transforms/StablehloLegalizeToVhlo.cpp", "stablehlo/transforms/StablehloRefineArguments.cpp", "stablehlo/transforms/StablehloRefineShapes.cpp", + "stablehlo/transforms/StablehloWrapInComposite.cpp", "stablehlo/transforms/VhloLegalizeToStablehlo.cpp", "stablehlo/transforms/VhloToVersion.cpp", ], diff --git a/docs/generated/stablehlo_passes.md b/docs/generated/stablehlo_passes.md index e261743930..c06f4ad1e1 100755 --- a/docs/generated/stablehlo_passes.md +++ b/docs/generated/stablehlo_passes.md @@ -338,6 +338,59 @@ Modules valid for shape refinement must have the following properties: * All calls to a single function resolve to the same argument shapes, and no recursive / co-recursive function calls are made. +### `-stablehlo-wrap-in-composite` + +_Wraps a non-composite StableHLO op in a composite op._ + +Wraps StableHLO ops, as specified by the pass option flag, in a +composite op. The composite op will inherit all attributes of the original +op. + +For example, using the pass option `--stablehlo-wrap-in-composite=op-names='stablehlo.add'`, + +```mlir +func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} +``` + +will become: + +```mlir +func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.composite "stablehlo.add" %arg0, %arg1 { + decomposition = @stablehlo.add.impl, + } : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +func.func private @stablehlo.add.impl(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} +``` + +The pass is also exposed as an API `createStablehloWrapInCompositePass` to +allow for more flexible selection of ops to wrap. +For example, the following will wrap all non-composite ops that are not +`stablehlo.add` or `stablehlo.convolution`: + +```c++ +auto pass = createStablehloWrapInCompositePass( + (Operation *op) { + return (op->getName().getStringRef() == "stablehlo.add" || + op->getName().getStringRef() == "stablehlo.convolution") && + !isa(op); + }); +``` + +#### Options + +``` +-op-names : The names of the ops to wrap. +``` + ### `-vhlo-legalize-to-stablehlo` _Legalize VHLO to StableHLO._ diff --git a/stablehlo/tests/transforms/stablehlo_wrap_in_composite.mlir b/stablehlo/tests/transforms/stablehlo_wrap_in_composite.mlir new file mode 100644 index 0000000000..0a5d6551eb --- /dev/null +++ b/stablehlo/tests/transforms/stablehlo_wrap_in_composite.mlir @@ -0,0 +1,93 @@ +// RUN: stablehlo-opt --stablehlo-wrap-in-composite=op-names='stablehlo.add,stablehlo.convolution,stablehlo.reduce' --split-input-file --verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: func.func @wrap_in_composite +// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x8x8x8xi8>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x4x8x32xi8>, +// CHECK-SAME: %[[ARG_2:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> { +// CHECK: %[[CONV:.*]] = stablehlo.composite "stablehlo.convolution" %[[ARG_0]], %[[ARG_1]] { +// CHECK-SAME: composite_attributes = {batch_group_count = 1 : i64, +// CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, +// CHECK-SAME: feature_group_count = 1 : i64, +// CHECK-SAME{LITERAL}: padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, +// CHECK-SAME{LITERAL}: rhs_dilation = array, +// CHECK-SAME{LITERAL}: window_strides = array}, +// CHECK-SAME: decomposition = @stablehlo.convolution.impl} : (tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32> +// CHECK: %[[ADD:.*]] = stablehlo.composite "stablehlo.add" %[[CONV]], %[[ARG_2]] {decomposition = @stablehlo.add.impl} : (tensor<64x3x3x32xi32>, tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> +// CHECK-NEXT: %[[ADD1:.*]] = stablehlo.composite "stablehlo.add" %[[ADD]], %[[ADD]] {decomposition = @stablehlo.add.impl1} : (tensor<64x3x3x32xi32>, tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> +// CHECK-NEXT: return %[[ADD1]] + +// CHECK-LABEL: func.func private @stablehlo.convolution.impl +// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x8x8x8xi8>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32> { +// CHECK: %[[VAL:.*]] = stablehlo.convolution(%[[ARG_0]], %[[ARG_1]]) +// CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], +// CHECK-SAME{LITERAL}: stride = [1, 1], +// CHECK-SAME{LITERAL}: pad = [[0, 1], [0, 1]], +// CHECK-SAME{LITERAL}: rhs_dilate = [2, 2]} +// CHECK-SAME: batch_group_count = 1 : i64 +// CHECK-SAME: feature_group_count = 1 : i64 +// CHECK-SAME: : (tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32> +// CHECK-NEXT: return %[[VAL]] + +// CHECK-LABEL: func.func private @stablehlo.add.impl +// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x3x3x32xi32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> { +// CHECK: %[[VAL:.*]] = stablehlo.add %[[ARG_0]], %[[ARG_1]] : tensor<64x3x3x32xi32> +// CHECK-NEXT: return %[[VAL]] + +// CHECK-LABEL: func.func private @stablehlo.add.impl1 +// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x3x3x32xi32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> { +// CHECK: %[[VAL:.*]] = stablehlo.add %[[ARG_1]], %[[ARG_1]] : tensor<64x3x3x32xi32> +// CHECK-NEXT: return %[[VAL]] + +func.func @wrap_in_composite( + %arg0: tensor<64x8x8x8xi8>, + %arg1: tensor<4x4x8x32xi8>, + %arg2: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> { + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : + (tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32> + %1 = stablehlo.add %0, %arg2 : tensor<64x3x3x32xi32> + %2 = stablehlo.add %1, %1 : tensor<64x3x3x32xi32> + func.return %2 : tensor<64x3x3x32xi32> +} + +// ----- + +// CHECK-LABEL: func.func @wrap_in_composite_op_with_region +// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x3xf32>) -> tensor<4xf32> +// CHECK: %[[CONST:.*]] = stablehlo.constant +// CHECK-NEXT: %[[COMPOSITE_REDUCE:.*]] = stablehlo.composite "stablehlo.reduce" %[[ARG_0]], %[[CONST]] { +// CHECK-SAME: composite_attributes = { +// CHECK-SAME: dimensions = array}, +// CHECK-SAME: decomposition = @stablehlo.reduce.impl} +// CHECK-SAME: (tensor<4x3xf32>, tensor) -> tensor<4xf32> +// CHECK-NEXT: return %[[COMPOSITE_REDUCE]] + +// CHECK-LABEL: func.func private @stablehlo.reduce.impl +// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x3xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor) -> tensor<4xf32> { +// CHECK: %[[REDUCE:.*]] = stablehlo.reduce(%[[ARG_0]] init: %[[ARG_1]]) +// CHECK-SAME{LITERAL}: applies stablehlo.add across dimensions = [1] +// CHECK-SAME: (tensor<4x3xf32>, tensor) -> tensor<4xf32> +// CHECK-NEXT: return %[[REDUCE]] +func.func @wrap_in_composite_op_with_region(%x : tensor<4x3xf32>) -> tensor<4xf32> { + %cst = stablehlo.constant dense<2.7> : tensor + %res = stablehlo.reduce(%x init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<4x3xf32>, tensor) -> tensor<4xf32> + func.return %res : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @cannot_be_wrapped_ops_does_not_match +// CHECK-SAME: %[[ARG_0:.*]]: tensor<2xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<2xf32>) -> tensor<2xf32> { +// CHECK: %[[VAL:.*]] = stablehlo.multiply %[[ARG_0]], %[[ARG_1]] : tensor<2xf32> +// CHECK-NEXT: return %[[VAL]] : tensor<2xf32> +func.func @cannot_be_wrapped_ops_does_not_match(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.multiply %arg0, %arg1 : tensor<2xf32> + func.return %0 : tensor<2xf32> +} diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index 4787369d37..a7fada9b3c 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -57,6 +57,7 @@ add_mlir_dialect_library(StablehloPasses StablehloLegalizeToVhlo.cpp StablehloRefineArguments.cpp StablehloRefineShapes.cpp + StablehloWrapInComposite.cpp VhloLegalizeToStablehlo.cpp VhloToVersion.cpp PassUtils.cpp diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index 055768cacd..e8060e5043 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef STABLEHLO_TRANSFORMS_PASSES_H #define STABLEHLO_TRANSFORMS_PASSES_H +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -102,6 +103,14 @@ void populateStablehloCompatibilityExpanderPatterns( std::unique_ptr> createStablehloRefineArgumentsPass( TypeRange refinedTypes); +/// Creates a pass that wraps StableHLO ops in CompositeOp. +/// +/// The pass will wrap the StableHLO ops that match the given opPredicate +/// function in CompositeOp. The opPredicate function should return true if the +/// op should be wrapped in CompositeOp. +std::unique_ptr> createStablehloWrapInCompositePass( + std::function opPredicate); + //// Pass pipelines //// // StableHLO consumers can add this pipeline to convert portable artifacts to diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index e0d9f317e1..abf090a758 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -409,3 +409,59 @@ def VhloToVersionPass : Pass<"vhlo-to-version"> { ]; let dependentDialects = ["mlir::vhlo::VhloDialect"]; } + +def StablehloWrapInCompositePass : Pass<"stablehlo-wrap-in-composite", "ModuleOp"> { + let summary = "Wraps a non-composite StableHLO op in a composite op."; + let description = [{ + Wraps StableHLO ops, as specified by the pass option flag, in a + composite op. The composite op will inherit all attributes of the original + op. + + For example, using the pass option `--stablehlo-wrap-in-composite=op-names='stablehlo.add'`, + + ```mlir + func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + } + ``` + + will become: + + ```mlir + func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.composite "stablehlo.add" %arg0, %arg1 { + decomposition = @stablehlo.add.impl, + } : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> + } + + func.func private @stablehlo.add.impl(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + } + ``` + + The pass is also exposed as an API `createStablehloWrapInCompositePass` to + allow for more flexible selection of ops to wrap. + For example, the following will wrap all non-composite ops that are not + `stablehlo.add` or `stablehlo.convolution`: + + ```c++ + auto pass = createStablehloWrapInCompositePass( + (Operation *op) { + return (op->getName().getStringRef() == "stablehlo.add" || + op->getName().getStringRef() == "stablehlo.convolution") && + !isa(op); + }); + ``` + }]; + let options = [ + ListOption<"opNamesOption", "op-names", "std::string", + "The names of the ops to wrap.">, + ]; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} diff --git a/stablehlo/transforms/StablehloWrapInComposite.cpp b/stablehlo/transforms/StablehloWrapInComposite.cpp new file mode 100644 index 0000000000..762d6a6473 --- /dev/null +++ b/stablehlo/transforms/StablehloWrapInComposite.cpp @@ -0,0 +1,216 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DEF_STABLEHLOWRAPINCOMPOSITEPASS +#include "stablehlo/transforms/Passes.h.inc" + +namespace { + +// Generates a unique function name based on the given `baseFuncName` within +// the provided `module`. Ensures the generated name does not clash with any +// existing symbols by appending a counter if necessary. +std::string generateUniqueFunctionName(StringRef baseFuncName, + mlir::ModuleOp module) { + mlir::SymbolTable symbolTable(module); + int counter = 0; + std::string baseName = baseFuncName.str() + ".impl"; + std::string funcName = baseName; + while (symbolTable.lookup(funcName)) { + counter++; + funcName = (baseName + std::to_string(counter)); + } + return funcName; +} + +// Builds a new function within the given `module` that encapsulates the +// functionality of the provided `implOp`. The new function is named uniquely +// and is set to private visibility. +mlir::func::FuncOp buildStableHLOCompositeImplFunc(mlir::ModuleOp module, + mlir::Operation* implOp) { + // Create an OpBuilder, insertion point at the end of module's body. + mlir::OpBuilder builder(module); + builder.setInsertionPointToEnd(&module.getBodyRegion().back()); + + // Prepare argument types and locations for the new function. + llvm::SmallVector argLocs; + llvm::SmallVector argTypes; + for (auto& operand : implOp->getOpOperands()) { + argTypes.push_back( + operand.get().getType()); // Get the type of each operand. + argLocs.push_back( + operand.get().getLoc()); // Get the location of each operand. + } + + // Prepare result types for the new function. + llvm::SmallVector resultTypes; + for (auto result : implOp->getResults()) { + resultTypes.push_back(result.getType()); // Get the type of each result. + } + + // Create the function operation. + auto uniqueFuncName = + generateUniqueFunctionName(implOp->getName().getStringRef(), module); + mlir::func::FuncOp implFunc = builder.create( + module.getLoc(), uniqueFuncName, + builder.getFunctionType(argTypes, + resultTypes)); // Set arg and result types + + // Create a block in the function body representing the function's content + // and map the arguments from the original op to the new function. + mlir::IRMapping mapping; // Maps values from the old op to the new function. + builder.createBlock(&implFunc.getBody(), implFunc.begin(), argTypes, argLocs); + + // Map the operands of the original op to the arguments of the newly created + // function. + for (const auto& operand : llvm::enumerate(implOp->getOperands())) { + mapping.map(operand.value(), implFunc.getArgument(operand.index())); + } + + // Clone the original operation into the body of the new function, + // using the value mapping to remap operands. + mlir::Operation* cloned_op = builder.clone(*implOp, mapping); + + // Create the return operation, returning the results of the cloned op. + llvm::SmallVector results; + results.append(cloned_op->getResults().begin(), + cloned_op->getResults().end()); + builder.create(implFunc.getBody().getLoc(), results); + + // Add the newly created function to the module's symbol table and make it + // private. + mlir::SymbolTable symbol_table(module); + implFunc.setPrivate(); + symbol_table.insert(implFunc); + + return implFunc; +} + +// A ConversionPattern that matches any operation and rewrites it as a +// stablehlo::CompositeOp. The original operation's functionality is +// encapsulated within a newly created private function. +class ConvertGenericOp : public ConversionPattern { + public: + explicit ConvertGenericOp(MLIRContext* ctx) + : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const override { + // get the enclosing module + auto module = op->getParentOfType(); + MLIRContext* context = op->getContext(); + if (module == nullptr) { + return rewriter.notifyMatchFailure(op, "Failed to find enclosing module"); + } + auto implFunc = buildStableHLOCompositeImplFunc(module, op); + auto name = op->getName().getStringRef(); + + llvm::SmallVector compositeOperands(op->operand_begin(), + op->operand_end()); + auto compositeOp = rewriter.create( + op->getLoc(), op->getResultTypes(), compositeOperands, name, + DictionaryAttr::get(context, op->getAttrs()), implFunc.getSymName()); + rewriter.replaceOp(op, compositeOp.getResults()); + return success(); + } +}; + +class StablehloWrapInCompositePass + : public impl::StablehloWrapInCompositePassBase< + StablehloWrapInCompositePass> { + public: + StablehloWrapInCompositePass() + : StablehloWrapInCompositePassBase() {} + StablehloWrapInCompositePass(const StablehloWrapInCompositePassOptions& opts) + : StablehloWrapInCompositePassBase(opts) {} + explicit StablehloWrapInCompositePass( + std::function opPredicate) { + this->opPredicate = opPredicate; + } + LogicalResult initialize(MLIRContext* context) override { + RewritePatternSet patterns_(context); + patterns_.add(context); + patterns = std::move(patterns_); + + if (!opPredicate) { + if (!opNamesOption.empty()) { + DenseSet opNames(opNamesOption.begin(), opNamesOption.end()); + opPredicate = [opNames](Operation* op) { + return opNames.contains(op->getName().getStringRef()); + }; + } else { + opPredicate = [](Operation* op) { return false; }; + } + } + return success(); + } + + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addDynamicallyLegalDialect( + [this](Operation* op) { return !opPredicate(op); }); + + Operation* op = getOperation(); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + getOperation().emitError("Wrap in composite pass failed."); + signalPassFailure(); + } + } + + private: + // FrozenRewritePatternSet for the pass. + FrozenRewritePatternSet patterns; + // Predicate function to determine which operations should be wrapped. + std::function opPredicate = nullptr; +}; + +} // namespace + +std::unique_ptr> createStablehloWrapInCompositePass( + std::function opPredicate) { + return std::make_unique(opPredicate); +} + +} // namespace stablehlo +} // namespace mlir