diff --git a/CMakeLists.txt b/CMakeLists.txt index 02ca0c2..8ed63c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,7 @@ set(CMAKE_CXX_STANDARD 17) set(APFP_PLATFORM "xilinx_u250_gen3x16_xdma_3_1_202020_1" CACHE STRING "Platform string for Vitis.") set(APFP_BITS 1024 CACHE STRING "Number of bits to use for a floating point number, including mantissa, exponent, and sign.") set(APFP_MULT_BASE_BITS 18 CACHE STRING "Number of bits to bottom out the multiplication at and use native multiplication.") +set(APFP_STREAMING_BASE_BITS 256 CACHE STRING "Bit width where Karatsuba will be implemented as a single pipeline.") set(APFP_TILE_SIZE_N 32 CACHE STRING "Tile size in the N-dimension when running matrix-matrix multiplication.") set(APFP_TILE_SIZE_M 32 CACHE STRING "Tile size in the M-dimension when running matrix-matrix multiplication.") set(APFP_COMPUTE_UNITS 1 CACHE STRING "Number of replications of the kernel to instantiate.") @@ -36,20 +37,26 @@ include_directories(${CMAKE_BINARY_DIR} include SYSTEM hlslib/include ${Vitis_IN configure_file(include/Config.h.in Config.h) set(APFP_KERNEL_FILES device/MatrixMultiplication.cpp - device/ArithmeticOperations.cpp - device/Karatsuba.cpp) + device/ArithmeticOperations.cpp) # Mapping to DDR ports set(APFP_PORT_MAPPING MatrixMultiplication_1.m_axi_a:DDR[1] MatrixMultiplication_1.m_axi_b:DDR[1] MatrixMultiplication_1.m_axi_c_read:DDR[1] MatrixMultiplication_1.m_axi_c_write:DDR[1]) +set(APFP_CONNECTIVITY MatrixMultiplication_1.a_to_kernel:FreeRunningMultiplication_1.a_to_kernel + MatrixMultiplication_1.b_to_kernel:FreeRunningMultiplication_1.b_to_kernel + FreeRunningMultiplication_1.ab_from_kernel:MatrixMultiplication_1.ab_from_kernel) if(${APFP_COMPUTE_UNITS} GREATER 1) set(APFP_PORT_MAPPING ${APFP_PORT_MAPPING} MatrixMultiplication_2.m_axi_a:DDR[0] MatrixMultiplication_2.m_axi_b:DDR[0] MatrixMultiplication_2.m_axi_c_read:DDR[0] MatrixMultiplication_2.m_axi_c_write:DDR[0]) + set(APFP_CONNECTIVITY ${APFP_CONNECTIVITY} + MatrixMultiplication_2.a_to_kernel:FreeRunningMultiplication_2.a_to_kernel + MatrixMultiplication_2.b_to_kernel:FreeRunningMultiplication_2.b_to_kernel + FreeRunningMultiplication_2.ab_from_kernel:MatrixMultiplication_2.ab_from_kernel) endif() if(${APFP_COMPUTE_UNITS} GREATER 2) set(APFP_PORT_MAPPING ${APFP_PORT_MAPPING} @@ -57,6 +64,10 @@ if(${APFP_COMPUTE_UNITS} GREATER 2) MatrixMultiplication_3.m_axi_b:DDR[2] MatrixMultiplication_3.m_axi_c_read:DDR[2] MatrixMultiplication_3.m_axi_c_write:DDR[2]) + set(APFP_CONNECTIVITY ${APFP_CONNECTIVITY} + MatrixMultiplication_3.a_to_kernel:FreeRunningMultiplication_3.a_to_kernel + MatrixMultiplication_3.b_to_kernel:FreeRunningMultiplication_3.b_to_kernel + FreeRunningMultiplication_3.ab_from_kernel:MatrixMultiplication_3.ab_from_kernel) endif() if(${APFP_COMPUTE_UNITS} GREATER 3) set(APFP_PORT_MAPPING ${APFP_PORT_MAPPING} @@ -64,6 +75,10 @@ if(${APFP_COMPUTE_UNITS} GREATER 3) MatrixMultiplication_4.m_axi_b:DDR[3] MatrixMultiplication_4.m_axi_c_read:DDR[3] MatrixMultiplication_4.m_axi_c_write:DDR[3]) + set(APFP_CONNECTIVITY ${APFP_CONNECTIVITY} + MatrixMultiplication_4.a_to_kernel:FreeRunningMultiplication_4.a_to_kernel + MatrixMultiplication_4.b_to_kernel:FreeRunningMultiplication_4.b_to_kernel + FreeRunningMultiplication_4.ab_from_kernel:MatrixMultiplication_4.ab_from_kernel) endif() if(${APFP_COMPUTE_UNITS} GREATER 4) set(APFP_PORT_MAPPING ${APFP_PORT_MAPPING} @@ -71,6 +86,10 @@ if(${APFP_COMPUTE_UNITS} GREATER 4) MatrixMultiplication_5.m_axi_b:DDR[1] MatrixMultiplication_5.m_axi_c_read:DDR[1] MatrixMultiplication_5.m_axi_c_write:DDR[1]) + set(APFP_CONNECTIVITY ${APFP_CONNECTIVITY} + MatrixMultiplication_5.a_to_kernel:FreeRunningMultiplication_5.a_to_kernel + MatrixMultiplication_5.b_to_kernel:FreeRunningMultiplication_5.b_to_kernel + FreeRunningMultiplication_5.ab_from_kernel:MatrixMultiplication_5.ab_from_kernel) endif() if(${APFP_COMPUTE_UNITS} GREATER 5) set(APFP_PORT_MAPPING ${APFP_PORT_MAPPING} @@ -78,6 +97,10 @@ if(${APFP_COMPUTE_UNITS} GREATER 5) MatrixMultiplication_6.m_axi_b:DDR[0] MatrixMultiplication_6.m_axi_c_read:DDR[0] MatrixMultiplication_6.m_axi_c_write:DDR[0]) + set(APFP_CONNECTIVITY ${APFP_CONNECTIVITY} + MatrixMultiplication_6.a_to_kernel:FreeRunningMultiplication_6.a_to_kernel + MatrixMultiplication_6.b_to_kernel:FreeRunningMultiplication_6.b_to_kernel + FreeRunningMultiplication_6.ab_from_kernel:MatrixMultiplication_6.ab_from_kernel) endif() if(${APFP_COMPUTE_UNITS} GREATER 6) set(APFP_PORT_MAPPING ${APFP_PORT_MAPPING} @@ -85,6 +108,10 @@ if(${APFP_COMPUTE_UNITS} GREATER 6) MatrixMultiplication_7.m_axi_b:DDR[2] MatrixMultiplication_7.m_axi_c_read:DDR[2] MatrixMultiplication_7.m_axi_c_write:DDR[2]) + set(APFP_CONNECTIVITY ${APFP_CONNECTIVITY} + MatrixMultiplication_7.a_to_kernel:FreeRunningMultiplication_7.a_to_kernel + MatrixMultiplication_7.b_to_kernel:FreeRunningMultiplication_7.b_to_kernel + FreeRunningMultiplication_7.ab_from_kernel:MatrixMultiplication_7.ab_from_kernel) endif() if(${APFP_COMPUTE_UNITS} GREATER 7) set(APFP_PORT_MAPPING ${APFP_PORT_MAPPING} @@ -92,26 +119,42 @@ if(${APFP_COMPUTE_UNITS} GREATER 7) MatrixMultiplication_8.m_axi_b:DDR[3] MatrixMultiplication_8.m_axi_c_read:DDR[3] MatrixMultiplication_8.m_axi_c_write:DDR[3]) + set(APFP_CONNECTIVITY ${APFP_CONNECTIVITY} + MatrixMultiplication_8.a_to_kernel:FreeRunningMultiplication_8.a_to_kernel + MatrixMultiplication_8.b_to_kernel:FreeRunningMultiplication_8.b_to_kernel + FreeRunningMultiplication_8.ab_from_kernel:MatrixMultiplication_8.ab_from_kernel) endif() if(${APFP_COMPUTE_UNITS} GREATER 8) message(FATAL_ERROR "More than 8 compute units is not supported.") endif() # Setup FPGA kernel targets +set(APFP_HLS_FLAGS "-DAP_INT_MAX_W=${APFP_MAX_BITS} -DAPFP_${APFP_SEMANTICS}_SEMANTICS") +set(APFP_HLS_CONFIG "config_compile -pipeline_style frp\nconfig_dataflow -fifo_depth 16") +set(APFP_INCLUDE_DIRS include hlslib/include ${CMAKE_BINARY_DIR}) +set(APFP_DEPENDS ${CMAKE_BINARY_DIR}/Config.h + include/ArithmeticOperations.h + include/DeviceTypes.h + include/Karatsuba.h + include/MatrixMultiplication.h + include/PackedFloat.h + include/PipelinedAdd.h) add_vitis_kernel(MatrixMultiplication FILES ${APFP_KERNEL_FILES} COMPUTE_UNITS ${APFP_COMPUTE_UNITS} - INCLUDE_DIRS include hlslib/include ${CMAKE_BINARY_DIR} - HLS_FLAGS "-DAP_INT_MAX_W=${APFP_MAX_BITS} -DAPFP_${APFP_SEMANTICS}_SEMANTICS" - HLS_CONFIG "config_compile -pipeline_style frp\nconfig_dataflow -fifo_depth 16" - DEPENDS ${CMAKE_BINARY_DIR}/Config.h - include/ArithmeticOperations.h - include/DeviceTypes.h - include/Karatsuba.h - include/MatrixMultiplication.h - include/PackedFloat.h - include/PipelinedAdd.h + INCLUDE_DIRS ${APFP_INCLUDE_DIRS} + HLS_FLAGS ${APFP_HLS_FLAGS} + HLS_CONFIG ${APFP_HLS_CONFIG} + DEPENDS ${APFP_DEPENDS} PORT_MAPPING ${APFP_PORT_MAPPING}) +add_vitis_kernel(FreeRunningMultiplication FILES ${APFP_KERNEL_FILES} + COMPUTE_UNITS ${APFP_COMPUTE_UNITS} + INCLUDE_DIRS ${APFP_INCLUDE_DIRS} + HLS_FLAGS ${APFP_HLS_FLAGS} + HLS_CONFIG ${APFP_HLS_CONFIG} + DEPENDS ${APFP_DEPENDS}) add_vitis_program(MatrixMultiplication ${APFP_PLATFORM} + KERNELS MatrixMultiplication FreeRunningMultiplication + CONNECTIVITY ${APFP_CONNECTIVITY} PROFILING ${APFP_PROFILING} DEBUGGING ${APFP_DEBUGGING} SAVE_TEMPS ${APFP_SAVE_TEMPS}) diff --git a/device/ArithmeticOperations.cpp b/device/ArithmeticOperations.cpp index 57e6fa8..51115c7 100644 --- a/device/ArithmeticOperations.cpp +++ b/device/ArithmeticOperations.cpp @@ -6,12 +6,6 @@ #include "Karatsuba.h" #include "PipelinedAdd.h" -template -inline bool IsMostSignificantBitSet(ap_uint const &num) { -#pragma HLS INLINE - return num.test(bits - 1); -} - template inline int CountLeadingZeros(ap_uint const &num) { #pragma HLS INLINE @@ -28,14 +22,6 @@ PackedFloat Multiply(PackedFloat const &a, PackedFloat const &b) { // Pad mantissas to avoid passing awkward sizes to Karatsuba const ap_uint a_mantissa_padded(a.GetMantissa()); const ap_uint b_mantissa_padded(b.GetMantissa()); -#ifdef APFP_GMP_SEMANTICS // Use GMP semantics - constexpr auto kLimbBits = 8 * sizeof(mp_limb_t); - // Meat of the computation. Only keep the top bits of the computation and throw away the rest - const ap_uint<(2 * kMantissaBits)> _m_mantissa = Karatsuba(a_mantissa_padded, b_mantissa_padded); - const bool limb_zero = _m_mantissa.range(kMantissaBits + kLimbBits - 1, kMantissaBits) == 0; - ap_uint m_mantissa = _m_mantissa; // Truncate - const Exponent m_exponent = a.GetExponent() + b.GetExponent() - limb_zero; -#else // Otherwise use MPFR semantics const ap_uint _m_mantissa = Karatsuba(a_mantissa_padded, b_mantissa_padded) >> (kMantissaBits - 1); // We need to shift the mantissa forward if the most significant bit is not set @@ -44,7 +30,6 @@ PackedFloat Multiply(PackedFloat const &a, PackedFloat const &b) { // Add up exponents. If the most significant bit was 1, we're done. Otherwise subtract 1 due to // the shift. const Exponent m_exponent = a.GetExponent() + b.GetExponent() - (should_be_shifted ? 1 : 0); -#endif // The sign is just the XOR of the existing signs PackedFloat result; result.SetMantissa(m_mantissa); diff --git a/device/Karatsuba.cpp b/device/Karatsuba.cpp deleted file mode 100644 index c2c5882..0000000 --- a/device/Karatsuba.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "Karatsuba.h" - -#include // std::enable_if - -#include "PipelinedAdd.h" - -template -auto _Karatsuba(ap_uint const &a, ap_uint const &b) -> - typename std::enable_if<(bits > kMultBaseBits), ap_uint<2 * bits>>::type { - static_assert(bits % 2 == 0, "Number of bits must be even."); - using Full = ap_uint; - using Half = ap_uint; - - // Decompose input operands into halves for the recursive step - Half a0 = a.range(bits / 2 - 1, 0); - Half a1 = a.range(bits - 1, bits / 2); - Half b0 = b.range(bits / 2 - 1, 0); - Half b1 = b.range(bits - 1, bits / 2); - - // Recurse on a_0 * b_0 and a_1 * b_1 - Full z0 = _Karatsuba(a0, b0); - Full z2 = _Karatsuba(a1, b1); - - // Compute |a_0 - a_1| and sign(a_0 - a_1) - bool a0a1_is_neg = a0 < a1; - Half a0a1 = PipelinedSub(a0a1_is_neg ? a1 : a0, a0a1_is_neg ? a0 : a1); -#pragma HLS BIND_OP variable = a0a1 op = sub impl = fabric latency = AddLatency(bits / 2) - // Compute |b_1 - b_0| and sign(b_1 - b_0) - bool b0b1_is_neg = b1 < b0; - Half b0b1 = PipelinedSub(b0b1_is_neg ? b0 : b1, b0b1_is_neg ? b1 : b0); -#pragma HLS BIND_OP variable = b0b1 op = sub impl = fabric latency = AddLatency(bits / 2) - - // XOR the two signs to get the final sign - bool a0a1b0b1_is_neg = a0a1_is_neg != b0b1_is_neg; - // Recurse on |a_0 - a_1| * |b_0 - b_1| - Full a0a1b0b1 = _Karatsuba(a0a1, b0b1); - ap_int a0a1b0b1_signed = a0a1b0b1_is_neg ? -ap_int(a0a1b0b1) : ap_int(a0a1b0b1); - ap_uint z1 = PipelinedAdd(ap_uint(a0a1b0b1_signed), PipelinedAdd(z0, z2)); - - // Align everything and combine - ap_uint<(2 * bits)> z0z2 = z0 | (ap_uint<(2 * bits)>(z2) << bits); - ap_uint<(bits + 2 + bits / 2)> z1_aligned = ap_uint<(bits + 2 + bits / 2)>(z1) << (bits / 2); - ap_uint<(2 * bits) + 1> z = PipelinedAdd<2 * bits>(z1_aligned, z0z2); - - return z; -} - -// Bottom out using SFINAE when the bit width is lower or equal to the specified base number of bits -template -auto _Karatsuba(ap_uint const &a, ap_uint const &b) -> - typename std::enable_if<(bits <= kMultBaseBits), ap_uint<2 * bits>>::type { -#pragma HLS INLINE - return a * b; -} - -ap_uint<2 * kBits> Karatsuba(ap_uint const &a, ap_uint const &b) { -#pragma HLS INLINE - return _Karatsuba(a, b); -} diff --git a/device/MatrixMultiplication.cpp b/device/MatrixMultiplication.cpp index 594cf69..eb2f15e 100644 --- a/device/MatrixMultiplication.cpp +++ b/device/MatrixMultiplication.cpp @@ -5,6 +5,8 @@ #include // hlslib::CeilDivide #include "ArithmeticOperations.h" +#include "Karatsuba.h" +#include "PipelinedAdd.h" // Annoyingly we have to specialize the innermost loop on whether multiple DRAM flits per number are required or not, // because HLS otherwise gets confused by pragmas applied to a loop of size 1 in the latter case. @@ -330,42 +332,89 @@ void WriteC(hlslib::Stream &from_kernel, DramLine *const mem, const //////////////////////////////////////////////////////////////////////////////// -void Compute(hlslib::Stream &a_in, hlslib::Stream &b_in, hlslib::Stream &c_in, - hlslib::Stream &c_out, int const size_n, int const size_k, int const size_m) { - PackedFloat a_buffer; // Just to make A symmetric to B and C +void ComputeEntry(hlslib::Stream &a_in, hlslib::Stream &b_in, + hlslib::Stream &a_out, hlslib::Stream &b_out, + hlslib::Stream> &ab_bypass, int const size_n, int const size_k, + int const size_m) { + PackedFloat a_buffer; PackedFloat b_buffer[kTileSizeM]; - PackedFloat c_buffer[kTileSizeN * kTileSizeM]; const int tiles_n = hlslib::CeilDivide(size_n, kTileSizeN); const int tiles_m = hlslib::CeilDivide(size_m, kTileSizeM); -Compute_TilesN: +ComputeEntry_TilesN: for (int n0 = 0; n0 < tiles_n; ++n0) { - Compute_TilesM: + ComputeEntry_TilesM: for (int m0 = 0; m0 < tiles_m; ++m0) { - Compute_K: + ComputeEntry_K: for (int k = 0; k < size_k; ++k) { - Compute_N: + ComputeEntry_N: for (int n1 = 0; n1 < ((n0 < tiles_n - 1) ? kTileSizeN : (size_n - n0 * kTileSizeN)); ++n1) { - Compute_M: + ComputeEntry_M: for (int m1 = 0; m1 < kTileSizeM; ++m1) { #pragma HLS PIPELINE II = 1 #pragma HLS LOOP_FLATTEN const PackedFloat a_read = a_in.Pop(); const PackedFloat b_read = b_in.Pop(); - const PackedFloat c_read = c_in.Pop(); - const PackedFloat a = (m1 == 0) ? a_read : a_buffer; - const PackedFloat b = (n1 == 0) ? b_read : b_buffer[m1]; - const PackedFloat c = (k == 0) ? c_read : c_buffer[n1 * kTileSizeM + m1]; + PackedFloat a = (m1 == 0) ? a_read : a_buffer; + PackedFloat b = (n1 == 0) ? b_read : b_buffer[m1]; a_buffer = a; b_buffer[m1] = b; // Ignore contributions from out-of-bound indices const bool in_bounds = (n0 * kTileSizeN + n1 < size_n) && (m0 * kTileSizeM + m1 < size_m); - // Meat of the computation - const auto res = MultiplyAccumulate(in_bounds ? a : PackedFloat::Zero(), - in_bounds ? b : PackedFloat::Zero(), c); - // Write back to buffer + if (!in_bounds) { + a.SetZero(); + b.SetZero(); + } + // Multiplication prologue + ap_uint<8 * sizeof(Exponent)> sign_exponent; + reinterpret_cast(&sign_exponent)->sign = a.GetSignBit() != b.GetSignBit(); + reinterpret_cast(&sign_exponent)->exponent = + a.GetExponent() + b.GetExponent(); + ab_bypass.Push(sign_exponent); + a_out.Push(a.GetMantissa()); + b_out.Push(b.GetMantissa()); + } + } + } + } + } +} + +void ComputeExit(hlslib::Stream> &ab_mantissa_in, + hlslib::Stream> &ab_bypass, hlslib::Stream &c_in, + hlslib::Stream &c_out, int const size_n, int const size_k, int const size_m) { + PackedFloat c_buffer[kTileSizeN * kTileSizeM]; + const int tiles_n = hlslib::CeilDivide(size_n, kTileSizeN); + const int tiles_m = hlslib::CeilDivide(size_m, kTileSizeM); +ComputeExit_TilesN: + for (int n0 = 0; n0 < tiles_n; ++n0) { + ComputeExit_TilesM: + for (int m0 = 0; m0 < tiles_m; ++m0) { + ComputeExit_K: + for (int k = 0; k < size_k; ++k) { + ComputeExit_N: + for (int n1 = 0; n1 < ((n0 < tiles_n - 1) ? kTileSizeN : (size_n - n0 * kTileSizeN)); ++n1) { + ComputeExit_M: + for (int m1 = 0; m1 < kTileSizeM; ++m1) { +#pragma HLS PIPELINE II = 1 +#pragma HLS LOOP_FLATTEN + const auto _ab_mantissa = ab_mantissa_in.Pop(); + const auto ab_sign_exponent = ab_bypass.Pop(); + // Matrix multiplication epilogue + PackedFloat ab; + ab.SetSignExponent(ab_sign_exponent); + const bool should_be_shifted = !IsMostSignificantBitSet(_ab_mantissa); + const ap_uint m_mantissa = + should_be_shifted ? _ab_mantissa : (_ab_mantissa >> 1); + ab.SetMantissa(m_mantissa); + // If the most significant bit was 0, subtract 1 due to the shift. + ab.SetExponent(ab.GetExponent() - (should_be_shifted ? 1 : 0)); + // Addition + const PackedFloat c_read = c_in.Pop(); + const PackedFloat c = (k == 0) ? c_read : c_buffer[n1 * kTileSizeM + m1]; + const PackedFloat res = Add(ab, c); + c_out.Push(res); c_buffer[n1 * kTileSizeM + m1] = res; #pragma HLS DEPENDENCE variable = c_buffer false - c_out.Push(res); } } } @@ -375,8 +424,122 @@ void Compute(hlslib::Stream &a_in, hlslib::Stream &b_i //////////////////////////////////////////////////////////////////////////////// +template +void StreamingKaratsubaEntry(hlslib::Stream> &a_in, hlslib::Stream> &b_in, + hlslib::Stream> &a0_out, hlslib::Stream> &b0_out, + hlslib::Stream> &a1_out, hlslib::Stream> &b1_out, + hlslib::Stream> &a0a1_out, hlslib::Stream> &b0b1_out, + hlslib::Stream &sign_out) { +#pragma HLS PIPELINE II = 1 + using Half = ap_uint; + + const auto a = a_in.Pop(); + const auto b = b_in.Pop(); + + // Decompose input operands into halves for the recursive step + Half a0 = a.range(bits / 2 - 1, 0); + Half a1 = a.range(bits - 1, bits / 2); + Half b0 = b.range(bits / 2 - 1, 0); + Half b1 = b.range(bits - 1, bits / 2); + + // Compute |a_0 - a_1| and sign(a_0 - a_1) + bool a0a1_is_neg = a0 < a1; + Half a0a1 = PipelinedSub(a0a1_is_neg ? a1 : a0, a0a1_is_neg ? a0 : a1); + // Compute |b_1 - b_0| and sign(b_1 - b_0) + bool b0b1_is_neg = b1 < b0; + Half b0b1 = PipelinedSub(b0b1_is_neg ? b0 : b1, b0b1_is_neg ? b1 : b0); + + // XOR the two signs to get the final sign + bool a0a1b0b1_is_neg = a0a1_is_neg != b0b1_is_neg; + + a0_out.Push(a0); + a1_out.Push(a1); + b0_out.Push(b0); + b1_out.Push(b1); + a0a1_out.Push(a0a1); + b0b1_out.Push(b0b1); + sign_out.Push(a0a1b0b1_is_neg); +} + +template +void StreamingKaratsubaExit(hlslib::Stream> &z0_in, hlslib::Stream> &z2_in, + hlslib::Stream &sign_in, hlslib::Stream> &a0a1b0b1_in, + hlslib::Stream> &result_out) { +#pragma HLS PIPELINE II = 1 + using Full = ap_uint; + + // Get results from recursive modules + const Full z0 = z0_in.Pop(); + const Full z2 = z2_in.Pop(); + const bool a0a1b0b1_is_neg = sign_in.Pop(); + const Full a0a1b0b1 = a0a1b0b1_in.Pop(); + + const ap_int a0a1b0b1_signed = a0a1b0b1_is_neg ? -ap_int(a0a1b0b1) : ap_int(a0a1b0b1); + const ap_uint z1 = PipelinedAdd(ap_uint(a0a1b0b1_signed), PipelinedAdd(z0, z2)); + + // Align everything and combine + const ap_uint<(2 * bits)> z0z2 = z0 | (ap_uint<(2 * bits)>(z2) << bits); + const ap_uint<(bits + 2 + bits / 2)> z1_aligned = ap_uint<(bits + 2 + bits / 2)>(z1) << (bits / 2); + const ap_uint<(2 * bits) + 1> z = PipelinedAdd<2 * bits>(z1_aligned, z0z2); + + result_out.Push(z); +} + +template +auto StreamingKaratsuba(hlslib::Stream> &a_in, hlslib::Stream> &b_in, + hlslib::Stream> &result_out) -> + typename std::enable_if<(bits > kStreamingBaseBits), void>::type { + static_assert(bits % 2 == 0, "Number of bits must be even."); +#pragma HLS INLINE + hlslib::Stream> a0; + hlslib::Stream> b0; + hlslib::Stream> a1; + hlslib::Stream> b1; + hlslib::Stream> a0a1; + hlslib::Stream> b0b1; + hlslib::Stream sign; + hlslib::Stream> z0; + hlslib::Stream> z2; + hlslib::Stream> a0a1b0b1; + StreamingKaratsubaEntry(a_in, b_in, a0, b0, a1, b1, a0a1, b0b1, sign); + StreamingKaratsuba<(bits / 2)>(a0, b0, z0); + StreamingKaratsuba<(bits / 2)>(a1, b1, z2); + StreamingKaratsuba<(bits / 2)>(a0a1, b0b1, a0a1b0b1); + StreamingKaratsubaExit(z0, z2, sign, a0a1b0b1, result_out); +} + +template +auto StreamingKaratsuba(hlslib::Stream> &a_in, hlslib::Stream> &b_in, + hlslib::Stream> &result_out) -> + typename std::enable_if<(bits <= kStreamingBaseBits), void>::type { +#pragma HLS PIPELINE II = 1 + result_out.Push(Karatsuba(a_in.Pop(), b_in.Pop())); +} + +void Truncate(hlslib::Stream> &ab_in, hlslib::Stream> &ab_out) { +#pragma HLS PIPELINE II = 1 + const ap_uint ab_mantissa = ab_in.Pop() >> (kMantissaBits - 1); + ab_out.Push(ab_mantissa); +} + +void FreeRunningMultiplication(hlslib::Stream &a_to_kernel, hlslib::Stream &b_to_kernel, + hlslib::Stream> &ab_from_kernel) { +#pragma HLS INTERFACE axis port = a_to_kernel +#pragma HLS INTERFACE axis port = b_to_kernel +#pragma HLS INTERFACE axis port = ab_from_kernel +#pragma HLS interface ap_ctrl_none port = return +#pragma HLS DATAFLOW + hlslib::Stream> truncate; + StreamingKaratsuba(a_to_kernel, b_to_kernel, truncate); + Truncate(truncate, ab_from_kernel); +} + +//////////////////////////////////////////////////////////////////////////////// + void MatrixMultiplication(DramLine const *const a, DramLine const *const b, DramLine const *const c_read, - DramLine *const c_write, const int size_n, const int size_k, int const size_m) { + DramLine *const c_write, const int size_n, const int size_k, int const size_m, + hlslib::Stream &a_to_kernel, hlslib::Stream &b_to_kernel, + hlslib::Stream> &ab_from_kernel) { #pragma HLS INTERFACE m_axi offset = slave port = a bundle = a #pragma HLS INTERFACE m_axi offset = slave port = b bundle = b // Even though they actually point to the same memory location, we use two separate interfaces for reading and writing @@ -390,6 +553,9 @@ void MatrixMultiplication(DramLine const *const a, DramLine const *const b, Dram #pragma HLS INTERFACE s_axilite port = size_n #pragma HLS INTERFACE s_axilite port = size_k #pragma HLS INTERFACE s_axilite port = size_m +#pragma HLS INTERFACE axis port = a_to_kernel +#pragma HLS INTERFACE axis port = b_to_kernel +#pragma HLS INTERFACE axis port = ab_from_kernel #pragma HLS STABLE variable = a #pragma HLS STABLE variable = b #pragma HLS STABLE variable = c_read @@ -399,21 +565,26 @@ void MatrixMultiplication(DramLine const *const a, DramLine const *const b, Dram #pragma HLS STABLE variable = size_m #pragma HLS DATAFLOW hlslib::Stream a_to_feeder("a_to_feeder"); - hlslib::Stream a_to_kernel("a_to_kernel"); + hlslib::Stream a_to_entry("a_to_entry"); hlslib::Stream b_to_feeder("b_to_feeder"); - hlslib::Stream b_to_kernel("b_to_kernel"); + hlslib::Stream b_to_entry("b_to_entry"); + hlslib::Stream, 1024> ab_bypass("ab_bypass"); hlslib::Stream c_to_feeder("c_to_feeder"); hlslib::Stream c_to_kernel("c_to_kernel"); hlslib::Stream c_from_kernel("c_from_kernel"); + hlslib::Stream c_from_exit("c_from_exit"); hlslib::Stream c_from_drainer("c_from_drainer"); HLSLIB_DATAFLOW_INIT(); HLSLIB_DATAFLOW_FUNCTION(ReadA, a, a_to_feeder, size_n, size_k, size_m); - HLSLIB_DATAFLOW_FUNCTION(FeedA, a_to_feeder, a_to_kernel, size_n, size_k, size_m); + HLSLIB_DATAFLOW_FUNCTION(FeedA, a_to_feeder, a_to_entry, size_n, size_k, size_m); HLSLIB_DATAFLOW_FUNCTION(ReadB, b, b_to_feeder, size_n, size_k, size_m); - HLSLIB_DATAFLOW_FUNCTION(FeedB, b_to_feeder, b_to_kernel, size_n, size_k, size_m); + HLSLIB_DATAFLOW_FUNCTION(FeedB, b_to_feeder, b_to_entry, size_n, size_k, size_m); HLSLIB_DATAFLOW_FUNCTION(ReadC, c_read, c_to_feeder, size_n, size_m); HLSLIB_DATAFLOW_FUNCTION(FeedC, c_to_feeder, c_to_kernel, size_n, size_k, size_m); - HLSLIB_DATAFLOW_FUNCTION(Compute, a_to_kernel, b_to_kernel, c_to_kernel, c_from_kernel, size_n, size_k, size_m); + HLSLIB_DATAFLOW_FUNCTION(ComputeEntry, a_to_entry, b_to_entry, a_to_kernel, b_to_kernel, ab_bypass, size_n, size_k, + size_m); + HLSLIB_DATAFLOW_FUNCTION(ComputeExit, ab_from_kernel, ab_bypass, c_to_kernel, c_from_kernel, size_n, size_k, + size_m); HLSLIB_DATAFLOW_FUNCTION(DrainC, c_from_kernel, c_from_drainer, size_n, size_k, size_m); HLSLIB_DATAFLOW_FUNCTION(WriteC, c_from_drainer, c_write, size_n, size_m); HLSLIB_DATAFLOW_FINALIZE(); diff --git a/host/TestProgram.cpp b/host/TestProgram.cpp index 4655a58..4661db9 100644 --- a/host/TestProgram.cpp +++ b/host/TestProgram.cpp @@ -1,4 +1,5 @@ #include +#include #include #include // putenv @@ -22,6 +23,13 @@ struct MpfrWrapper { }; #ifdef HLSLIB_SIMULATE_OPENCL +void RunFreeRunningKernel(hlslib::Stream &a_mantissa_in, hlslib::Stream &b_mantissa_in, + hlslib::Stream> &ab_mantissa_out) { + while (true) { + FreeRunningMultiplication(a_mantissa_in, b_mantissa_in, ab_mantissa_out); + } +} + bool RunTestSimulation(int size_n, int size_k, int size_m, bool verify) { const std::string kernel_path(""); #else @@ -111,9 +119,14 @@ bool RunTest(std::string const &kernel_path, int size_n, int size_k, int size_m, // In simulation mode, this will call the function "MatrixMultiplication" and run it in software. // Otherwise, the provided path to a kernel binary will be loaded and executed. std::vector kernels; + hlslib::Stream a_to_kernel[kComputeUnits]; + hlslib::Stream b_to_kernel[kComputeUnits]; + hlslib::Stream> ab_from_kernel[kComputeUnits]; for (int i = 0; i < kComputeUnits; ++i) { - kernels.emplace_back(program.MakeKernel(MatrixMultiplication, "MatrixMultiplication", a_device[i], b_device[i], - c_device[i], c_device[i], n_partition_size[i], size_k, size_m)); + kernels.emplace_back(program.MakeKernel( + MatrixMultiplication, "MatrixMultiplication", a_device[i], b_device[i], c_device[i], c_device[i], + n_partition_size[i], size_k, size_m, hlslib::ocl::SimulationOnly(a_to_kernel[i]), + hlslib::ocl::SimulationOnly(b_to_kernel[i]), hlslib::ocl::SimulationOnly(ab_from_kernel[i]))); } const float expected_runtime = expected_cycles / 0.3e9; @@ -126,6 +139,14 @@ bool RunTest(std::string const &kernel_path, int size_n, int size_k, int size_m, << bandwidth << " GB/s.\n"; std::cout << "Executing kernel...\n"; +#ifdef HLSLIB_SIMULATE_OPENCL + for (int i = 0; i < kComputeUnits; ++i) { + std::thread free_running(RunFreeRunningKernel, std::ref(a_to_kernel[i]), std::ref(b_to_kernel[i]), + std::ref(ab_from_kernel[i])); + // Will be killed when the program exits + free_running.detach(); + } +#endif std::vector events; auto start = std::chrono::high_resolution_clock::now(); for (int i = 0; i < kComputeUnits; ++i) { diff --git a/include/ArithmeticOperations.h b/include/ArithmeticOperations.h index 1776e60..e633441 100644 --- a/include/ArithmeticOperations.h +++ b/include/ArithmeticOperations.h @@ -5,3 +5,9 @@ PackedFloat MultiplyAccumulate(PackedFloat const &a, PackedFloat const &b, PackedFloat const &c); PackedFloat Multiply(PackedFloat const &a, PackedFloat const &b); PackedFloat Add(PackedFloat const &a, PackedFloat const &b); + +template +inline bool IsMostSignificantBitSet(ap_uint const &num) { +#pragma HLS INLINE + return num.test(bits - 1); +} diff --git a/include/Config.h.in b/include/Config.h.in index f8f03d7..4aee15e 100644 --- a/include/Config.h.in +++ b/include/Config.h.in @@ -3,6 +3,7 @@ constexpr int kBits = ${APFP_BITS}; constexpr int kBytes = kBits / 8; constexpr int kMultBaseBits = ${APFP_MULT_BASE_BITS}; +constexpr int kStreamingBaseBits = ${APFP_STREAMING_BASE_BITS}; constexpr int kTileSizeN = ${APFP_TILE_SIZE_N}; constexpr int kTileSizeM = ${APFP_TILE_SIZE_M}; constexpr int kComputeUnits = ${APFP_COMPUTE_UNITS}; diff --git a/include/Karatsuba.h b/include/Karatsuba.h index e6c3d31..16de678 100644 --- a/include/Karatsuba.h +++ b/include/Karatsuba.h @@ -2,6 +2,55 @@ #include +#include // std::enable_if + #include "Config.h" +#include "Karatsuba.h" +#include "PipelinedAdd.h" + +template +auto Karatsuba(ap_uint const &a, ap_uint const &b) -> + typename std::enable_if<(bits > kMultBaseBits), ap_uint<2 * bits>>::type { + static_assert(bits % 2 == 0, "Number of bits must be even."); + using Full = ap_uint; + using Half = ap_uint; + + // Decompose input operands into halves for the recursive step + Half a0 = a.range(bits / 2 - 1, 0); + Half a1 = a.range(bits - 1, bits / 2); + Half b0 = b.range(bits / 2 - 1, 0); + Half b1 = b.range(bits - 1, bits / 2); + + // Recurse on a_0 * b_0 and a_1 * b_1 + Full z0 = Karatsuba(a0, b0); + Full z2 = Karatsuba(a1, b1); + + // Compute |a_0 - a_1| and sign(a_0 - a_1) + bool a0a1_is_neg = a0 < a1; + Half a0a1 = PipelinedSub(a0a1_is_neg ? a1 : a0, a0a1_is_neg ? a0 : a1); + // Compute |b_1 - b_0| and sign(b_1 - b_0) + bool b0b1_is_neg = b1 < b0; + Half b0b1 = PipelinedSub(b0b1_is_neg ? b0 : b1, b0b1_is_neg ? b1 : b0); + + // XOR the two signs to get the final sign + bool a0a1b0b1_is_neg = a0a1_is_neg != b0b1_is_neg; + // Recurse on |a_0 - a_1| * |b_0 - b_1| + Full a0a1b0b1 = Karatsuba(a0a1, b0b1); + ap_int a0a1b0b1_signed = a0a1b0b1_is_neg ? -ap_int(a0a1b0b1) : ap_int(a0a1b0b1); + ap_uint z1 = PipelinedAdd(ap_uint(a0a1b0b1_signed), PipelinedAdd(z0, z2)); + + // Align everything and combine + ap_uint<(2 * bits)> z0z2 = z0 | (ap_uint<(2 * bits)>(z2) << bits); + ap_uint<(bits + 2 + bits / 2)> z1_aligned = ap_uint<(bits + 2 + bits / 2)>(z1) << (bits / 2); + ap_uint<(2 * bits) + 1> z = PipelinedAdd<2 * bits>(z1_aligned, z0z2); + + return z; +} -ap_uint<2 * kBits> Karatsuba(ap_uint const &a, ap_uint const &b); +// Bottom out using SFINAE when the bit width is lower or equal to the specified base number of bits +template +auto Karatsuba(ap_uint const &a, ap_uint const &b) -> + typename std::enable_if<(bits <= kMultBaseBits), ap_uint<2 * bits>>::type { +#pragma HLS INLINE + return a * b; +} diff --git a/include/MatrixMultiplication.h b/include/MatrixMultiplication.h index 43198fe..25566d9 100644 --- a/include/MatrixMultiplication.h +++ b/include/MatrixMultiplication.h @@ -1,7 +1,18 @@ #pragma once +#include + #include "Config.h" #include "DeviceTypes.h" +#include "PackedFloat.h" + +extern "C" void MatrixMultiplication(DramLine const *const a, DramLine const *const b, DramLine const *const c_read, + DramLine *const c_write, const int size_n, const int size_k, int const size_m, + hlslib::Stream &a_to_kernel, + hlslib::Stream &b_to_kernel, + hlslib::Stream> &ab_from_kernel); + +extern "C" void FreeRunningMultiplication(hlslib::Stream &a_mantissa_in, + hlslib::Stream &b_mantissa_in, + hlslib::Stream> &ab_mantissa_out); -extern "C" void MatrixMultiplication(DramLine const *a, DramLine const *b, DramLine const *c_read, DramLine *c_write, - int n, int m, int k); diff --git a/include/PackedFloat.h b/include/PackedFloat.h index 8ee0e3e..656d4fb 100644 --- a/include/PackedFloat.h +++ b/include/PackedFloat.h @@ -95,6 +95,11 @@ class PackedFloat { data_.set_bit(kBits - 1, sign); } + void SetSignExponent(ap_uint<8 * sizeof(Exponent)> const &sign_exponent) { +#pragma HLS INLINE + data_.range(kBits - 1, kBits - 8 * sizeof(Exponent)) = sign_exponent; + } + DramLine GetFlit(const size_t i) const { #pragma HLS INLINE return data_.range((i + 1) * 512 - 1, i * 512); @@ -113,10 +118,15 @@ class PackedFloat { } } + void SetZero() { +#pragma HLS INLINE + data_ = 0; + } + static PackedFloat Zero() { #pragma HLS INLINE PackedFloat x; - x.data_ = 0; + x.SetZero(); return x; }