Skip to content

Commit

Permalink
Fix redundant traspose in TCtoTTGT pass (#36)
Browse files Browse the repository at this point in the history
This commit partially addresses issue #29.

  When one or more tensors are transposed before a contraction (C=A*B), the shape of the contraction result
  might be different than the expected output C. Thus this intermediate results is stored in a temporary tensor,
  which will later be tranposed in the C, taking the expected shape.
  However, instead of just allocating this intermediate tensor we would also transpose tensor C into it, making
  a redundant trasposition. This commit removes this extra transposition and maintains it only for the case of C+=,-= 
  A*B.
  • Loading branch information
pthomadakis committed Oct 22, 2023
1 parent e1010e2 commit 37f8cd9
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,17 @@ namespace
MemRefType::get(lhsDims, lhsMemrefType.getElementType()), loc,
rewriter);
useLHSTranspose = true;
// TODO(gkestor): we might need this copy if we support update C[] += A[] * B[]
rewriter.create<linalg::TransposeOp>(loc, lhsMemref, lhsAlloc, llvm::ArrayRef<int64_t>(lhsOutPerm_int64));
double beta_val = betaAttr.cast<FloatAttr>().getValueAsDouble();

if(beta_val == 0)
{
Value constantOp = rewriter.create<ConstantOp>(loc, rewriter.getF64FloatAttr(0.0));
rewriter.create<linalg::FillOp>(loc, constantOp, lhsAlloc);
}
else
{
rewriter.create<linalg::TransposeOp>(loc, lhsMemref, lhsAlloc, llvm::ArrayRef<int64_t>(lhsOutPerm_int64));
}
}

RankedTensorType collapsedTensorType;
Expand Down

0 comments on commit 37f8cd9

Please sign in to comment.