Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass to wrap StableHLO ops in composite #2722

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,7 @@ cc_library(
"stablehlo/transforms/StablehloLegalizeToVhlo.cpp",
"stablehlo/transforms/StablehloRefineArguments.cpp",
"stablehlo/transforms/StablehloRefineShapes.cpp",
"stablehlo/transforms/StablehloWrapInComposite.cpp",
"stablehlo/transforms/VhloLegalizeToStablehlo.cpp",
"stablehlo/transforms/VhloToVersion.cpp",
],
Expand Down
53 changes: 53 additions & 0 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,59 @@ Modules valid for shape refinement must have the following properties:
* All calls to a single function resolve to the same argument shapes, and no
recursive / co-recursive function calls are made.

### `-stablehlo-wrap-in-composite`

_Wraps a non-composite StableHLO op in a composite op._

Wraps StableHLO ops, as specified by the pass option flag, in a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's talk about the design more. It's missing a few things to be generally useful:

  • What to do with ops that have attributes?
  • What should the composites be named, or should the name be user specified as a part of the callback fn, etc

Need to think more, maybe we have a little brainstorm and run through some examples

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good points! SGTM.

What to do with ops that have attributes?

All the op attributes are preserved as composite attributes. To demonstrate that I have used convolution op as an example in stablehlo/tests/transforms/stablehlo_wrap_in_composite.mlir.

composite op. The composite op will inherit all attributes of the original
op.

For example, using the pass option `--stablehlo-wrap-in-composite=op-names='stablehlo.add'`,

```mlir
func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
```

will become:

```mlir
func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.composite "stablehlo.add" %arg0, %arg1 {
decomposition = @stablehlo.add.impl,
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}

func.func private @stablehlo.add.impl(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
```

The pass is also exposed as an API `createStablehloWrapInCompositePass` to
allow for more flexible selection of ops to wrap.
For example, the following will wrap all non-composite ops that are not
`stablehlo.add` or `stablehlo.convolution`:

```c++
auto pass = createStablehloWrapInCompositePass(
(Operation *op) {
return (op->getName().getStringRef() == "stablehlo.add" ||
op->getName().getStringRef() == "stablehlo.convolution") &&
!isa<stablehlo::CompositeOp>(op);
});
```

#### Options

```
-op-names : The names of the ops to wrap.
```

### `-vhlo-legalize-to-stablehlo`

_Legalize VHLO to StableHLO._
Expand Down
93 changes: 93 additions & 0 deletions stablehlo/tests/transforms/stablehlo_wrap_in_composite.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// RUN: stablehlo-opt --stablehlo-wrap-in-composite=op-names='stablehlo.add,stablehlo.convolution,stablehlo.reduce' --split-input-file --verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: func.func @wrap_in_composite
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x8x8x8xi8>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x4x8x32xi8>,
// CHECK-SAME: %[[ARG_2:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
// CHECK: %[[CONV:.*]] = stablehlo.composite "stablehlo.convolution" %[[ARG_0]], %[[ARG_1]] {
// CHECK-SAME: composite_attributes = {batch_group_count = 1 : i64,
// CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
// CHECK-SAME: feature_group_count = 1 : i64,
// CHECK-SAME{LITERAL}: padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
// CHECK-SAME{LITERAL}: rhs_dilation = array<i64: 2, 2>,
// CHECK-SAME{LITERAL}: window_strides = array<i64: 1, 1>},
// CHECK-SAME: decomposition = @stablehlo.convolution.impl} : (tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32>
// CHECK: %[[ADD:.*]] = stablehlo.composite "stablehlo.add" %[[CONV]], %[[ARG_2]] {decomposition = @stablehlo.add.impl} : (tensor<64x3x3x32xi32>, tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32>
// CHECK-NEXT: %[[ADD1:.*]] = stablehlo.composite "stablehlo.add" %[[ADD]], %[[ADD]] {decomposition = @stablehlo.add.impl1} : (tensor<64x3x3x32xi32>, tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32>
// CHECK-NEXT: return %[[ADD1]]

// CHECK-LABEL: func.func private @stablehlo.convolution.impl
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x8x8x8xi8>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32> {
// CHECK: %[[VAL:.*]] = stablehlo.convolution(%[[ARG_0]], %[[ARG_1]])
// CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
// CHECK-SAME{LITERAL}: stride = [1, 1],
// CHECK-SAME{LITERAL}: pad = [[0, 1], [0, 1]],
// CHECK-SAME{LITERAL}: rhs_dilate = [2, 2]}
// CHECK-SAME: batch_group_count = 1 : i64
// CHECK-SAME: feature_group_count = 1 : i64
// CHECK-SAME: : (tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32>
// CHECK-NEXT: return %[[VAL]]

// CHECK-LABEL: func.func private @stablehlo.add.impl
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x3x3x32xi32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
// CHECK: %[[VAL:.*]] = stablehlo.add %[[ARG_0]], %[[ARG_1]] : tensor<64x3x3x32xi32>
// CHECK-NEXT: return %[[VAL]]

// CHECK-LABEL: func.func private @stablehlo.add.impl1
// CHECK-SAME: %[[ARG_0:.*]]: tensor<64x3x3x32xi32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
// CHECK: %[[VAL:.*]] = stablehlo.add %[[ARG_1]], %[[ARG_1]] : tensor<64x3x3x32xi32>
// CHECK-NEXT: return %[[VAL]]

func.func @wrap_in_composite(
%arg0: tensor<64x8x8x8xi8>,
%arg1: tensor<4x4x8x32xi8>,
%arg2: tensor<64x3x3x32xi32>) -> tensor<64x3x3x32xi32> {
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [2, 2]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64} :
(tensor<64x8x8x8xi8>, tensor<4x4x8x32xi8>) -> tensor<64x3x3x32xi32>
%1 = stablehlo.add %0, %arg2 : tensor<64x3x3x32xi32>
%2 = stablehlo.add %1, %1 : tensor<64x3x3x32xi32>
func.return %2 : tensor<64x3x3x32xi32>
}

// -----
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have an example with regions like reduce?
Should this pass run as part of a pass pipeline? If we run a decomposer pass that introduces stablehlo.convolution, they also need to get wrapped right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Should this pass run as part of a pass pipeline?

Yes, that is the more practical usage of the pass.

If we run a decomposer pass that introduces stablehlo.convolution, they also need to get wrapped right?

If we have a decomposer pass and it emit a convolution and if the user intend to wrap that conv into a composite, then yes it will be wrapped.
Let me know if that address your concern.


// CHECK-LABEL: func.func @wrap_in_composite_op_with_region
// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x3xf32>) -> tensor<4xf32>
// CHECK: %[[CONST:.*]] = stablehlo.constant
// CHECK-NEXT: %[[COMPOSITE_REDUCE:.*]] = stablehlo.composite "stablehlo.reduce" %[[ARG_0]], %[[CONST]] {
// CHECK-SAME: composite_attributes = {
// CHECK-SAME: dimensions = array<i64: 1>},
// CHECK-SAME: decomposition = @stablehlo.reduce.impl}
// CHECK-SAME: (tensor<4x3xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-NEXT: return %[[COMPOSITE_REDUCE]]

// CHECK-LABEL: func.func private @stablehlo.reduce.impl
// CHECK-SAME: %[[ARG_0:.*]]: tensor<4x3xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<4xf32> {
// CHECK: %[[REDUCE:.*]] = stablehlo.reduce(%[[ARG_0]] init: %[[ARG_1]])
// CHECK-SAME{LITERAL}: applies stablehlo.add across dimensions = [1]
// CHECK-SAME: (tensor<4x3xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-NEXT: return %[[REDUCE]]
func.func @wrap_in_composite_op_with_region(%x : tensor<4x3xf32>) -> tensor<4xf32> {
%cst = stablehlo.constant dense<2.7> : tensor<f32>
%res = stablehlo.reduce(%x init: %cst) applies stablehlo.add across dimensions = [1] : (tensor<4x3xf32>, tensor<f32>) -> tensor<4xf32>
func.return %res : tensor<4xf32>
}

// -----

// CHECK-LABEL: func.func @cannot_be_wrapped_ops_does_not_match
// CHECK-SAME: %[[ARG_0:.*]]: tensor<2xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: %[[VAL:.*]] = stablehlo.multiply %[[ARG_0]], %[[ARG_1]] : tensor<2xf32>
// CHECK-NEXT: return %[[VAL]] : tensor<2xf32>
func.func @cannot_be_wrapped_ops_does_not_match(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.multiply %arg0, %arg1 : tensor<2xf32>
func.return %0 : tensor<2xf32>
}
1 change: 1 addition & 0 deletions stablehlo/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ add_mlir_dialect_library(StablehloPasses
StablehloLegalizeToVhlo.cpp
StablehloRefineArguments.cpp
StablehloRefineShapes.cpp
StablehloWrapInComposite.cpp
VhloLegalizeToStablehlo.cpp
VhloToVersion.cpp
PassUtils.cpp
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef STABLEHLO_TRANSFORMS_PASSES_H
#define STABLEHLO_TRANSFORMS_PASSES_H

#include <functional>
#include <memory>

#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -102,6 +103,14 @@ void populateStablehloCompatibilityExpanderPatterns(
std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
TypeRange refinedTypes);

/// Creates a pass that wraps StableHLO ops in CompositeOp.
///
/// The pass will wrap the StableHLO ops that match the given opPredicate
/// function in CompositeOp. The opPredicate function should return true if the
/// op should be wrapped in CompositeOp.
std::unique_ptr<OperationPass<ModuleOp>> createStablehloWrapInCompositePass(
std::function<bool(Operation *)> opPredicate);

//// Pass pipelines ////

// StableHLO consumers can add this pipeline to convert portable artifacts to
Expand Down
56 changes: 56 additions & 0 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,59 @@ def VhloToVersionPass : Pass<"vhlo-to-version"> {
];
let dependentDialects = ["mlir::vhlo::VhloDialect"];
}

def StablehloWrapInCompositePass : Pass<"stablehlo-wrap-in-composite", "ModuleOp"> {
let summary = "Wraps a non-composite StableHLO op in a composite op.";
let description = [{
Wraps StableHLO ops, as specified by the pass option flag, in a
composite op. The composite op will inherit all attributes of the original
op.

For example, using the pass option `--stablehlo-wrap-in-composite=op-names='stablehlo.add'`,

```mlir
func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
```

will become:

```mlir
func.func @add(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = stablehlo.composite "stablehlo.add" %arg0, %arg1 {
decomposition = @stablehlo.add.impl,
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}

func.func private @stablehlo.add.impl(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
```

The pass is also exposed as an API `createStablehloWrapInCompositePass` to
allow for more flexible selection of ops to wrap.
For example, the following will wrap all non-composite ops that are not
`stablehlo.add` or `stablehlo.convolution`:

```c++
auto pass = createStablehloWrapInCompositePass(
(Operation *op) {
return (op->getName().getStringRef() == "stablehlo.add" ||
op->getName().getStringRef() == "stablehlo.convolution") &&
!isa<stablehlo::CompositeOp>(op);
});
```
}];
let options = [
ListOption<"opNamesOption", "op-names", "std::string",
"The names of the ops to wrap.">,
];
let dependentDialects = [
"mlir::func::FuncDialect",
"mlir::stablehlo::StablehloDialect",
];
}
Loading