Skip to content

Commit

Permalink
Use llvm error instead of asserts.
Browse files Browse the repository at this point in the history
This ensures the check isn't optimized away by the compiler.

PiperOrigin-RevId: 680546512
  • Loading branch information
Google-ML-Automation authored and copybara-github committed Oct 1, 2024
1 parent c464210 commit f810ee8
Show file tree
Hide file tree
Showing 25 changed files with 1,386 additions and 4,159 deletions.
57 changes: 57 additions & 0 deletions docs/sdy_dialect.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,63 @@ Interfaces: `Symbol`
</table>


### `sdy.named_computation` (sdy::NamedComputationOp)

_Named computation operation_


Syntax:

```
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
```

Groups a computation, i.e. a block of operations, and gives it a name.
Propagation will flow in/out of the region as if everything was inlined.

This can be used to handle propagating through call instructions to other
functions. Any users of Shardy should write an import/export pass that
converts their call ops to `sdy.named_computation` ops, duplicating/copying
the body of the called function into the body of the `named_computation`.

The type of each block arguments and returned values in the region must be
the same as the type of the operands and results type of the op.

Example:

```mlir
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
```

Traits: `IsolatedFromAbove`, `RecursiveMemoryEffects`, `RecursivelySpeculatableImplTrait`, `SingleBlockImplicitTerminator<ReturnOp>`, `SingleBlock`

Interfaces: `ConditionallySpeculatable`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `operands` | variadic of any type

#### Results:

| Result | Description |
| :----: | ----------- |
&laquo;unnamed&raquo; | variadic of any type


### `sdy.propagation_barrier` (sdy::PropagationBarrierOp)

_Propagation barrier operation_
Expand Down
49 changes: 49 additions & 0 deletions docs/sdy_export_passes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,53 @@
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
### `-sdy-insert-explicit-reshards`

_Inserts explicit reshards to make all operations have compatible shardings._

A compatible sharding essentially means that the operation can accept the
sharded operands and produce a sharded result without requiring any reshard
communications (note that the operation might still require communication
such as all-reduce or halo-swaps).

After propagation, some opeartions may still have incompatible shardings.

Please note, when an axis (or sub-axis) is used to shard non-corresponding
dimensions (e.g. non-contracting dimensions in matmul) across multiple
tensors, or when an axis shards a dimension in one tensor but not the
corresponding dimension in the other tensor, it is said that the operation
has a sharding conflict. Hence, after this pass, the opeartions become
conflict-free.

This pass injects reshard operations explicitly so that, for each operation,
corresponding dimensions become sharded in the same way across all operands
and results, and every axis (or sub-axis) can only be used to shard a single
dimension type.

A clarifying example:

Input:
```mlir
mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

Output:
```mlir
sdy.mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32>
stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

In the example above, there is a conflict since `lhs` and `rhs` tensors
are both sharded on axis "x" on their non-contracting dimensions. Here,
`rhs` tensor is resharded, before the dot operation, explicitly to be
sharded only on its first dimension and on axis "x". This way, the dot
opearation becomes compatible.
### `-sdy-sharding-constraint-to-reshard`

_Converts ShardingConstraintOp into ReshardOp._
Expand Down
4 changes: 2 additions & 2 deletions shardy/dialect/sdy/ir/dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ struct ShardyDialectInlinerInterface : public DialectInlinerInterface {
return true;
}

// ManualComputationOp is an op with a region, and it should be allowed to be
// inlined into another op.
// `ManualComputationOp` and `NamedComputationOp` are ops with a region, and
// it should be allowed to be inlined into another op.
bool isLegalToInline(Region*, Region*, bool, IRMapping&) const final {
return true;
}
Expand Down
46 changes: 46 additions & 0 deletions shardy/dialect/sdy/ir/ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,50 @@ def PropagationBarrierOp : Sdy_Op<"propagation_barrier",
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// NamedComputationOp
//===----------------------------------------------------------------------===//

def NamedComputationOp : Sdy_Op<"named_computation",
[RecursiveMemoryEffects, SingleBlockImplicitTerminator<"ReturnOp">,
RecursivelySpeculatable, IsolatedFromAbove]> {
let summary = "named computation operation";
let description = [{
Groups a computation, i.e. a block of operations, and gives it a name.
Propagation will flow in/out of the region as if everything was inlined.

This can be used to handle propagating through call instructions to other
functions. Any users of Shardy should write an import/export pass that
converts their call ops to `sdy.named_computation` ops, duplicating/copying
the body of the called function into the body of the `named_computation`.

The type of each block arguments and returned values in the region must be
the same as the type of the operands and results type of the op.

Example:

```mlir
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
```
}];

let arguments = (ins
StrAttr:$name,
Variadic<AnyType>:$operands
);
let results = (outs Variadic<AnyType>);
let regions = (region SizedRegion<1>:$body);

let assemblyFormat = [{
`<`$name`>` `` `(` $operands `)`
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
}];

let hasVerifier = 1;
}

#endif // SDY_OPS
24 changes: 24 additions & 0 deletions shardy/dialect/sdy/ir/test/named_computation_parse_print.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: sdy_opt %s 2>&1 | FileCheck %s

// CHECK-LABEL: func @one_input_output
func.func @one_input_output(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
// CHECK-NEXT: %0 = sdy.named_computation<"foo">(%arg0) (%arg1: tensor<8x2xi32>) {
// CHECK-NEXT: sdy.return %arg1 : tensor<8x2xi32>
// CHECK-NEXT: } : (tensor<8x2xi32>) -> tensor<8x2xi32>
%0 = sdy.named_computation<"foo">(%arg0) (%arg1: tensor<8x2xi32>) {
sdy.return %arg1 : tensor<8x2xi32>
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
return %0 : tensor<8x2xi32>
}


// CHECK-LABEL: func @two_inputs_outputs
func.func @two_inputs_outputs(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) {
// CHECK-NEXT: %0:2 = sdy.named_computation<"named_computation">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) {
//CHECK-NEXT: sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32>
// CHECK-NEXT: } : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
%0:2 = sdy.named_computation<"named_computation">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) {
sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32>
} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32>
}
41 changes: 41 additions & 0 deletions shardy/dialect/sdy/ir/test/named_computation_verification.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: sdy_opt %s -split-input-file -verify-diagnostics

func.func @invalid_operand_type(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
// expected-error@+1 {{expected the type of the 0'th block argument to match the type of the corresponding operand: 'tensor<4x2xi32>' vs 'tensor<8x2xi32>'}}
%0 = sdy.named_computation<"bar">(%arg0) (%arg1: tensor<4x2xi32>) {
%1 = stablehlo.custom_call @foo(%arg1) : (tensor<4x2xi32>) -> tensor<8x2xi32>
sdy.return %1 : tensor<8x2xi32>
}: (tensor<8x2xi32>) -> tensor<8x2xi32>
return %0 : tensor<8x2xi32>
}

// -----

func.func @operand_count_mismatch(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
// expected-error@+1 {{number of block arguments must match the number of operands: 2 != 1}}
%0 = sdy.named_computation<"bar">(%arg0) (%arg1: tensor<8x2xi32>, %arg2: tensor<8x2xi32>) {
sdy.return %arg1 : tensor<8x2xi32>
}: (tensor<8x2xi32>) -> tensor<8x2xi32>
return %0 : tensor<8x2xi32>
}

// -----

func.func @invalid_result_type(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
// expected-error@+1 {{expected the type of the 0'th returned value to match the type of the corresponding result: 'tensor<4x2xi32>' vs 'tensor<8x2xi32>'}}
%0 = sdy.named_computation<"bar">(%arg0) (%arg1: tensor<8x2xi32>) {
%1 = stablehlo.custom_call @foo(%arg1) : (tensor<8x2xi32>) -> tensor<4x2xi32>
sdy.return %1 : tensor<4x2xi32>
}: (tensor<8x2xi32>) -> tensor<8x2xi32>
return %0 : tensor<8x2xi32>
}

// -----

func.func @result_count_mismatch(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
// expected-error@+1 {{number of returned values must match the number of results: 2 != 1}}
%0 = sdy.named_computation<"bar">(%arg0) (%arg1: tensor<8x2xi32>) {
sdy.return %arg1, %arg1 : tensor<8x2xi32>, tensor<8x2xi32>
}: (tensor<8x2xi32>) -> tensor<8x2xi32>
return %0 : tensor<8x2xi32>
}
40 changes: 40 additions & 0 deletions shardy/dialect/sdy/ir/verifiers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,46 @@ LogicalResult PropagationBarrierOp::verify() {
return success();
}

namespace {

LogicalResult AllInnerAndOuterTypesMatchInNamedComputation(
NamedComputationOp op, TypeRange innerTypes, TypeRange outerTypes,
StringRef innerName, StringRef outerName) {
if (innerTypes.size() != outerTypes.size()) {
return op.emitError("number of ")
<< innerName << "s must match the number of " << outerName
<< "s: " << innerTypes.size() << " != " << outerTypes.size();
}

for (auto [i, types] :
llvm::enumerate(llvm::zip_equal(innerTypes, outerTypes))) {
auto [innerType, outerType] = types;
if (innerType != outerType) {
return op.emitError("expected the type of the ")
<< i << "'th " << innerName
<< " to match the type of the corresponding " << outerName << ": "
<< innerType << " vs " << outerType;
}
}

return success();
}

} // namespace

LogicalResult NamedComputationOp::verify() {
if (failed(AllInnerAndOuterTypesMatchInNamedComputation(
*this, getBody().getArgumentTypes(), getOperandTypes(),
"block argument", "operand")) ||
failed(AllInnerAndOuterTypesMatchInNamedComputation(
*this, getBodyTerminatorOpOperandTypes(*this), getResultTypes(),
"returned value", "result"))) {
return failure();
}

return success();
}

LogicalResult SdyDialect::verifyRegionArgAttribute(Operation* op,
unsigned regionIndex,
unsigned argIndex,
Expand Down
1 change: 1 addition & 0 deletions shardy/dialect/sdy/transforms/export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ cc_library(
name = "passes",
srcs = [
"export_pipeline.cc",
"insert_explicit_reshards.cc",
"sharding_constraint_to_reshard.cc",
"sink_data_flow_edges.cc",
"update_non_divisible_input_output_shardings.cc",
Expand Down
1 change: 1 addition & 0 deletions shardy/dialect/sdy/transforms/export/export_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void addExportPipeline(OpPassManager& pm, StringRef dumpDirectory) {
pm.addNestedPass<func::FuncOp>(createShardingConstraintToReshardPass());
pm.addNestedPass<func::FuncOp>(
createUpdateNonDivisibleInputOutputShardingsPass());
pm.addNestedPass<func::FuncOp>(createInsertExplicitReshardsPass());
pm.addPass(mlir::sdy::createSaveModuleOpPass(dumpDirectory,
"sdy_module_after_sdy_export"));
}
Expand Down
42 changes: 42 additions & 0 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2024 The Shardy Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cassert>

#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "shardy/dialect/sdy/ir/dialect.h" // IWYU pragma: keep

namespace mlir {
namespace sdy {

#define GEN_PASS_DEF_INSERTEXPLICITRESHARDSPASS
#include "shardy/dialect/sdy/transforms/export/passes.h.inc"

namespace {

struct InsertExplicitReshardsPass
: public impl::InsertExplicitReshardsPassBase<InsertExplicitReshardsPass> {
using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase;

void runOnOperation() final {
// Not ready yet. It is currently a no-op.
}
};

} // namespace

} // namespace sdy
} // namespace mlir
Loading

0 comments on commit f810ee8

Please sign in to comment.