diff --git a/BUILD.bazel b/BUILD.bazel index 7286ff580802..5c546496c232 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -9,7 +9,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") # buildifier: disable=out-of-order-load load("@rules_python//python:py_binary.bzl", "py_binary") -load(":build_defs.bzl", "xnnpack_aggregate_library", "xnnpack_cc_library", "xnnpack_gcc_std_copts", "xnnpack_min_size_copts", "xnnpack_msvc_std_copts", "xnnpack_slinky_defines", "xnnpack_slinky_deps", "xnnpack_slinky_srcs", "xnnpack_std_cxxopts", "xnnpack_transitive_source_list", "xnnpack_visibility") +load(":build_defs.bzl", "xnnpack_aggregate_library", "xnnpack_cc_library", "xnnpack_gcc_std_copts", "xnnpack_if_kleidiai_enabled", "xnnpack_kleidiai_defines", "xnnpack_min_size_copts", "xnnpack_msvc_std_copts", "xnnpack_slinky_defines", "xnnpack_slinky_deps", "xnnpack_slinky_srcs", "xnnpack_std_cxxopts", "xnnpack_transitive_source_list", "xnnpack_visibility") load("//gen:microkernels.bzl", "AARCH32_ASM_MICROKERNEL_SRCS", "AARCH32_JIT_MICROKERNEL_SRCS", "AARCH64_ASM_MICROKERNEL_SRCS", "AARCH64_JIT_MICROKERNEL_SRCS", "ALL_ARMSIMD32_MICROKERNEL_SRCS", "ALL_AVX2_MICROKERNEL_SRCS", "ALL_AVX512AMX_MICROKERNEL_SRCS", "ALL_AVX512FP16_MICROKERNEL_SRCS", "ALL_AVX512F_MICROKERNEL_SRCS", "ALL_AVX512SKX_MICROKERNEL_SRCS", "ALL_AVX512VBMI_MICROKERNEL_SRCS", "ALL_AVX512VNNIGFNI_MICROKERNEL_SRCS", "ALL_AVX512VNNI_MICROKERNEL_SRCS", "ALL_AVXVNNI_MICROKERNEL_SRCS", "ALL_AVX_MICROKERNEL_SRCS", "ALL_F16C_MICROKERNEL_SRCS", "ALL_FMA3_MICROKERNEL_SRCS", "ALL_FMA_MICROKERNEL_SRCS", "ALL_FP16ARITH_MICROKERNEL_SRCS", "ALL_HEXAGON_MICROKERNEL_SRCS", "ALL_HVX_MICROKERNEL_SRCS", "ALL_NEONBF16_AARCH64_MICROKERNEL_SRCS", "ALL_NEONBF16_MICROKERNEL_SRCS", "ALL_NEONDOTFP16ARITH_MICROKERNEL_SRCS", "ALL_NEONDOT_AARCH64_MICROKERNEL_SRCS", "ALL_NEONDOT_MICROKERNEL_SRCS", "ALL_NEONFMA_AARCH64_MICROKERNEL_SRCS", "ALL_NEONFMA_MICROKERNEL_SRCS", "ALL_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS", "ALL_NEONFP16ARITH_MICROKERNEL_SRCS", "ALL_NEONFP16_MICROKERNEL_SRCS", "ALL_NEONI8MM_MICROKERNEL_SRCS", "ALL_NEONV8_MICROKERNEL_SRCS", "ALL_NEON_AARCH64_MICROKERNEL_SRCS", "ALL_NEON_MICROKERNEL_SRCS", "ALL_RVVFP16ARITH_MICROKERNEL_SRCS", "ALL_RVV_MICROKERNEL_SRCS", "ALL_SCALAR_MICROKERNEL_SRCS", "ALL_SSE2_MICROKERNEL_SRCS", "ALL_SSE41_MICROKERNEL_SRCS", "ALL_SSE_MICROKERNEL_SRCS", "ALL_SSSE3_MICROKERNEL_SRCS", "ALL_WASMRELAXEDSIMD_MICROKERNEL_SRCS", "ALL_WASMSIMD_MICROKERNEL_SRCS", "ALL_WASM_MICROKERNEL_SRCS", "WASM32_ASM_MICROKERNEL_SRCS", "WASM32_JIT_MICROKERNEL_SRCS", "WASMRELAXEDSIMD32_JIT_MICROKERNEL_SRCS", "WASMSIMD32_JIT_MICROKERNEL_SRCS") licenses(["notice"]) @@ -135,6 +135,7 @@ MICROKERNEL_HDRS = [ "src/xnnpack/lut.h", "src/xnnpack/maxpool.h", "src/xnnpack/packb.h", + "src/xnnpack/packq.h", "src/xnnpack/packw.h", "src/xnnpack/packx.h", "src/xnnpack/pad.h", @@ -725,9 +726,10 @@ xnnpack_cc_library( "src/amalgam/gen/neon-aarch64.c", "src/amalgam/gen/neon.c", ], + defines = xnnpack_kleidiai_defines(), gcc_copts = xnnpack_gcc_std_copts(), msvc_copts = xnnpack_msvc_std_copts(), - deps = MICROKERNEL_DEPS, + deps = MICROKERNEL_DEPS + xnnpack_if_kleidiai_enabled(["@KleidiAI:kleidiai_neon"]), ) xnnpack_cc_library( @@ -739,9 +741,12 @@ xnnpack_cc_library( ], aarch32_srcs = ALL_NEON_MICROKERNEL_SRCS, aarch64_srcs = ALL_NEON_MICROKERNEL_SRCS + ALL_NEON_AARCH64_MICROKERNEL_SRCS, + defines = xnnpack_kleidiai_defines(), gcc_copts = xnnpack_gcc_std_copts(), msvc_copts = xnnpack_msvc_std_copts(), - deps = MICROKERNEL_DEPS, + deps = MICROKERNEL_DEPS + xnnpack_if_kleidiai_enabled([ + "@KleidiAI:kleidiai_neon", + ]), ) xnnpack_cc_library( @@ -2564,6 +2569,18 @@ config_setting( define_values = {"xnn_enable_cpuinfo": "false"}, ) +# Enables usage of the KleidiAI library. +config_setting( + name = "xnn_enable_kleidiai_explicit_true", + define_values = {"xnn_enable_kleidiai": "true"}, +) + +# Disables usage of the KleidiAI library. +config_setting( + name = "xnn_enable_kleidiai_explicit_false", + define_values = {"xnn_enable_kleidiai": "false"}, +) + # Enables usage of assembly kernels. config_setting( name = "xnn_enable_assembly_explicit_true", @@ -2946,6 +2963,22 @@ alias( }), ) +selects.config_setting_group( + name = "kleidiai_enabled_by_default", + match_any = [ + "//build_config:aarch64", + ], +) + +alias( + name = "kleidiai_enabled", + actual = select({ + ":xnn_enable_kleidiai_explicit_true": ":xnn_enable_kleidiai_explicit_true", + ":xnn_enable_kleidiai_explicit_false": ":xnn_enable_kleidiai_explicit_true", + "//conditions:default": ":kleidiai_enabled_by_default", + }), +) + selects.config_setting_group( name = "assembly_enabled_by_default", match_any = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index a556da87eba7..7b9c8f812dd8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,7 @@ OPTION(USE_GNU_SOURCE "Use _GNU_SOURCE macro" OFF) IF(XNNPACK_BUILD_BENCHMARKS OR XNNPACK_BUILD_TESTS) SET(XNNPACK_BUILD_ALL_MICROKERNELS ON) ENDIF() +OPTION(XNNPACK_ENABLE_KLEIDIAI "Use KleidiAI GEMM microkernels for Arm" OFF) # --- [ Determine target processor IF(CMAKE_OSX_ARCHITECTURES) @@ -160,6 +161,7 @@ ADD_COMPILE_DEFINITIONS("XNN_ENABLE_SPARSE=$") ADD_COMPILE_DEFINITIONS("XNN_ENABLE_GEMM_M_SPECIALIZATION=$") ADD_COMPILE_DEFINITIONS("XNN_ENABLE_DWCONV_MULTIPASS=$") ADD_COMPILE_DEFINITIONS("XNN_ENABLE_HVX=$") +ADD_COMPILE_DEFINITIONS("XNN_ENABLE_KLEIDIAI=$") IF(XNNPACK_PLATFORM_JIT STREQUAL "ON" OR XNNPACK_PLATFORM_JIT STREQUAL "OFF") ADD_COMPILE_DEFINITIONS("XNN_PLATFORM_JIT=$") @@ -253,6 +255,16 @@ IF(NOT XNNPACK_USE_SYSTEM_LIBS) WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/googlebenchmark-download") SET(GOOGLEBENCHMARK_SOURCE_DIR "${CMAKE_BINARY_DIR}/googlebenchmark-source" CACHE STRING "Google Benchmark source directory") ENDIF() + + IF(XNNPACK_ENABLE_KLEIDIAI AND NOT DEFINED KLEIDIAI_SOURCE_DIR) + MESSAGE(STATUS "Downloading KleidiAI to ${CMAKE_BINARY_DIR}/kleidiai-source (define KLEIDIAI_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadKleidiAI.cmake "${CMAKE_BINARY_DIR}/kleidiai-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/kleidiai-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/kleidiai-download") + SET(KLEIDIAI_SOURCE_DIR "${CMAKE_BINARY_DIR}/kleidiai-source" CACHE STRING "kleidiai source directory") + ENDIF() ENDIF() # ---[ XNNPACK library @@ -1136,6 +1148,34 @@ IF(XNNPACK_BUILD_LIBRARY) PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) ENDIF() +# ---[ Configure KleidiAI +IF(XNNPACK_ENABLE_KLEIDIAI) + IF(NOT TARGET kleidiai) + IF(NOT XNNPACK_USE_SYSTEM_LIBS) + SET(KLEIDIAI_BUILD_TESTS OFF CACHE BOOL "") + ADD_SUBDIRECTORY( + "${KLEIDIAI_SOURCE_DIR}" + "${CMAKE_BINARY_DIR}/kleidiai") + ELSE() + ADD_LIBRARY(kleidiai SHARED IMPORTED) + FIND_LIBRARY(KLEIDIAI_LIBRARY kleidiai PATHS "${KLEIDIAI_SOURCE_DIR}/lib") + IF(NOT KLEIDIAI_LIBRARY) + MESSAGE(FATAL_ERROR "Cannot find KleidiAI") + ENDIF() + TARGET_INCLUDE_DIRECTORIES(kleidiai INTERFACE "${KLEIDIAI_SOURCE_DIR}") + SET_PROPERTY(TARGET kleidiai PROPERTY IMPORTED_LOCATION "${KLEIDIAI_LIBRARY}") + SET_PROPERTY(TARGET kleidiai PROPERTY IMPORTED_IMPLIB "${KLEIDIAI_LIBRARY}") + ENDIF() + ENDIF() + IF(XNNPACK_BUILD_ALL_MICROKERNELS) + TARGET_LINK_LIBRARIES(microkernels-all PRIVATE kleidiai) + ENDIF() + TARGET_LINK_LIBRARIES(microkernels-prod PRIVATE kleidiai) + IF(XNNPACK_BUILD_LIBRARY) + TARGET_LINK_LIBRARIES(XNNPACK PRIVATE kleidiai) + ENDIF() +ENDIF() + # ---[ XNNPACK unit tests IF(XNNPACK_BUILD_TESTS) # ---[ Build google test @@ -1193,6 +1233,13 @@ IF(XNNPACK_BUILD_TESTS) TARGET_INCLUDE_DIRECTORIES(convolution-test-helpers PRIVATE include src) TARGET_LINK_LIBRARIES(convolution-test-helpers PRIVATE fp16) + ADD_LIBRARY(packq-microkernel-tester STATIC test/packq-microkernel-tester.cc) + TARGET_INCLUDE_DIRECTORIES(packq-microkernel-tester PRIVATE . include src test) + TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE XNNPACK fp16 pthreadpool GTest::gtest) + IF(XNNPACK_ENABLE_KLEIDIAI) + TARGET_LINK_LIBRARIES(packq-microkernel-tester PRIVATE kleidiai) + ENDIF() + IF(XNNPACK_BUILD_LIBRARY) # ---[ Build size tests ADD_EXECUTABLE(operator-size-test test/operator-size.c) @@ -3340,6 +3387,15 @@ IF(XNNPACK_BUILD_TESTS) TARGET_LINK_LIBRARIES(x32-packb-test PRIVATE hardware-config logging microkernels-all packing) ADD_TEST(NAME x32-packb-test COMMAND x32-packb-test) + ADD_EXECUTABLE(x8-packq-test test/x8-packq.cc) + TARGET_INCLUDE_DIRECTORIES(x8-packq-test PRIVATE include src test) + TARGET_LINK_LIBRARIES(x8-packq-test PRIVATE pthreadpool GTest::gtest GTest::gtest_main) + TARGET_LINK_LIBRARIES(x8-packq-test PRIVATE packq-microkernel-tester hardware-config logging microkernels-all packing) + IF(XNNPACK_ENABLE_KLEIDIAI) + TARGET_LINK_LIBRARIES(x8-packq-test PRIVATE kleidiai) + ENDIF() + ADD_TEST(NAME x8-packq-test COMMAND x8-packq-test) + ADD_EXECUTABLE(x8-packw-test test/x8-packw.cc) TARGET_INCLUDE_DIRECTORIES(x8-packw-test PRIVATE include src test) TARGET_LINK_LIBRARIES(x8-packw-test PRIVATE pthreadpool GTest::gtest GTest::gtest_main) @@ -3519,6 +3575,14 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(bench-utils PRIVATE logging memory) ENDIF() + # Helper libraries + ADD_LIBRARY(packq-benchmark STATIC bench/packq-benchmark.cc) + TARGET_INCLUDE_DIRECTORIES(packq-benchmark PRIVATE . include src bench) + TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE XNNPACK benchmark::benchmark bench-utils) + IF(XNNPACK_ENABLE_KLEIDIAI) + TARGET_LINK_LIBRARIES(packq-benchmark PRIVATE kleidiai) + ENDIF() + # ---[ Build accuracy microbenchmarks ADD_EXECUTABLE(f16-exp-ulp-eval eval/f16-exp-ulp.cc) TARGET_INCLUDE_DIRECTORIES(f16-exp-ulp-eval PRIVATE . src) @@ -4494,6 +4558,14 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(x24-transpose-bench PRIVATE benchmark::benchmark pthreadpool) TARGET_LINK_LIBRARIES(x24-transpose-bench PRIVATE bench-utils hardware-config logging microkernels-all microparams-init) + ADD_EXECUTABLE(x8-packq-bench bench/x8-packq.cc) + TARGET_INCLUDE_DIRECTORIES(x8-packq-bench PRIVATE . include src) + TARGET_LINK_LIBRARIES(x8-packq-bench PRIVATE benchmark::benchmark pthreadpool) + IF(XNNPACK_ENABLE_KLEIDIAI) + TARGET_LINK_LIBRARIES(x8-packq-bench PRIVATE kleidiai) + ENDIF() + TARGET_LINK_LIBRARIES(x8-packq-bench PRIVATE packq-benchmark bench-utils hardware-config logging microkernels-all packing) + ADD_EXECUTABLE(x8-packw-bench bench/x8-packw.cc) TARGET_INCLUDE_DIRECTORIES(x8-packw-bench PRIVATE . include src) TARGET_LINK_LIBRARIES(x8-packw-bench PRIVATE benchmark::benchmark pthreadpool) diff --git a/WORKSPACE b/WORKSPACE index 7a3e527b11f4..211198a6e481 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -33,6 +33,7 @@ http_archive( ], ) +# LINT.IfChange # Google Test framework, used by most unit-tests. http_archive( name = "com_google_googletest", @@ -40,7 +41,9 @@ http_archive( strip_prefix = "googletest-e23cdb78e9fef1f69a9ef917f447add5638daf2a", urls = ["https://github.com/google/googletest/archive/e23cdb78e9fef1f69a9ef917f447add5638daf2a.zip"], ) +# LINT.ThenChange(cmake/DownloadGoogleTest.cmake) +# LINT.IfChange # Google Benchmark library, used in micro-benchmarks. http_archive( name = "com_google_benchmark", @@ -48,7 +51,9 @@ http_archive( strip_prefix = "benchmark-d2a8a4ee41b923876c034afb939c4fc03598e622", urls = ["https://github.com/google/benchmark/archive/d2a8a4ee41b923876c034afb939c4fc03598e622.zip"], ) +# LINT.ThenChange(cmake/DownloadGoogleBenchmark.cmake) +# LINT.IfChange # FP16 library, used for half-precision conversions http_archive( name = "FP16", @@ -59,7 +64,9 @@ http_archive( "https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip", ], ) +# LINT.ThenChange(cmake/DownloadFP16.cmake) +# LINT.IfChange # FXdiv library, used for repeated integer division by the same factor http_archive( name = "FXdiv", @@ -67,6 +74,7 @@ http_archive( strip_prefix = "FXdiv-b408327ac2a15ec3e43352421954f5b1967701d1", urls = ["https://github.com/Maratyszcza/FXdiv/archive/b408327ac2a15ec3e43352421954f5b1967701d1.zip"], ) +# LINT.ThenChange(cmake/DownloadFXdiv.cmake) # LINT.IfChange # pthreadpool library, used for parallelization @@ -78,6 +86,7 @@ http_archive( ) # LINT.ThenChange(cmake/DownloadPThreadPool.cmake) +# LINT.IfChange # cpuinfo library, used for detecting processor characteristics http_archive( name = "cpuinfo", @@ -87,6 +96,19 @@ http_archive( "https://github.com/pytorch/cpuinfo/archive/d6860c477c99f1fce9e28eb206891af3c0e1a1d7.zip" ], ) +# LINT.ThenChange(cmake/DownloadCpuinfo.cmake) + +# LINT.IfChange +# KleidiAI library, used for ARM microkernels. +http_archive( + name = "KleidiAI", + sha256 = "39b26d8840ec719afaa480b0622a77952d0f22dbb8e8ba58ec9f93e39895a205", + strip_prefix = "kleidiai-1976f8661e8d5aa7d4cdca0f3d2a915e5ecb4c53", + urls = [ + "https://gitlab.arm.com/kleidi/kleidiai/-/archive/1976f8661e8d5aa7d4cdca0f3d2a915e5ecb4c53/kleidiai-1976f8661e8d5aa7d4cdca0f3d2a915e5ecb4c53.zip" + ], +) +# LINT.ThenChange(cmake/DownloadKleidiAI.cmake) # Ruy library, used to benchmark against http_archive( diff --git a/bench/BUILD.bazel b/bench/BUILD.bazel index af84acb4f481..f187b7de7d9f 100644 --- a/bench/BUILD.bazel +++ b/bench/BUILD.bazel @@ -20,18 +20,18 @@ load( MICROKERNEL_BENCHMARK_DEPS = [ ":bench_utils", + "@FP16", "//:aligned_allocator", "//:test_microkernels", "//:common", "//:enable_assembly", "//:jit", "//:microkernels_h", + "//:microparams_init", + "//:microparams", "//:packing", "//:params", - "//:microparams", - "//:microparams_init", "//:xnnpack_h", - "@FP16", ] OPERATOR_BENCHMARK_DEPS = [ @@ -961,6 +961,31 @@ xnnpack_benchmark( deps = MICROKERNEL_BENCHMARK_DEPS, ) +xnnpack_cc_library( + name = "packq_benchmark", + srcs = [ + "bgemm.h", + "packq-benchmark.cc", + ], + hdrs = ["packq-benchmark.h"], + deps = MICROKERNEL_BENCHMARK_DEPS + [ + "@com_google_benchmark//:benchmark", + ], +) + +xnnpack_benchmark( + name = "x8_packq_bench", + srcs = [ + "bgemm.h", + "x8-packq.cc", + ], + deps = MICROKERNEL_BENCHMARK_DEPS + [ + ":packq_benchmark", + "//:allocator", + "//:math", + ], +) + xnnpack_benchmark( name = "x8_packw_bench", srcs = [ diff --git a/bench/packq-benchmark.cc b/bench/packq-benchmark.cc new file mode 100644 index 000000000000..6e09544478f4 --- /dev/null +++ b/bench/packq-benchmark.cc @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "packq-benchmark.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "bench/utils.h" +#include + +void x8_packq(benchmark::State& state, xnn_x8_packq_f32qp8_ukernel_fn packq, + size_t mr, size_t kr, size_t sr, + benchmark::utils::IsaCheckFunction isa_check) { + if (isa_check != nullptr && !isa_check(state)) { + return; + } + + const size_t batch = state.range(0); + const size_t dim_m = state.range(2); + const size_t dim_k = state.range(3); + + const size_t rounded_n = benchmark::utils::RoundUp(dim_m, mr); + const size_t rounded_k = benchmark::utils::RoundUp(dim_k, kr); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto f32rng = [&]() { + return std::uniform_real_distribution(-10, 10)(rng); + }; + + // Compute a num_buffers that fit cache with source weights + packed_weights. + const size_t num_buffers = + 1 + benchmark::utils::DivideRoundUp( + benchmark::utils::GetMaxCacheSize(), + sizeof(int8_t) * batch * + (dim_m * dim_k + rounded_n * rounded_k + rounded_n)); + + std::vector> input(num_buffers * batch * + dim_m * dim_k); + std::generate(input.begin(), input.end(), f32rng); + const size_t packed_size = + xnn_x8_packq_f32qp8_packed_size(batch * dim_m, dim_k, mr, kr, sr); + std::vector> packed_weights(num_buffers * + packed_size); + + size_t buffer_index = 0; + for (auto _ : state) { + if (++buffer_index == num_buffers) { + buffer_index = 0; + } + + packq(batch * dim_m, dim_k, mr, kr, sr, + /*m_idx_start=*/buffer_index * dim_m, + input.data() + buffer_index * batch * dim_m * dim_k, + dim_k * sizeof(float), packed_weights.data()); + } + + const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); + if (cpu_frequency != 0) { + state.counters["cpufreq"] = cpu_frequency; + } + + const size_t elements_per_iteration = batch * dim_m * dim_k; + state.counters["elements"] = benchmark::Counter( + static_cast(state.iterations()) * elements_per_iteration, + benchmark::Counter::kIsRate); + + const size_t bytes_per_iteration = + (elements_per_iteration + batch * (rounded_n * rounded_k + rounded_n)) * + sizeof(int8_t); + state.counters["bytes"] = benchmark::Counter( + static_cast(state.iterations()) * bytes_per_iteration, + benchmark::Counter::kIsRate); +} diff --git a/bench/packq-benchmark.h b/bench/packq-benchmark.h new file mode 100644 index 000000000000..62e315797e0d --- /dev/null +++ b/bench/packq-benchmark.h @@ -0,0 +1,24 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef __XNNPACK_BENCH_PACKQ_BENCHMARK_H +#define __XNNPACK_BENCH_PACKQ_BENCHMARK_H + +#include +#include +#include +#include +#include + +#include + +#include "bench/utils.h" +#include + +void x8_packq(benchmark::State& state, xnn_x8_packq_f32qp8_ukernel_fn packq, + size_t mr, size_t kr, size_t sr, + benchmark::utils::IsaCheckFunction isa_check = nullptr); + +#endif // __XNNPACK_TEST_PACKQ_MICROKERNEL_TESTER_H diff --git a/bench/x8-packq.cc b/bench/x8-packq.cc new file mode 100644 index 000000000000..48ca740e6c31 --- /dev/null +++ b/bench/x8-packq.cc @@ -0,0 +1,165 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: test/x8-packq.yaml +// Generator: tools/generate-packq-test.py + + +#include +#include + +#include +#include "bench/bgemm.h" +#include "bench/packq-benchmark.h" + + +#if XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + static void x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_1_kr_1( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2, + /*mr=*/1, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEON); + } + BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_1_kr_1) + static void x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_1_kr_2( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2, + /*mr=*/1, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckNEON); + } + BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_1_kr_2) + static void x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_1_kr_4( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2, + /*mr=*/1, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEON); + } + BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_1_kr_4) + static void x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_2_kr_1( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2, + /*mr=*/2, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEON); + } + BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_2_kr_1) + static void x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_2_kr_2( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2, + /*mr=*/2, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckNEON); + } + BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_2_kr_2) + static void x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_2_kr_4( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2, + /*mr=*/2, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEON); + } + BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_2_kr_4) + static void x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_4_kr_1( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2, + /*mr=*/4, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEON); + } + BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_4_kr_1) + static void x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_4_kr_2( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2, + /*mr=*/4, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckNEON); + } + BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_4_kr_2) + static void x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_4_kr_4( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2, + /*mr=*/4, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEON); + } + BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__aarch64_neon_u2_mr_4_kr_4) + + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ARCH_ARM64 + + +static void x8_packq_f32qp8_ukernel__scalar_u1_mr_1_kr_1( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__scalar_u1, + /*mr=*/1, /*kr=*/1, /*sr=*/1); +} +BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__scalar_u1_mr_1_kr_1) +static void x8_packq_f32qp8_ukernel__scalar_u1_mr_1_kr_2( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__scalar_u1, + /*mr=*/1, /*kr=*/2, /*sr=*/1); +} +BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__scalar_u1_mr_1_kr_2) +static void x8_packq_f32qp8_ukernel__scalar_u1_mr_1_kr_4( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__scalar_u1, + /*mr=*/1, /*kr=*/4, /*sr=*/1); +} +BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__scalar_u1_mr_1_kr_4) +static void x8_packq_f32qp8_ukernel__scalar_u1_mr_2_kr_1( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__scalar_u1, + /*mr=*/2, /*kr=*/1, /*sr=*/1); +} +BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__scalar_u1_mr_2_kr_1) +static void x8_packq_f32qp8_ukernel__scalar_u1_mr_2_kr_2( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__scalar_u1, + /*mr=*/2, /*kr=*/2, /*sr=*/1); +} +BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__scalar_u1_mr_2_kr_2) +static void x8_packq_f32qp8_ukernel__scalar_u1_mr_2_kr_4( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__scalar_u1, + /*mr=*/2, /*kr=*/4, /*sr=*/1); +} +BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__scalar_u1_mr_2_kr_4) +static void x8_packq_f32qp8_ukernel__scalar_u1_mr_4_kr_1( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__scalar_u1, + /*mr=*/4, /*kr=*/1, /*sr=*/1); +} +BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__scalar_u1_mr_4_kr_1) +static void x8_packq_f32qp8_ukernel__scalar_u1_mr_4_kr_2( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__scalar_u1, + /*mr=*/4, /*kr=*/2, /*sr=*/1); +} +BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__scalar_u1_mr_4_kr_2) +static void x8_packq_f32qp8_ukernel__scalar_u1_mr_4_kr_4( + benchmark::State& state, const char* net) { + x8_packq(state, + xnn_x8_packq_f32qp8_ukernel__scalar_u1, + /*mr=*/4, /*kr=*/4, /*sr=*/1); +} +BENCHMARK_BGEMM(x8_packq_f32qp8_ukernel__scalar_u1_mr_4_kr_4) + + +#ifndef XNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/build_defs.bzl b/build_defs.bzl index cd315313e847..381bed47341c 100644 --- a/build_defs.bzl +++ b/build_defs.bzl @@ -71,6 +71,18 @@ def xnnpack_slinky_deps(): def xnnpack_slinky_defines(): return [] +def xnnpack_if_kleidiai_enabled(enabled = [], not_enabled = []): + return select({ + ":kleidiai_enabled": enabled, + "//conditions:default": not_enabled, + }) + +def xnnpack_kleidiai_defines(): + return xnnpack_if_kleidiai_enabled( + enabled = ["XNN_ENABLE_KLEIDIAI=1"], + not_enabled = ["XNN_ENABLE_KLEIDIAI=0"], + ) + def xnnpack_cc_library( name, srcs = [], diff --git a/cmake/DownloadKleidiAI.cmake b/cmake/DownloadKleidiAI.cmake new file mode 100644 index 000000000000..14d4a85741b6 --- /dev/null +++ b/cmake/DownloadKleidiAI.cmake @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# Copyright 2019 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 3.5 FATAL_ERROR) + +PROJECT(kleidiai-download NONE) + +# Set file timestamps to the time of extraction. +IF(POLICY CMP0135) + CMAKE_POLICY(SET CMP0135 NEW) +ENDIF() + +INCLUDE(ExternalProject) +ExternalProject_Add(kleidiai + URL https://gitlab.arm.com/kleidi/kleidiai/-/archive/8c6cf04366a3602ca974c4b262e29ae24b699556/kleidiai-8c6cf04366a3602ca974c4b262e29ae24b699556.zip + URL_HASH SHA256=e6f2b475378173e5f5a41147d8255bdeee0d264f80ec6d23409bd0ea8dc88cd1 + SOURCE_DIR "${CMAKE_BINARY_DIR}/kleidiai-source" + BINARY_DIR "${CMAKE_BINARY_DIR}/kleidiai" + CONFIGURE_COMMAND "" + PATCH_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/cmake/gen/neon_aarch64_microkernels.cmake b/cmake/gen/neon_aarch64_microkernels.cmake index f7e727b86d7a..dcc2d77d3795 100644 --- a/cmake/gen/neon_aarch64_microkernels.cmake +++ b/cmake/gen/neon_aarch64_microkernels.cmake @@ -23,5 +23,6 @@ SET(ALL_NEON_AARCH64_MICROKERNEL_SRCS src/x8-lut/gen/x8-lut-aarch64-neon-tbx128x4-u32.c src/x8-lut/gen/x8-lut-aarch64-neon-tbx128x4-u48.c src/x8-lut/gen/x8-lut-aarch64-neon-tbx128x4-u64.c + src/x8-packq/x8-packq-aarch64-neon-f32qp8-u2.c src/x24-transposec/x24-transposec-4x4-aarch64-neon-tbl128.c src/x32-transposec/x32-transposec-4x4-aarch64-neon-tbl128.c) diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index eabc359b835f..3715aa5f1d30 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -1061,6 +1061,7 @@ SET(ALL_SCALAR_MICROKERNEL_SRCS src/x8-lut/gen/x8-lut-scalar-u4.c src/x8-lut/gen/x8-lut-scalar-u8.c src/x8-lut/gen/x8-lut-scalar-u16.c + src/x8-packq/x8-packq-scalar-f32qp8-u1.c src/x8-packw/gen/x8-packw-x2-gemm-goi-scalar-int-u2.c src/x8-packw/gen/x8-packw-x2-gemm-goi-scalar-int-u4.c src/x8-packw/gen/x8-packw-x4-gemm-goi-scalar-int-u2.c diff --git a/gen/neon_aarch64_microkernels.bzl b/gen/neon_aarch64_microkernels.bzl index 66574ff966b3..bb450add8c50 100644 --- a/gen/neon_aarch64_microkernels.bzl +++ b/gen/neon_aarch64_microkernels.bzl @@ -19,6 +19,7 @@ ALL_NEON_AARCH64_MICROKERNEL_SRCS = [ "src/x8-lut/gen/x8-lut-aarch64-neon-tbx128x4-u32.c", "src/x8-lut/gen/x8-lut-aarch64-neon-tbx128x4-u48.c", "src/x8-lut/gen/x8-lut-aarch64-neon-tbx128x4-u64.c", + "src/x8-packq/x8-packq-aarch64-neon-f32qp8-u2.c", "src/x24-transposec/x24-transposec-4x4-aarch64-neon-tbl128.c", "src/x32-transposec/x32-transposec-4x4-aarch64-neon-tbl128.c", ] diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index 5e789d5dc13b..cee0e4374847 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -1057,6 +1057,7 @@ ALL_SCALAR_MICROKERNEL_SRCS = [ "src/x8-lut/gen/x8-lut-scalar-u4.c", "src/x8-lut/gen/x8-lut-scalar-u8.c", "src/x8-lut/gen/x8-lut-scalar-u16.c", + "src/x8-packq/x8-packq-scalar-f32qp8-u1.c", "src/x8-packw/gen/x8-packw-x2-gemm-goi-scalar-int-u2.c", "src/x8-packw/gen/x8-packw-x2-gemm-goi-scalar-int-u4.c", "src/x8-packw/gen/x8-packw-x4-gemm-goi-scalar-int-u2.c", diff --git a/include/xnnpack.h b/include/xnnpack.h index 9ae998b9d605..aac5b9252dc3 100644 --- a/include/xnnpack.h +++ b/include/xnnpack.h @@ -236,20 +236,30 @@ enum xnn_datatype { xnn_datatype_fp32 = 1, /// IEEE754 half-precision floating-point. xnn_datatype_fp16 = 2, - /// Quantized 8-bit signed integer with shared per-Value quantization parameters. + /// Quantized 8-bit signed integer with shared per-Value quantization + /// parameters. xnn_datatype_qint8 = 3, - /// Quantized 8-bit unsigned integer with shared per-Value quantization parameters. + /// Quantized 8-bit unsigned integer with shared per-Value quantization + /// parameters. xnn_datatype_quint8 = 4, - /// Quantized 32-bit signed integer with shared per-Value quantization parameters. + /// Quantized 32-bit signed integer with shared per-Value quantization + /// parameters. xnn_datatype_qint32 = 5, - /// Quantized 8-bit signed integer with shared per-channel quantization parameters. + /// Quantized 8-bit signed integer with shared per-channel quantization + /// parameters. xnn_datatype_qcint8 = 6, - /// Quantized 32-bit signed integer with shared per-channel quantization parameters. + /// Quantized 32-bit signed integer with shared per-channel quantization + /// parameters. xnn_datatype_qcint32 = 7, - /// Quantized 4-bit signed integer with shared per-channel quantization parameters. + /// Quantized 4-bit signed integer with shared per-channel quantization + /// parameters. xnn_datatype_qcint4 = 8, - /// Dynamically quantized 8-bit signed integer with per-batch quantization parameters. + /// Dynamically quantized 8-bit signed integer with per-batch quantization + /// parameters. xnn_datatype_qdint8 = 9, + /// Dynamically quantized 8-bit signed integers packed with their per-row + /// quantization parameters. + xnn_datatype_qpint8 = 10, }; /// Define a tensor-type Value and add it to a Subgraph. @@ -2846,6 +2856,17 @@ enum xnn_status xnn_setup_convert_nc_f32_qd8( int8_t* output, struct xnn_dynamic_quantization_params* quantization_params); +enum xnn_status xnn_create_convert_nc_f32_qp8(uint32_t flags, + xnn_operator_t* convert_op_out); + +enum xnn_status xnn_reshape_convert_nc_f32_qp8( + xnn_operator_t convert_op, size_t batch_size, size_t channels, + size_t input_stride, pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convert_nc_f32_qp8(xnn_operator_t convert_op, + const float* input, + int8_t* output); + enum xnn_status xnn_create_convert_nc_f32_f16( uint32_t flags, xnn_operator_t* convert_op_out); diff --git a/scripts/generate-tests.sh b/scripts/generate-tests.sh index fb3f346a7502..395eacd24849 100755 --- a/scripts/generate-tests.sh +++ b/scripts/generate-tests.sh @@ -7,6 +7,9 @@ ### Tests for packing micro-kernels tools/generate-pack-test.py --spec test/x32-packx.yaml --output test/x32-packx.cc & +### Tests for Pack quantized micro-kernels +tools/generate-packq-test.py --spec test/x8-packq.yaml --output test/x8-packq.cc --output-bench bench/x8-packq.cc & + ### Tests for Pack Weights micro-kernels tools/generate-packw-test.py --spec test/x8-packw.yaml --output test/x8-packw.cc --output-bench bench/x8-packw.cc & tools/generate-packw-test.py --spec test/x16-packw.yaml --output test/x16-packw.cc --output-bench bench/x16-packw.cc & diff --git a/src/amalgam/gen/neon-aarch64.c b/src/amalgam/gen/neon-aarch64.c index 62bda313196d..bfe8f8acad7c 100644 --- a/src/amalgam/gen/neon-aarch64.c +++ b/src/amalgam/gen/neon-aarch64.c @@ -4,7 +4,13 @@ // LICENSE file in the root directory of this source tree. #include +#include #include +#include +#include +#if XNN_ENABLE_KLEIDIAI +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#endif // XNN_ENABLE_KLEIDIAI #include @@ -13,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -581,3 +588,17 @@ void xnn_x8_lut_ukernel__aarch64_neon_tbx128x4_u64( } } } + +void xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2(size_t m, size_t k, size_t mr, + size_t kr, size_t sr, + size_t m_idx_start, + const float* XNN_RESTRICT lhs, + size_t lhs_stride, + void* XNN_RESTRICT lhs_packed) { +#if XNN_ENABLE_KLEIDIAI + kai_run_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr, m_idx_start, lhs, + lhs_stride, lhs_packed); +#else + assert("Not compiled with XNN_ENABLE_KLEIDIAI" && 0); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/amalgam/gen/scalar.c b/src/amalgam/gen/scalar.c index b753ff46ea99..d4ef0c5bbac8 100644 --- a/src/amalgam/gen/scalar.c +++ b/src/amalgam/gen/scalar.c @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -26,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -28223,6 +28225,140 @@ void xnn_x8_lut_ukernel__scalar_u4( } } +inline static size_t k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = round_up(kr * sr, 4); + return round_up(k, kr_sr_roundedup4); +} + +inline static size_t lhs_packed_stride(size_t k, size_t mr, size_t kr, + size_t sr) { + const size_t k_internal = k_roundedup(k, kr, sr); + + assert((k_internal % 2) == 0); + + // Assuming the same sizeof() for kai_num_bytes_per_offset and + // kai_num_bytes_per_multiplier + static const size_t num_bytes_per_multiplier = sizeof(float); + static const size_t num_bytes_per_offset = sizeof(int32_t); + + return mr * (k_internal * sizeof(int8_t) + num_bytes_per_multiplier + + num_bytes_per_offset); +} + +void xnn_x8_packq_f32qp8_ukernel__scalar_u1(size_t m, size_t k, size_t mr, + size_t kr, size_t sr, + size_t m_idx_start, + const float* XNN_RESTRICT lhs, + size_t lhs_stride, + void* XNN_RESTRICT lhs_packed) { + assert((kr % sr) == 0); + + // Assuming the same sizeof() for kai_num_bytes_per_offset and + // kai_num_bytes_per_multiplier + static const size_t num_bytes_per_multiplier = sizeof(float); + static const size_t num_bytes_per_offset = sizeof(int32_t); + assert(num_bytes_per_offset == num_bytes_per_multiplier); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + + const float* src_ptr = lhs; + + const size_t dst_stride = lhs_packed_stride(k, mr, kr, sr); + const size_t k_internal = k_roundedup(k, kr, sr); + const int32_t k_block_len = (int32_t)(kr / sr); + + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + float max0 = 0.0f; + float min0 = 0.0f; + + // Find min/max for each channel + int32_t k_idx = 0; + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0_0 = *(src_ptr + (size_t)k_idx); + max0 = math_max_f32(src0_0, max0); + min0 = math_min_f32(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float scale0 = min0 == max0 ? 1.F : (qmax - qmin) / (max0 - min0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + + const float descaled_min0 = min0 * scale0; + const float descaled_max0 = max0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 + ? qmin - descaled_min0 + : qmax - descaled_max0; + + zero_point0 = math_max_f32(zero_point0, qmin); + zero_point0 = math_min_f32(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + + uint8_t* dst_ptr = + (uint8_t*)lhs_packed + dst_x * k_block_len * sizeof(int8_t); + + // Quantize the channels + k_idx = 0; + for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { + for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; + ++k_block_idx) { + // Clamp at the last valid k-index + const size_t k_idx_start = min((size_t)k_idx + k_block_idx, k); + + const float src0_0 = *(src_ptr + k_idx_start); + + // Scale the values + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = math_max_s32(v0_s32, INT8_MIN); + v0_s32 = math_min_s32(v0_s32, INT8_MAX); + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + + dst_ptr += dst_x * num_bytes_per_offset; + + // LHS offset at the beginning of the row + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + + dst_ptr += mr * num_bytes_per_multiplier; + + // Store the scale quantization params + *((float*)(dst_ptr)) = recip_scale0; + + src_ptr += (lhs_stride / sizeof(float)); + + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } + } +} + void xnn_x8_packw_gemm_goi_ukernel_x16__scalar_int_u2( size_t g, size_t nc, diff --git a/src/configs/unary-elementwise-config.c b/src/configs/unary-elementwise-config.c index 680ece569299..c4b5b1cf4829 100644 --- a/src/configs/unary-elementwise-config.c +++ b/src/configs/unary-elementwise-config.c @@ -16,11 +16,11 @@ #include #include #include +#include #include #include #include - static struct xnn_unary_elementwise_config f16_abs_config = {0}; static struct xnn_unary_elementwise_config f16_clamp_config = {0}; static struct xnn_unary_elementwise_config f16_elu_config = {0}; @@ -56,6 +56,7 @@ static struct xnn_unary_elementwise_config f32_sqrt_config = {0}; static struct xnn_unary_elementwise_config f32_tanh_config = {0}; static struct xnn_unary_elementwise_config f32_to_f16_cvt_config = {0}; static struct xnn_unary_elementwise_config f32_to_qs8_cvt_config = {0}; +static struct xnn_unary_elementwise_config f32_to_qp8_cvt_config = {0}; static struct xnn_unary_elementwise_config f32_to_qu8_cvt_config = {0}; static struct xnn_unary_elementwise_config qs8_cvt_config = {0}; static struct xnn_unary_elementwise_config qs8_lrelu_config = {0}; @@ -105,6 +106,7 @@ static struct xnn_unary_elementwise_config xx_copy_config = {0}; static INIT_ONCE init_guard_f32_sqrt = INIT_ONCE_STATIC_INIT; static INIT_ONCE init_guard_f32_tanh = INIT_ONCE_STATIC_INIT; static INIT_ONCE init_guard_f32_to_f16_cvt = INIT_ONCE_STATIC_INIT; + static INIT_ONCE init_guard_f32_to_qp8_cvt = INIT_ONCE_STATIC_INIT; static INIT_ONCE init_guard_f32_to_qs8_cvt = INIT_ONCE_STATIC_INIT; static INIT_ONCE init_guard_f32_to_qu8_cvt = INIT_ONCE_STATIC_INIT; static INIT_ONCE init_guard_qs8_cvt = INIT_ONCE_STATIC_INIT; @@ -154,6 +156,7 @@ static struct xnn_unary_elementwise_config xx_copy_config = {0}; static pthread_once_t init_guard_f32_tanh = PTHREAD_ONCE_INIT; static pthread_once_t init_guard_f32_to_f16_cvt = PTHREAD_ONCE_INIT; static pthread_once_t init_guard_f32_to_qs8_cvt = PTHREAD_ONCE_INIT; + static pthread_once_t init_guard_f32_to_qp8_cvt = PTHREAD_ONCE_INIT; static pthread_once_t init_guard_f32_to_qu8_cvt = PTHREAD_ONCE_INIT; static pthread_once_t init_guard_qs8_cvt = PTHREAD_ONCE_INIT; static pthread_once_t init_guard_qs16_to_qs8_cvt = PTHREAD_ONCE_INIT; @@ -1573,6 +1576,16 @@ static void init_f32_to_f16_cvt_config(void) { #endif } +static void init_f32_to_qp8_cvt_config(void) { +#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI + f32_to_qp8_cvt_config.ukernel = + (xnn_vunary_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2; +#else + f32_to_qp8_cvt_config.ukernel = + (xnn_vunary_ukernel_fn)xnn_x8_packq_f32qp8_ukernel__scalar_u1; +#endif +} + static void init_f32_to_qs8_cvt_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -2536,6 +2549,11 @@ static void init_xx_copy_config(void) { return TRUE; } + static BOOL CALLBACK init_f32_to_qp8_cvt_config_windows(PINIT_ONCE init_once, PVOID parameter, PVOID* context) { + init_f32_to_qp8_cvt_config(); + return TRUE; + } + static BOOL CALLBACK init_f32_to_qs8_cvt_config_windows(PINIT_ONCE init_once, PVOID parameter, PVOID* context) { init_f32_to_qs8_cvt_config(); return TRUE; @@ -3027,6 +3045,21 @@ const struct xnn_unary_elementwise_config* xnn_init_f32_to_f16_cvt_config() { return &f32_to_f16_cvt_config; } +const struct xnn_unary_elementwise_config* xnn_init_f32_to_qp8_cvt_config() { + const struct xnn_hardware_config* hardware_config = + xnn_init_hardware_config(); + if (hardware_config == NULL) { + return NULL; + } +#if XNN_PLATFORM_WINDOWS + InitOnceExecuteOnce(&init_guard_f32_to_qp8_cvt, + &init_f32_to_qp8_cvt_config_windows, NULL, NULL); +#else + pthread_once(&init_guard_f32_to_qp8_cvt, &init_f32_to_qp8_cvt_config); +#endif + return &f32_to_qp8_cvt_config; +} + const struct xnn_unary_elementwise_config* xnn_init_f32_to_qs8_cvt_config() { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); if (hardware_config == NULL) { diff --git a/src/enums/datatype-strings.c b/src/enums/datatype-strings.c index 348bed4a7705..384e1cdefc75 100644 --- a/src/enums/datatype-strings.c +++ b/src/enums/datatype-strings.c @@ -35,6 +35,8 @@ const char* xnn_datatype_to_string(enum xnn_datatype type) { return "QCINT32"; case xnn_datatype_qdint8: return "QDINT8"; + case xnn_datatype_qpint8: + return "QPINT8"; } XNN_UNREACHABLE; return NULL; diff --git a/src/enums/operator-type.c b/src/enums/operator-type.c index f6686f28afde..0ad6ac999926 100644 --- a/src/enums/operator-type.c +++ b/src/enums/operator-type.c @@ -12,15 +12,15 @@ #include -static const uint16_t offset[159] = { +static const uint16_t offset[160] = { 0, 8, 22, 36, 50, 64, 78, 92, 119, 147, 175, 203, 230, 257, 289, 321, 364, 382, 400, 425, 451, 467, 483, 498, 513, - 535, 558, 581, 604, 627, 650, 673, 696, 719, 737, 760, 783, 807, 825, 848, 872, 896, 920, 944, 979, 1014, 1038, 1062, - 1086, 1100, 1115, 1130, 1156, 1182, 1219, 1245, 1271, 1303, 1335, 1361, 1388, 1415, 1432, 1449, 1483, 1517, 1531, - 1545, 1559, 1575, 1591, 1617, 1643, 1675, 1707, 1744, 1781, 1818, 1855, 1881, 1913, 1939, 1973, 2007, 2041, 2075, - 2109, 2143, 2173, 2203, 2223, 2243, 2264, 2285, 2306, 2327, 2351, 2375, 2398, 2421, 2439, 2457, 2472, 2487, 2505, - 2523, 2542, 2561, 2580, 2599, 2616, 2633, 2649, 2665, 2698, 2731, 2759, 2787, 2815, 2843, 2870, 2897, 2914, 2931, - 2972, 3013, 3031, 3049, 3067, 3085, 3100, 3116, 3132, 3150, 3168, 3186, 3212, 3239, 3266, 3283, 3300, 3322, 3344, - 3373, 3402, 3421, 3440, 3459, 3478, 3493, 3508, 3523, 3538, 3557, 3577, 3597, 3617, 3638, 3659 + 535, 558, 581, 604, 627, 650, 673, 696, 719, 742, 760, 783, 806, 830, 848, 871, 895, 919, 943, 967, 1002, 1037, 1061, + 1085, 1109, 1123, 1138, 1153, 1179, 1205, 1242, 1268, 1294, 1326, 1358, 1384, 1411, 1438, 1455, 1472, 1506, 1540, + 1554, 1568, 1582, 1598, 1614, 1640, 1666, 1698, 1730, 1767, 1804, 1841, 1878, 1904, 1936, 1962, 1996, 2030, 2064, + 2098, 2132, 2166, 2196, 2226, 2246, 2266, 2287, 2308, 2329, 2350, 2374, 2398, 2421, 2444, 2462, 2480, 2495, 2510, + 2528, 2546, 2565, 2584, 2603, 2622, 2639, 2656, 2672, 2688, 2721, 2754, 2782, 2810, 2838, 2866, 2893, 2920, 2937, + 2954, 2995, 3036, 3054, 3072, 3090, 3108, 3123, 3139, 3155, 3173, 3191, 3209, 3235, 3262, 3289, 3306, 3323, 3345, + 3367, 3396, 3425, 3444, 3463, 3482, 3501, 3516, 3531, 3546, 3561, 3580, 3600, 3620, 3640, 3661, 3682 }; static const char data[] = @@ -55,6 +55,7 @@ static const char data[] = "Convert (NC, F16, QD8)\0" "Convert (NC, F32, F16)\0" "Convert (NC, F32, QD8)\0" + "Convert (NC, F32, QP8)\0" "Convert (NC, F32, QS8)\0" "Convert (NC, F32, QU8)\0" "Convert (NC, QS8)\0" diff --git a/src/enums/operator-type.yaml b/src/enums/operator-type.yaml index a64ca63e334f..75beb71f7d53 100644 --- a/src/enums/operator-type.yaml +++ b/src/enums/operator-type.yaml @@ -67,6 +67,8 @@ string: "Convert (NC, F32, F16)" - name: xnn_operator_type_convert_nc_f32_qd8 string: "Convert (NC, F32, QD8)" +- name: xnn_operator_type_convert_nc_f32_qp8 + string: "Convert (NC, F32, QP8)" - name: xnn_operator_type_convert_nc_f32_qs8 string: "Convert (NC, F32, QS8)" - name: xnn_operator_type_convert_nc_f32_qu8 diff --git a/src/operator-run.c b/src/operator-run.c index 681818b451a9..193bfd4c4b19 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "pthreadpool.h" @@ -2247,6 +2248,21 @@ void xnn_compute_f32_qd8_convert( context->convert_ukernel(n, input, output, ¶ms); } +void xnn_compute_f32_qp8_convert( + const struct f32_qp8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t m_idx_start) { + const float* lhs = (const float*)((const char*)context->lhs + + m_idx_start * context->lhs_stride); + int8_t* lhs_packed = + context->lhs_packed + + xnn_x8_packq_f32qp8_packed_offset(m_idx_start, context->k, context->mr, + context->kr, context->sr); + + context->packq_ukernel(/*m=*/1, context->k, context->mr, context->kr, + context->sr, m_idx_start, lhs, context->lhs_stride, + lhs_packed); +} + void xnn_compute_u8_softmax( const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index) diff --git a/src/operators/unary-elementwise-nc.c b/src/operators/unary-elementwise-nc.c index 3b613d0f3878..91b7aaec1a07 100644 --- a/src/operators/unary-elementwise-nc.c +++ b/src/operators/unary-elementwise-nc.c @@ -26,6 +26,19 @@ #include "pthreadpool.h" #include +static xnn_status_t check_op_type(xnn_operator_t op, + enum xnn_operator_type expected_type) { + if (op->type != expected_type) { + xnn_log_error( + "failed to setup operator: operator type mismatch (expected %s, got " + "%s)", + xnn_operator_type_to_string(expected_type), + xnn_operator_type_to_string(op->type)); + return xnn_status_invalid_parameter; + } + return xnn_status_success; +} + static void init_unary_elementwise_nc( uint32_t flags, const void* params, @@ -593,6 +606,27 @@ enum xnn_status xnn_create_convert_nc_f32_qd8( xnn_operator_type_convert_nc_f32_qd8, convert_op_out); } +enum xnn_status xnn_create_convert_nc_f32_qp8(uint32_t flags, + xnn_operator_t* convert_op_out) { + const struct xnn_reduce_config* f32_rminmax_config = + xnn_init_f32_rminmax_config(); + if (f32_rminmax_config == NULL) { + xnn_log_error( + "failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qp8)); + return xnn_status_unsupported_hardware; + } + + union xnn_f32_default_params params; + if (f32_rminmax_config->init.f32_default != NULL) { + f32_rminmax_config->init.f32_default(¶ms); + } + + return create_unary_elementwise_nc( + flags, xnn_init_f32_to_qp8_cvt_config(), f32_rminmax_config, ¶ms, + sizeof(params), xnn_operator_type_convert_nc_f32_qp8, convert_op_out); +} + enum xnn_status xnn_create_convert_nc_f32_qu8( float output_scale, uint8_t output_zero_point, @@ -1719,6 +1753,60 @@ enum xnn_status xnn_reshape_convert_nc_f32_qd8( return xnn_status_success; } +enum xnn_status xnn_reshape_convert_nc_f32_qp8(xnn_operator_t convert_op, + size_t batch_size, + size_t channels, + size_t input_stride, + pthreadpool_t threadpool) { + if (convert_op->type != xnn_operator_type_convert_nc_f32_qp8) { + xnn_log_error( + "failed to setup operator: operator type mismatch (expected %s, got " + "%s)", + xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qp8), + xnn_operator_type_to_string(convert_op->type)); + return xnn_status_invalid_parameter; + } + convert_op->state = xnn_run_state_invalid; + + if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { + xnn_log_error( + "failed to setup %s operator: XNNPACK is not initialized", + xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qp8)); + return xnn_status_uninitialized; + } + + if (batch_size == 0) { + convert_op->state = xnn_run_state_skip; + return xnn_status_success; + } + + convert_op->batch_size = batch_size; + + const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_nr2_config(); + + convert_op->context.f32_qp8_convert = (struct f32_qp8_convert_context){ + .m = batch_size, + .k = channels, + .mr = gemm_config->mr, + .kr = 1 << gemm_config->log2_kr, + .sr = 1 << gemm_config->log2_sr, + .lhs_stride = input_stride, + .packq_ukernel = (xnn_x8_packq_f32qp8_ukernel_fn) + convert_op->unary_elementwise_config->ukernel, + }; + + // TODO(b/340399245) - Ideally, this should parallelize along `batch` in + // groups of `mr`. + convert_op->compute[0].type = xnn_parallelization_type_1d; + convert_op->compute[0].task_1d = + (pthreadpool_task_1d_t)xnn_compute_f32_qp8_convert; + convert_op->compute[0].range[0] = batch_size; + + convert_op->state = xnn_run_state_needs_setup; + + return xnn_status_success; +} + enum xnn_status xnn_reshape_convert_nc_f32_qs8( xnn_operator_t convert_op, size_t batch_size, @@ -2541,6 +2629,38 @@ enum xnn_status xnn_setup_convert_nc_f32_qd8( return xnn_status_success; } +enum xnn_status xnn_setup_convert_nc_f32_qp8(xnn_operator_t convert_op, + const float* input, + int8_t* output) { + xnn_status_t status = + check_op_type(convert_op, xnn_operator_type_convert_nc_f32_qp8); + if (status != xnn_status_success) { + return status; + } + + switch (convert_op->state) { + case xnn_run_state_skip: + return xnn_status_success; + case xnn_run_state_invalid: + xnn_log_error( + "failed to setup %s operator: operator has not been reshaped yet", + xnn_operator_type_to_string(convert_op->type)); + return xnn_status_invalid_state; + case xnn_run_state_needs_setup: + // Operator has been reshaped, but not setup, continue with setup. + case xnn_run_state_ready: + // Operator has been reshaped, and we are setting up with different + // pointers. + break; + } + + convert_op->context.f32_qp8_convert.lhs = input; + convert_op->context.f32_qp8_convert.lhs_packed = output; + convert_op->state = xnn_run_state_ready; + + return xnn_status_success; +} + enum xnn_status xnn_setup_convert_nc_f32_qs8( xnn_operator_t convert_op, const float* input, diff --git a/src/tensor.c b/src/tensor.c index 755655fa3844..b087e85f6ca5 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -513,6 +513,7 @@ size_t xnn_tensor_get_size(const struct xnn_value* value) case xnn_datatype_qint8: case xnn_datatype_quint8: case xnn_datatype_qcint8: + case xnn_datatype_qpint8: size = 1; break; case xnn_datatype_qint32: @@ -528,6 +529,11 @@ size_t xnn_tensor_get_size(const struct xnn_value* value) // Adjustments for nibbles, assume that we can't have sizes are byte-aligned (rounded up). if (value->datatype == xnn_datatype_qcint4) { size = round_up_po2(size, 2) >> 1; + } else if (value->datatype == xnn_datatype_qpint8) { + // TODO(b/340399245): Compute the correct size depending on the shape and + // packing constraints/alignment. + xnn_log_fatal("Support for %s is not yet implemented.", + xnn_datatype_to_string(value->datatype)); } return size; diff --git a/src/x8-packq/x8-packq-aarch64-neon-f32qp8-u2.c b/src/x8-packq/x8-packq-aarch64-neon-f32qp8-u2.c new file mode 100644 index 000000000000..65e1cf46cf16 --- /dev/null +++ b/src/x8-packq/x8-packq-aarch64-neon-f32qp8-u2.c @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if XNN_ENABLE_KLEIDIAI +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#endif // XNN_ENABLE_KLEIDIAI + + +// This function just wraps KleidiAI's `kai_run_lhs_quant_pack_qai8dxp_f32`, but +// with a name that is recognized by our tooling. + +void xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2(size_t m, size_t k, size_t mr, + size_t kr, size_t sr, + size_t m_idx_start, + const float* XNN_RESTRICT lhs, + size_t lhs_stride, + void* XNN_RESTRICT lhs_packed) { +#if XNN_ENABLE_KLEIDIAI + kai_run_lhs_quant_pack_qai8dxp_f32(m, k, mr, kr, sr, m_idx_start, lhs, + lhs_stride, lhs_packed); +#else + assert("Not compiled with XNN_ENABLE_KLEIDIAI" && 0); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/x8-packq/x8-packq-scalar-f32qp8-u1.c b/src/x8-packq/x8-packq-scalar-f32qp8-u1.c new file mode 100644 index 000000000000..9601d87de10e --- /dev/null +++ b/src/x8-packq/x8-packq-scalar-f32qp8-u1.c @@ -0,0 +1,150 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include + +// These functions have been adapted from KleidiAI's +// `kai_run_lhs_quant_pack_qai8dxp_f32` as a reference scalar implementation. + +inline static size_t k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = round_up(kr * sr, 4); + return round_up(k, kr_sr_roundedup4); +} + +inline static size_t lhs_packed_stride(size_t k, size_t mr, size_t kr, + size_t sr) { + const size_t k_internal = k_roundedup(k, kr, sr); + + assert((k_internal % 2) == 0); + + // Assuming the same sizeof() for kai_num_bytes_per_offset and + // kai_num_bytes_per_multiplier + static const size_t num_bytes_per_multiplier = sizeof(float); + static const size_t num_bytes_per_offset = sizeof(int32_t); + + return mr * (k_internal * sizeof(int8_t) + num_bytes_per_multiplier + + num_bytes_per_offset); +} + +void xnn_x8_packq_f32qp8_ukernel__scalar_u1(size_t m, size_t k, size_t mr, + size_t kr, size_t sr, + size_t m_idx_start, + const float* XNN_RESTRICT lhs, + size_t lhs_stride, + void* XNN_RESTRICT lhs_packed) { + assert((kr % sr) == 0); + + // Assuming the same sizeof() for kai_num_bytes_per_offset and + // kai_num_bytes_per_multiplier + static const size_t num_bytes_per_multiplier = sizeof(float); + static const size_t num_bytes_per_offset = sizeof(int32_t); + assert(num_bytes_per_offset == num_bytes_per_multiplier); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + + const float* src_ptr = lhs; + + const size_t dst_stride = lhs_packed_stride(k, mr, kr, sr); + const size_t k_internal = k_roundedup(k, kr, sr); + const int32_t k_block_len = (int32_t)(kr / sr); + + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + float max0 = 0.0f; + float min0 = 0.0f; + + // Find min/max for each channel + int32_t k_idx = 0; + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0_0 = *(src_ptr + (size_t)k_idx); + max0 = math_max_f32(src0_0, max0); + min0 = math_min_f32(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float scale0 = min0 == max0 ? 1.F : (qmax - qmin) / (max0 - min0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + + const float descaled_min0 = min0 * scale0; + const float descaled_max0 = max0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 + ? qmin - descaled_min0 + : qmax - descaled_max0; + + zero_point0 = math_max_f32(zero_point0, qmin); + zero_point0 = math_min_f32(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + + uint8_t* dst_ptr = + (uint8_t*)lhs_packed + dst_x * k_block_len * sizeof(int8_t); + + // Quantize the channels + k_idx = 0; + for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { + for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; + ++k_block_idx) { + // Clamp at the last valid k-index + const size_t k_idx_start = min((size_t)k_idx + k_block_idx, k); + + const float src0_0 = *(src_ptr + k_idx_start); + + // Scale the values + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = math_max_s32(v0_s32, INT8_MIN); + v0_s32 = math_min_s32(v0_s32, INT8_MAX); + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + + dst_ptr += dst_x * num_bytes_per_offset; + + // LHS offset at the beginning of the row + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + + dst_ptr += mr * num_bytes_per_multiplier; + + // Store the scale quantization params + *((float*)(dst_ptr)) = recip_scale0; + + src_ptr += (lhs_stride / sizeof(float)); + + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } + } +} diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index 915ac1c14896..2c6e9d6dd0bc 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -1583,17 +1583,36 @@ struct f32_qd8_convert_context { size_t batch_index); #endif -struct u8_softmax_context { - size_t n; - const uint8_t* x; - size_t x_stride; - const uint32_t* t; - uint8_t* y; - size_t y_stride; - xnn_u8_rmax_ukernel_fn rmax_ukernel; - xnn_u8_lut32norm_ukernel_fn lut_norm_ukernel; +struct f32_qp8_convert_context { + size_t m; + size_t k; + size_t mr; + size_t kr; + size_t sr; + const float* XNN_RESTRICT lhs; + size_t lhs_stride; + int8_t* XNN_RESTRICT lhs_packed; + xnn_x8_packq_f32qp8_ukernel_fn packq_ukernel; }; +#ifndef __cplusplus + XNN_PRIVATE void xnn_compute_f32_qp8_convert( + const struct f32_qp8_convert_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t m_idx_start); +#endif + + struct u8_softmax_context { + size_t n; + const uint8_t* x; + size_t x_stride; + const uint32_t* t; + uint8_t* y; + size_t y_stride; + xnn_u8_rmax_ukernel_fn rmax_ukernel; + xnn_u8_lut32norm_ukernel_fn lut_norm_ukernel; + }; + #ifndef __cplusplus XNN_PRIVATE void xnn_compute_u8_softmax( const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], diff --git a/src/xnnpack/config.h b/src/xnnpack/config.h index 8ac4d0093ecf..34caa137dfa2 100644 --- a/src/xnnpack/config.h +++ b/src/xnnpack/config.h @@ -291,6 +291,8 @@ XNN_INTERNAL const struct xnn_unary_elementwise_config* xnn_init_f32_sqr_config( XNN_INTERNAL const struct xnn_unary_elementwise_config* xnn_init_f32_sqrt_config(); XNN_INTERNAL const struct xnn_unary_elementwise_config* xnn_init_f32_tanh_config(); XNN_INTERNAL const struct xnn_unary_elementwise_config* xnn_init_f32_to_f16_cvt_config(); +XNN_INTERNAL const struct xnn_unary_elementwise_config* +xnn_init_f32_to_qp8_cvt_config(); XNN_INTERNAL const struct xnn_unary_elementwise_config* xnn_init_f32_to_qs8_cvt_config(); XNN_INTERNAL const struct xnn_unary_elementwise_config* xnn_init_f32_to_qu8_cvt_config(); XNN_INTERNAL const struct xnn_unary_elementwise_config* xnn_init_qs8_cvt_config(); diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index 28087bdf8933..fc31c6375bd8 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -1378,6 +1378,22 @@ typedef void (*xnn_x32_zerob_gemm_ukernel_fn)( size_t channel_subtile_stride, const union xnn_x32_packb_params params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); +// PACKQ: PACK and Quantize (weights) the left-hand operator for GEMM matrix +// multiplication. + +typedef void (*xnn_x8_packq_f32qp8_ukernel_fn)( + size_t m, // Number of rows to pack. + size_t k, // Number of columns/channels per row. + size_t mr, // Number of rows to interleave in the same output row. + size_t kr, // Number of columns/channels loaded per step in the matmul + // microkernel. + size_t sr, // Number of `kr` splits. + size_t m_idx_start, // Starting index in `lhs_packed`. + const float* XNN_RESTRICT lhs, // Left-hand operator to pack. + size_t lhs_stride, // Stride in bytes between the rows of `lhs`. + void* XNN_RESTRICT lhs_packed // The quantized and packed output. +); + // PACKW: PACK W (weights) for GEMM matrix multiplication // Weights in GOI layout: Groups, Output channels, Input channels. diff --git a/src/xnnpack/operator-type.h b/src/xnnpack/operator-type.h index 53cdb87a44fb..d0b8eb168f6a 100644 --- a/src/xnnpack/operator-type.h +++ b/src/xnnpack/operator-type.h @@ -48,6 +48,7 @@ enum xnn_operator_type { xnn_operator_type_convert_nc_f16_qd8, xnn_operator_type_convert_nc_f32_f16, xnn_operator_type_convert_nc_f32_qd8, + xnn_operator_type_convert_nc_f32_qp8, xnn_operator_type_convert_nc_f32_qs8, xnn_operator_type_convert_nc_f32_qu8, xnn_operator_type_convert_nc_qs8, diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h index 26d205c18091..ed99d559d575 100644 --- a/src/xnnpack/operator.h +++ b/src/xnnpack/operator.h @@ -429,6 +429,7 @@ struct xnn_operator { struct u8_softmax_context u8_softmax; struct f16_qd8_convert_context f16_qd8_convert; struct f32_qd8_convert_context f32_qd8_convert; + struct f32_qp8_convert_context f32_qp8_convert; struct univector_contiguous_context univector_contiguous; struct univector_strided_context univector_strided; struct unpooling_context unpooling; diff --git a/src/xnnpack/packq.h b/src/xnnpack/packq.h new file mode 100644 index 000000000000..217ecc2fadd4 --- /dev/null +++ b/src/xnnpack/packq.h @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef __XNNPACK_SRC_XNNPACK_PACKQ_H +#define __XNNPACK_SRC_XNNPACK_PACKQ_H + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// These functions have been adapted from KleidiAI's +// `kai_run_lhs_quant_pack_qai8dxp_f32` as a reference scalar implementation. + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = round_up(kr * sr, 4); + return round_up(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, + size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + assert((k_internal % 2) == 0); + + static const size_t kai_num_bytes_per_multiplier = sizeof(float); + static const size_t kai_num_bytes_per_offset = sizeof(int32_t); + + return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + + kai_num_bytes_per_offset); +} + +XNN_INLINE size_t xnn_x8_packq_f32qp8_packed_offset(size_t m_idx, size_t k, + size_t mr, size_t kr, + size_t sr) { + // It always points to the beginning of the row + return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); +} + +XNN_INLINE size_t xnn_x8_packq_f32qp8_packed_size(size_t m, size_t k, size_t mr, + size_t kr, size_t sr) { + const size_t num_rows = round_up(m, mr) / mr; + + return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); +} + +#define DECLARE_X8_PACKQ_UKERNEL_FUNCTION(fn_name) \ + XNN_INTERNAL void fn_name(size_t m, size_t k, size_t mr, size_t kr, \ + size_t sr, size_t m_idx_start, \ + const float* XNN_RESTRICT lhs, size_t lhs_stride, \ + void* XNN_RESTRICT lhs_packed); + +DECLARE_X8_PACKQ_UKERNEL_FUNCTION(xnn_x8_packq_f32qp8_ukernel__scalar_u1) + +#if XNN_ENABLE_KLEIDIAI +DECLARE_X8_PACKQ_UKERNEL_FUNCTION(xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2) +#endif // XNN_ENABLE_KLEIDIAI + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // __XNNPACK_SRC_XNNPACK_PACKQ_H diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 0806ceb00c7d..ffc03844bbcd 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -137,6 +137,14 @@ xnnpack_cc_library( deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(), ) +xnnpack_cc_library( + name = "packq_microkernel_tester", + testonly = True, + srcs = ["packq-microkernel-tester.cc"], + hdrs = ["packq-microkernel-tester.h"], + deps = MICROKERNEL_TEST_DEPS + xnnpack_test_deps_for_library(), +) + ######################### Unit tests for micro-kernels ######################### [xnnpack_unit_test( @@ -1310,6 +1318,16 @@ xnnpack_unit_test( deps = MICROKERNEL_TEST_DEPS, ) +xnnpack_unit_test( + name = "x8_packq_test", + srcs = [ + "x8-packq.cc", + ], + deps = MICROKERNEL_TEST_DEPS + [ + ":packq_microkernel_tester", + ], +) + [xnnpack_unit_test( name = "%s_test" % kernel, srcs = [ @@ -1527,7 +1545,7 @@ xnnpack_unit_test( "convert-operator-tester.h", ], shard_count = 5, - deps = OPERATOR_TEST_DEPS, + deps = OPERATOR_TEST_DEPS + ["//:microkernels_h"], ) xnnpack_unit_test( @@ -1536,7 +1554,7 @@ xnnpack_unit_test( "convert-nc-eager.cc", "convert-operator-tester.h", ], - deps = OPERATOR_TEST_DEPS, + deps = OPERATOR_TEST_DEPS + ["//:microkernels_h"], ) xnnpack_unit_test( diff --git a/test/convert-nc.cc b/test/convert-nc.cc index 2a113ed76112..f45704b6fddc 100644 --- a/test/convert-nc.cc +++ b/test/convert-nc.cc @@ -280,40 +280,6 @@ TEST(CONVERT_NC_F32_QD8, small_batch_with_input_and_output_stride) { } } -TEST(CONVERT_NC_F32_QD8, output_min) { - for (int16_t qmin = std::numeric_limits::min(); - qmin < std::numeric_limits::max(); - qmin += 51) - { - for (size_t channels = 1; channels < 100; channels++) { - ConvertOperatorTester() - .batch_size(3) - .channels(channels) - .qmin(qmin) - .qmax(std::numeric_limits::max()) - .iterations(3) - .TestF32toQD8(); - } - } -} - -TEST(CONVERT_NC_F32_QD8, output_max) { - for (int16_t qmax = std::numeric_limits::min() + 1; - qmax <= std::numeric_limits::max(); - qmax += 51) - { - for (size_t channels = 1; channels < 100; channels++) { - ConvertOperatorTester() - .batch_size(3) - .channels(channels) - .qmin(std::numeric_limits::min()) - .qmax(qmax) - .iterations(3) - .TestF32toQD8(); - } - } -} - TEST(CONVERT_NC_F32_QS8, unit_batch) { for (size_t channels = 1; channels < 100; channels++) { ConvertOperatorTester() @@ -921,3 +887,40 @@ TEST(CONVERT_NC_QU8_F32, input_zero_point) { } } } + +TEST(CONVERT_NC_F32_QP8, unit_batch) { + for (size_t channels = 1; channels < 100; channels++) { + ConvertOperatorTester() + .batch_size(1) + .channels(channels) + .qmin(std::numeric_limits::min()) + .qmax(std::numeric_limits::max()) + .iterations(3) + .TestF32toQD8(); + } +} + +TEST(CONVERT_NC_F32_QP8, small_batch) { + for (size_t channels = 1; channels < 100; channels++) { + ConvertOperatorTester() + .batch_size(3) + .channels(channels) + .qmin(std::numeric_limits::min()) + .qmax(std::numeric_limits::max()) + .iterations(3) + .TestF32toQD8(); + } +} + +TEST(CONVERT_NC_F32_QP8, small_batch_with_input_stride) { + for (size_t channels = 10; channels < 11; channels += 15) { + ConvertOperatorTester() + .batch_size(3) + .channels(channels) + .input_stride(129) + .qmin(std::numeric_limits::min()) + .qmax(std::numeric_limits::max()) + .iterations(3) + .TestF32toQD8(); + } +} diff --git a/test/convert-operator-tester.h b/test/convert-operator-tester.h index 90d679f9f9c6..95af4a53addd 100644 --- a/test/convert-operator-tester.h +++ b/test/convert-operator-tester.h @@ -6,7 +6,9 @@ #pragma once #include +#include #include +#include #include #include @@ -356,6 +358,65 @@ class ConvertOperatorTester { } } + void TestF32toQP8() const { + xnnpack::ReplicableRandomDevice rng; + + // The parameters of the GEMM config are used as packing parameters. + const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_nr2_config(); + + std::vector input(XNN_EXTRA_BYTES / sizeof(float) + + (batch_size() - 1) * input_stride() + channels()); + std::vector output(xnn_x8_packq_f32qp8_packed_size( + batch_size(), channels(), gemm_config->mr, 1 << gemm_config->log2_kr, + 1 << gemm_config->log2_sr)); + std::uniform_real_distribution range_dist(-100000, 100000); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + const float first_val = range_dist(rng); + const float second_val = range_dist(rng); + std::uniform_real_distribution f32dist( + std::min(first_val, second_val), std::max(first_val, second_val)); + std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); + std::fill(output.begin(), output.end(), INT8_C(0xA5)); + + // Create, setup, run, and destroy Convert operator. + ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); + xnn_operator_t convert_op = nullptr; + + ASSERT_EQ(xnn_status_success, + xnn_create_convert_nc_f32_qp8(0, &convert_op)); + ASSERT_NE(nullptr, convert_op); + + // Smart pointer to automatically delete convert op. + std::unique_ptr + auto_convert_op(convert_op, xnn_delete_operator); + + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f32_qp8(convert_op, batch_size(), + channels(), input_stride(), + /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f32_qp8(convert_op, input.data(), + output.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + + // Verify results. + for (size_t i = 0; i < batch_size(); i++) { + // const float* input_ptr = &input[i * input_stride()]; + // const auto minmax = + // std::minmax_element(input_ptr, input_ptr + channels()); + // const float rmin = math_min_f32(0.0f, *minmax.first); + // const float rmax = math_max_f32(0.0f, *minmax.second); + // const float max_acceptable_error = + // 0.5001f * (rmax - rmin) / std::numeric_limits::max(); + + // TODO(b/340399245) - Find a way to extract individual quantized values + // from the packing? + ASSERT_TRUE(true); + } + } + } + void TestF32toQS8() const { ASSERT_GE(qmin(), std::numeric_limits::min()); ASSERT_LE(qmax(), std::numeric_limits::max()); diff --git a/test/packq-microkernel-tester.cc b/test/packq-microkernel-tester.cc new file mode 100644 index 000000000000..326227d3905a --- /dev/null +++ b/test/packq-microkernel-tester.cc @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "packq-microkernel-tester.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace xnnpack { + +void PackQMicrokernelTester::Test(xnn_x8_packq_f32qp8_ukernel_fn packq) const { + // Allocate the input and output data. + std::vector input(m() * k() + XNN_EXTRA_BYTES / sizeof(float)); + const size_t packed_size = + xnn_x8_packq_f32qp8_packed_size(m(), k(), mr(), kr(), sr()); + std::vector> packed_w(packed_size); + std::vector> packed_w_ref(packed_size); + + // Populate the input and output data. + std::iota(input.begin(), input.end(), 0); + std::fill(packed_w.begin(), packed_w.end(), INT8_C(0x12)); + std::fill(packed_w_ref.begin(), packed_w_ref.end(), INT8_C(0x7B)); + + // Compute reference results. + xnn_x8_packq_f32qp8_ukernel__scalar_u1( + m(), k(), mr(), kr(), sr(), /*m_idx_start=*/0, input.data(), + /*lhs_stride=*/k() * sizeof(float), packed_w_ref.data()); + + // Call optimized micro-kernel. + packq(m(), k(), mr(), kr(), sr(), /*m_idx_start=*/0, input.data(), + /*lhs_stride=*/k() * sizeof(float), packed_w.data()); + + // Verify results. + for (size_t i = 0; i < packed_size; i++) { + if (packed_w_ref[i] != INT8_C(0x7B)) { // Allow pad to differ + ASSERT_EQ((int32_t)packed_w[i], (int32_t)packed_w_ref[i]) + << "at n " << i << " of " << packed_size << ", m=" << m() + << ", k=" << k() << ", mr=" << mr() << ", kr=" << kr() + << ", sr=" << sr(); + } + } +} + +}; // namespace xnnpack diff --git a/test/packq-microkernel-tester.h b/test/packq-microkernel-tester.h new file mode 100644 index 000000000000..22ab35cdfd44 --- /dev/null +++ b/test/packq-microkernel-tester.h @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef __XNNPACK_TEST_PACKQ_MICROKERNEL_TESTER_H +#define __XNNPACK_TEST_PACKQ_MICROKERNEL_TESTER_H + +#include +#include +#include +#include + +#include +#include +#include + +namespace xnnpack { + +class PackQMicrokernelTester { + public: + PackQMicrokernelTester&m(size_t m) { + this->m_ = m; + return *this; + } + + size_t m() const { return this->m_; } + + PackQMicrokernelTester& k(size_t k) { + assert(k != 0); + this->k_ = k; + return *this; + } + + size_t k() const { return this->k_; } + + PackQMicrokernelTester& mr(size_t mr) { + this->mr_ = mr; + return *this; + } + + size_t mr() const { return this->mr_; } + + PackQMicrokernelTester& kr(size_t kr) { + this->kr_ = kr; + return *this; + } + + size_t kr() const { return this->kr_; } + + PackQMicrokernelTester& sr(size_t sr) { + this->sr_ = sr; + return *this; + } + + size_t sr() const { return this->sr_; } + + size_t packed_k() const { return round_up_po2(k(), kr() * sr()); } + + size_t packed_m() const { return round_up(m(), mr()); } + + PackQMicrokernelTester& nullbias(bool nullbias) { + this->nullbias_ = nullbias; + return *this; + } + + bool nullbias() const { return this->nullbias_; } + + void Test(xnn_x8_packq_f32qp8_ukernel_fn packq) const; + + private: + size_t m_{1}; + size_t k_{1}; + size_t mr_{1}; + size_t kr_{1}; + size_t sr_{1}; + bool nullbias_{false}; +}; + +}; // namespace xnnpack + +#endif // __XNNPACK_TEST_PACKQ_MICROKERNEL_TESTER_H diff --git a/test/qd8-f32-qc4w-gemm-minmax-2.cc b/test/qd8-f32-qc4w-gemm-minmax-2.cc index 5d8e7a721dbb..e026d0d7862f 100644 --- a/test/qd8-f32-qc4w-gemm-minmax-2.cc +++ b/test/qd8-f32-qc4w-gemm-minmax-2.cc @@ -1311,6 +1311,7 @@ INSTANTIATE_TEST_SUITE_P( return info.param.test_name; }); + #if XNN_ARCH_WASM || XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD INSTANTIATE_TEST_SUITE_P( QD8_F32_QC4W_GEMM_MINMAX_2X4__WASM, GemmTest, diff --git a/test/x8-packq.cc b/test/x8-packq.cc new file mode 100644 index 000000000000..848a3696cf5a --- /dev/null +++ b/test/x8-packq.cc @@ -0,0 +1,142 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: test/x8-packq.yaml +// Generator: tools/generate-packq-test.py + + +#include +#include +#include + +#include + +#include "packq-microkernel-tester.h" +#include + + +namespace xnnpack { + +TEST(X8_PACKQ_F32QP8__SCALAR_U1, k_div_kr_m_div_mr) { + for (size_t kr = 1; kr <= 4; kr++) { + for (size_t mr = 1; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(mr * 10) + .k(kr * 10) + .mr(mr) + .kr(kr) + .Test(xnn_x8_packq_f32qp8_ukernel__scalar_u1); + } + } +} + +TEST(X8_PACKQ_F32QP8__SCALAR_U1, k_div_kr_m_div_mr_kr_div_sr) { + for (size_t sr = 1; sr <= 4; sr++) { + for (size_t kr = sr; kr <= 4 * sr; kr += sr) { + for (size_t mr = 1; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(mr * 10) + .k(kr * 10) + .mr(mr) + .kr(kr) + .sr(sr) + .Test(xnn_x8_packq_f32qp8_ukernel__scalar_u1); + } + } + } +} + +TEST(X8_PACKQ_F32QP8__SCALAR_U1, k_div_kr_m_lt_mr) { + for (size_t kr = 1; kr <= 4; kr++) { + for (size_t mr = 2; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(mr - 1) + .k(kr * 10) + .mr(mr) + .kr(kr) + .Test(xnn_x8_packq_f32qp8_ukernel__scalar_u1); + } + } +} + +TEST(X8_PACKQ_F32QP8__SCALAR_U1, k_div_kr_m_gt_mr) { + for (size_t kr = 1; kr <= 4; kr++) { + for (size_t mr = 2; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(2 * mr + 1) + .k(kr * 10) + .mr(mr) + .kr(kr) + .Test(xnn_x8_packq_f32qp8_ukernel__scalar_u1); + } + } +} + +#if XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + TEST(X8_PACKQ_F32QP8__AARCH64_NEON_U2, k_div_kr_m_div_mr) { + TEST_REQUIRES_ARM_NEON; + for (size_t kr = 1; kr <= 4; kr++) { + for (size_t mr = 1; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(mr * 20) + .k(kr * 20) + .mr(mr) + .kr(kr) + .Test(xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2); + } + } + } + + TEST(X8_PACKQ_F32QP8__AARCH64_NEON_U2, k_div_kr_m_div_mr_kr_div_sr) { + TEST_REQUIRES_ARM_NEON; + for (size_t sr = 1; sr <= 4; sr++) { + for (size_t kr = sr; kr <= 4 * sr; kr += sr) { + for (size_t mr = 1; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(mr * 20) + .k(kr * 20) + .mr(mr) + .kr(kr) + .sr(sr) + .Test(xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2); + } + } + } + } + + TEST(X8_PACKQ_F32QP8__AARCH64_NEON_U2, k_div_kr_m_lt_mr) { + TEST_REQUIRES_ARM_NEON; + for (size_t kr = 1; kr <= 4; kr++) { + for (size_t mr = 2; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(mr - 1) + .k(kr * 20) + .mr(mr) + .kr(kr) + .Test(xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2); + } + } + } + + TEST(X8_PACKQ_F32QP8__AARCH64_NEON_U2, k_div_kr_m_gt_mr) { + TEST_REQUIRES_ARM_NEON; + for (size_t kr = 1; kr <= 4; kr++) { + for (size_t mr = 2; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(2 * mr + 1) + .k(kr * 20) + .mr(mr) + .kr(kr) + .Test(xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2); + } + } + } + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ARCH_ARM64 + + +}; // namespace xnnpack diff --git a/test/x8-packq.yaml b/test/x8-packq.yaml new file mode 100644 index 000000000000..376a0da12ae6 --- /dev/null +++ b/test/x8-packq.yaml @@ -0,0 +1,12 @@ +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# Scalar +- name: xnn_x8_packq_f32qp8_ukernel__scalar_u1 + +# Aarch64 NEON +- name: xnn_x8_packq_f32qp8_ukernel__aarch64_neon_u2 + cpp_check: XNN_ENABLE_KLEIDIAI diff --git a/tools/generate-packq-test.py b/tools/generate-packq-test.py new file mode 100755 index 000000000000..8c417426515a --- /dev/null +++ b/tools/generate-packq-test.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python +# Copyright 2023 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import codecs +import math +import os +import re +import sys +import yaml + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from primes import next_prime +import xngen +import xnncommon + + +parser = argparse.ArgumentParser(description='PackW microkernel test generator') +parser.add_argument("-s", "--spec", metavar="FILE", required=True, + help="Specification (YAML) file") +parser.add_argument("-o", "--output", metavar="FILE", required=True, + help='Output (C++ source) file') +parser.add_argument( + "-b", + "--output-bench", + metavar="FILE", + required=False, + help="Benchmark output (C++ source) file(s)") +parser.set_defaults(defines=list()) + +def split_ukernel_name(name): + match = re.fullmatch( + r"xnn_(x8|x16|x32)_packq_(f32|f16)qp8_ukernel__(.+)_u(\d+)", name + ) + assert match is not None + unroll = int(match.group(4)) + arch, isa, _ = xnncommon.parse_target_name(target_name=match.group(3)) + return arch, isa, unroll + +PACKQ_BENCHMARK_TEMPLATE = """\ +$if CPP_CHECK: + #if ${CPP_CHECK} +$for MR in (1, 2, 4): + $for KR in (1, 2, 4): + static void ${BENCHMARK_NAME}_mr_${MR}_kr_${KR}( + benchmark::State& state, const char* net) { + ${DATATYPE}_packq(state, + ${UKERNEL_NAME}, + $if ISA_CHECK: + /*mr=*/${MR}, /*kr=*/${KR}, /*sr=*/1, + benchmark::utils::${ISA_CHECK}); + $else: + /*mr=*/${MR}, /*kr=*/${KR}, /*sr=*/1); + } + BENCHMARK_BGEMM(${BENCHMARK_NAME}_mr_${MR}_kr_${KR}) +$if CPP_CHECK: + #endif // ${CPP_CHECK} +""" + +PACKQ_TEST_TEMPLATE = """\ +$if CPP_CHECK: + #if ${CPP_CHECK} +TEST(${TEST_NAME}, k_div_kr_m_div_mr) { + $if ISA_CHECK: + ${ISA_CHECK}; + for (size_t kr = 1; kr <= 4; kr++) { + for (size_t mr = 1; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(mr * ${UNROLL * 10}) + .k(kr * ${UNROLL * 10}) + .mr(mr) + .kr(kr) + .Test(${", ".join(TEST_ARGS)}); + } + } +} + +TEST(${TEST_NAME}, k_div_kr_m_div_mr_kr_div_sr) { + $if ISA_CHECK: + ${ISA_CHECK}; + for (size_t sr = 1; sr <= 4; sr++) { + for (size_t kr = sr; kr <= 4 * sr; kr += sr) { + for (size_t mr = 1; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(mr * ${UNROLL * 10}) + .k(kr * ${UNROLL * 10}) + .mr(mr) + .kr(kr) + .sr(sr) + .Test(${", ".join(TEST_ARGS)}); + } + } + } +} + +TEST(${TEST_NAME}, k_div_kr_m_lt_mr) { + $if ISA_CHECK: + ${ISA_CHECK}; + for (size_t kr = 1; kr <= 4; kr++) { + for (size_t mr = 2; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(mr - 1) + .k(kr * ${UNROLL * 10}) + .mr(mr) + .kr(kr) + .Test(${", ".join(TEST_ARGS)}); + } + } +} + +TEST(${TEST_NAME}, k_div_kr_m_gt_mr) { + $if ISA_CHECK: + ${ISA_CHECK}; + for (size_t kr = 1; kr <= 4; kr++) { + for (size_t mr = 2; mr <= 4; mr++) { + PackQMicrokernelTester() + .m(2 * mr + 1) + .k(kr * ${UNROLL * 10}) + .mr(mr) + .kr(kr) + .Test(${", ".join(TEST_ARGS)}); + } + } +} +$if CPP_CHECK: + #endif // ${CPP_CHECK} +""" + + +def generate_cases(ukernel, cpp_check,isa, unroll): + """Generates all tests cases for a PACKQ micro-kernel. + + Args: + ukernel: C name of the micro-kernel function. + cpp_check: Optional preprocessor macro to check for the availability of the + micro-kernel. + isa: instruction set required to run the micro-kernel. Generated unit test + will skip execution if the host processor doesn't support this ISA. + unroll: The number of inputs processed per step. + + Returns: + Code for the test and benchmark cases. + """ + _, test_name = ukernel.split("_", 1) + _, datatype, _ = ukernel.split("_", 2) + test_case = xngen.preprocess( + PACKQ_TEST_TEMPLATE, + { + "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), + "TEST_ARGS": [ukernel], + "UNROLL": unroll, + "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), + "next_prime": next_prime, + "CPP_CHECK": cpp_check, + }, + ) + + benchmark = xngen.preprocess( + PACKQ_BENCHMARK_TEMPLATE, + { + "DATATYPE": datatype, + "BENCHMARK_NAME": test_name, + "UKERNEL_NAME": ukernel, + "ISA_CHECK": xnncommon.generate_isa_utilcheck_macro(isa), + "next_prime": next_prime, + "CPP_CHECK": cpp_check, + }, + ) + + return test_case, benchmark + + +def main(args): + options = parser.parse_args(args) + + with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: + spec_yaml = yaml.safe_load(spec_file) + if not isinstance(spec_yaml, list): + raise ValueError("expected a list of micro-kernels in the spec") + + tests = """\ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: {specification} +// Generator: {generator} + + +#include +#include +#include + +#include + +#include "packq-microkernel-tester.h" +#include + + +namespace xnnpack {{""".format( + specification=options.spec, generator=sys.argv[0] + ) + + bench_output = """\ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: {specification} +// Generator: {generator} + + +#include +#include + +#include +#include "bench/bgemm.h" +#include "bench/packq-benchmark.h" +""".format(specification=options.spec, generator=sys.argv[0]) + + isa_hierarchy = xnncommon._ISA_HIERARCHY_MAP + benches = [""] * len(isa_hierarchy) + + for ukernel_spec in spec_yaml: + name = ukernel_spec["name"] + cpp_check = ukernel_spec.get("cpp_check") + arch, isa, unroll = split_ukernel_name(name) + + test_case, benchmark = generate_cases(name, cpp_check,isa, unroll) + tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) + + benches[isa_hierarchy.get(isa, 0)] += \ + "\n\n" + xnncommon.postprocess_test_case(benchmark, arch, isa) + + tests += "\n\n}; // namespace xnnpack\n" + + xnncommon.overwrite_if_changed(options.output, tests) + + for arch_idx in reversed(range(len(isa_hierarchy))): + bench_output += benches[arch_idx] + + bench_output += """\n +#ifndef XNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif +""" + + if options.output_bench: + output_name = options.output_bench + xnncommon.overwrite_if_changed(output_name, bench_output) + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/tools/update-microkernels.py b/tools/update-microkernels.py index dea82dbcf205..49179281a01b 100755 --- a/tools/update-microkernels.py +++ b/tools/update-microkernels.py @@ -7,6 +7,7 @@ import argparse from functools import cmp_to_key import io +import itertools import os import re import sys @@ -126,11 +127,29 @@ """ +UNWANTED_INCLUDES = ( + 'arm_acle.h', + 'arm_fp16.h', + 'arm_neon.h', + 'emmintrin.h', + 'immintrin.h', + 'nmmintrin.h', + 'smmintrin.h', + 'tmmintrin.h', + 'xmmintrin.h', + 'riscv_vector.h', + 'wasm_simd128.h', +) parser = argparse.ArgumentParser( - description='Utility for re-generating microkernel lists') -parser.add_argument('-a', '--amalgamate', action='store_true', - help='Amalgamate production microkernels') + description='Utility for re-generating microkernel lists' +) +parser.add_argument( + '-a', + '--amalgamate', + action='store_true', + help='Amalgamate production microkernels', +) def human_sort_key(text): @@ -141,9 +160,14 @@ def human_sort_key(text): ] +def _discard(l, val): + if val in l: + l.remove(val) + + def amalgamate_microkernel_sources(source_paths, include_header): amalgam_lines = list() - amalgam_includes = set() + amalgam_includes = [] for filepath in sorted(source_paths): with open(filepath, 'r', encoding='utf-8') as file: source_lines = file.read().splitlines() @@ -157,7 +181,8 @@ def amalgamate_microkernel_sources(source_paths, include_header): continue elif line.lstrip().startswith('#'): if not consumed_includes: - amalgam_includes.add(line) + if line not in amalgam_includes: + amalgam_includes.append(line) continue consumed_license = True elif not line: @@ -173,29 +198,29 @@ def amalgamate_microkernel_sources(source_paths, include_header): amalgam_lines.append('') # Multi-line sequence for XOP intrinsics, which don't have a standardized header - amalgam_includes.discard('#ifdef _MSC_VER') - amalgam_includes.discard(' #include ') - amalgam_includes.discard('#else') - amalgam_includes.discard(' #include ') - amalgam_includes.discard('#endif') + _discard(amalgam_includes, '#ifdef _MSC_VER') + _discard(amalgam_includes, ' #include ') + _discard(amalgam_includes, '#else') + _discard(amalgam_includes, ' #include ') + _discard(amalgam_includes, '#endif') # Single-line sequences for intrinsics with a standardized header - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') - amalgam_includes.discard('#include ') + for filename in UNWANTED_INCLUDES: + _discard(amalgam_includes, f'#include <{filename}>') amalgam_text = AMALGAMATION_HEADER - amalgam_text += "\n".join(sorted(inc for inc in amalgam_includes if - not inc.startswith('#include