Skip to content

Commit

Permalink
Implement tanh(x) based on linear approximation lookup tables (#639)
Browse files Browse the repository at this point in the history
  • Loading branch information
linay-xsj authored Sep 18, 2023
1 parent f90379d commit bb7653e
Show file tree
Hide file tree
Showing 11 changed files with 384 additions and 6 deletions.
12 changes: 12 additions & 0 deletions aie_runtime_lib/AIE/tanh.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//===--- tanh.cpp - tanh loopup tables ---===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023 Advanced Micro Devices, Inc.
//
//
//===----------------------------------------------------------------------===//
// These are tanh lookup tables for bfloat16 type
//===----------------------------------------------------------------------===//
19 changes: 19 additions & 0 deletions aie_runtime_lib/AIE/tanh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===- tanh.h - get hyperbolic tangent values based on linear approximation
//-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023 Advanced Micro Devices, Inc.
//
//
//===----------------------------------------------------------------------===//
// This is the implementation of compute hyperbolic tangent values based on
// linear approximation
//===----------------------------------------------------------------------===//

#ifndef __TANH__
#define __TANH__

#endif //__TANH__
148 changes: 148 additions & 0 deletions aie_runtime_lib/AIE2/tanh.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
//===--- tanh.cpp - tanh loopup tables ---===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023 Advanced Micro Devices, Inc.
//
//
//===----------------------------------------------------------------------===//
// These are tanh lookup tables for bfloat16 type
//===----------------------------------------------------------------------===//

// Divides into 32 segments between [-4,4], bank size: (32*2*2*4)*2=1k, one
// lut=512B
float chess_storage(% chess_alignof(v32int8)) tanh_lut_ab[128] = {
0.00000000000000000000000000000000, -1.00000000000000000000000000000000,
0.00283813476562500000000000000000, -0.98828125000000000000000000000000,
0.00000000000000000000000000000000, -1.00000000000000000000000000000000,
0.00283813476562500000000000000000, -0.98828125000000000000000000000000,
0.00509643554687500000000000000000, -0.98046875000000000000000000000000,
0.00750732421875000000000000000000, -0.97265625000000000000000000000000,
0.00509643554687500000000000000000, -0.98046875000000000000000000000000,
0.00750732421875000000000000000000, -0.97265625000000000000000000000000,
0.01269531250000000000000000000000, -0.95703125000000000000000000000000,
0.02124023437500000000000000000000, -0.93359375000000000000000000000000,
0.01269531250000000000000000000000, -0.95703125000000000000000000000000,
0.02124023437500000000000000000000, -0.93359375000000000000000000000000,
0.03540039062500000000000000000000, -0.89843750000000000000000000000000,
0.05639648437500000000000000000000, -0.85156250000000000000000000000000,
0.03540039062500000000000000000000, -0.89843750000000000000000000000000,
0.05639648437500000000000000000000, -0.85156250000000000000000000000000,
0.09179687500000000000000000000000, -0.78125000000000000000000000000000,
0.14550781250000000000000000000000, -0.68750000000000000000000000000000,
0.09179687500000000000000000000000, -0.78125000000000000000000000000000,
0.14550781250000000000000000000000, -0.68750000000000000000000000000000,
0.22949218750000000000000000000000, -0.56250000000000000000000000000000,
0.34765625000000000000000000000000, -0.41601562500000000000000000000000,
0.22949218750000000000000000000000, -0.56250000000000000000000000000000,
0.34765625000000000000000000000000, -0.41601562500000000000000000000000,
0.50390625000000000000000000000000, -0.25976562500000000000000000000000,
0.69140625000000000000000000000000, -0.11962890625000000000000000000000,
0.50390625000000000000000000000000, -0.25976562500000000000000000000000,
0.69140625000000000000000000000000, -0.11962890625000000000000000000000,
0.86718750000000000000000000000000, -0.03076171875000000000000000000000,
1.00000000000000000000000000000000, 0.00000000000000000000000000000000,
0.86718750000000000000000000000000, -0.03076171875000000000000000000000,
1.00000000000000000000000000000000, 0.00000000000000000000000000000000,
1.00000000000000000000000000000000, 0.00000000000000000000000000000000,
0.86718750000000000000000000000000, 0.03076171875000000000000000000000,
1.00000000000000000000000000000000, 0.00000000000000000000000000000000,
0.86718750000000000000000000000000, 0.03076171875000000000000000000000,
0.69140625000000000000000000000000, 0.11962890625000000000000000000000,
0.50390625000000000000000000000000, 0.25976562500000000000000000000000,
0.69140625000000000000000000000000, 0.11962890625000000000000000000000,
0.50390625000000000000000000000000, 0.25976562500000000000000000000000,
0.34765625000000000000000000000000, 0.41601562500000000000000000000000,
0.22949218750000000000000000000000, 0.56250000000000000000000000000000,
0.34765625000000000000000000000000, 0.41601562500000000000000000000000,
0.22949218750000000000000000000000, 0.56250000000000000000000000000000,
0.14550781250000000000000000000000, 0.68750000000000000000000000000000,
0.09179687500000000000000000000000, 0.78125000000000000000000000000000,
0.14550781250000000000000000000000, 0.68750000000000000000000000000000,
0.09179687500000000000000000000000, 0.78125000000000000000000000000000,
0.05639648437500000000000000000000, 0.85156250000000000000000000000000,
0.03540039062500000000000000000000, 0.89843750000000000000000000000000,
0.05639648437500000000000000000000, 0.85156250000000000000000000000000,
0.03540039062500000000000000000000, 0.89843750000000000000000000000000,
0.02124023437500000000000000000000, 0.93359375000000000000000000000000,
0.01269531250000000000000000000000, 0.95703125000000000000000000000000,
0.02124023437500000000000000000000, 0.93359375000000000000000000000000,
0.01269531250000000000000000000000, 0.95703125000000000000000000000000,
0.00750732421875000000000000000000, 0.97265625000000000000000000000000,
0.00509643554687500000000000000000, 0.98046875000000000000000000000000,
0.00750732421875000000000000000000, 0.97265625000000000000000000000000,
0.00509643554687500000000000000000, 0.98046875000000000000000000000000,
0.00283813476562500000000000000000, 0.98828125000000000000000000000000,
0.00000000000000000000000000000000, 1.00000000000000000000000000000000,
0.00283813476562500000000000000000, 0.98828125000000000000000000000000,
0.00000000000000000000000000000000, 1.00000000000000000000000000000000,
};

float chess_storage(% chess_alignof(v32int8)) tanh_lut_cd[128] = {
0.00000000000000000000000000000000, -1.00000000000000000000000000000000,
0.00283813476562500000000000000000, -0.98828125000000000000000000000000,
0.00000000000000000000000000000000, -1.00000000000000000000000000000000,
0.00283813476562500000000000000000, -0.98828125000000000000000000000000,
0.00509643554687500000000000000000, -0.98046875000000000000000000000000,
0.00750732421875000000000000000000, -0.97265625000000000000000000000000,
0.00509643554687500000000000000000, -0.98046875000000000000000000000000,
0.00750732421875000000000000000000, -0.97265625000000000000000000000000,
0.01269531250000000000000000000000, -0.95703125000000000000000000000000,
0.02124023437500000000000000000000, -0.93359375000000000000000000000000,
0.01269531250000000000000000000000, -0.95703125000000000000000000000000,
0.02124023437500000000000000000000, -0.93359375000000000000000000000000,
0.03540039062500000000000000000000, -0.89843750000000000000000000000000,
0.05639648437500000000000000000000, -0.85156250000000000000000000000000,
0.03540039062500000000000000000000, -0.89843750000000000000000000000000,
0.05639648437500000000000000000000, -0.85156250000000000000000000000000,
0.09179687500000000000000000000000, -0.78125000000000000000000000000000,
0.14550781250000000000000000000000, -0.68750000000000000000000000000000,
0.09179687500000000000000000000000, -0.78125000000000000000000000000000,
0.14550781250000000000000000000000, -0.68750000000000000000000000000000,
0.22949218750000000000000000000000, -0.56250000000000000000000000000000,
0.34765625000000000000000000000000, -0.41601562500000000000000000000000,
0.22949218750000000000000000000000, -0.56250000000000000000000000000000,
0.34765625000000000000000000000000, -0.41601562500000000000000000000000,
0.50390625000000000000000000000000, -0.25976562500000000000000000000000,
0.69140625000000000000000000000000, -0.11962890625000000000000000000000,
0.50390625000000000000000000000000, -0.25976562500000000000000000000000,
0.69140625000000000000000000000000, -0.11962890625000000000000000000000,
0.86718750000000000000000000000000, -0.03076171875000000000000000000000,
1.00000000000000000000000000000000, 0.00000000000000000000000000000000,
0.86718750000000000000000000000000, -0.03076171875000000000000000000000,
1.00000000000000000000000000000000, 0.00000000000000000000000000000000,
1.00000000000000000000000000000000, 0.00000000000000000000000000000000,
0.86718750000000000000000000000000, 0.03076171875000000000000000000000,
1.00000000000000000000000000000000, 0.00000000000000000000000000000000,
0.86718750000000000000000000000000, 0.03076171875000000000000000000000,
0.69140625000000000000000000000000, 0.11962890625000000000000000000000,
0.50390625000000000000000000000000, 0.25976562500000000000000000000000,
0.69140625000000000000000000000000, 0.11962890625000000000000000000000,
0.50390625000000000000000000000000, 0.25976562500000000000000000000000,
0.34765625000000000000000000000000, 0.41601562500000000000000000000000,
0.22949218750000000000000000000000, 0.56250000000000000000000000000000,
0.34765625000000000000000000000000, 0.41601562500000000000000000000000,
0.22949218750000000000000000000000, 0.56250000000000000000000000000000,
0.14550781250000000000000000000000, 0.68750000000000000000000000000000,
0.09179687500000000000000000000000, 0.78125000000000000000000000000000,
0.14550781250000000000000000000000, 0.68750000000000000000000000000000,
0.09179687500000000000000000000000, 0.78125000000000000000000000000000,
0.05639648437500000000000000000000, 0.85156250000000000000000000000000,
0.03540039062500000000000000000000, 0.89843750000000000000000000000000,
0.05639648437500000000000000000000, 0.85156250000000000000000000000000,
0.03540039062500000000000000000000, 0.89843750000000000000000000000000,
0.02124023437500000000000000000000, 0.93359375000000000000000000000000,
0.01269531250000000000000000000000, 0.95703125000000000000000000000000,
0.02124023437500000000000000000000, 0.93359375000000000000000000000000,
0.01269531250000000000000000000000, 0.95703125000000000000000000000000,
0.00750732421875000000000000000000, 0.97265625000000000000000000000000,
0.00509643554687500000000000000000, 0.98046875000000000000000000000000,
0.00750732421875000000000000000000, 0.97265625000000000000000000000000,
0.00509643554687500000000000000000, 0.98046875000000000000000000000000,
0.00283813476562500000000000000000, 0.98828125000000000000000000000000,
0.00000000000000000000000000000000, 1.00000000000000000000000000000000,
0.00283813476562500000000000000000, 0.98828125000000000000000000000000,
0.00000000000000000000000000000000, 1.00000000000000000000000000000000,
};
49 changes: 49 additions & 0 deletions aie_runtime_lib/AIE2/tanh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===- tanh.h - get hyperbolic tangent values based on linear approximation
//-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023 Advanced Micro Devices, Inc.
//
//
//===----------------------------------------------------------------------===//
// This is the implementation of compute hyperbolic tangent values based on
// linear approximation
//===----------------------------------------------------------------------===//

#ifndef __TANH__
#define __TANH__

#include "aie_api/aie.hpp"
#include <aie_api/aie_adf.hpp>
#include <aie_api/utils.hpp>

extern float tanh_lut_ab[];
extern float tanh_lut_cd[];

v16bfloat16 __attribute__((always_inline)) getTanhBf16(v16bfloat16 vInput) {
aie::vector<bfloat16, 16> input = vInput;

int step_bits = -2;
int bias = 16;
int data_size = 16;
int LUT_elems = 32;
int shift_offset = 0; // unused

using lut_type = aie::lut<4, float, bfloat16>;

lut_type test_lut(LUT_elems, (bfloat16 *)tanh_lut_ab,
(bfloat16 *)tanh_lut_cd);

aie::linear_approx<bfloat16, lut_type> lin_aprox(test_lut, step_bits, bias,
shift_offset);

aie::vector<bfloat16, 16> output =
lin_aprox.compute(input).to_vector<bfloat16>();

return (v16bfloat16)output;
}

#endif //__TANH__
4 changes: 3 additions & 1 deletion aie_runtime_lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ function(add_aie_runtime_libs arch)
set(INSTALLS
chess_intrinsic_wrapper.cpp
lut_based_ops.cpp
lut_based_ops.h)
lut_based_ops.h
tanh.h
tanh.cpp)

foreach(file ${INSTALLS})
add_custom_target(aie-copy-${arch}-runtime-libs-${file} ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${file})
Expand Down
56 changes: 54 additions & 2 deletions lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1791,14 +1791,49 @@ struct ComputeInvOpByLUTPattern : public OpConversionPattern<arith::DivFOp> {
arith::TruncFOp truncOp = cast<arith::TruncFOp>(*divOp->getUsers().begin());

rewriter.setInsertionPoint(truncOp);
auto funcOp = rewriter.replaceOpWithNewOp<emitc::CallOp>(
rewriter.replaceOpWithNewOp<emitc::CallOp>(
truncOp, TypeRange{truncOp.getResult().getType()}, "getInvBf16",
nullptr, nullptr, invOperands);
rewriter.eraseOp(divOp);
moduleOp = funcOp->getParentOfType<mlir::ModuleOp>();
return success();
}
};

// Convert math.tanh to a function call to compute tanh(x) by look up tables
struct ComputeTanhOpByLUTPattern : public OpConversionPattern<math::TanhOp> {
using OpConversionPattern<math::TanhOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
Type scalarType = srcType.getElementType();
if (!srcType || !isa<FloatType>(scalarType)) {
return failure();
}

unsigned laneSize = getVectorLaneSize(srcType);
unsigned elWidth = scalarType.getIntOrFloatBitWidth();

if (elWidth != 16 || laneSize != 16) {
return failure();
}

StringRef includeName = "tanh.h";
ModuleOp moduleOp = tanhOp->getParentOfType<mlir::ModuleOp>();
rewriter.setInsertionPointToStart(
&moduleOp.getRegion().getBlocks().front());
rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);

rewriter.setInsertionPoint(tanhOp);
SmallVector<Value> tanhOperands = {adaptor.getOperand()};
rewriter.replaceOpWithNewOp<emitc::CallOp>(
tanhOp, TypeRange{tanhOp.getResult().getType()}, "getTanhBf16", nullptr,
nullptr, tanhOperands);
return success();
}
};

//===----------------------------------------------------------------------===//
// Pattern collection
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1830,6 +1865,7 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
LowerVectorSubIOpToAIEVecSubElemOp,
ComputeExpOpByLUTPattern,
ComputeInvOpByLUTPattern,
ComputeTanhOpByLUTPattern,
ConvertMulIToAIEVecMulElemOpPattern,
LowerVectorAddFOpToAIEVecAddElemOp,
LowerVectorSubFOpToAIEVecSubElemOp,
Expand Down Expand Up @@ -1903,6 +1939,22 @@ static void configureAIEVecCommonLegalizations(ConversionTarget &target,
return false;
});

target.addDynamicallyLegalOp<math::TanhOp>([](math::TanhOp tanhOp) {
VectorType srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
Type scalarType = srcType.getElementType();
if (!srcType || !isa<FloatType>(scalarType)) {
return true;
}

unsigned laneSize = getVectorLaneSize(srcType);
unsigned elWidth = scalarType.getIntOrFloatBitWidth();
if (elWidth != 16 || laneSize != 16) {
return true;
}

return false;
});

target.addDynamicallyLegalOp<arith::AddIOp>(
[](arith::AddIOp op) { return !isa<VectorType>(op.getType()); });
target.addDynamicallyLegalOp<arith::AddFOp>(
Expand Down
10 changes: 7 additions & 3 deletions lib/Targets/AIEVecToCpp/TranslateAIEVecToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2045,9 +2045,13 @@ static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *callOp.getOperation();

if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ true)))
return failure();
if (callOp.getCallee() == "getTanhBf16") {
if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ false)))
return failure();
} else {
if (failed(emitter.emitAssignPrefix(op, /*isAcc*/ true)))
return failure();
}
os << callOp.getCallee();

auto emitArgs = [&](Attribute attr) -> LogicalResult {
Expand Down
16 changes: 16 additions & 0 deletions test/unit_tests/aievec_tests/bf16_tanh/bf16_tanh.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// REQUIRES: valid_xchess_license
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg))" -o linalg.mlir
// RUN: mlir-opt linalg.mlir --linalg-fuse-elementwise-ops --eliminate-empty-tensors --empty-tensor-to-alloc-tensor --one-shot-bufferize="allow-return-allocs allow-unknown-ops bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" --drop-equivalent-buffer-results --buffer-results-to-out-params --buffer-deallocation --canonicalize --cse --convert-linalg-to-affine-loops --affine-super-vectorize="virtual-vector-size=16" -o affine.mlir
// RUN: aie-opt affine.mlir --convert-vector-to-aievec="aie-target=aieml" -lower-affine -o aievec.mlir
// RUN: aie-translate aievec.mlir -aieml=true --aievec-to-cpp -o dut.cc
// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I%aie_runtime_lib%/AIE2 %aie_runtime_lib%/AIE2/tanh.cpp -I %aietools/include -D__AIEARCH__=20 -D__AIENGINE__ -I. %S/testbench.cc dut.cc
// RUN: mkdir -p data
// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout
// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s
// CHECK: TEST PASSED
// Cycle count: 807

func.func @dut(%arg0: tensor<1024xbf16>) -> (tensor<1024xbf16>) {
%0 = "tosa.tanh"(%arg0) : (tensor<1024xbf16>) -> tensor<1024xbf16>
return %0 : tensor<1024xbf16>
}
3 changes: 3 additions & 0 deletions test/unit_tests/aievec_tests/bf16_tanh/defines.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#pragma once
constexpr unsigned const IN0_SIZE = 1024;
constexpr unsigned const OUT0_SIZE = 1024;
13 changes: 13 additions & 0 deletions test/unit_tests/aievec_tests/bf16_tanh/dut.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "tanh.h"
void dut(bfloat16 *restrict v1, bfloat16 *restrict v2) {
size_t v3 = 0;
size_t v4 = 1024;
size_t v5 = 16;
for (size_t v6 = v3; v6 < v4; v6 += v5)
chess_prepare_for_pipelining chess_loop_range(64, 64) {
v16bfloat16 v7 = *(v16bfloat16 *)(v1 + v6);
v16bfloat16 v8 = getTanhBf16(v7);
*(v16bfloat16 *)(v2 + v6) = v8;
}
return;
}
Loading

0 comments on commit bb7653e

Please sign in to comment.