Skip to content

Commit

Permalink
Add support of broadcast with vector width = 256 or 1024 and fix TOSA…
Browse files Browse the repository at this point in the history
… tests (#653)

*Add support of broadcast_elem/broadcast_to_vxx for vector width == 256 (e.g. v16bf16) or 1024 (e.g. v32int32).
*Since we lower vector.broadcast op to multiple aievec ops, we have to fix FoldMulAddChainToConv pass to recognize the new aievec.broadcast patterns
*Add the following list of PASS tests for implicit broadcast:
i32xi32_sub_elem_16x1024_broadcast_1
i32xi32_sub_elem_2d_broadcast_1d_unit_dim_v16 (out=i32, lane=16)
i32xi32_sub_elem_2d_broadcast_1d_unit_dim_v32 (out=i32, lane=32)
i32xi32_sub_elem_2d_broadcast_scalar_v16 (out=i32, lane=16)
i32xi32_sub_elem_2d_broadcast_scalar_v32 (out=i32, lane=32)
i32xi32_sub_elem_16x1024_broadcast_1024
i32xi32_sub_elem_2d_broadcast_1d_reshape_v16 (out=i32, lane=16)
i32xi32_sub_elem_2d_broadcast_1d_reshape_v32 (out=i32, lane=32)
i32xi32_sub_elem_2d_broadcast_1d_v16 (out=i32, lane=16)
i32xi32_sub_elem_2d_broadcast_1d_v32 (out=i32, lane=32)
i32xi32_sub_elem_2d_broadcast_2d_v16 (out=i32, lane=16)
i32xi32_sub_elem_2d_broadcast_2d_v32 (out=i32, lane=32)
*Add dut.cc reference for bf16xbf16_sub_elem_16x1024_broadcast_1 tests. The resulting dut.cc is legal, but it's blocked by "broadcast_elem() of v32bfloat16" bug. Hence, the tests are still marked XFAIL.
*Add conversion test coverage for aievec.broadcast and aievec.broadcast_scalar in test_broadcast.mlir
*Fix i8xi16_mul_elem_v32 mlir script
  • Loading branch information
jamestcl-amd authored Sep 25, 2023
1 parent afe87cb commit cd3f907
Show file tree
Hide file tree
Showing 29 changed files with 831 additions and 21 deletions.
67 changes: 58 additions & 9 deletions lib/Dialect/AIEVec/Transforms/FoldMulAddChainToConvOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
#include "aie/Dialect/AIEVec/AIEVecUtils.h"
#include "aie/Dialect/AIEVec/Analysis/Passes.h"
#include "aie/Dialect/AIEVec/IR/AIEVecOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#include <tuple>
#include <utility>

#include "FoldMulAddChainToConvOp.h"

#define DEBUG_TYPE "fold-mul-add-chain-to-conv"

using namespace mlir;
using namespace arith;
using namespace vector;
Expand Down Expand Up @@ -228,13 +232,58 @@ struct LongestConvMACChainAnalysis {
if (!mulOpLhsDefOp || !mulOpRhsDefOp)
return nullptr;

Value convMacRhs = nullptr;
uint8_t convMacBcastIdx = 0;

auto getConvMacRhs = [&](Operation *mulOpOperand) -> bool {
SetVector<Operation *> opBwdSlices;
auto opFilter = [](Operation *op) {
return isa<aievec::BroadcastOp>(op) || isa<aievec::ExtOp>(op) ||
isa<aievec::ConcatOp>(op);
};
BackwardSliceOptions backwardSliceOptions;
backwardSliceOptions.filter = opFilter;

getBackwardSlice(mulOpOperand, &opBwdSlices, backwardSliceOptions);
opBwdSlices.insert(mulOpOperand);

LLVM_DEBUG(llvm::dbgs() << "opBwdSlices = [\n");
for (auto op : opBwdSlices) {
LLVM_DEBUG(llvm::dbgs() << *op << "\n");
}
LLVM_DEBUG(llvm::dbgs() << "]\n");

if (opBwdSlices.size() == 1) {
if (auto bcastOp = dyn_cast<aievec::BroadcastOp>(opBwdSlices[0])) {
convMacRhs = bcastOp.getSource();
convMacBcastIdx = bcastOp.getIdx();
return true;
}
} else if (opBwdSlices.size() >= 3) {
auto sliceSz = opBwdSlices.size();
if ((isa<aievec::ExtOp>(opBwdSlices[sliceSz - 3]) &&
isa<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]) &&
isa<aievec::ConcatOp>(opBwdSlices[sliceSz - 1])) ||
(isa<aievec::ConcatOp>(opBwdSlices[sliceSz - 3]) &&
isa<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]) &&
isa<aievec::ExtOp>(opBwdSlices[sliceSz - 1]))) {
convMacRhs = opBwdSlices[sliceSz - 3]->getOperand(0);
convMacBcastIdx =
dyn_cast<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]).getIdx();
return true;
}
}

return false;
};

// Obtain the broadcast operation feeding into the MulIOp
auto bcastOp = dyn_cast<aievec::BroadcastOp>(mulOpRhsDefOp);
if (!bcastOp) {
bcastOp = dyn_cast<aievec::BroadcastOp>(mulOpLhsDefOp);
std::swap(mulOpLhsDefOp, mulOpRhsDefOp);
if (!getConvMacRhs(mulOpRhsDefOp)) {
if (getConvMacRhs(mulOpLhsDefOp)) {
std::swap(mulOpLhsDefOp, mulOpRhsDefOp);
}
}
if (!bcastOp)
if (!convMacRhs)
return nullptr;

// Obtain the ext or ext->shift op feeding into the MulIOp
Expand All @@ -251,8 +300,7 @@ struct LongestConvMACChainAnalysis {
if (!extOp)
return nullptr;

Value lhs = extOp.getSource();
Value rhs = bcastOp.getSource();
Value convMacLhs = extOp.getSource();
uint8_t shift = 0;
if (shiftOp) {
auto shiftConstDefOp =
Expand All @@ -263,8 +311,9 @@ struct LongestConvMACChainAnalysis {
shift = 8 * shiftAttr.getInt() / getElementSizeInBits(vType);
}
}
uint8_t bcastIdx = bcastOp.getIdx();
return std::make_unique<ConvMac>(lhs, rhs, shift, bcastIdx);

return std::make_unique<ConvMac>(convMacLhs, convMacRhs, shift,
convMacBcastIdx);
}

std::unique_ptr<ConvMac> getConvMacFromAddOp(arith::AddIOp addOp) {
Expand Down
56 changes: 48 additions & 8 deletions lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,41 @@ struct FoldVectorExtractAndBroadcastToAIEBroadcast
rewriter.getI8IntegerAttr(half))
.getResult();
}
rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(bcastOp, resultType, src,
posVal);

unsigned elWidth = resultType.getElementType().getIntOrFloatBitWidth();
unsigned laneSize = getVectorLaneSize(resultType);

if (laneSize * elWidth == 512) {
// Common use case for the broadcast_elem intrinsic
rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(bcastOp, resultType, src,
posVal);
} else if (laneSize * elWidth == 256) {
// e.g. need v16bf16 due to the subsequent v16accfloat operation
VectorType aievecBcastType =
createVectorType(512 / elWidth, resultType.getElementType());
auto concatOp = rewriter.create<aievec::ConcatOp>(
bcastOp.getLoc(), aievecBcastType, SmallVector<Value>({src, src}));
auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
bcastOp.getLoc(), aievecBcastType, concatOp.getResult(), posVal);
rewriter.replaceOpWithNewOp<aievec::ExtOp>(bcastOp, resultType,
aieBcastOp.getResult(), 0);
} else if (laneSize * elWidth == 1024) {
// e.g. need v32int32 due to the subsequent v32acc32 operation
VectorType aievecBcastType =
createVectorType(512 / elWidth, resultType.getElementType());
int8_t half = static_cast<int8_t>(posVal / resultType.getNumElements());
posVal -= half * resultType.getNumElements();
auto extOp =
rewriter.create<aievec::ExtOp>(bcastOp.getLoc(), aievecBcastType, src,
rewriter.getI8IntegerAttr(half));
auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
bcastOp.getLoc(), aievecBcastType, extOp.getResult(), posVal);
rewriter.replaceOpWithNewOp<aievec::ConcatOp>(
bcastOp, resultType,
SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
} else {
return failure();
}

return success();
}
Expand All @@ -491,10 +524,11 @@ struct ConvertBroadcastToAIEBroadcast
matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto extOp =
dyn_cast<vector::ExtractOp>(adaptor.getSource().getDefiningOp());
if (auto extOp = adaptor.getSource().getDefiningOp<vector::ExtractOp>())
return failure();

if (extOp)
// Only support broadcasting a single element for now
if (!isa<IntegerType, IndexType, FloatType>(adaptor.getSource().getType()))
return failure();

VectorType resultType = cast<VectorType>(bcastOp.getResult().getType());
Expand All @@ -507,15 +541,21 @@ struct ConvertBroadcastToAIEBroadcast
rewriter.replaceOpWithNewOp<aievec::BroadcastScalarOp>(bcastOp,
resultType, src);
return success();
}

if (laneSize * elWidth == 256) {
} else if (laneSize * elWidth == 256) {
VectorType vecType = createVectorType(512 / elWidth, scalarType);
auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
bcastOp.getLoc(), vecType, src);
rewriter.replaceOpWithNewOp<aievec::ExtOp>(bcastOp, resultType,
aieBcastOp.getResult(), 0);
return success();
} else if (laneSize * elWidth == 1024) {
VectorType vecType = createVectorType(512 / elWidth, scalarType);
auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
bcastOp.getLoc(), vecType, src);
rewriter.replaceOpWithNewOp<aievec::ConcatOp>(
bcastOp, resultType,
SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
return success();
}

return failure();
Expand Down
48 changes: 46 additions & 2 deletions test/Conversion/VectorToAIEVec/test_broadcast.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: aie-opt %s --convert-vector-to-aievec="aie-target=aieml" | FileCheck %s

// CHECK-LABEL: func @vector_extract_broadcast_to_aievec(
// CHECK-LABEL: func @vector_extract_broadcast_to_aievec_512(
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: vector<16xi32>
func.func @vector_extract_broadcast_to_aievec(%a : vector<16xi32>) -> (vector<16xi32>, vector<16xi32>) {
func.func @vector_extract_broadcast_to_aievec_512(%a : vector<16xi32>) -> (vector<16xi32>, vector<16xi32>) {
// CHECK: aievec.broadcast %[[A]] {idx = 0 : i8} : vector<16xi32>, vector<16xi32>
%0 = vector.extract %a[0] : vector<16xi32>
%1 = vector.broadcast %0 : i32 to vector<16xi32>
Expand All @@ -11,3 +11,47 @@ func.func @vector_extract_broadcast_to_aievec(%a : vector<16xi32>) -> (vector<16
%3 = vector.broadcast %2 : i32 to vector<16xi32>
return %1, %3 : vector<16xi32>, vector<16xi32>
}

// CHECK-LABEL: func @vector_extract_broadcast_to_aievec_256(
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: vector<16xbf16>
func.func @vector_extract_broadcast_to_aievec_256(%a : vector<16xbf16>) -> (vector<16xbf16>, vector<16xbf16>) {
// CHECK: %[[CC1:.*]] = aievec.concat %[[A]], %[[A]] : vector<16xbf16>, vector<32xbf16>
// CHECK: %[[BCAST1:.*]] = aievec.broadcast %[[CC1]] {idx = 0 : i8} : vector<32xbf16>, vector<32xbf16>
// CHECK: %[[EXT1:.*]] = aievec.ext %[[BCAST1]] {index = 0 : i8} : vector<32xbf16>, vector<16xbf16>
%0 = vector.extract %a[0] : vector<16xbf16>
%1 = vector.broadcast %0 : bf16 to vector<16xbf16>
// CHECK: %[[BCAST2:.*]] = aievec.broadcast %[[CC1]] {idx = 2 : i8} : vector<32xbf16>, vector<32xbf16>
// CHECK: %[[EXT2:.*]] = aievec.ext %[[BCAST2]] {index = 0 : i8} : vector<32xbf16>, vector<16xbf16>
%2 = vector.extract %a[2] : vector<16xbf16>
%3 = vector.broadcast %2 : bf16 to vector<16xbf16>
return %1, %3 : vector<16xbf16>, vector<16xbf16>
}

// CHECK-LABEL: func @vector_extract_broadcast_to_aievec_1024(
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: vector<32xi32>
func.func @vector_extract_broadcast_to_aievec_1024(%a : vector<32xi32>) -> (vector<32xi32>, vector<32xi32>) {
// CHECK: %[[EXT1:.*]] = aievec.ext %[[A]] {index = 0 : i8} : vector<32xi32>, vector<16xi32>
// CHECK: %[[BCAST1:.*]] = aievec.broadcast %[[EXT1]] {idx = 0 : i8} : vector<16xi32>, vector<16xi32>
// CHECK: %[[CC1:.*]] = aievec.concat %[[BCAST1]], %[[BCAST1]] : vector<16xi32>, vector<32xi32>
%0 = vector.extract %a[0] : vector<32xi32>
%1 = vector.broadcast %0 : i32 to vector<32xi32>
// CHECK: %[[BCAST2:.*]] = aievec.broadcast %[[EXT1]] {idx = 2 : i8} : vector<16xi32>, vector<16xi32>
// CHECK: %[[CC2:.*]] = aievec.concat %[[BCAST2]], %[[BCAST2]] : vector<16xi32>, vector<32xi32>
%2 = vector.extract %a[2] : vector<32xi32>
%3 = vector.broadcast %2 : i32 to vector<32xi32>
return %1, %3 : vector<32xi32>, vector<32xi32>
}

// CHECK-LABEL: func @vector_broadcast_from_scalar(
func.func @vector_broadcast_from_scalar(%a : i32, %b :bf16) -> (vector<16xi32>, vector<32xi32>, vector<16xbf16>, vector<32xbf16>) {
// CHECK: %[[BCAST1:.*]] = aievec.broadcast_scalar %arg0 : i32, vector<16xi32>
%0 = vector.broadcast %a : i32 to vector<16xi32>
// CHECK: %[[CC:.*]] = aievec.concat %[[BCAST1]], %[[BCAST1]] : vector<16xi32>, vector<32xi32>
%1 = vector.broadcast %a : i32 to vector<32xi32>
// CHECK: %[[BCAST2:.*]] = aievec.broadcast_scalar %arg1 : bf16, vector<32xbf16>
%3 = vector.broadcast %b : bf16 to vector<32xbf16>
// CHECK: %[[EXT1:.*]] = aievec.ext %[[BCAST2]] {index = 0 : i8} : vector<32xbf16>, vector<16xbf16>
%2 = vector.broadcast %b : bf16 to vector<16xbf16>
return %0, %1, %2, %3 : vector<16xi32>, vector<32xi32>, vector<16xbf16>, vector<32xbf16>
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Cycle count: 1111
// clang-format off
void dut(bfloat16 * restrict v1, bfloat16 * restrict v2, bfloat16 * restrict v3) {
size_t v4 = 0;
bfloat16 * restrict v5 = v2;
v16bfloat16 v6 = *(v16bfloat16 *)(v5 + v4+v4);
v32bfloat16 v7 = concat(v6, v6);
v32bfloat16 v8 = broadcast_elem(v7, 0);
v16bfloat16 v9 = extract_v16bfloat16(v8, 0);
v16accfloat v10 = ups_to_v16accfloat(v9);
size_t v11 = 0;
size_t v12 = 16;
size_t v13 = 1;
for (size_t v14 = v11; v14 < v12; v14 += v13)
chess_prepare_for_pipelining
chess_loop_range(16, 16)
{
size_t v15 = 0;
size_t v16 = 1024;
size_t v17 = 16;
for (size_t v18 = v15; v18 < v16; v18 += v17)
chess_prepare_for_pipelining
chess_loop_range(64, 64)
{
v16bfloat16 v19 = *(v16bfloat16 *)(v1 + 1024*v14+v18);
v16accfloat v20 = ups_to_v16accfloat(v19);
v16accfloat v21 = sub(v20, v10);
v16bfloat16 v22 = to_v16bfloat16(v21);
*(v16bfloat16 *)(v3 + 1024*v14+v18) = v22;
}
}
return;
}
// clang-format on
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#pragma once
constexpr unsigned const IN0_SIZE = 16 * 1024;
constexpr unsigned const IN1_SIZE = 1;
constexpr unsigned const OUT0_SIZE = 16 * 1024;
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Cycle count: 2131
// clang-format off
void dut(int32_t * restrict v1, int32_t * restrict v2, int32_t * restrict v3) {
size_t v4 = 0;
int32_t * restrict v5 = v2;
v16int32 v6 = *(v16int32 *)(v5 + v4+v4);
v16int32 v7 = broadcast_elem(v6, 0);
size_t v8 = 0;
size_t v9 = 16;
size_t v10 = 1;
for (size_t v11 = v8; v11 < v9; v11 += v10)
chess_prepare_for_pipelining
chess_loop_range(16, 16)
{
size_t v12 = 0;
size_t v13 = 1024;
size_t v14 = 16;
for (size_t v15 = v12; v15 < v13; v15 += v14)
chess_prepare_for_pipelining
chess_loop_range(64, 64)
{
v16int32 v16 = *(v16int32 *)(v1 + 1024*v11+v15);
v16int32 v17 = sub(v16, v7);
*(v16int32 *)(v3 + 1024*v11+v15) = v17;
}
}
return;
}
// clang-format on
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Copyright (C) 2023, Advanced Micro Devices, Inc.

// REQUIRES: valid_xchess_license
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(tosa-make-broadcastable, tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor))" -o linalg.mlir
// RUN: mlir-opt linalg.mlir --linalg-fuse-elementwise-ops --eliminate-empty-tensors --empty-tensor-to-alloc-tensor --one-shot-bufferize="allow-return-allocs allow-unknown-ops bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" --drop-equivalent-buffer-results --buffer-results-to-out-params --buffer-deallocation --canonicalize --cse --convert-linalg-to-affine-loops --affine-super-vectorize="virtual-vector-size=16" -o affine.mlir
// RUN: aie-opt affine.mlir --convert-vector-to-aievec="aie-target=aieml" -lower-affine -o aievec.mlir
// RUN: aie-translate aievec.mlir -aieml=true --aievec-to-cpp -o dut.cc
// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I. %S/../testbench.cc dut.cc
// RUN: mkdir -p data
// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout
// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s
// CHECK: TEST PASSED

module {
func.func @dut(%arg0: tensor<16x1024xi32>, %arg1: tensor<1xi32>) -> (tensor<16x1024xi32>) {
%0 = "tosa.reshape"(%arg1) { new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> (tensor<1x1xi32>)
%1 = "tosa.sub"(%arg0,%0) : (tensor<16x1024xi32>, tensor<1x1xi32>) -> (tensor<16x1024xi32>)
return %1 : tensor<16x1024xi32>
}
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Cycle count: 2148
// clang-format off
void dut(int32_t * restrict v1, int32_t * restrict v2, int32_t * restrict v3) {
size_t v4 = 0;
int32_t * restrict v5 = v2;
v16int32 v6 = *(v16int32 *)(v5 + v4+v4);
v16int32 v7 = broadcast_elem(v6, 0);
v32int32 v8 = concat(v7, v7);
v32acc32 v9 = v32acc32(v8);
size_t v10 = 0;
size_t v11 = 16;
size_t v12 = 1;
for (size_t v13 = v10; v13 < v11; v13 += v12)
chess_prepare_for_pipelining
chess_loop_range(16, 16)
{
size_t v14 = 0;
size_t v15 = 1024;
size_t v16 = 32;
for (size_t v17 = v14; v17 < v15; v17 += v16)
chess_prepare_for_pipelining
chess_loop_range(32, 32)
{
v32int32 v18 = *(v32int32 *)(v1 + 1024*v13+v17);
v32acc32 v19 = v32acc32(v18);
v32acc32 v20 = sub(v19, v9);
v32int32 v21 = v32int32(v20);
*(v32int32 *)(v3 + 1024*v13+v17) = v21;
}
}
return;
}
// clang-format on
Loading

0 comments on commit cd3f907

Please sign in to comment.