Skip to content

Commit

Permalink
[Optimizer] Resharding ON by default. (#1163)
Browse files Browse the repository at this point in the history
  • Loading branch information
nobradovictt authored Nov 6, 2024
1 parent 45b7ee4 commit f91d41e
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 37 deletions.
14 changes: 7 additions & 7 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def TT_GridAttr : TT_Attr<"Grid", "grid"> {
return GridAttr::get(context, SmallVector<std::int64_t>(rank, 1));
}

uint64_t mutable cNumUsedCores = 0;
uint64_t getNumUsedCores() const {
if (cNumUsedCores != 0) {
return cNumUsedCores;
uint64_t mutable cGridVolume = 0;
uint64_t getGridVolume() const {
if (cGridVolume != 0) {
return cGridVolume;
}

cNumUsedCores = 1;
cGridVolume = 1;
for (int64_t dim : getShape()) {
cNumUsedCores *= dim;
cGridVolume *= dim;
}
return cNumUsedCores;
return cGridVolume;
}
}];
}
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ class ShardSolver {
const std::unordered_set<Edge> &overrideReshardEdges);
RemainingLayoutAttrs at(Operation *operation) const;
void set(Operation *operation, tt::LayoutAttr const &layout);
static bool supportsInterleavedInputShardedOutput(Operation *op);

private:
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> *legalLayouts;
Expand Down
6 changes: 2 additions & 4 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,8 @@ struct TTIRToTTNNBackendPipelineOptions
//
Option<bool> memReconfigEnabled{
*this, "memreconfig-enabled",
llvm::cl::desc("Memory layout reconfiguration pass. Temp disabled till "
"we support all types "
"of shard specs."),
llvm::cl::init(false)};
llvm::cl::desc("Memory layout reconfiguration pass."),
llvm::cl::init(true)};

// Specify policy for memory layout analysis.
//
Expand Down
6 changes: 2 additions & 4 deletions include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,8 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
::llvm::cl::init(false)};
::mlir::Pass::Option<bool> memReconfigEnabled{
*this, "memreconfig-enabled",
::llvm::cl::desc("Memory layout reconfiguration pass. Temp disabled till "
"we support all "
"types of shard specs."),
::llvm::cl::init(false)};
::llvm::cl::desc("Memory layout reconfiguration pass."),
::llvm::cl::init(true)};
::mlir::Pass::Option<mlir::tt::MemoryLayoutAnalysisPolicyType,
mlir::tt::MemoryLayoutAnalysisPolicyTypeParser>
memoryLayoutAnalysisPolicy{
Expand Down
22 changes: 13 additions & 9 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,19 @@ void DFShardingPolicy::run() {

if (l1UsageValid) {
// TODO(nobradovic)
// It seems that bunch of TTNN ops have constraints which prevent
// It seems that some TTNN ops have constraints which prevent
// them from being sharded if both inputs are interleaved,
// so proposal for now is starting a shard chain
// with reshard op(at later phase only when necessary based on op
// type) For this reason we also need to validate that currentOp
// can fit into L1 with its first input sharded.
// with reshard op. For this reason we also need to validate that
// currentOp can fit into L1 with its first input sharded.
//
bool firstInputL1UsageValid = true;
if (l1ChainConfigs->back().isEmpty()) {
if (l1ChainConfigs->back().isEmpty() &&
(!ShardSolver::supportsInterleavedInputShardedOutput(
currentOp) ||
overrideReshardEdges.count(
Edge(currentOp->getOperand(0).getDefiningOp(), currentOp,
0)) > 0)) {
RankedTensorType firstOpInputTensorType =
mlir::cast<RankedTensorType>(currentOp->getOperand(0)
.getDefiningOp()
Expand Down Expand Up @@ -212,11 +216,11 @@ void DFShardingPolicy::pickOpShardLayouts(ShardSolver &shardSolver,
const tt::LayoutAttr *selectedLayout = &(*validLayouts.begin());
for (const tt::LayoutAttr &layout : validLayouts) {

if (layout.getGrid().getNumUsedCores() >
selectedLayout->getGrid().getNumUsedCores()) {
if (layout.getGrid().getGridVolume() >
selectedLayout->getGrid().getGridVolume()) {
selectedLayout = &layout;
} else if (layout.getGrid().getNumUsedCores() ==
selectedLayout->getGrid().getNumUsedCores()) {
} else if (layout.getGrid().getGridVolume() ==
selectedLayout->getGrid().getGridVolume()) {
if (layout.getMemLayout() != tt::TensorMemoryLayout::BlockSharded) {
selectedLayout = &layout;
}
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ void LegalGridAnalysis::analysisImplementation() {
// Pick top largest sharded grids.
std::sort(shardedResults.begin(), shardedResults.end(),
[](tt::LayoutAttr a, tt::LayoutAttr b) {
return a.getGrid().getNumUsedCores() >
b.getGrid().getNumUsedCores();
return a.getGrid().getGridVolume() > b.getGrid().getGridVolume();
});

analysisResult.insert(
Expand Down
27 changes: 16 additions & 11 deletions lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ ShardSolver::ShardSolver(
const unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges)
: legalLayouts(&legalLayouts), shardSpecs(&shardSpecs),
shardedOps(&shardedOps), usableL1CacheSize(usableL1CacheSize) {
shardedOps(&shardedOps), usableL1CacheSize(usableL1CacheSize),
memReconfigEdges(overrideReshardEdges) {
pathSets.reserve(shardSpecs.size());
pathSetIds.reserve(shardSpecs.size());
bitsets.reserve(shardedOps.size());
Expand All @@ -46,12 +47,6 @@ ShardSolver::ShardSolver(
}
}

// Insert override resharding edges
//
for (const Edge &edge : overrideReshardEdges) {
insertReshard(edge);
}

// Resolve shard chain.
//
resolve();
Expand Down Expand Up @@ -181,17 +176,27 @@ bool ShardSolver::resolveStep() {
return true;
}

bool ShardSolver::supportsInterleavedInputShardedOutput(Operation *op) {
// TODO(nobradovic,mbezulj): Add check whether this op type can have sharded
// output from interleaved inputs. For now assuming it can.
//
return true;
}

// We need to check if first op requires sharded inputs and if so, insert
// reshard edge, then invalidate all sharding options which would go above L1
// size limits.
//
void ShardSolver::preprocessFirstOp() {
// TODO(nobradovic): Add check whether this op type can have sharded output
// from interleaved inputs. For now assuming it can not.
//
Operation *firstOp = shardSpecs->front().op;
if (supportsInterleavedInputShardedOutput(firstOp) &&
memReconfigEdges.count(
Edge(firstOp->getOperand(0).getDefiningOp(), firstOp, 0)) == 0) {
return;
}

// Insert reshard edge for the first op to start the chain.
//
Operation *firstOp = shardSpecs->front().op;
Edge shardChainInputEdge =
Edge(firstOp->getOperand(0).getDefiningOp(), firstOp, 0 /*operandIndex*/);

Expand Down

0 comments on commit f91d41e

Please sign in to comment.