Skip to content

Commit

Permalink
[shardy] Fix a bug for non-minor-most factors in `BasicFactorPropagat…
Browse files Browse the repository at this point in the history
…ion::compatiblePrefixNoConflictsWithinFactor`.

When the following conditions hold
1. factor is open
2. factor has NO existing overflow axes
3. factor is NOT minor-most
4. The axis to be considered is not contained in the existing sharding axes.

Before this cl, we make decisions based on `factorSize % shardedSize == 0`. However, it is possible that we can append a sub-axis with size of `gcd(factorSize / shardedSize, axisSize)`.

For the following reshape example
```
func.func @reshape_merge_dim_major_factor_overflows(%arg0: tensor<4x4xf32> {sdy.sharding = #sdy.sharding<@mesh_a_4_b_2_c_2, [{"b", "a"}, {}]>}) -> tensor<16xf32> {
  %0 = stablehlo.reshape %arg0 : (tensor<4x4xf32>) -> tensor<16xf32>
  return %0 : tensor<16xf32>
}

Result sharding
* before this cl: [<@mesh_a_4_b_2_c_2, [{"b", ?}]>]
* with this cl: [<@mesh_a_4_b_2_c_2, [{"b", "a":(1)2, ?}]>]
```

PiperOrigin-RevId: 682091772
  • Loading branch information
ZixuanJiang authored and copybara-github committed Oct 4, 2024
1 parent c464210 commit f669e68
Show file tree
Hide file tree
Showing 42 changed files with 2,472 additions and 4,221 deletions.
61 changes: 61 additions & 0 deletions docs/sdy_dialect.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,67 @@ Interfaces: `Symbol`
</table>


### `sdy.named_computation` (sdy::NamedComputationOp)

_Named computation operation_


Syntax:

```
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
```

Groups a computation, i.e. a block of operations, and gives it a name.
Propagation will flow in/out of the region as if everything was inlined.

This can be used to handle propagating through call instructions to other
functions. Any users of Shardy should write an import/export pass that
converts their call ops to `sdy.named_computation` ops, duplicating/copying
the body of the called function into the body of the `named_computation`.

The type of each block arguments and returned values in the region must be
the same as the type of the operands and results type of the op.

Example:

```mlir
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
```

Traits: `IsolatedFromAbove`, `RecursiveMemoryEffects`, `RecursivelySpeculatableImplTrait`, `SingleBlockImplicitTerminator<ReturnOp>`, `SingleBlock`

Interfaces: `ConditionallySpeculatable`, `ShardableDataFlowOpInterface`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>in_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
<tr><td><code>out_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `operands` | variadic of any type

#### Results:

| Result | Description |
| :----: | ----------- |
&laquo;unnamed&raquo; | variadic of any type


### `sdy.propagation_barrier` (sdy::PropagationBarrierOp)

_Propagation barrier operation_
Expand Down
49 changes: 49 additions & 0 deletions docs/sdy_export_passes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,53 @@
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
### `-sdy-insert-explicit-reshards`

_Inserts explicit reshards to make all operations have compatible shardings._

A compatible sharding essentially means that the operation can accept the
sharded operands and produce a sharded result without requiring any reshard
communications (note that the operation might still require communication
such as all-reduce or halo-swaps).

After propagation, some opeartions may still have incompatible shardings.

Please note, when an axis (or sub-axis) is used to shard non-corresponding
dimensions (e.g. non-contracting dimensions in matmul) across multiple
tensors, or when an axis shards a dimension in one tensor but not the
corresponding dimension in the other tensor, it is said that the operation
has a sharding conflict. Hence, after this pass, the opeartions become
conflict-free.

This pass injects reshard operations explicitly so that, for each operation,
corresponding dimensions become sharded in the same way across all operands
and results, and every axis (or sub-axis) can only be used to shard a single
dimension type.

A clarifying example:

Input:
```mlir
mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

Output:
```mlir
sdy.mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32>
stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

In the example above, there is a conflict since `lhs` and `rhs` tensors
are both sharded on axis "x" on their non-contracting dimensions. Here,
`rhs` tensor is resharded, before the dot operation, explicitly to be
sharded only on its first dimension and on axis "x". This way, the dot
opearation becomes compatible.
### `-sdy-sharding-constraint-to-reshard`

_Converts ShardingConstraintOp into ReshardOp._
Expand Down
41 changes: 35 additions & 6 deletions docs/sdy_op_interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ NOTE: This method *must* be implemented by the user.
```c++
void setBlockArgumentEdgeOwnerSharding(unsigned index, mlir::sdy::TensorShardingAttr sharding);
```
Sets the shardings of the block argument edge owner with the given index.
Sets the `sharding` of the block argument edge owner with the given
`index`.
NOTE: This method *must* be implemented by the user.
Expand All @@ -66,7 +67,34 @@ NOTE: This method *must* be implemented by the user.
```c++
void setBlockArgumentEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);
```
Sets shardings of all block argument edge owners.
Sets `shardings` of all block argument edge owners.

NOTE: This method *must* be implemented by the user.

#### `getOpResultEdgeOwnerShardings`

```c++
mlir::ArrayRef<mlir::sdy::TensorShardingAttr> getOpResultEdgeOwnerShardings();
```
Returns the shardings of all op result data flow edge owners.

NOTE: This method *must* be implemented by the user.

#### `setOpResultEdgeOwnerSharding`

```c++
void setOpResultEdgeOwnerSharding(unsigned index, mlir::sdy::TensorShardingAttr sharding);
```
Sets the `sharding` of the op result edge owner with the given `index`.
NOTE: This method *must* be implemented by the user.
#### `setOpResultEdgeOwnerShardings`
```c++
void setOpResultEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);
```
Sets `shardings` of all op result edge owners.

NOTE: This method *must* be implemented by the user.

Expand All @@ -93,7 +121,7 @@ NOTE: This method *must* be implemented by the user.
```c++
mlir::SmallVector<mlir::Value> getEdgeSources(mlir::Value target);
```
Gets the data flow edge sources given a target value.
Gets the data flow edge sources given a `target` value.
NOTE: This method *must* be implemented by the user.
Expand All @@ -102,7 +130,8 @@ NOTE: This method *must* be implemented by the user.
```c++
mlir::Value getEdgeOwnerFromTarget(mlir::Value target);
```
Gets the owner target of a data flow edge given a target that may or may not be the owner.
Gets the owner `target` of a data flow edge given a `target` that may or
may not be the owner.

NOTE: This method *must* be implemented by the user.

Expand All @@ -111,7 +140,7 @@ NOTE: This method *must* be implemented by the user.
```c++
mlir::Value getEdgeOwnerFromSource(mlir::OpOperand&source);
```
Gets the owner target of a data flow edge given a source.
Gets the owner target of a data flow edge given a `source`.
NOTE: This method *must* be implemented by the user.
Expand All @@ -120,7 +149,7 @@ NOTE: This method *must* be implemented by the user.
```c++
mlir::SmallVector<mlir::Value> getNonEdgeOwnerTargets(mlir::Value owner);
```
Gets the non-owner targets of a data flow edge given the edge owner.
Gets the non-owner targets of a data flow edge given the edge `owner`.

NOTE: This method *must* be implemented by the user.

Expand Down
10 changes: 8 additions & 2 deletions shardy/dialect/sdy/ir/attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,17 @@ def Sdy_TensorShardingPerValue : AttrDef<Sdy_Dialect, "TensorShardingPerValue">
let assemblyFormat = "`<` `[` (`]`):($shardings^ `]`)? `>`";

let extraClassDeclaration = [{
// Builds a `TensorShardingPerValue` for each type in `types`, with all
// dimension shardings marked open (can be further replicated/sharded).
// Builds a `TensorSharding` for each type in `types`, with all dimension
// shardings marked open (can be further replicated/sharded).
static TensorShardingPerValueAttr getFullyOpen(
MLIRContext* context, TypeRange types, StringRef meshName);

// Builds an open `TensorSharding` for each type in `types`, but
// with the sharding at `index` replaced with `sharding`.
static TensorShardingPerValueAttr getOpenWithShardingAtIndex(
MLIRContext* context, TypeRange types, int64_t index,
TensorShardingAttr sharding);

// Returns whether there are no values.
bool empty() const { return getShardings().empty(); }

Expand Down
8 changes: 8 additions & 0 deletions shardy/dialect/sdy/ir/data_flow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ void setBlockArgumentEdgeOwnerShardings(
shardings);
}

void setOpResultEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings) {
if (auto shardableDataFlowOp = dyn_cast<ShardableDataFlowOpInterface>(op)) {
return shardableDataFlowOp.setOpResultEdgeOwnerShardings(shardings);
}
setShardings(op, shardings);
}

DataFlowEdgeOp getDataFlowEdge(Value target) {
return DataFlowEdgeOp::getDataFlowEdgeUser(getDataFlowEdgeOwner(target));
}
Expand Down
7 changes: 6 additions & 1 deletion shardy/dialect/sdy/ir/data_flow_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,16 @@ SmallVector<Value> getDataFlowSources(DataFlowEdgeOp dataFlowEdge);
// Returns all non-edge-owner targets of the given `dataFlowEdge`.
SmallVector<Value> getNonEdgeOwnerTargets(DataFlowEdgeOp dataFlowEdge);

// Sets the block argument edge owner shardings if the `op` is a
// Sets the block argument edge owner `shardings` if the `op` is a
// `ShardableDataFlowOpInterface`.
void setBlockArgumentEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings);

// Sets the op result edge owner `shardings` if the `op` is a
// `ShardableDataFlowOpInterface`.
void setOpResultEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings);

} // namespace sdy
} // namespace mlir

Expand Down
Loading

0 comments on commit f669e68

Please sign in to comment.