Skip to content

Commit

Permalink
[aievec] Add support for 0-rank xfer reads
Browse files Browse the repository at this point in the history
This patch solves issue #584

#584
  • Loading branch information
jsetoain committed Aug 17, 2023
1 parent ddb9d83 commit 4dad574
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 15 deletions.
44 changes: 29 additions & 15 deletions lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,32 @@ struct ConvertSplatTransferReadToBroadcastPattern
if (!map.isConstant())
return failure();

// If the innermost index comes from an `affine.apply` op, take the base
// as the new innermost index for the new `vector.transfer_read`, and the
// offset as the index for the `aievec.broadcast` op.
Value srcMemRef = adaptor.getSource();
SmallVector<Value, 8> indices;
indices.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
Value innerMostIdx = indices[indices.size() - 1];
Value newIdx = innerMostIdx;
Value newIdx;
int64_t offset = 0;
if (auto defOp = innerMostIdx.getDefiningOp())
if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
// If it's a zero-rank memory access
if (cast<MemRefType>(srcMemRef.getType()).getRank() == 0) {
srcMemRef = rewriter
.create<memref::ExpandShapeOp>(
readOp.getLoc(), SmallVector<int64_t, 1>({1}),
srcMemRef, SmallVector<ReassociationIndices, 1>({}))
.getResult();
newIdx = rewriter.create<arith::ConstantOp>(readOp.getLoc(),
rewriter.getIndexAttr(0L));
indices.push_back(newIdx);
} else {
indices.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
newIdx = indices[indices.size() - 1];
// If the innermost index comes from an `affine.apply` op, take the base
// as the new innermost index for the new `vector.transfer_read`, and the
// offset as the index for the `aievec.broadcast` op.
if (auto applyOp = newIdx.getDefiningOp<AffineApplyOp>())
if (applyOp.getAffineMap().getNumDims() == 1) {
newIdx = applyOp.getMapOperands()[0];
offset = applyOp.getAffineMap().compose(ArrayRef<int64_t>{0})[0];
}
}
// XXX: We assume we are reading 1D vectors
int64_t vlen = readOp.getVector().getType().getShape()[0];
if (offset >= vlen) {
Expand All @@ -175,8 +187,8 @@ struct ConvertSplatTransferReadToBroadcastPattern
}
indices[indices.size() - 1] = newIdx;
auto newReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), readOp.getVector().getType(), adaptor.getSource(),
indices, adaptor.getPadding());
readOp.getLoc(), readOp.getVector().getType(), srcMemRef, indices,
adaptor.getPadding());
auto extractOp = rewriter.create<vector::ExtractOp>(
readOp.getLoc(), newReadOp.getResult(), ArrayRef<int64_t>{offset});
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
Expand Down Expand Up @@ -232,10 +244,9 @@ struct ComputeExpOpByLUTPattern : public OpConversionPattern<math::ExpOp> {
//============================================================================//
static void
configureCommonAIECanonicalizeLegalizations(ConversionTarget &target) {
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<AffineDialect>();
target.addLegalDialect<aievec::AIEVecDialect>();
target.addLegalDialect<vector::VectorDialect>();
target.addLegalDialect<arith::ArithDialect, AffineDialect,
aievec::AIEVecDialect, memref::MemRefDialect,
vector::VectorDialect>();
}

static void
Expand Down Expand Up @@ -325,11 +336,14 @@ struct CanonicalizeVectorForAIEVecPass
StringRef getArgument() const final {
return "test-canonicalize-vector-for-aievec";
}

StringRef getDescription() const final {
return "Canonicalize vector operations for AIEVec conversion";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, vector::VectorDialect>();
registry.insert<arith::ArithDialect, memref::MemRefDialect,
vector::VectorDialect>();
}

Option<std::string> aieTarget{
Expand Down
22 changes: 22 additions & 0 deletions lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,26 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}

// Print an expand shape by forwarding the value to the next op
static LogicalResult printOperation(CppEmitter &emitter,
memref::ExpandShapeOp expandShapeOp) {
Value source = expandShapeOp.getSrc();

// If the memref being outputted is not already emitted,
// error out
if (!emitter.hasValueInScope(source))
return failure();

if (failed(emitter.emitAssignPrefix(*expandShapeOp)))
return failure();

raw_indented_ostream &os = emitter.ostream();

os << emitter.getOrCreateName(source);

return success();
}

static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
Attribute value) {
OpResult result = operation->getResult(0);
Expand Down Expand Up @@ -2856,6 +2876,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
// Memref ops.
.Case<memref::StoreOp>(
[&](auto op) { return printOperation(*this, op); })
.Case<memref::ExpandShapeOp>(
[&](auto op) { return printOperation(*this, op); })
.Case<aievec::AddOp, aievec::AddElemOp, aievec::ConcatOp,
aievec::ExtOp, aievec::FMAOp, aievec::MulOp, aievec::PackOp,
aievec::SelectOp, aievec::SRSOp, aievec::SubOp,
Expand Down
17 changes: 17 additions & 0 deletions test/dialect/AIEVec/precanonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,20 @@ func.func @unaligned_transfer_read(%m : memref<1024xi32>, %pos : index) -> vecto
// CHECK: return %[[AV]] : vector<8xi32>
return %v : vector<8xi32>
}

// -----

// CHECK-LABEL: func.func @rank_zero_transfer_read(
// CHECK-SAME: %[[MEM:.*]]: memref<i16>
func.func @rank_zero_transfer_read(%m : memref<i16>) -> vector<16xi16> {
%c0_i16 = arith.constant 0 : i16
// CHECK-DAG: %[[C0idx:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0i16:.*]] = arith.constant 0 : i16
// CHECK-DAG: %[[EXPMEM:.*]] = memref.expand_shape %[[MEM]] [] : memref<i16> into memref<1xi16>
// CHECK: %[[LV:.*]] = vector.transfer_read %[[EXPMEM]][%[[C0idx]]], %[[C0i16]] : memref<1xi16>, vector<16xi16>
// CHECK: %[[E:.*]] = vector.extract %[[LV]][0] : vector<16xi16>
// CHECK: %[[S:.*]] = vector.broadcast %[[E]] : i16 to vector<16xi16>
%v = vector.transfer_read %m[], %c0_i16 {permutation_map = affine_map<()->(0)>} : memref<i16>, vector<16xi16>
// CHECK: return %[[S]] : vector<16xi16>
return %v : vector<16xi16>
}

0 comments on commit 4dad574

Please sign in to comment.