diff --git a/.devops/main-cuda.Dockerfile b/.devops/main-cuda.Dockerfile index 9345241acff..bd9a43adb7a 100644 --- a/.devops/main-cuda.Dockerfile +++ b/.devops/main-cuda.Dockerfile @@ -28,6 +28,8 @@ COPY .. . RUN make FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime +ENV CUDA_MAIN_VERSION=12.3 +ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH WORKDIR /app RUN apt-get update && \ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ae6c4ce9b45..2b51051848f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,10 +15,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Build ${{ matrix.arch }} run: | @@ -36,7 +36,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Dependencies run: | @@ -53,10 +53,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Build - uses: cross-platform-actions/action@v0.15.0 + uses: cross-platform-actions/action@v0.24.0 with: operating_system: freebsd version: '13.2' @@ -77,10 +77,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Build ${{ matrix.arch }} run: | @@ -105,10 +105,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Build ${{ matrix.arch }} run: | @@ -133,10 +133,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Build ${{ matrix.arch }} run: | @@ -150,6 +150,164 @@ jobs: make ctest -L gh --output-on-failure' + ubuntu-22-cmake-sycl: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + dwhisper_sycl: [ON] + dcmake_c_compiler: [icx] + dcmake_cxx_compiler: [icpx] + arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + + continue-on-error: true + + steps: + - name: Clone + uses: actions/checkout@v4 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + mkdir build + cd build + cmake -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake --build . --config Release -j $(nproc) + + ubuntu-22-cmake-sycl-fp16: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + dwhisper_sycl: [ON] + dcmake_c_compiler: [icx] + dcmake_cxx_compiler: [icpx] + arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + + continue-on-error: true + + steps: + - name: Clone + uses: actions/checkout@v4 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel + + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + mkdir build + cd build + cmake -DWHISPER_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake --build . --config Release -j $(nproc) + + windows-msys2: + runs-on: windows-latest + + strategy: + fail-fast: false + matrix: + include: + - { sys: UCRT64, env: ucrt-x86_64, build: Release } + - { sys: CLANG64, env: clang-x86_64, build: Release } + + steps: + - name: Clone + uses: actions/checkout@v4 + + - name: Setup ${{ matrix.sys }} + uses: msys2/setup-msys2@v2 + with: + update: true + msystem: ${{matrix.sys}} + install: >- + base-devel + mingw-w64-${{matrix.env}}-toolchain + mingw-w64-${{matrix.env}}-cmake + mingw-w64-${{matrix.env}}-SDL2 + mingw-w64-${{matrix.env}}-openblas + + - name: Build using make + shell: msys2 {0} + run: | + make -j $(nproc) + + - name: Clean after building using make + shell: msys2 {0} + run: | + make clean + + - name: Build using make w/ OpenBLAS + shell: msys2 {0} + run: | + make WHISPER_OPENBLAS=1 -j $(nproc) + + - name: Build using CMake + shell: msys2 {0} + run: | + cmake -B build + cmake --build build --config ${{ matrix.build }} -j $(nproc) + + - name: Clean after building using CMake + shell: msys2 {0} + run: | + rm -rf build + + - name: Build using CMake w/ OpenBLAS + shell: msys2 {0} + run: | + cmake -B build -DWHISPER_OPENBLAS=ON + cmake --build build --config ${{ matrix.build }} -j $(nproc) + windows: runs-on: windows-latest @@ -170,10 +328,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v1 + uses: microsoft/setup-msbuild@v2 - name: Fetch SDL2 and set SDL2_DIR if: matrix.sdl2 == 'ON' @@ -198,14 +356,14 @@ jobs: run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - name: Upload dll - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ matrix.jnaPath }}_whisper.dll path: build/bin/${{ matrix.build }}/whisper.dll - name: Upload binaries if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: whisper-bin-${{ matrix.arch }} path: build/bin/${{ matrix.build }} @@ -234,10 +392,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v1 + uses: microsoft/setup-msbuild@v2 - name: Fetch OpenBLAS if: matrix.blas == 'ON' @@ -295,13 +453,13 @@ jobs: - name: Upload binaries if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: whisper-blas${{ matrix.clblast == 'ON' && '-clblast' || ''}}-bin-${{ matrix.arch }} path: build/bin/${{ matrix.build }} windows-cublas: - runs-on: windows-latest + runs-on: windows-2019 strategy: matrix: @@ -318,14 +476,14 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v1 + uses: microsoft/setup-msbuild@v2 - name: Install CUDA Toolkit id: cuda-toolkit - uses: Jimver/cuda-toolkit@v0.2.11 + uses: Jimver/cuda-toolkit@v0.2.15 with: cuda: '${{ matrix.cuda-toolkit }}' @@ -340,7 +498,7 @@ jobs: run: > cmake -S . -B ./build -A ${{ matrix.arch }} -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DWHISPER_CUBLAS=${{ matrix.cublas }} + -DWHISPER_CUDA=${{ matrix.cublas }} -DWHISPER_SDL2=${{ matrix.sdl2 }} - name: Build ${{ matrix.cuda-toolkit }} @@ -361,7 +519,7 @@ jobs: - name: Upload binaries if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }} path: build/bin/${{ matrix.build }} @@ -375,10 +533,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v12 + uses: mymindstorm/setup-emsdk@v14 - name: Verify run: emcc -v @@ -397,7 +555,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Configure run: | @@ -415,61 +573,75 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 + with: + path: whisper + + - name: Clone + uses: actions/checkout@v4 + with: + repository: ggerganov/ggml + path: ggml - name: Install Java - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: distribution: zulu - java-version: 17 + java-version: 21 - name: Setup Android SDK - uses: android-actions/setup-android@v2 + uses: android-actions/setup-android@v3 - name: Build run: | - cd examples/whisper.android + cd whisper/examples/whisper.android ./gradlew assembleRelease --no-daemon + - name: Build with external ggml + run: | + export PATH_TO_GGML=$PWD/ggml + cd whisper/examples/whisper.android + ./gradlew assembleRelease --no-daemon -PGGML_HOME=$PATH_TO_GGML + android_java: runs-on: ubuntu-latest steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: set up JDK 11 - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: java-version: '11' distribution: 'temurin' cache: gradle - name: Setup Android SDK - uses: android-actions/setup-android@v2 + uses: android-actions/setup-android@v3 with: - api-level: 30 - build-tools-version: 30.0.3 + cmdline-tools-version: 9.0 - name: Build run: | cd examples/whisper.android.java - chmod +x ./gradlew + chmod +x ./gradlew ./gradlew assembleRelease java: needs: [ 'windows' ] runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Java - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 17 + distribution: zulu + java-version: 20 - name: Download Windows lib - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: win32-x86-64_whisper.dll path: bindings/java/build/generated/resources/main/win32-x86-64 @@ -482,7 +654,7 @@ jobs: ./gradlew build - name: Upload jar - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: whispercpp.jar path: bindings/java/build/libs/whispercpp-*.jar @@ -504,7 +676,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Test quantize run: | diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index ddaf5e9de5d..808dd18c0b7 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -37,7 +37,7 @@ jobs: run: npm install - name: Compile addon.node - run: npx cmake-js compile -T whisper-addon -B Release + run: npx cmake-js compile -T addon.node -B Release - name: Download test model run: | diff --git a/.gitignore b/.gitignore index 30fca39ad6d..e3319ad0329 100644 --- a/.gitignore +++ b/.gitignore @@ -6,8 +6,11 @@ .vs/ .vscode/ .DS_Store +.vimspector.json +/CMakeSettings.json build/ +build-blas/ build-coreml/ build-em/ build-debug/ @@ -58,4 +61,4 @@ benchmark_results.csv cmake-build-debug/ .cxx/ .gradle/ -local.properties \ No newline at end of file +local.properties diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 00000000000..33e6c9649c7 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,301 @@ +# date: Tue Apr 9 20:27:03 EEST 2024 +# this file is auto-generated by scripts/gen-authors.sh + +0/0 +0cc4m +0xsourcecode <134374803+0xsourcecode@users.noreply.github.com> +AT +Aarni Koskela +Aaron Pham <29749331+aarnphm@users.noreply.github.com> +Aaron Taylor +Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> +Abitofevrything <54505189+abitofevrything@users.noreply.github.com> +AfryMask +Ahmad Bilal +AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> +Akash Mahajan +Akash Mahajan +Al Hoang <3811822-hoanga@users.noreply.gitlab.com> +Alan +Aleksander Andrzejewski <18704749+aleksanderandrzejewski@users.noreply.github.com> +Alex Azarov +Alex Bacart <13940752+alex-bacart@users.noreply.github.com> +Alex Evgrashin +Alexandr Graschenkov +Alexandru Mariuti +Alexey Kharlamov +Alfredo Montesinos +Ali Alameh +Ananta Bastola +Andreu Huguet +Andrew Huynh +Andrew S +Andy Maloney +Anton Kostin +Artyom Mezin +Asad Memon +Ashraful Islam +AsukaMinato +AustinMroz +Avik Sengupta +Bader-eddine Ouaich <49657842+baderouaich@users.noreply.github.com> +Baffin Lee +Ben Nortier +Benjamin Heiniger +Bo-Yi Wu +Boris Bliznioukov +Borislav Stanimirov +Brad Murray <59848399+bradmurray-dt@users.noreply.github.com> +Brian Murray +CRD716 +Canis Lupus +Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> +ChangSeok Oh +Chaoqun <27287694+OpenWaygate@users.noreply.github.com> +Chia-Hsiang Cheng <88014292+garychia@users.noreply.github.com> +Chidi Williams +Christian <12550267+iceychris@users.noreply.github.com> +Clifford Heath +Colin +DGdev91 +Damian Czaja +Daniel Bevenius +David +David Thorpe +Davidson Francis +Dener Stassun +Didzis Gosko +Digipom +Dimo +Dody Suria Wijaya +Dr. Tom Murphy VII Ph.D <499244+tom7@users.noreply.github.com> +Duncan McConnell +Egor Egorov +Elkana Bardugo +Emmanuel Schmidbauer +Engininja2 <139037756+Engininja2@users.noreply.github.com> +Eric Swanson +Eric Tendian +Erik Scholz +Evan Jones +Evan Martin +Eve <139727413+netrunnereve@users.noreply.github.com> +Evgeny Kuznetsov +F1L1P <78918286+F1L1Pv2@users.noreply.github.com> +Fangjun Kuang +Felix +Finn Voorhees +FlippFuzz <41221030+FlippFuzz@users.noreply.github.com> +Gang Chen +Gavin Cai +George Hindle +Georgi Gerganov +GitAritron <103900385+GitAritron@users.noreply.github.com> +GiviMAD +Gleicon Moraes +Gregor Jasny +Guillaume Wenzek +HY. Kelvin Lee <34256578+hykelvinlee42@users.noreply.github.com> +Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com> +Hang +Herman Semenov +Hrishikesh Barman +Ian Bicking +Ian Bull +Ikko Ashimine +InconsolableCellist <23345188+InconsolableCellist@users.noreply.github.com> +Ismatulla Mansurov <47342870+sapoepsilon@users.noreply.github.com> +Ivan Gorin +JJ <103335846+computerscienceiscool@users.noreply.github.com> +Jack Mousseau +JacobLinCool +Jakub Ráček +Jared Van Bortel +Jay Binks +Jhen-Jie Hong +Jhen-Jie Hong +JidongZhang-THU <1119708529@qq.com> +Jo Liss +Johan +Johannes Gäßler +John Balis +Jonathan Soo +Jonno <1160532+razodactyl@users.noreply.github.com> +Joonas Pihlajamaa +Jose <34888496+Jerry-Master@users.noreply.github.com> +Josh Bleecher Snyder +Judd +Jumper775 <78500318+jumpers775@users.noreply.github.com> +Justine Tunney +KP Kaiser +Kamilake +Kartik Saranathan <278928+Kartiku@users.noreply.github.com> +Kasumi <90275229+kasumi-1@users.noreply.github.com> +Kawrakow <48489457+ikawrakow@users.noreply.github.com> +Kevin Brothaler +Konstantin Zhuravlyov +Kreijstal +Kylin <56434533+KyL0N@users.noreply.github.com> +LBlue <153975653+lbluep@users.noreply.github.com> +Larry Battle +Laytan Laats +Leo Moll +Lexevolution <31176843+Lexevolution@users.noreply.github.com> +LittleLoli <26589867+WhichWho@users.noreply.github.com> +Lucas Zanek <57494138+LucasZNK@users.noreply.github.com> +Luis Herrera +Lukas Rist +M. A. Ali <73258591+MightyStud@users.noreply.github.com> +M. Eren Akbiyik +Maciek +Marcin Mielniczuk +Martin Warnaar +Matheus de Sousa <23645013+keyehzy@users.noreply.github.com> +Mathijs de Bruin +Matija Pevec +Maximiliano Levi <8160966+maxilevi@users.noreply.github.com> +Meng, Hengyu +Michael Podvitskiy +Michael Rienstra +Mikhail Grigorev +Mohammadreza Hendiani +Mohit Agarwal +Murilo Santana +Neil Chudleigh +Neo Zhang Jianyu +Neuman Vong +Nicholas Albion +Niels Mayer +Okabintaro <103938900+Okabintaro@users.noreply.github.com> +Oleg Sidorov +Oleg Sidorov +Ondrej Kokes +Ouadie EL FAROUKI +Paul Tsochantaris +Philipp Zabel +Philippe Normand +Przemysław Pawełczyk +Qianhe Chen <54462604+chenqianhe@users.noreply.github.com> +Radosław Gryta +Reinforce-II +Reinis Muiznieks +RelatedTitle +RhinoDevel +Rich Jones +Robin +Roddur Dasgupta +Roland Rabien +Rotem Dan +Ryan Hitchman +Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> +RyanChang +Sam <49637763+Onlyartist9@users.noreply.github.com> +Sam Pullara +Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> +Sergio López +Siddharth Ramakrishnan +Simon Moisselin +Sindre Sorhus +Slava Primenko +Syahmi Azhar +Syed Jafri +Sơn Phan Trung +Taisei Mima +Takeshi Inoue +Tamotsu Takahashi +Taras Glek +Tauseef Mohiuddin <35351464+tauseefmohammed2@users.noreply.github.com> +Thijs Raymakers +Thomas Fitzsimmons +Tiago Fassoni +Tienshiao Ma +Timothy Cronin <40186632+4imothy@users.noreply.github.com> +Tobrun +Todd +Tong Li <31761981+litongjava@users.noreply.github.com> +Topping1 <78745143+Topping1@users.noreply.github.com> +Travis Cline +UEXTM.com <84163508+uextm@users.noreply.github.com> +Vadim Peretokin +Valentin Gosu <1454649+valenting@users.noreply.github.com> +Vulcan <93451215+trholding@users.noreply.github.com> +WhiteOlivierus <36532695+WhiteOlivierus@users.noreply.github.com> +Xiang (Kevin) Li +Xiao-Yong Jin +XiaotaoChen +Yajing Tang +Yang Shen +Yunès +ZaBlazzingZephyrus <119159668+blazingzephyr@users.noreply.github.com> +Zigfrid Zvezdin +Zollner <24618122+Zolliner@users.noreply.github.com> +ai-at-home <149282006+ai-at-home@users.noreply.github.com> +alonfaraj +andypayne +ardfork <134447697+ardfork@users.noreply.github.com> +automaticcat +be-next +bert hubert +bmwl +bobqianic <129547291+bobqianic@users.noreply.github.com> +bocytko +boolemancer <48014766+boolemancer@users.noreply.github.com> +boolemancer +bradmit <151883577+bradmit@users.noreply.github.com> +brunofaustino +bssrdf +byte-6174 <88070277+byte-6174@users.noreply.github.com> +cdosoftei +clach04 +compilade <113953597+compilade@users.noreply.github.com> +conradg +ddpasa <112642920+ddpasa@users.noreply.github.com> +denersc +dscripka +duthils +ecneladis +faker +fitzsim +fraxy-v <65565042+fraxy-v@users.noreply.github.com> +genevera (she/her) +geniusnut +greeshmay +hydai +iamthad +james wolf +joecryptotoo <80373433+joecryptotoo@users.noreply.github.com> +jorismertz <35079666+jorismertz@users.noreply.github.com> +junkfood <69683722+JunkFood02@users.noreply.github.com> +jwijffels +kamranjon +katsu560 +kennethge <57784063+kenneth-ge@users.noreply.github.com> +keyehzy +leejet +litong <31761981+litongjava@users.noreply.github.com> +lnyan +m.bell +mkiol +novag <7754358+novag@users.noreply.github.com> +pajowu +polarmoon <90010972+polarmoon@users.noreply.github.com> +rlapray +sandrohanea <40202887+sandrohanea@users.noreply.github.com> +semiformal-net <84111142+semiformal-net@users.noreply.github.com> +shibukazu <61775791+shibukazu@users.noreply.github.com> +shikokuchuo <53399081+shikokuchuo@users.noreply.github.com> +slaren +slashlib +snadampal <87143774+snadampal@users.noreply.github.com> +st-gr <38470677+st-gr@users.noreply.github.com> +texmex76 <40733439+texmex76@users.noreply.github.com> +thefinaldegree +trixirt +ulatekh +undef +venkr +vicalloy +xdrudis +zhouwg <6889919+zhouwg@users.noreply.github.com> +布客飞龙 <562826179@qq.com> +Артём Земляк diff --git a/CMakeLists.txt b/CMakeLists.txt index 6db8273272c..329ebb6fe94 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,10 @@ cmake_minimum_required (VERSION 3.5) -project(whisper.cpp VERSION 1.5.4) +# Allow for the creation of solution folders. +set_property(GLOBAL PROPERTY USE_FOLDERS ON) + +project(whisper.cpp VERSION 1.6.2) +set(SOVERSION 1) # Add path to modules list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") @@ -55,10 +59,17 @@ option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDA option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) -option(WHISPER_NO_AVX "whisper: disable AVX" OFF) -option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF) -option(WHISPER_NO_FMA "whisper: disable FMA" OFF) -option(WHISPER_NO_F16C "whisper: disable F16c" OFF) +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF) +endif() + +option(WHISPER_NO_AVX "whisper: disable AVX" OFF) +option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF) +option(WHISPER_NO_AVX512 "whisper: disable AVX512" ON) +option(WHISPER_NO_AVX512_VBMI "whisper: disable AVX512-VBMI" ON) +option(WHISPER_NO_AVX512_VNNI "whisper: disable AVX512-VNNI" ON) +option(WHISPER_NO_FMA "whisper: disable FMA" OFF) +option(WHISPER_NO_F16C "whisper: disable F16c" OFF) option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF) @@ -68,13 +79,22 @@ if (APPLE) option(WHISPER_METAL_NDEBUG "whisper: disable Metal debugging" OFF) option(WHISPER_COREML "whisper: enable Core ML framework" OFF) option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF) + option(WHISPER_METAL_EMBED_LIBRARY "whisper: embed Metal library" OFF) + option(WHISPER_BLAS "whisper: use BLAS" ON) + set (WHISPER_BLAS_VENDOR "Apple" CACHE STRING + "whisper: BLAS library vendor") else() - option(WHISPER_BLAS "whisper: use BLAS libraries" OFF) - option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic) - option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF) - option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF) - option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF) - option(WHISPER_CLBLAST "whisper: use CLBlast" OFF) + option(WHISPER_CUDA "whisper: support for CUDA" OFF) + option(WHISPER_CUDA_FA_ALL_QUANTS "whisper: compile all quants for FlashAttention" OFF) + option(WHISPER_CUBLAS "whisper: support for CUDA (deprecated)" OFF) + option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF) + option(WHISPER_CLBLAST "whisper: use CLBlast" OFF) + option(WHISPER_MKL "whisper: use Intel Math Kernel Library (MKL)" OFF) + option(WHISPER_SYCL "whisper: use SYCL" OFF) + option(WHISPER_SYCL_F16 "whisper: use 16 bit floats for sycl calculations" OFF) + option(WHISPER_BLAS "whisper: use BLAS" OFF) + set (WHISPER_BLAS_VENDOR "Generic" CACHE STRING + "whisper: BLAS library vendor") endif() option(WHISPER_PERF "whisper: enable perf timings" OFF) @@ -105,6 +125,33 @@ endif() find_package(Threads REQUIRED) +#compile flag sycl +if (WHISPER_SYCL) + set(CMAKE_CXX_STANDARD 17) +else() + set(CMAKE_CXX_STANDARD 11) +endif() + +if (WHISPER_FFMPEG) + # As of cmake 3.27, there is no official cmake support for FindFFmpeg. + # Consequnelty we added a FindFFmpeg.cmake script the cmake subfolder: + # whisper.cpp does not need the full ffmpeg libs, just AVFORMAT AVCODEC AVUTIL SWRESAMPLE + # libswresample performs highly optimized audio resampling, rematrixing and sample format conversion operations + # libavcodec provides a generic encoding/decoding framework and contains multiple decoders and encoders for audio, video and subtitle streams, and several bitstream filters. + # libavformat provides a generic framework for multiplexing and demultiplexing (muxing and demuxing) audio, video and subtitle streams. + find_package(FFmpeg REQUIRED) + if (NOT ${FFMPEG_FOUND}) + message(FATAL_ERROR "Cannot find ffmpeg libs/headers") + endif() + message(STATUS "Found ffmpeg libs: ${FFMPEG_LIBRARIES}") + message(STATUS "Found ffmpeg headers in: ${FFMPEG_INCLUDE_DIRS}") + message(STATUS "ffmpeg definitions: ${FFMPEG_DEFINITIONS}") + message(STATUS "Found avformat ${AVFORMAT_VERSION}") + include_directories(${FFMPEG_INCLUDE_DIRS}) + add_compile_definitions(WHISPER_FFMPEG) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${FFMPEG_LIBRARIES}) +endif() + # on APPLE if (APPLE) # include Accelerate framework @@ -115,7 +162,7 @@ if (APPLE) message(STATUS "Accelerate framework found") set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) - set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE) + set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64) else() message(FATAL_ERROR "Accelerate framework not found") endif() @@ -145,8 +192,42 @@ if (APPLE) set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) - # copy ggml-metal.metal to bin directory + # copy ggml-common.h and ggml-metal.metal to bin directory + configure_file(ggml-common.h bin/ggml-common.h COPYONLY) configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY) + + if (WHISPER_METAL_EMBED_LIBRARY) + enable_language(ASM) + set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_EMBED_LIBRARY) + + set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + set(COMMON_HEADER "${CMAKE_CURRENT_SOURCE_DIR}/ggml-common.h") + + file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") + set(EMBED_METALLIB_ASSEMBLY "${CMAKE_BINARY_DIR}/autogenerated/ggml-embed-metallib.s") + set(EMBED_METALLIB_SOURCE "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-combined.metal") + + add_custom_command( + OUTPUT ${EMBED_METALLIB_SOURCE} + COMMAND sed -e "/^#include \\\"ggml-common.h\\\"/r ${COMMON_HEADER}" -e "/^#include \\\"ggml-common.h\\\"/d" ${METALLIB_SOURCE} > ${EMBED_METALLIB_SOURCE} + DEPENDS ${METALLIB_SOURCE} ${COMMON_HEADER} + COMMENT "Generating combined Metal library for embedding" + ) + + add_custom_command( + OUTPUT ${EMBED_METALLIB_ASSEMBLY} + COMMAND echo ".section __DATA,__ggml_metallib" > ${EMBED_METALLIB_ASSEMBLY} + COMMAND echo ".globl _ggml_metallib_start" >> ${EMBED_METALLIB_ASSEMBLY} + COMMAND echo "_ggml_metallib_start:" >> ${EMBED_METALLIB_ASSEMBLY} + COMMAND echo ".incbin \\\"${EMBED_METALLIB_SOURCE}\\\"" >> ${EMBED_METALLIB_ASSEMBLY} + COMMAND echo ".globl _ggml_metallib_end" >> ${EMBED_METALLIB_ASSEMBLY} + COMMAND echo "_ggml_metallib_end:" >> ${EMBED_METALLIB_ASSEMBLY} + DEPENDS ${EMBED_METALLIB_SOURCE} + COMMENT "Generate assembly for embedded Metal library" + ) + + set(GGML_SOURCES_METAL ${GGML_SOURCES_METAL} ${EMBED_METALLIB_ASSEMBLY}) + endif() endif() if (WHISPER_COREML) @@ -167,65 +248,161 @@ if (APPLE) endif() endif() -if (WHISPER_OPENBLAS) - set(WHISPER_BLAS_VENDOR "OpenBLAS") - set(WHISPER_BLAS ON) -endif() - if (WHISPER_BLAS) - if (WIN32) - if(DEFINED ENV{OPENBLAS_PATH}) - set(BLAS_LIBRARIES $ENV{OPENBLAS_PATH}/lib/libopenblas.dll.a) - message(STATUS "Libraries ${BLAS_LIBRARIES}") - set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS) - include_directories($ENV{OPENBLAS_PATH}/include) - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES}) - else () - message(FATAL_ERROR "BLAS library was not found. Environment variable OPENBLAS_PATH not defined.") - endif () - else () - set(BLA_STATIC 1) - set(BLA_VENDOR ${WHISPER_BLAS_VENDOR}) - set(BLA_SIZEOF_INTEGER 8) - set(BLA_PREFER_PKGCONFIG 1) - find_package(BLAS) - - if(BLAS_FOUND) - message(STATUS "BLAS compatible library found") - message(STATUS "Libraries ${BLAS_LIBRARIES}") - find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include/openblas /usr/local/include/openblas $ENV{BLAS_HOME}/include) - set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS) - include_directories(${BLAS_INCLUDE_DIRS}) - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES}) - else() - message(FATAL_ERROR "BLAS library was not found") + if (WHISPER_STATIC) + set(BLA_STATIC ON) + endif() + #if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22) + # set(BLA_SIZEOF_INTEGER 8) + #endif() + + set(BLA_VENDOR ${WHISPER_BLAS_VENDOR}) + find_package(BLAS) + + if (BLAS_FOUND) + message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + + if (("${BLAS_INCLUDE_DIRS}" STREQUAL "") AND NOT (${WHISPER_BLAS_VENDOR} MATCHES "Apple")) + # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. + # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 + find_package(PkgConfig REQUIRED) + if (${WHISPER_BLAS_VENDOR} MATCHES "Generic") + pkg_check_modules(DepBLAS REQUIRED blas) + elseif (${WHISPER_BLAS_VENDOR} MATCHES "OpenBLAS") + # As of openblas v0.3.22, the 64-bit is named openblas64.pc + pkg_check_modules(DepBLAS openblas64) + if (NOT DepBLAS_FOUND) + pkg_check_modules(DepBLAS REQUIRED openblas) + endif() + elseif (${WHISPER_BLAS_VENDOR} MATCHES "FLAME") + pkg_check_modules(DepBLAS REQUIRED blis) + elseif (${WHISPER_BLAS_VENDOR} MATCHES "ATLAS") + pkg_check_modules(DepBLAS REQUIRED blas-atlas) + elseif (${WHISPER_BLAS_VENDOR} MATCHES "FlexiBLAS") + pkg_check_modules(DepBLAS REQUIRED flexiblas_api) + elseif (${WHISPER_BLAS_VENDOR} MATCHES "Intel") + # all Intel* libraries share the same include path + pkg_check_modules(DepBLAS REQUIRED mkl-sdl) + elseif (${WHISPER_BLAS_VENDOR} MATCHES "NVHPC") + # this doesn't provide pkg-config + # suggest to assign BLAS_INCLUDE_DIRS on your own + if ("${NVHPC_VERSION}" STREQUAL "") + message(WARNING "Better to set NVHPC_VERSION") + else() + set(DepBLAS_FOUND ON) + set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") + endif() + endif() + if (DepBLAS_FOUND) + set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) + else() + message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" + " detected by pkgconfig, trying to find cblas.h from possible paths...") + find_path(BLAS_INCLUDE_DIRS + NAMES cblas.h + HINTS + /usr/include + /usr/local/include + /usr/include/openblas + /opt/homebrew/opt/openblas/include + /usr/local/opt/openblas/include + /usr/include/x86_64-linux-gnu/openblas/include + ) + endif() endif() - endif () -endif () + + message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + + add_compile_options(${BLAS_LINKER_FLAGS}) + + add_compile_definitions(GGML_USE_BLAS) + + if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${WHISPER_BLAS_VENDOR} MATCHES "Generic" OR ${WHISPER_BLAS_VENDOR} MATCHES "Intel")) + add_compile_definitions(GGML_BLAS_USE_MKL) + endif() + + set(GGML_HEADERS_BLAS ggml-blas.h) + set(GGML_SOURCES_BLAS ggml-blas.cpp) + + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES}) + set(WHISPER_EXTRA_INCLUDES ${WHISPER_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) + else() + message(WARNING "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct WHISPER_BLAS_VENDOR") + endif() +endif() + +if (WHISPER_MKL) + find_package(MKL CONFIG REQUIRED PATHS $ENV{MKLROOT}) + message(STATUS "Imported oneMKL targets: ${MKL_IMPORTED_TARGETS}") + set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS) + set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_BLAS_USE_MKL) +endif() if (WHISPER_CUBLAS) - cmake_minimum_required(VERSION 3.17) + message(WARNING "WHISPER_CUBLAS is deprecated and will be removed in the future.\nUse WHISPER_CUDA instead") + set(WHISPER_CUDA ON) +endif() + +if (WHISPER_CUDA) + cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES find_package(CUDAToolkit) if (CUDAToolkit_FOUND) message(STATUS "cuBLAS found") + if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + # 52 == lowest CUDA 12 standard + # 60 == f16 CUDA intrinsics + # 61 == integer CUDA intrinsics + # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster + if (WHISPER_CUDA_F16 OR WHISPER_CUDA_DMMV_F16) + set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics + else() + set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics + #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work + endif() + endif() + message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") + enable_language(CUDA) - set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) + file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu") + list(APPEND GGML_SOURCES_CUDA ggml-cuda.h) + list(APPEND GGML_SOURCES_CUDA ggml-cuda.cu) + + file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + + if (WHISPER_CUDA_FA_ALL_QUANTS) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) + else() + file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + endif() - add_compile_definitions(GGML_USE_CUBLAS) + add_compile_definitions(GGML_USE_CUDA) + add_compile_definitions(GGML_CUDA_USE_GRAPHS) if (WHISPER_STATIC) if (WIN32) # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt CUDA::cufft) else () - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static CUDA::cufft_static) endif() else() - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cufft) endif() set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver) @@ -250,40 +427,70 @@ if (WHISPER_HIPBLAS) if (${hipblas_FOUND} AND ${hip_FOUND}) message(STATUS "HIP and hipBLAS found") - add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) - add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) - set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON) - set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) - target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) + set(GGML_HEADERS_ROCM "ggml-cuda.h") + + file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu") + list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu") + + file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + + if (WHISPER_CUDA_FA_ALL_QUANTS) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) + else() + file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + endif() + + add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA) + set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) if (WHISPER_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm) + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} hip::device PUBLIC hip::host roc::rocblas roc::hipblas) else() message(FATAL_ERROR "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") endif() endif() -if (WHISPER_CLBLAST) - find_package(CLBlast) - if (CLBlast_FOUND) - message(STATUS "CLBlast found") - - set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h) +if( WHISPER_OPENVINO ) + find_package(OpenVINO REQUIRED COMPONENTS Runtime) +endif() - add_compile_definitions(GGML_USE_CLBLAST) +if (WHISPER_SYCL) + if ( NOT DEFINED ENV{ONEAPI_ROOT}) + message(FATAL_ERROR "Not detect ENV {ONEAPI_ROOT}, please install oneAPI & source it, like: source /opt/intel/oneapi/setvars.sh") + endif() + #todo: AOT - set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} clblast) - else() - message(FATAL_ERROR "CLBlast not found") + find_package(IntelSYCL REQUIRED) + if (WHISPER_SYCL_F16) + add_compile_definitions(GGML_SYCL_F16) endif() -endif() + add_compile_definitions(GGML_USE_SYCL) -if( WHISPER_OPENVINO ) - find_package(OpenVINO REQUIRED COMPONENTS Runtime) -endif() + add_compile_options(-I./) #include DPCT + add_compile_options(-I/${SYCL_INCLUDE_DIR}) + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl -L${MKLROOT}/lib") + + set(GGML_HEADERS_SYCL ggml-sycl.h) + file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp") + list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) +endif() # compiler flags if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) @@ -332,16 +539,30 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /utf-8") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /utf-8") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /utf-8") - if(NOT WHISPER_NO_AVX2) + if(NOT WHISPER_NO_AVX512) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX512") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX512") + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. + if (NOT WHISPER_NO_AVX512_VBMI) + add_compile_definitions($<$:__AVX512VBMI__>) + add_compile_definitions($<$:__AVX512VBMI__>) + endif() + if (NOT WHISPER_NO_AVX512_VNNI) + add_compile_definitions($<$:__AVX512VNNI__>) + add_compile_definitions($<$:__AVX512VNNI__>) + endif() + elseif(NOT WHISPER_NO_AVX2) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2") - else() - if(NOT WHISPER_NO_AVX) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX") - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX") - endif() + elseif(NOT WHISPER_NO_AVX) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX") endif() else() if (EMSCRIPTEN) @@ -354,6 +575,15 @@ else() if(NOT WHISPER_NO_AVX2) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2") endif() + if(NOT WHISPER_NO_AVX512) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw") + if(NOT WHISPER_NO_AVX512_VBMI) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512vbmi") + endif() + if(NOT WHISPER_NO_AVX512_VNNI) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512vnni") + endif() + endif() if(NOT WHISPER_NO_FMA) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma") endif() @@ -440,6 +670,7 @@ if (WHISPER_COREML) set_target_properties(${TARGET} PROPERTIES COMPILE_FLAGS "-fobjc-arc" ) + set_target_properties(${TARGET} PROPERTIES FOLDER "libs") endif() if (WHISPER_OPENVINO) @@ -458,6 +689,7 @@ if (WHISPER_OPENVINO) set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_OPENVINO) target_link_libraries(${TARGET} PRIVATE openvino::runtime) + set_target_properties(${TARGET} PROPERTIES FOLDER "libs") endif() # @@ -477,12 +709,27 @@ add_library(${TARGET} ggml-quants.c ${GGML_SOURCES_METAL} ${GGML_SOURCES_CUDA} - ${GGML_SOURCES_OPENCL} + ${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL} + ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM} + ${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS} unicode.h whisper.h whisper.cpp ) +if (WHISPER_CUDA) + target_sources(${TARGET} PRIVATE whisper-mel-cuda.cu) +endif() + +include_directories ( + . +) +# Set the version numbers +set_target_properties(whisper PROPERTIES + VERSION ${PROJECT_VERSION} + SOVERSION ${SOVERSION} +) + include(DefaultTargetOptions) target_include_directories(${TARGET} PUBLIC @@ -497,6 +744,10 @@ if (WHISPER_OPENVINO) target_link_libraries(${TARGET} PRIVATE whisper.openvino) endif() +if (WHISPER_MKL) + target_link_libraries(${TARGET} PUBLIC MKL::MKL) +endif() + if (MSVC) target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT}) @@ -549,6 +800,7 @@ target_compile_definitions(${TARGET} PUBLIC ) set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "ggml.h;whisper.h") +set_target_properties(${TARGET} PROPERTIES FOLDER "libs") include(GNUInstallDirs) diff --git a/LICENSE b/LICENSE index 76f67efdc64..acb96ce78e0 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Georgi Gerganov +Copyright (c) 2023-2024 The ggml authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index b9eee85bb69..2c474039a61 100644 --- a/Makefile +++ b/Makefile @@ -18,12 +18,25 @@ ifndef NVCC_VERSION endif endif +# In GNU make default CXX is g++ instead of c++. Let's fix that so that users +# of non-gcc compilers don't have to provide g++ alias or wrapper. +DEFCC := cc +DEFCXX := c++ +ifeq ($(origin CC),default) +CC := $(DEFCC) +endif +ifeq ($(origin CXX),default) +CXX := $(DEFCXX) +endif + CCV := $(shell $(CC) --version | head -n 1) CXXV := $(shell $(CXX) --version | head -n 1) # Mac OS + Arm can report x86_64 # ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789 ifeq ($(UNAME_S),Darwin) + WHISPER_NO_OPENMP := 1 + ifneq ($(UNAME_P),arm) SYSCTL_M := $(shell sysctl -n hw.optional.arm64) ifeq ($(SYSCTL_M),1) @@ -42,6 +55,12 @@ CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC LDFLAGS = +ifdef MACOSX_DEPLOYMENT_TARGET + CFLAGS += -mmacosx-version-min=$(MACOSX_DEPLOYMENT_TARGET) + CXXFLAGS += -mmacosx-version-min=$(MACOSX_DEPLOYMENT_TARGET) + LDFLAGS += -mmacosx-version-min=$(MACOSX_DEPLOYMENT_TARGET) +endif + # clock_gettime came in POSIX.1b (1993) # CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional # posix_memalign came in POSIX.1-2001 / SUSv3 @@ -125,41 +144,68 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64)) CPUINFO_CMD := sysinfo -cpu endif + # x86 ISA extensions (chronological order) ifdef CPUINFO_CMD + SSE3_M := $(shell $(CPUINFO_CMD) | grep -iwE 'PNI|SSE3') + SSSE3_M := $(shell $(CPUINFO_CMD) | grep -iw 'SSSE3') AVX_M := $(shell $(CPUINFO_CMD) | grep -iwE 'AVX|AVX1.0') + F16C_M := $(shell $(CPUINFO_CMD) | grep -iw 'F16C') + FMA_M := $(shell $(CPUINFO_CMD) | grep -iw 'FMA') + AVX2_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX2') + AVX512F_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX512F') + AVX512VBMI_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX512VBMI') + AVX512VNNI_M := $(shell $(CPUINFO_CMD) | grep -iwE 'AVX512_VNNI|AVX512VNNI') + + # AVX-512 has many subsets, so let's make it easy to disable them all + ifneq ($(filter-out 0,$(WHISPER_NO_AVX512)),) + AVX512F_M := + AVX512VBMI_M := + AVX512VNNI_M := + endif + + ifneq (,$(SSE3_M)) + CFLAGS += -msse3 + CXXFLAGS += -msse3 + endif + + ifneq (,$(SSSE3_M)) + CFLAGS += -mssse3 + CXXFLAGS += -mssse3 + endif + ifneq (,$(AVX_M)) CFLAGS += -mavx CXXFLAGS += -mavx endif - AVX2_M := $(shell $(CPUINFO_CMD) | grep -iw 'AVX2') - ifneq (,$(AVX2_M)) - CFLAGS += -mavx2 - CXXFLAGS += -mavx2 + ifneq (,$(F16C_M)) + CFLAGS += -mf16c + CXXFLAGS += -mf16c endif - FMA_M := $(shell $(CPUINFO_CMD) | grep -iw 'FMA') ifneq (,$(FMA_M)) CFLAGS += -mfma CXXFLAGS += -mfma endif - F16C_M := $(shell $(CPUINFO_CMD) | grep -iw 'F16C') - ifneq (,$(F16C_M)) - CFLAGS += -mf16c - CXXFLAGS += -mf16c + ifneq (,$(AVX2_M)) + CFLAGS += -mavx2 + CXXFLAGS += -mavx2 endif - SSE3_M := $(shell $(CPUINFO_CMD) | grep -iwE 'PNI|SSE3') - ifneq (,$(SSE3_M)) - CFLAGS += -msse3 - CXXFLAGS += -msse3 + ifneq (,$(AVX512F_M)) + CFLAGS += -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw + CXXFLAGS += -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw endif - SSSE3_M := $(shell $(CPUINFO_CMD) | grep -iw 'SSSE3') - ifneq (,$(SSSE3_M)) - CFLAGS += -mssse3 - CXXFLAGS += -mssse3 + ifneq (,$(AVX512VBMI_M)) + CFLAGS += -mavx512vbmi + CXXFLAGS += -mavx512vbmi + endif + + ifneq (,$(AVX512VNNI_M)) + CFLAGS += -mavx512vnni + CXXFLAGS += -mavx512vnni endif endif endif @@ -178,8 +224,14 @@ endif ifndef WHISPER_NO_ACCELERATE # Mac M1 - include Accelerate framework ifeq ($(UNAME_S),Darwin) - CFLAGS += -DGGML_USE_ACCELERATE - LDFLAGS += -framework Accelerate + CFLAGS += -DGGML_USE_ACCELERATE -DGGML_USE_BLAS + CFLAGS += -DACCELERATE_NEW_LAPACK + CFLAGS += -DACCELERATE_LAPACK_ILP64 + CXXFLAGS += -DGGML_USE_ACCELERATE -DGGML_USE_BLAS + CXXFLAGS += -DACCELERATE_NEW_LAPACK + CXXFLAGS += -DACCELERATE_LAPACK_ILP64 + LDFLAGS += -framework Accelerate + WHISPER_OBJ += ggml-blas.o endif endif @@ -202,26 +254,70 @@ ifndef WHISPER_NO_METAL endif endif +ifndef WHISPER_NO_OPENMP + CXXFLAGS += -DGGML_USE_OPENMP + CFLAGS += -fopenmp + CXXFLAGS += -fopenmp +endif # WHISPER_NO_OPENMP + ifdef WHISPER_OPENBLAS - CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas - LDFLAGS += -lopenblas -endif + CXXFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas) + CFLAGS += $(shell pkg-config --cflags-only-other openblas) + LDFLAGS += $(shell pkg-config --libs openblas) + WHISPER_OBJ += ggml-blas.o +endif # WHISPER_OPENBLAS + +ifdef WHISPER_OPENBLAS64 + CXXFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas64) + CFLAGS += $(shell pkg-config --cflags-only-other openblas64) + LDFLAGS += $(shell pkg-config --libs openblas64) + WHISPER_OBJ += ggml-blas.o +endif # WHISPER_OPENBLAS64 + +ifdef WHISPER_BLIS + CXXFLAGS += -DGGML_USE_BLAS -I/usr/local/include/blis -I/usr/include/blis + LDFLAGS += -lblis -L/usr/local/lib + WHISPER_OBJ += ggml-blas.o +endif # WHISPER_BLIS ifdef WHISPER_CUBLAS +# WHISPER_CUBLAS is deprecated and will be removed in the future + WHISPER_CUDA := 1 +endif + +OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu)) +OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/mmq*.cu)) +ifdef WHISPER_CUDA_FA_ALL_QUANTS + OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu)) +else + OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu)) + OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu)) + OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*f16-f16.cu)) +endif # WHISPER_CUDA_FA_ALL_QUANTS + +ifdef WHISPER_CUDA ifeq ($(shell expr $(NVCC_VERSION) \>= 11.6), 1) - CUDA_ARCH_FLAG=native + CUDA_ARCH_FLAG ?= native else - CUDA_ARCH_FLAG=all + CUDA_ARCH_FLAG ?= all endif - CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include - CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include - LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib - WHISPER_OBJ += ggml-cuda.o + CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include + CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS + LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib + WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o + WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu)) + WHISPER_OBJ += $(OBJS_CUDA_TEMP_INST) NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG) -ggml-cuda.o: ggml-cuda.cu ggml-cuda.h +ggml-cuda/%.o: ggml-cuda/%.cu ggml.h ggml-common.h ggml-cuda/common.cuh + $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@ + +ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh) + $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ + +whisper-mel-cuda.o: whisper-mel-cuda.cu whisper.h ggml.h ggml-backend.h whisper-mel.hpp whisper-mel-cuda.hpp $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif @@ -229,30 +325,20 @@ ifdef WHISPER_HIPBLAS ROCM_PATH ?= /opt/rocm HIPCC ?= $(ROCM_PATH)/bin/hipcc GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch) - CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS - CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS + CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUDA + CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUDA LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib LDFLAGS += -lhipblas -lamdhip64 -lrocblas HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) WHISPER_OBJ += ggml-cuda.o + WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu)) + WHISPER_OBJ += $(OBJS_CUDA_TEMP_INST) -ggml-cuda.o: ggml-cuda.cu ggml-cuda.h +ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< -endif -ifdef WHISPER_CLBLAST - CFLAGS += -DGGML_USE_CLBLAST - CXXFLAGS += -DGGML_USE_CLBLAST - LDFLAGS += -lclblast - ifeq ($(UNAME_S),Darwin) - LDFLAGS += -framework OpenCL - else - LDFLAGS += -lOpenCL - endif - WHISPER_OBJ += ggml-opencl.o - -ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h - $(CXX) $(CXXFLAGS) -c $< -o $@ +ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh) + $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< endif ifdef WHISPER_GPROF @@ -301,6 +387,13 @@ $(info I CC: $(CCV)) $(info I CXX: $(CXXV)) $(info ) +ifdef WHISPER_CUBLAS +$(info !!!!) +$(info WHISPER_CUBLAS is deprecated and will be removed in the future. Use WHISPER_CUDA instead.) +$(info !!!!) +$(info ) +endif + # # Build library # @@ -317,9 +410,12 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h $(CC) $(CFLAGS) -c $< -o $@ +ggml-blas.o: ggml-blas.cpp ggml-blas.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o -whisper.o: whisper.cpp whisper.h unicode.h ggml.h ggml-cuda.h +whisper.o: whisper.cpp whisper.h unicode.h whisper-mel.hpp ggml.h ggml-cuda.h $(CXX) $(CXXFLAGS) -c $< -o $@ ifndef WHISPER_COREML @@ -339,6 +435,26 @@ ggml-metal.o: ggml-metal.m ggml-metal.h $(CC) $(CFLAGS) -c $< -o $@ WHISPER_OBJ += ggml-metal.o + +ifdef WHISPER_METAL_EMBED_LIBRARY +CFLAGS += -DGGML_METAL_EMBED_LIBRARY + +ggml-metal-embed.o: ggml-metal.metal ggml-common.h + @echo "Embedding Metal library" + $(eval TEMP_ASSEMBLY=$(shell mktemp)) + $(eval TEMP_METALLIB=$(shell mktemp)) + @sed "/^#include \"ggml-common.h\"/{r ggml-common.h"$$'\n'"d;}" ggml-metal.metal > $(TEMP_METALLIB) + @echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY) + @echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY) + @echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY) + @echo ".incbin \"$(TEMP_METALLIB)\"" >> $(TEMP_ASSEMBLY) + @echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY) + @echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY) + @$(AS) $(TEMP_ASSEMBLY) -o $@ + @rm -f $(TEMP_ASSEMBLY) $(TEMP_METALLIB) + +WHISPER_OBJ += ggml-metal-embed.o +endif endif libwhisper.a: $(WHISPER_OBJ) @@ -349,6 +465,8 @@ libwhisper.so: $(WHISPER_OBJ) clean: rm -f *.o main stream command talk talk-llama bench quantize server lsp libwhisper.a libwhisper.so + rm -vrf ggml-cuda/*.o + rm -vrf ggml-cuda/template-instances/*.o # # Examples @@ -356,7 +474,7 @@ clean: CC_SDL=`sdl2-config --cflags --libs` -SRC_COMMON = examples/common.cpp examples/common-ggml.cpp +SRC_COMMON = examples/common.cpp examples/common-ggml.cpp examples/grammar-parser.cpp SRC_COMMON_SDL = examples/common-sdl.cpp SRC_CONSOLE = examples/console.cpp @@ -376,8 +494,8 @@ server: examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ) stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_CONSOLE) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS) -command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_CONSOLE) $(SRC_COMMON_SDL) $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS) +command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS) lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS) @@ -385,8 +503,8 @@ lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_CONSOLE) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS) -talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_CONSOLE) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS) +talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS) # # Audio samples diff --git a/Package.swift b/Package.swift index e8b85afce81..bbb7fb03b99 100644 --- a/Package.swift +++ b/Package.swift @@ -13,13 +13,9 @@ let package = Package( products: [ .library(name: "whisper", targets: ["whisper"]), ], - dependencies: [ - .package(url: "https://github.com/ggerganov/ggml.git", .branch("release")) - ], targets: [ .target( name: "whisper", - dependencies: ["ggml"], path: ".", exclude: [ "bindings", @@ -36,8 +32,14 @@ let package = Package( "Makefile" ], sources: [ + "ggml.c", "whisper.cpp", + "ggml-alloc.c", + "ggml-backend.c", + "ggml-quants.c", + "ggml-metal.m" ], + resources: [.process("ggml-metal.metal")], publicHeadersPath: "spm-headers", cSettings: [ .unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]), diff --git a/README.md b/README.md index 5702e4d7be7..289869d861c 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,10 @@ [![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) +[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.5.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.4) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) +Stable: [v1.6.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: @@ -19,7 +20,6 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp - Zero memory allocations at runtime - Support for CPU-only inference - [Efficient GPU support for NVIDIA](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas) -- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast) - [OpenVINO Support](https://github.com/ggerganov/whisper.cpp#openvino-support) - [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h) @@ -278,6 +278,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in - To ensure `coremltools` operates correctly, please confirm that [Xcode](https://developer.apple.com/xcode/) is installed and execute `xcode-select --install` to install the command-line tools. - Python 3.10 is recommended. + - MacOS Sonoma (version 14) or newer is recommended, as older versions of MacOS might experience issues with transcription hallucination. - [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step: - To create an environment, use: `conda create -n py310-whisper python=3.10 -y` - To activate the environment, use: `conda activate py310-whisper` @@ -339,7 +340,7 @@ This can result in significant speedup in encoder performance. Here are the inst python -m venv openvino_conv_env openvino_conv_env\Scripts\activate python -m pip install --upgrade pip - pip install -r openvino-conversion-requirements.txt + pip install -r requirements-openvino.txt ``` Linux and macOS: @@ -349,7 +350,7 @@ This can result in significant speedup in encoder performance. Here are the inst python3 -m venv openvino_conv_env source openvino_conv_env/bin/activate python -m pip install --upgrade pip - pip install -r openvino-conversion-requirements.txt + pip install -r requirements-openvino.txt ``` - Generate an OpenVINO encoder model. For example, to generate a `base.en` model, use: @@ -413,45 +414,38 @@ For more information about the Core ML implementation please refer to PR [#1037] With NVIDIA cards the processing of the models is done efficiently on the GPU via cuBLAS and custom CUDA kernels. First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads -Now build `whisper.cpp` with cuBLAS support: +Now build `whisper.cpp` with CUDA support: ``` make clean -WHISPER_CUBLAS=1 make -j +WHISPER_CUDA=1 make -j ``` -## OpenCL GPU support via CLBlast - -For cards and integrated GPUs that support OpenCL, the Encoder processing can be largely offloaded to the GPU through CLBlast. This is especially useful for users with AMD APUs or low end devices for up to ~2x speedup. +## BLAS CPU support via OpenBLAS -First, make sure you have installed `CLBlast` for your OS or Distribution: https://github.com/CNugteren/CLBlast +Encoder processing can be accelerated on the CPU via OpenBLAS. +First, make sure you have installed `openblas`: https://www.openblas.net/ -Now build `whisper.cpp` with CLBlast support: +Now build `whisper.cpp` with OpenBLAS support: ``` -Makefile: -cd whisper.cpp make clean -WHISPER_CLBLAST=1 make -j - -CMake: -cd whisper.cpp -cmake -B build -DWHISPER_CLBLAST=ON -cmake --build build -j --config Release +WHISPER_OPENBLAS=1 make -j ``` -Run all the examples as usual. +## BLAS CPU support via Intel MKL -## BLAS CPU support via OpenBLAS - -Encoder processing can be accelerated on the CPU via OpenBLAS. -First, make sure you have installed `openblas`: https://www.openblas.net/ +Encoder processing can be accelerated on the CPU via the BLAS compatible interface of Intel's Math Kernel Library. +First, make sure you have installed Intel's MKL runtime and development packages: https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl-download.html -Now build `whisper.cpp` with OpenBLAS support: +Now build `whisper.cpp` with Intel MKL BLAS support: ``` -make clean -WHISPER_OPENBLAS=1 make -j +source /opt/intel/oneapi/setvars.sh +mkdir build +cd build +cmake -DWHISPER_MKL=ON .. +WHISPER_MKL=1 make -j ``` ## Docker @@ -486,6 +480,16 @@ docker run -it --rm \ whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav" ``` +## Installing with Conan + +You can install pre-built binaries for whisper.cpp or build it from source using [Conan](https://conan.io/). Use the following command: + +``` +conan install --requires="whisper-cpp/[*]" --build=missing +``` + +For detailed instructions on how to use Conan, please refer to the [Conan documentation](https://docs.conan.io/2/). + ## Limitations - Inference only @@ -694,7 +698,7 @@ The [main](examples/main) example provides support for output of karaoke-style m currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script. This requires to have `ffmpeg` installed. -Here are a few *"typical"* examples: +Here are a few _"typical"_ examples: ```bash ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts @@ -728,10 +732,10 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a ## Video comparison of different models -Use the [extra/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/extra/bench-wts.sh) script to generate a video in the following format: +Use the [scripts/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/scripts/bench-wts.sh) script to generate a video in the following format: ```bash -./extra/bench-wts.sh samples/jfk.wav +./scripts/bench-wts.sh samples/jfk.wav ffplay ./samples/jfk.wav.all.mp4 ``` @@ -752,7 +756,7 @@ Additionally a script to run whisper.cpp with different models and audio files i You can run it with the following command, by default it will run against any standard model in the models folder. ```bash -python3 extra/bench.py -f samples/jfk.wav -t 2,4,8 -p 1,2 +python3 scripts/bench.py -f samples/jfk.wav -t 2,4,8 -p 1,2 ``` It is written in python with the intention of being easy to modify and extend for your benchmarking use case. @@ -792,6 +796,7 @@ For more details, see the conversion script [models/convert-pt-to-ggml.py](model - [NickDarvey/whisper](https://github.com/NickDarvey/whisper) - [x] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9) - [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython) + - [AIWintermuteAI/whispercpp](https://github.com/AIWintermuteAI/whispercpp) (Updated fork of aarnphm/whispercpp) - [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11) - [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper) - [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity) diff --git a/README_sycl.md b/README_sycl.md new file mode 100644 index 00000000000..9ea2a7908ab --- /dev/null +++ b/README_sycl.md @@ -0,0 +1,249 @@ +# whisper.cpp for SYCL + +[Background](#background) + +[OS](#os) + +[Intel GPU](#intel-gpu) + +[Linux](#linux) + +[Environment Variable](#environment-variable) + +[Known Issue](#known-issue) + +[Todo](#todo) + +## Background + +SYCL is a higher-level programming model to improve programming productivity on various hardware acceleratorssuch as CPUs, GPUs, and FPGAs. It is a single-source embedded domain-specific language based on pure C++17. + +oneAPI is a specification that is open and standards-based, supporting multiple architecture types including but not limited to GPU, CPU, and FPGA. The spec has both direct programming and API-based programming paradigms. + +Intel uses the SYCL as direct programming language to support CPU, GPUs and FPGAs. + +To avoid re-inventing the wheel, this code refers other code paths in llama.cpp (like OpenBLAS, cuBLAS, CLBlast). We use a open-source tool [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) (Commercial release [Intel DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) migrate to SYCL. + +The whisper.cpp for SYCL is used to support Intel GPUs. + +For Intel CPU, recommend to use whisper.cpp for X86 (Intel MKL build). + +## OS + +|OS|Status|Verified| +|-|-|-| +|Linux|Support|Ubuntu 22.04| +|Windows|Ongoing| | + + +## Intel GPU + +|Intel GPU| Status | Verified Model| +|-|-|-| +|Intel Data Center Max Series| Support| Max 1550| +|Intel Data Center Flex Series| Support| Flex 170| +|Intel Arc Series| Support| Arc 770| +|Intel built-in Arc GPU| Support| built-in Arc GPU in Meteor Lake| +|Intel iGPU| Support| iGPU in i5-1250P, i7-1165G7| + + +## Linux + +### Setup Environment + +1. Install Intel GPU driver. + +a. Please install Intel GPU driver by official guide: [Install GPU Drivers](https://dgpu-docs.intel.com/driver/installation.html). + +Note: for iGPU, please install the client GPU driver. + +b. Add user to group: video, render. + +``` +sudo usermod -aG render username +sudo usermod -aG video username +``` + +Note: re-login to enable it. + +c. Check + +``` +sudo apt install clinfo +sudo clinfo -l +``` + +Output (example): + +``` +Platform #0: Intel(R) OpenCL Graphics + `-- Device #0: Intel(R) Arc(TM) A770 Graphics + + +Platform #0: Intel(R) OpenCL HD Graphics + `-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49] +``` + +2. Install Intel oneAPI Base toolkit. + + +a. Please follow the procedure in [Get the Intel oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html). + +Recommend to install to default folder: **/opt/intel/oneapi**. + +Following guide use the default folder as example. If you use other folder, please modify the following guide info with your folder. + +b. Check + +``` +source /opt/intel/oneapi/setvars.sh + +sycl-ls +``` + +There should be one or more level-zero devices. Like **[ext_oneapi_level_zero:gpu:0]**. + +Output (example): +``` +[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000] +[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000] +[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50] +[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918] + +``` + +2. Build locally: + +``` +mkdir -p build +cd build +source /opt/intel/oneapi/setvars.sh + +#for FP16 +#cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DWHISPER_SYCL_F16=ON + +#for FP32 +cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx + +#build example/main only +#cmake --build . --config Release --target main + +#build all binary +cmake --build . --config Release -v + +``` + +or + +``` +./examples/sycl/build.sh +``` + +Note: + +- By default, it will build for all binary files. It will take more time. To reduce the time, we recommend to build for **example/main** only. + +### Run + +1. Put model file to folder **models** + +2. Enable oneAPI running environment + +``` +source /opt/intel/oneapi/setvars.sh +``` + +3. List device ID + +Run without parameter: + +``` +./build/bin/ls-sycl-device + +or + +./build/bin/main +``` + +Check the ID in startup log, like: + +``` +found 4 SYCL devices: + Device 0: Intel(R) Arc(TM) A770 Graphics, compute capability 1.3, + max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136 + Device 1: Intel(R) FPGA Emulation Device, compute capability 1.2, + max compute_units 24, max work group size 67108864, max sub group size 64, global mem size 67065057280 + Device 2: 13th Gen Intel(R) Core(TM) i7-13700K, compute capability 3.0, + max compute_units 24, max work group size 8192, max sub group size 64, global mem size 67065057280 + Device 3: Intel(R) Arc(TM) A770 Graphics, compute capability 3.0, + max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136 + +``` + +|Attribute|Note| +|-|-| +|compute capability 1.3|Level-zero running time, recommended | +|compute capability 3.0|OpenCL running time, slower than level-zero in most cases| + +4. Set device ID and execute whisper.cpp + +Set device ID = 0 by **GGML_SYCL_DEVICE=0** + +``` +GGML_SYCL_DEVICE=0 ./build/bin/main -m models/ggml-base.en.bin -f samples/jfk.wav +``` +or run by script: + +``` +./examples/sycl/run_whisper.sh +``` + + + +5. Check the device ID in output + +Like: +``` +Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device +``` + + +## Environment Variable + +#### Build + +|Name|Value|Function| +|-|-|-| +|WHISPER_SYCL|ON (mandatory)|Enable build with SYCL code path.
For FP32/FP16, WHISPER_SYCL=ON is mandatory.| +|WHISPER_SYCL_F16|ON (optional)|Enable FP16 build with SYCL code path.For FP32, do not set it.| +|CMAKE_C_COMPILER|icx|Use icx compiler for SYCL code path| +|CMAKE_CXX_COMPILER|icpx|use icpx for SYCL code path| + +#### Running + + +|Name|Value|Function| +|-|-|-| +|GGML_SYCL_DEVICE|0 (default) or 1|Set the device id used. Check the device ids by default running output| +|GGML_SYCL_DEBUG|0 (default) or 1|Enable log function by macro: GGML_SYCL_DEBUG| + +## Known Issue + +- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`. + + Miss to enable oneAPI running environment. + + Install oneAPI base toolkit and enable it by: `source /opt/intel/oneapi/setvars.sh`. + + +- Hang during startup + + llama.cpp use mmap as default way to read model file and copy to GPU. In some system, memcpy will be abnormal and block. + + Solution: add **--no-mmap**. + +## Todo + +- Support to build in Windows. + +- Support multiple cards. \ No newline at end of file diff --git a/bindings/go/examples/go-whisper/flags.go b/bindings/go/examples/go-whisper/flags.go index ea204455c80..766c92f1827 100644 --- a/bindings/go/examples/go-whisper/flags.go +++ b/bindings/go/examples/go-whisper/flags.go @@ -68,10 +68,6 @@ func (flags *Flags) GetOut() string { return strings.ToLower(flags.Lookup("out").Value.String()) } -func (flags *Flags) IsSpeedup() bool { - return flags.Lookup("speedup").Value.String() == "true" -} - func (flags *Flags) IsTokens() bool { return flags.Lookup("tokens").Value.String() == "true" } @@ -111,10 +107,6 @@ func (flags *Flags) SetParams(context whisper.Context) error { fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration) context.SetDuration(duration) } - if flags.IsSpeedup() { - fmt.Fprintf(flags.Output(), "Setting speedup to true\n") - context.SetSpeedup(true) - } if threads := flags.GetThreads(); threads != 0 { fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads) context.SetThreads(threads) @@ -146,7 +138,6 @@ func registerFlags(flag *Flags) { flag.Duration("offset", 0, "Time offset") flag.Duration("duration", 0, "Duration of audio to process") flag.Uint("threads", 0, "Number of threads to use") - flag.Bool("speedup", false, "Enable speedup") flag.Uint("max-len", 0, "Maximum segment length in characters") flag.Uint("max-tokens", 0, "Maximum tokens per segment") flag.Float64("word-thold", 0, "Maximum segment score") diff --git a/bindings/go/params.go b/bindings/go/params.go index 408decc74ea..8ed10ad57e9 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -43,10 +43,6 @@ func (p *Params) SetPrintTimestamps(v bool) { p.print_timestamps = toBool(v) } -func (p *Params) SetSpeedup(v bool) { - p.speed_up = toBool(v) -} - // Set language id func (p *Params) SetLanguage(lang int) error { if lang == -1 { @@ -173,9 +169,6 @@ func (p *Params) String() string { if p.token_timestamps { str += " token_timestamps" } - if p.speed_up { - str += " speed_up" - } return str + ">" } diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index 7b61f323532..ead92648f3e 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -76,9 +76,8 @@ func (context *context) SetTranslate(v bool) { context.params.SetTranslate(v) } -// Set speedup flag -func (context *context) SetSpeedup(v bool) { - context.params.SetSpeedup(v) +func (context *context) SetSplitOnWord(v bool) { + context.params.SetSplitOnWord(v) } // Set number of threads to use diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index b6c47397894..b430e7ce853 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -41,7 +41,7 @@ type Context interface { SetOffset(time.Duration) // Set offset SetDuration(time.Duration) // Set duration SetThreads(uint) // Set number of threads to use - SetSpeedup(bool) // Set speedup flag + SetSplitOnWord(bool) // Set split on word flag SetTokenThreshold(float32) // Set timestamp token probability threshold SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold SetMaxSegmentLength(uint) // Set max segment length in characters diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 9660662084f..87da83f0f10 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -10,7 +10,7 @@ import ( /* #cgo LDFLAGS: -lwhisper -lm -lstdc++ -#cgo darwin LDFLAGS: -framework Accelerate +#cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics #include #include diff --git a/bindings/ios b/bindings/ios index b21b6ff3259..a2085436c2e 160000 --- a/bindings/ios +++ b/bindings/ios @@ -1 +1 @@ -Subproject commit b21b6ff32595d73694045f4c5568a0981096cbf7 +Subproject commit a2085436c2eb796af90956b62bd64731f5e5b823 diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java index 56a37380136..1a73cee1181 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java @@ -20,7 +20,7 @@ public interface WhisperCppJnaLibrary extends Library { * @return Whisper context on success, null on failure */ Pointer whisper_init_from_file(String path_model); - + /** * Provides default params which can be used with `whisper_init_from_file_with_params()` etc. * Because this function allocates memory for the params, the caller must call either: @@ -304,14 +304,6 @@ public interface WhisperCppJnaLibrary extends Library { /** Language id associated with the provided state */ int whisper_full_lang_id_from_state(Pointer state); - /** - * Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. - * The resulting spectrogram is stored inside the default state of the provided whisper context. - * @return 0 on success - */ - int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads); - - int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads); /** Get the start time of the specified segment. */ long whisper_full_get_segment_t0(Pointer ctx, int i_segment); diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 43c9a0dcf34..90d8c15767c 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -129,14 +129,6 @@ public void splitOnWord(boolean enable) { /** Maximum tokens per segment (0, default = no limit) */ public int max_tokens; - /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */ - public CBool speed_up; - - /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */ - public void speedUp(boolean enable) { - speed_up = enable ? CBool.TRUE : CBool.FALSE; - } - /** Overwrite the audio context size (0 = use default). */ public int audio_ctx; @@ -148,6 +140,9 @@ public void tdrzEnable(boolean enable) { tdrz_enable = enable ? CBool.TRUE : CBool.FALSE; } + /** Regular expression matching tokens to suppress. */ + public String suppress_regex; + /** Tokens to provide to the whisper decoder as an initial prompt. * These are prepended to any existing text context from a previous call. */ public String initial_prompt; @@ -318,8 +313,8 @@ protected List getFieldOrder() { return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate", "no_context", "single_segment", "no_timestamps", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", - "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx", - "tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", + "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx", + "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", "new_segment_callback", "new_segment_callback_user_data", diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 0391ed25169..39506de25b3 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.5.4", + "version": "1.6.2", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile new file mode 100644 index 00000000000..354d8ef2547 --- /dev/null +++ b/bindings/ruby/Rakefile @@ -0,0 +1,12 @@ +require 'rake/clean' + require 'rubygems/package' + +desc 'Build gem' +task :package do + spec_source = File.read File.join(File.dirname(__FILE__),'whispercpp.gemspec') + spec = nil + # see: http://gist.github.com/16215 + Thread.new { spec = eval("#{spec_source}") }.join + spec.validate + Gem::Package.build(spec) +end diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index c736d30237f..f22c550ee37 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -1,6 +1,7 @@ require 'mkmf' system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper-mel.hpp')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .") @@ -9,6 +10,7 @@ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-backend-impl.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-backend.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-backend.c')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-common.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-quants.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-quants.c')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .") diff --git a/bindings/ruby/ext/ggml-backend-impl.h b/bindings/ruby/ext/ggml-backend-impl.h index 31788cd6baa..f121e1de420 100644 --- a/bindings/ruby/ext/ggml-backend-impl.h +++ b/bindings/ruby/ext/ggml-backend-impl.h @@ -12,31 +12,63 @@ extern "C" { // Backend buffer // + // buffer type + typedef void * ggml_backend_buffer_type_context_t; + + struct ggml_backend_buffer_type_i { + const char * (*GGML_CALL get_name) (ggml_backend_buffer_type_t buft); + ggml_backend_buffer_t (*GGML_CALL alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size); + size_t (*GGML_CALL get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment + size_t (*GGML_CALL get_max_size) (ggml_backend_buffer_type_t buft); // allocation max size + size_t (*GGML_CALL get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding + bool (*GGML_CALL supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend + // check if tensor data is in host memory + // should be equivalent to supports_backend(buft, ggml_backend_cpu_init()) + bool (*GGML_CALL is_host) (ggml_backend_buffer_type_t buft); + }; + + struct ggml_backend_buffer_type { + struct ggml_backend_buffer_type_i iface; + ggml_backend_buffer_type_context_t context; + }; + + // buffer typedef void * ggml_backend_buffer_context_t; struct ggml_backend_buffer_i { - void (*free_buffer) (ggml_backend_buffer_t buffer); - void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer - size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback - void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback - void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback + const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer); + void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer); + void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer); + void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer + void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value); + void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras }; struct ggml_backend_buffer { - struct ggml_backend_buffer_i iface; - - ggml_backend_t backend; + struct ggml_backend_buffer_i iface; + ggml_backend_buffer_type_t buft; ggml_backend_buffer_context_t context; - size_t size; + enum ggml_backend_buffer_usage usage; }; - GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( - struct ggml_backend * backend, + GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, struct ggml_backend_buffer_i iface, ggml_backend_buffer_context_t context, size_t size); + // do not use directly, use ggml_backend_tensor_copy instead + bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst); + + // buffer that contains a collection of buffers + GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers); + GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer); + GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + // // Backend // @@ -44,44 +76,66 @@ extern "C" { typedef void * ggml_backend_context_t; struct ggml_backend_i { - const char * (*get_name)(ggml_backend_t backend); + const char * (*GGML_CALL get_name)(ggml_backend_t backend); - void (*free)(ggml_backend_t backend); + void (*GGML_CALL free)(ggml_backend_t backend); // buffer allocation - ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size); + ggml_backend_buffer_type_t (*GGML_CALL get_default_buffer_type)(ggml_backend_t backend); - // get buffer alignment - size_t (*get_alignment)(ggml_backend_t backend); + // (optional) asynchronous tensor data access + void (*GGML_CALL set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); - // tensor data access - // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize - void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - void (*synchronize) (ggml_backend_t backend); + // (optional) complete all pending operations + void (*GGML_CALL synchronize)(ggml_backend_t backend); - // (optional) copy tensor between different backends, allow for single-copy tranfers - void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); - void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); + // compute graph with a plan (not used currently) + ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph); + void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); // compute graph with a plan - ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph); - void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); - void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); - - // compute graph without a plan - bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph); + enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); + // compute graph without a plan (async) + enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph); // check if the backend supports an operation - bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); + bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); + + // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer + // these should be expensive operations with large batch sizes that may benefit from running on this backend + // even if the weight has to be copied from the CPU temporarily + bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op); + + // (optional) event synchronization + ggml_backend_event_t (*GGML_CALL event_new) (ggml_backend_t backend); + void (*GGML_CALL event_free) (ggml_backend_event_t event); + void (*GGML_CALL event_record) (ggml_backend_event_t event); + void (*GGML_CALL event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + void (*GGML_CALL event_synchronize) (ggml_backend_event_t event); }; struct ggml_backend { - struct ggml_backend_i iface; + ggml_guid_t guid; + struct ggml_backend_i iface; ggml_backend_context_t context; }; + struct ggml_backend_event { + ggml_backend_t backend; + void * context; + }; + + // + // Backend registry + // + + typedef ggml_backend_t (*GGML_CALL ggml_backend_init_fn)(const char * params, void * user_data); + + GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data); + #ifdef __cplusplus } #endif diff --git a/bindings/ruby/ext/ggml-backend.c b/bindings/ruby/ext/ggml-backend.c index 128e33ce630..402d86ef3ac 100644 --- a/bindings/ruby/ext/ggml-backend.c +++ b/bindings/ruby/ext/ggml-backend.c @@ -9,31 +9,76 @@ #include #include -#define UNUSED GGML_UNUSED #define MAX(a, b) ((a) > (b) ? (a) : (b)) +// backend buffer type + +const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name(buft); +} + +GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return buft->iface.alloc_buffer(buft, size); +} + +size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) { + return buft->iface.get_alignment(buft); +} + +size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { + // get_max_size is optional, defaults to SIZE_MAX + if (buft->iface.get_max_size) { + return buft->iface.get_max_size(buft); + } + return SIZE_MAX; +} + +GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) { + // get_alloc_size is optional, defaults to ggml_nbytes + if (buft->iface.get_alloc_size) { + size_t size = buft->iface.get_alloc_size(buft, tensor); + assert(size >= ggml_nbytes(tensor)); + return size; + } + return ggml_nbytes(tensor); +} + +bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return buft->iface.supports_backend(buft, backend); +} + +bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { + if (buft->iface.is_host) { + return buft->iface.is_host(buft); + } + return false; +} + // backend buffer -ggml_backend_buffer_t ggml_backend_buffer_init( - struct ggml_backend * backend, +GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, struct ggml_backend_buffer_i iface, ggml_backend_buffer_context_t context, size_t size) { ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer)); - GGML_ASSERT(iface.get_base != NULL); - (*buffer) = (struct ggml_backend_buffer) { /* .interface = */ iface, - /* .backend = */ backend, + /* .buft = */ buft, /* .context = */ context, /* .size = */ size, + /* .usage = */ GGML_BACKEND_BUFFER_USAGE_ANY }; return buffer; } +const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) { + return buffer->iface.get_name(buffer); +} + void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { if (buffer == NULL) { return; @@ -45,10 +90,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { free(buffer); } -size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) { - return ggml_backend_get_alignment(buffer->backend); -} - size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { return buffer->size; } @@ -61,32 +102,67 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { return base; } +GGML_CALL void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + // init_tensor is optional + if (buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } +} + +size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) { + return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer)); +} + +size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) { + return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer)); +} + size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - // get_alloc_size is optional, defaults to ggml_nbytes - if (buffer->iface.get_alloc_size) { - return buffer->iface.get_alloc_size(buffer, tensor); + return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor); +} + +void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + buffer->iface.clear(buffer, value); +} + +bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) { + return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer)); +} + +void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + buffer->usage = usage; + + // FIXME: add a generic callback to the buffer interface + if (ggml_backend_buffer_is_multi_buffer(buffer)) { + ggml_backend_multi_buffer_set_usage(buffer, usage); } - return ggml_nbytes(tensor); } -void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - // init_tensor is optional - if (buffer->iface.init_tensor) { - buffer->iface.init_tensor(buffer, tensor); +ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) { + return buffer->buft; +} + +void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) { + if (buffer->iface.reset) { + buffer->iface.reset(buffer); } } -void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - // free_tensor is optional - if (buffer->iface.free_tensor) { - buffer->iface.free_tensor(buffer, tensor); +bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer; + if (dst_buf->iface.cpy_tensor) { + return src->buffer->iface.cpy_tensor(dst_buf, src, dst); } + return false; } // backend -ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) { - return tensor->buffer ? tensor->buffer->backend : NULL; +ggml_guid_t ggml_backend_guid(ggml_backend_t backend) { + if (backend == NULL) { + return NULL; + } + return backend->guid; } const char * ggml_backend_name(ggml_backend_t backend) { @@ -104,59 +180,105 @@ void ggml_backend_free(ggml_backend_t backend) { backend->iface.free(backend); } +ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) { + return backend->iface.get_default_buffer_type(backend); +} + ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) { - return backend->iface.alloc_buffer(backend, size); + return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size); } size_t ggml_backend_get_alignment(ggml_backend_t backend) { - return backend->iface.get_alignment(backend); + return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend)); +} + +size_t ggml_backend_get_max_size(ggml_backend_t backend) { + return ggml_backend_buft_get_max_size(ggml_backend_get_default_buffer_type(backend)); } -void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); +void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + if (backend->iface.set_tensor_async == NULL) { + ggml_backend_tensor_set(tensor, data, offset, size); + } else { + backend->iface.set_tensor_async(backend, tensor, data, offset, size); + } } -void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); +void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + if (backend->iface.get_tensor_async == NULL) { + ggml_backend_tensor_get(tensor, data, offset, size); + } else { + backend->iface.get_tensor_async(backend, tensor, data, offset, size); + } } -void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_t backend = ggml_get_backend(tensor); +GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(backend != NULL && "tensor backend not set"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - backend->iface.set_tensor_async(backend, tensor, data, offset, size); - backend->iface.synchronize(backend); + if (!size) { + return; + } + + buf->iface.set_tensor(buf, tensor, data, offset, size); } -void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_backend_t backend = ggml_get_backend(tensor); +GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(backend != NULL && "tensor backend not set"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - backend->iface.get_tensor_async(backend, tensor, data, offset, size); - backend->iface.synchronize(backend); + if (!size) { + return; + } + + buf->iface.get_tensor(buf, tensor, data, offset, size); } void ggml_backend_synchronize(ggml_backend_t backend) { + if (backend->iface.synchronize == NULL) { + return; + } + backend->iface.synchronize(backend); } ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend->iface.graph_plan_create != NULL); + return backend->iface.graph_plan_create(backend, cgraph); } void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend->iface.graph_plan_free != NULL); + backend->iface.graph_plan_free(backend, plan); } -void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - backend->iface.graph_plan_compute(backend, plan); +enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend->iface.graph_plan_compute != NULL); + + return backend->iface.graph_plan_compute(backend, plan); +} + +enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph); + ggml_backend_synchronize(backend); + return err; } -bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { +enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) { return backend->iface.graph_compute(backend, cgraph); } @@ -164,6 +286,13 @@ bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * return backend->iface.supports_op(backend, op); } +bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) { + if (backend->iface.offload_op != NULL) { + return backend->iface.offload_op(backend, op); + } + return false; +} + // backend copy static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { @@ -182,27 +311,20 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml } void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { - //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]); - //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]); GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); - // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src)); - if (src == dst) { return; } - // TODO: allow backends to support copy to/from same backend - - if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) { - ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst); - } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) { - ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst); - } else { - // shouldn't be hit when copying from/to CPU - #ifndef NDEBUG - fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend)); - #endif + if (ggml_backend_buffer_is_host(src->buffer)) { + ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src)); + } else if (ggml_backend_buffer_is_host(dst->buffer)) { + ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); + } else if (!ggml_backend_buffer_copy_tensor(src, dst)) { +#ifndef NDEBUG + fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); +#endif size_t nbytes = ggml_nbytes(src); void * data = malloc(nbytes); ggml_backend_tensor_get(src, data, 0, nbytes); @@ -211,318 +333,846 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } } -// backend CPU +void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); -struct ggml_backend_cpu_context { - int n_threads; - void * work_data; - size_t work_size; -}; + if (src == dst) { + return; + } -static const char * ggml_backend_cpu_name(ggml_backend_t backend) { - return "CPU"; + if (backend_dst->iface.cpy_tensor_async != NULL) { + if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) { + return; + } + } - UNUSED(backend); + // an async copy would normally happen after all the queued operations on both backends are completed + // sync src, set_async dst + if (ggml_backend_buffer_is_host(src->buffer)) { + ggml_backend_synchronize(backend_src); + ggml_backend_tensor_set_async(backend_dst, dst, src->data, 0, ggml_nbytes(src)); + } else { + ggml_backend_synchronize(backend_src); + ggml_backend_tensor_copy(src, dst); + ggml_backend_synchronize(backend_dst); + } } -static void ggml_backend_cpu_free(ggml_backend_t backend) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - free(cpu_ctx->work_data); - free(cpu_ctx); - free(backend); +// events + +ggml_backend_event_t ggml_backend_event_new(ggml_backend_t backend) { + if (backend->iface.event_new == NULL) { + return NULL; + } + return backend->iface.event_new(backend); +} + +void ggml_backend_event_free(ggml_backend_event_t event) { + if (event == NULL) { + return; + } + event->backend->iface.event_free(event); } -static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { - return (void *)buffer->context; +void ggml_backend_event_record(ggml_backend_event_t event) { + GGML_ASSERT(event->backend->iface.event_record != NULL); + + event->backend->iface.event_record(event); } -static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { - free(buffer->context); - UNUSED(buffer); +void ggml_backend_event_synchronize(ggml_backend_event_t event) { + GGML_ASSERT(event->backend->iface.event_synchronize != NULL); + + event->backend->iface.event_synchronize(event); } -static struct ggml_backend_buffer_i cpu_backend_buffer_i = { - /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, - /* .get_base = */ ggml_backend_cpu_buffer_get_base, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .init_tensor = */ NULL, // no initialization required - /* .free_tensor = */ NULL, // no cleanup required -}; +void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_ASSERT(backend->iface.event_wait != NULL); -// for buffers from ptr, free is not called -static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { - /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed - /* .get_base = */ ggml_backend_cpu_buffer_get_base, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .init_tensor = */ NULL, - /* .free_tensor = */ NULL, + backend->iface.event_wait(backend, event); +} + +// backend registry + +#define GGML_REG_MAX_BACKENDS 16 + +struct ggml_backend_reg { + char name[128]; + ggml_backend_init_fn init_fn; + ggml_backend_buffer_type_t default_buffer_type; + void * user_data; }; -static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 +static struct ggml_backend_reg ggml_backend_registry[GGML_REG_MAX_BACKENDS]; +static size_t ggml_backend_registry_count = 0; -static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) { - size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned - void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC? +GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data); + +GGML_CALL static void ggml_backend_registry_init(void) { + static bool initialized = false; - GGML_ASSERT(data != NULL && "failed to allocate buffer"); + if (initialized) { + return; + } + + initialized = true; + + ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL); + + // add forward decls here to avoid including the backend headers +#ifdef GGML_USE_CUDA + extern GGML_CALL void ggml_backend_cuda_reg_devices(void); + ggml_backend_cuda_reg_devices(); +#endif + +#ifdef GGML_USE_SYCL + extern void ggml_backend_sycl_reg_devices(void); + ggml_backend_sycl_reg_devices(); +#endif - return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size); +#ifdef GGML_USE_METAL + extern GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); + extern GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); + ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL); +#endif + +#ifdef GGML_USE_VULKAN + extern GGML_CALL int ggml_backend_vk_reg_devices(void); + ggml_backend_vk_reg_devices(); +#endif + +#ifdef GGML_USE_KOMPUTE + extern GGML_CALL void ggml_backend_kompute_reg_devices(void); + ggml_backend_kompute_reg_devices(); +#endif } -static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) { - return TENSOR_ALIGNMENT; - UNUSED(backend); +GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) { + GGML_ASSERT(ggml_backend_registry_count < GGML_REG_MAX_BACKENDS); + + size_t id = ggml_backend_registry_count; + + ggml_backend_registry[id] = (struct ggml_backend_reg) { + /* .name = */ {0}, + /* .fn = */ init_fn, + /* .default_buffer_type = */ default_buffer_type, + /* .user_data = */ user_data, + }; + + snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name); + +#ifndef NDEBUG + fprintf(stderr, "%s: registered backend %s\n", __func__, name); +#endif + + ggml_backend_registry_count++; } -static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); +size_t ggml_backend_reg_get_count(void) { + ggml_backend_registry_init(); - memcpy((char *)tensor->data + offset, data, size); + return ggml_backend_registry_count; +} + +size_t ggml_backend_reg_find_by_name(const char * name) { + ggml_backend_registry_init(); + + for (size_t i = 0; i < ggml_backend_registry_count; i++) { + // TODO: case insensitive in a portable way + if (strcmp(ggml_backend_registry[i].name, name) == 0) { + return i; + } + } - UNUSED(backend); + // not found + return SIZE_MAX; } -static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); +// init from backend:params string +ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) { + ggml_backend_registry_init(); - memcpy(data, (const char *)tensor->data + offset, size); + const char * params = strchr(backend_str, ':'); + char backend_name[128]; + if (params == NULL) { + snprintf(backend_name, sizeof(backend_name), "%s", backend_str); + params = ""; + } else { + snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str); + params++; + } + + size_t backend_i = ggml_backend_reg_find_by_name(backend_name); - UNUSED(backend); + if (backend_i == SIZE_MAX) { + fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name); + return NULL; + } + + return ggml_backend_reg_init_backend(backend_i, params); } -static void ggml_backend_cpu_synchronize(ggml_backend_t backend) { - UNUSED(backend); +const char * ggml_backend_reg_get_name(size_t i) { + ggml_backend_registry_init(); + + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_registry[i].name; } -static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { - ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); +ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) { + ggml_backend_registry_init(); - UNUSED(backend); + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data); } -static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { - ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src)); +ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) { + ggml_backend_registry_init(); - UNUSED(backend); + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_registry[i].default_buffer_type; } -struct ggml_backend_plan_cpu { - struct ggml_cplan cplan; - struct ggml_cgraph cgraph; -}; +ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) { + ggml_backend_registry_init(); -static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size); +} - struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); +// backend CPU - cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); - cpu_plan->cgraph = *cgraph; +static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment - if (cpu_plan->cplan.work_size > 0) { - cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); - } +GGML_CALL static const char * ggml_backend_cpu_buffer_name(ggml_backend_buffer_t buffer) { + return "CPU"; - return cpu_plan; + GGML_UNUSED(buffer); } -static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; +GGML_CALL static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { + uintptr_t data = (uintptr_t)buffer->context; - free(cpu_plan->cplan.work_data); - free(cpu_plan); + // align the buffer + if (data % TENSOR_ALIGNMENT != 0) { + data = GGML_PAD(data, TENSOR_ALIGNMENT); + } - UNUSED(backend); + return (void *)data; } -static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; +GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + free(buffer->context); +} - ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); +GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + memcpy((char *)tensor->data + offset, data, size); - UNUSED(backend); + GGML_UNUSED(buffer); } -static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; +GGML_CALL static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + memcpy(data, (const char *)tensor->data + offset, size); - struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + GGML_UNUSED(buffer); +} - if (cpu_ctx->work_size < cplan.work_size) { - // TODO: may be faster to free and use malloc to avoid the copy - cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size); - cpu_ctx->work_size = cplan.work_size; +GGML_CALL static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; } + return false; - cplan.work_data = cpu_ctx->work_data; - - ggml_graph_compute(cgraph, &cplan); + GGML_UNUSED(buffer); } -static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - return true; - UNUSED(backend); - UNUSED(op); +GGML_CALL static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); } -static struct ggml_backend_i cpu_backend_i = { - /* .get_name = */ ggml_backend_cpu_name, - /* .free = */ ggml_backend_cpu_free, - /* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_get_alignment, - /* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async, - /* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async, - /* .synchronize = */ ggml_backend_cpu_synchronize, - /* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from, - /* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to, - /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, - /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, - /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, - /* .graph_compute = */ ggml_backend_cpu_graph_compute, - /* .supports_op = */ ggml_backend_cpu_supports_op, +static struct ggml_backend_buffer_i cpu_backend_buffer_i = { + /* .get_name = */ ggml_backend_cpu_buffer_name, + /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, }; -ggml_backend_t ggml_backend_cpu_init(void) { - struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); - - ctx->n_threads = GGML_DEFAULT_N_THREADS; - ctx->work_data = NULL; - ctx->work_size = 0; +// for buffers from ptr, free is not called +static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { + /* .get_name = */ ggml_backend_cpu_buffer_name, + /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; - ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); +GGML_CALL static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU"; - *cpu_backend = (struct ggml_backend) { - /* .interface = */ cpu_backend_i, - /* .context = */ ctx - }; - return cpu_backend; + GGML_UNUSED(buft); } -bool ggml_backend_is_cpu(ggml_backend_t backend) { - return backend->iface.get_name == ggml_backend_cpu_name; +GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned + void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h) + if (data == NULL) { + fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size); + return NULL; + } + + return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size); } -void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { - GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); +GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; - struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; - ctx->n_threads = n_threads; + GGML_UNUSED(buft); } -ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) { - return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size); +GGML_CALL static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return ggml_backend_is_cpu(backend); + + GGML_UNUSED(buft); } -// scheduler +GGML_CALL static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; -#define GGML_MAX_BACKENDS 4 -#define GGML_MAX_SPLITS 256 -#define GGML_MAX_SPLIT_INPUTS 16 + GGML_UNUSED(buft); +} -struct ggml_backend_sched_split { - ggml_tallocr_t tallocr; - int i_start; - int i_end; - struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS]; - int n_inputs; - struct ggml_cgraph * graph; -}; +GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend, + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .context = */ NULL, + }; -struct ggml_backend_sched { - int n_backends; - ggml_backend_t backends[GGML_MAX_BACKENDS]; - ggml_tallocr_t tallocs[GGML_MAX_BACKENDS]; + return &ggml_backend_cpu_buffer_type; +} - ggml_gallocr_t galloc; +#ifdef GGML_USE_CPU_HBM - struct ggml_hash_set hash_set; - ggml_tallocr_t * node_talloc; // [hash_set.size] - struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS] +// buffer type HBM - struct ggml_cgraph * graph; - struct ggml_backend_sched_split splits[GGML_MAX_SPLITS]; - int n_splits; +#include - struct ggml_context * ctx; +GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_HBM"; - // align context_buffer to GGML_MEM_ALIGN - #ifdef _MSC_VER - __declspec(align(GGML_MEM_ALIGN)) - #else - __attribute__((aligned(GGML_MEM_ALIGN))) - #endif - char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)]; -}; + GGML_UNUSED(buft); +} -#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node) -#define node_allocr(node) sched->node_talloc[hash_id(node)] +GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_get_name(ggml_backend_buffer_t buf) { + return "CPU_HBM"; -static bool ggml_is_view_op(enum ggml_op op) { - return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE; + GGML_UNUSED(buf); } -// returns the priority of the backend, lower is better -static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) { - for (int i = 0; i < sched->n_backends; i++) { - if (sched->backends[i] == backend) { - return i; - } - } - return INT_MAX; +GGML_CALL static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) { + hbw_free(buffer->context); } -static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) { - for (int i = 0; i < sched->n_backends; i++) { - if (sched->tallocs[i] == allocr) { - return i; - } +GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + //void * ptr = hbw_malloc(size); + void * ptr; + int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size); + if (result != 0) { + fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size); + return NULL; } - return INT_MAX; -} + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.get_name = ggml_backend_cpu_hbm_buffer_get_name; + buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer; + + return buffer; +} + +ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend, + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_buffer_type_hbm; +} +#endif + +struct ggml_backend_cpu_context { + int n_threads; + void * work_data; + size_t work_size; + + ggml_abort_callback abort_callback; + void * abort_callback_data; +}; + +GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) { + return "CPU"; + + GGML_UNUSED(backend); +} + +GGML_CALL static void ggml_backend_cpu_free(ggml_backend_t backend) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + free(cpu_ctx->work_data); + free(cpu_ctx); + free(backend); +} + +GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(backend); +} + +struct ggml_backend_plan_cpu { + struct ggml_cplan cplan; + struct ggml_cgraph cgraph; +}; + +GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); + + cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + cpu_plan->cgraph = *cgraph; // FIXME: deep copy + + if (cpu_plan->cplan.work_size > 0) { + cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); + if (cpu_plan->cplan.work_data == NULL) { + free(cpu_plan); + return NULL; + } + } + + cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; + cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return cpu_plan; +} + +GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + free(cpu_plan->cplan.work_data); + free(cpu_plan); + + GGML_UNUSED(backend); +} + +GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); + + GGML_UNUSED(backend); +} + +GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + + if (cpu_ctx->work_size < cplan.work_size) { + free(cpu_ctx->work_data); + cpu_ctx->work_data = malloc(cplan.work_size); + if (cpu_ctx->work_data == NULL) { + cpu_ctx->work_size = 0; + return GGML_STATUS_ALLOC_FAILED; + } + cpu_ctx->work_size = cplan.work_size; + } + cplan.work_data = cpu_ctx->work_data; + + cplan.abort_callback = cpu_ctx->abort_callback; + cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return ggml_graph_compute(cgraph, &cplan); +} + +GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_CPY: + return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS && op->type != GGML_TYPE_IQ1_S; // missing type_traits.from_float + case GGML_OP_MUL_MAT: + return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; + default: + return true; + } + + GGML_UNUSED(backend); +} + +static struct ggml_backend_i cpu_backend_i = { + /* .get_name = */ ggml_backend_cpu_name, + /* .free = */ ggml_backend_cpu_free, + /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, + /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cpu_graph_compute, + /* .supports_op = */ ggml_backend_cpu_supports_op, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static ggml_guid_t ggml_backend_cpu_guid(void) { + static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 }; + return &guid; +} + +ggml_backend_t ggml_backend_cpu_init(void) { + struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); + if (ctx == NULL) { + return NULL; + } + + ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->work_data = NULL; + ctx->work_size = 0; + ctx->abort_callback = NULL; + ctx->abort_callback_data = NULL; + + ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); + if (cpu_backend == NULL) { + free(ctx); + return NULL; + } + + *cpu_backend = (struct ggml_backend) { + /* .guid = */ ggml_backend_cpu_guid(), + /* .interface = */ cpu_backend_i, + /* .context = */ ctx + }; + return cpu_backend; +} + +GGML_CALL bool ggml_backend_is_cpu(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid()); +} + +void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->n_threads = n_threads; +} + +void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = abort_callback_data; +} + +GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { + GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); + return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size); +} + +GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) { + return ggml_backend_cpu_init(); + + GGML_UNUSED(params); + GGML_UNUSED(user_data); +} + +// multi-buffer buffer + +struct ggml_backend_multi_buffer_context { + ggml_backend_buffer_t * buffers; + size_t n_buffers; +}; + +typedef struct ggml_backend_multi_buffer_context * ggml_backend_multi_buffer_context_t; + +GGML_CALL static const char * ggml_backend_multi_buffer_get_name(ggml_backend_buffer_t buffer) { + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; + + return ctx->buffers[0]->iface.get_name(ctx->buffers[0]); +} + +GGML_CALL static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; + for (size_t i = 0; i < ctx->n_buffers; i++) { + ggml_backend_buffer_free(ctx->buffers[i]); + } + + free(ctx->buffers); + free(ctx); +} + +GGML_CALL static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; + for (size_t i = 0; i < ctx->n_buffers; i++) { + ggml_backend_buffer_clear(ctx->buffers[i], value); + } +} + +static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(void) { + static struct ggml_backend_buffer_i multi_backend_buffer_i = { + /* .get_name = */ ggml_backend_multi_buffer_get_name, + /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer, + /* .get_base = */ NULL, + /* .init_tensor = */ NULL, + /* .set_tensor = */ NULL, + /* .get_tensor = */ NULL, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_multi_buffer_clear, + /* .reset = */ NULL, + }; + + return multi_backend_buffer_i; +} + +GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) { + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) malloc(sizeof(struct ggml_backend_multi_buffer_context)); + ctx->n_buffers = n_buffers; + ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t)); + + GGML_ASSERT(ctx->buffers != NULL); + + size_t total_size = 0; + for (size_t i = 0; i < n_buffers; i++) { + ctx->buffers[i] = buffers[i]; + total_size += ggml_backend_buffer_get_size(buffers[i]); + } + + return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_context_interface(), ctx, total_size); +} + +GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) { + return buffer->iface.get_name == ggml_backend_multi_buffer_get_name; +} + +GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer)); + ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; + for (size_t i = 0; i < ctx->n_buffers; i++) { + ggml_backend_buffer_set_usage(ctx->buffers[i], usage); + } +} + +// creates a copy of the tensor with the same memory layout +static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) { + struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor); + for (int i = 0; i < GGML_MAX_DIMS; i++) { + dup->nb[i] = tensor->nb[i]; + } + return dup; +} + +static bool ggml_is_view_op(enum ggml_op op) { + return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE; +} + +// scheduler + +#ifndef GGML_SCHED_MAX_BACKENDS +#define GGML_SCHED_MAX_BACKENDS 16 +#endif + +#ifndef GGML_SCHED_MAX_SPLITS +#define GGML_SCHED_MAX_SPLITS 2048 +#endif + +#ifndef GGML_SCHED_MAX_SPLIT_INPUTS +#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC +#endif + +#ifndef GGML_SCHED_MAX_COPIES +#define GGML_SCHED_MAX_COPIES 4 +#endif + +struct ggml_backend_sched_split { + int backend_id; + int i_start; + int i_end; + struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; + int n_inputs; + // graph view of this split + struct ggml_cgraph graph; +}; + +struct ggml_backend_sched { + bool is_reset; // true if the scheduler has been reset since the last graph split + bool is_alloc; + + int n_backends; + + ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS]; + ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS]; + ggml_gallocr_t galloc; + + // hash keys of the nodes in the graph + struct ggml_hash_set hash_set; + // hash values + int * tensor_backend_id; + struct ggml_tensor * (* tensor_copies)[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; + + int * node_backend_ids; // [graph_size] + int * leaf_backend_ids; // [graph_size] + + // copy of the graph with modified inputs + struct ggml_cgraph * graph; + + // graph splits + struct ggml_backend_sched_split * splits; + int n_splits; + int splits_capacity; + + // pipeline parallelism support + int n_copies; + int cur_copy; + ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; + struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; + int n_graph_inputs; + + struct ggml_context * ctx; + + ggml_backend_sched_eval_callback callback_eval; + void * callback_eval_user_data; + + // align context_buffer to GGML_MEM_ALIGN +#ifdef _MSC_VER + __declspec(align(GGML_MEM_ALIGN)) +#else + __attribute__((aligned(GGML_MEM_ALIGN))) +#endif + char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)]; +}; + +#define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor) +#define tensor_backend_id(tensor) sched->tensor_backend_id[hash_id(tensor)] + +// returns the priority of the backend, lower id is higher priority +static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) { + for (int i = 0; i < sched->n_backends; i++) { + if (sched->backends[i] == backend) { + return i; + } + } + return -1; +} + +static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor) { + ggml_backend_buffer_t buffer = tensor->buffer; + if (buffer == NULL) { + return -1; + } + + // find highest prio backend that supports the buffer type + for (int i = 0; i < sched->n_backends; i++) { + if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) { + return i; + } + } + + fprintf(stderr, "%s: error: no backend supports buffer type %s used in tensor %s\n", + __func__, ggml_backend_buffer_name(buffer), tensor->name); + GGML_ASSERT(false); + + return -1; +} + +#if 0 +static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only +#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__) +#define GET_CAUSE(node) causes[hash_id(node)] +#else +#define SET_CAUSE(node, ...) +#define GET_CAUSE(node) "" +#endif // returns the backend that should be used for the node based on the current locations -char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove -static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) { - // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there - // ie. kv cache updates - // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend. - // dst - ggml_backend_t cur_backend = ggml_get_backend(node); - if (cur_backend != NULL) { - sprintf(causes[hash_id(node)], "1.dst"); - return cur_backend; +static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) { + // TODO: use supports_op to check if the backend supports the op + + // assign pre-allocated nodes to their backend + int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor); + if (cur_backend_id != -1) { + SET_CAUSE(tensor, "1.dst"); + return cur_backend_id; } // view_src - if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) { - sprintf(causes[hash_id(node)], "1.vsrc"); - return ggml_get_backend(node->view_src); + if (tensor->view_src != NULL) { + cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src); + if (cur_backend_id != -1) { + SET_CAUSE(tensor, "1.vsrc"); + return cur_backend_id; + } } - // src - int cur_prio = INT_MAX; - size_t cur_size = 0; + // graph input + if (tensor->flags & GGML_TENSOR_FLAG_INPUT) { + cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU) + SET_CAUSE(tensor, "1.inp"); + return cur_backend_id; + } + // assign nodes that use weights to the backend of the weights + // operations with weights are preferably run on the same backend as the weights for (int i = 0; i < GGML_MAX_SRC; i++) { - const struct ggml_tensor * src = node->src[i]; + const struct ggml_tensor * src = tensor->src[i]; if (src == NULL) { - break; + continue; } - ggml_backend_t src_backend = ggml_get_backend(src); - if (src_backend != NULL) { - int src_prio = sched_backend_prio(sched, src_backend); - size_t src_size = ggml_nbytes(src); - if (src_prio < cur_prio && src_size >= cur_size) { - cur_prio = src_prio; - cur_size = src_size; - cur_backend = src_backend; - sprintf(causes[hash_id(node)], "1.src%d", i); + if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src); + // check if a backend with higher prio wants to offload the op + if (src_backend_id == sched->n_backends - 1) { + for (int b = 0; b < src_backend_id; b++) { + if (ggml_backend_offload_op(sched->backends[b], tensor)) { + SET_CAUSE(tensor, "1.off"); + return b; + } + } } + SET_CAUSE(tensor, "1.wgt%d", i); + return src_backend_id; } } - return cur_backend; + + return -1; } static char * fmt_size(size_t size) { @@ -535,14 +1185,16 @@ static char * fmt_size(size_t size) { return buffer; } -static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { +static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { int cur_split = 0; for (int i = 0; i < graph->n_nodes; i++) { if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) { - ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend; - fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs); + ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id]; + fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), + sched->splits[cur_split].n_inputs); for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) { - fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j]))); + fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, + fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j]))); } fprintf(stderr, "\n"); cur_split++; @@ -551,341 +1203,558 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra if (ggml_is_view_op(node->op)) { continue; } - ggml_tallocr_t node_allocr = node_allocr(node); - ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL; - fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]); + ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); + fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name, + fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node)); for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { - break; + continue; } - ggml_tallocr_t src_allocr = node_allocr(src); - ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL; - fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]); + ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src); + fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name, + fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src)); } fprintf(stderr, "\n"); } } -// creates a copy of the tensor with the same memory layout -static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) { - struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor); - for (int i = 0; i < GGML_MAX_DIMS; i++) { - dup->nb[i] = tensor->nb[i]; - } - return dup; -} +//#define DEBUG_PASS1 +//#define DEBUG_PASS2 +//#define DEBUG_PASS3 +//#define DEBUG_PASS4 // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend -// TODO: merge passes -static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { - // reset state - size_t hash_size = sched->hash_set.size; - memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); - memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size); - memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size); +static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + // reset splits sched->n_splits = 0; + sched->n_graph_inputs = 0; + sched->is_reset = false; struct ggml_init_params params = { - /*.mem_size = */ sizeof(sched->context_buffer), - /*.mem_buffer = */ sched->context_buffer, - /*.no_alloc = */ true + /* .mem_size = */ sizeof(sched->context_buffer), + /* .mem_buffer = */ sched->context_buffer, + /* .no_alloc = */ true }; - if (sched->ctx != NULL) { - ggml_free(sched->ctx); - } + ggml_free(sched->ctx); sched->ctx = ggml_init(params); + if (sched->ctx == NULL) { + fprintf(stderr, "%s: failed to initialize context\n", __func__); + GGML_ASSERT(false); + } - // pass 1: assign backends to ops with allocated inputs + // pass 1: assign backends to ops with pre-allocated inputs for (int i = 0; i < graph->n_leafs; i++) { struct ggml_tensor * leaf = graph->leafs[i]; - if (node_allocr(leaf) != NULL) { + int * leaf_backend_id = &tensor_backend_id(leaf); + if (*leaf_backend_id != -1) { // do not overwrite user assignments continue; } - ggml_backend_t leaf_backend = ggml_get_backend(leaf); - if (leaf_backend == NULL && leaf->view_src != NULL) { - leaf_backend = ggml_get_backend(leaf->view_src); - } - if (leaf_backend != NULL) { - node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend); - } + *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf); } for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; - if (node_allocr(node) != NULL) { + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { // do not overwrite user assignments continue; } - ggml_backend_t node_backend = sched_backend_from_cur(sched, node); - if (node_backend != NULL) { - node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend); + *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node); + // src + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + int * src_backend_id = &tensor_backend_id(src); + if (*src_backend_id == -1) { + *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src); + } } } - //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#ifdef DEBUG_PASS1 + fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); +#endif - // pass 2: assign backends to ops from current assignments - // TODO: - // - reuse sched_backend_from_cur - for (int i = 0; i < graph->n_nodes; i++) { - struct ggml_tensor * node = graph->nodes[i]; - ggml_tallocr_t node_allocr = node_allocr(node); - if (node_allocr == NULL) { - int cur_prio = INT_MAX; - size_t cur_size = 0; - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * src = node->src[j]; - if (src == NULL) { - break; + // pass 2: expand current backend assignments + // assign the same backend to adjacent nodes + // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend) + // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops + + + // pass 2.2 expand gpu down + { + int cur_backend_id = -1; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + if (*node_backend_id == sched->n_backends - 1) { + // skip cpu (lowest prio backend) + cur_backend_id = -1; + } else { + cur_backend_id = *node_backend_id; } - ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr != NULL) { - int src_prio = sched_allocr_prio(sched, src_allocr); - size_t src_size = ggml_nbytes(src); - if (src_prio < cur_prio && src_size >= cur_size) { - cur_prio = src_prio; - cur_size = src_size; - node_allocr = src_allocr; - sprintf(causes[hash_id(node)], "2.src%d", j); - } + } else { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.2"); + } + } + } + // pass 2.1 expand gpu up + { + int cur_backend_id = -1; + for (int i = graph->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + if (*node_backend_id == sched->n_backends - 1) { + // skip cpu (lowest prio backend) + cur_backend_id = -1; + } else { + cur_backend_id = *node_backend_id; } + } else { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.1"); } - if (node_allocr != NULL) { - node_allocr(node) = node_allocr; + } + } + // pass 2.4 expand rest down + { + int cur_backend_id = -1; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + cur_backend_id = *node_backend_id; + } else { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.4"); + } + } + } + // pass 2.3 expand rest up + { + int cur_backend_id = -1; + for (int i = graph->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id != -1) { + cur_backend_id = *node_backend_id; + } else { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.3"); } } } - //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); - // pass 3: assign backends to remaining src from dst (should only be leafs) +#ifdef DEBUG_PASS2 + fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); +#endif + + // pass 3: assign backends to remaining src from dst and view_src for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; - ggml_tallocr_t node_allocr = node_allocr(node); + int * cur_backend_id = &tensor_backend_id(node); + if (node->view_src != NULL && *cur_backend_id == -1) { + *cur_backend_id = tensor_backend_id(node->view_src); + SET_CAUSE(node, "3.vsrc"); + } for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { - break; + continue; } - ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr == NULL) { - node_allocr(src) = node_allocr; + int * src_backend_id = &tensor_backend_id(src); + if (*src_backend_id == -1) { + if (src->view_src != NULL) { + // views are always on the same backend as the source + *src_backend_id = tensor_backend_id(src->view_src); + SET_CAUSE(src, "3.vsrc"); + } else { + *src_backend_id = *cur_backend_id; + SET_CAUSE(src, "3.cur"); + } } } } - //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); +#ifdef DEBUG_PASS3 + fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); +#endif // pass 4: split graph, find tensors that need to be copied - // TODO: - // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost - // find first backend - int cur_split = 0; - for (int i = 0; i < graph->n_nodes; i++) { - struct ggml_tensor * node = graph->nodes[i]; - if (node->view_src == NULL) { - sched->splits[0].tallocr = node_allocr(node); - break; - } - } - sched->splits[0].i_start = 0; - sched->splits[0].n_inputs = 0; - memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK - ggml_tallocr_t cur_allocr = sched->splits[0].tallocr; - size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr); - for (int i = 0; i < graph->n_nodes; i++) { - struct ggml_tensor * node = graph->nodes[i]; - - if (ggml_is_view_op(node->op)) { - continue; + { + int i_split = 0; + struct ggml_backend_sched_split * split = &sched->splits[0]; + // find the backend of the first split, skipping view ops + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (!ggml_is_view_op(node->op)) { + split->backend_id = tensor_backend_id(node); + break; + } } + split->i_start = 0; + split->n_inputs = 0; + memset(split->inputs, 0, sizeof(split->inputs)); //HACK + int cur_backend_id = split->backend_id; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + + if (ggml_is_view_op(node->op)) { + continue; + } - ggml_tallocr_t node_allocr = node_allocr(node); + const int node_backend_id = tensor_backend_id(node); - if (node_allocr != cur_allocr) { - sched->splits[cur_split].i_end = i; - cur_split++; - GGML_ASSERT(cur_split < GGML_MAX_SPLITS); - sched->splits[cur_split].tallocr = node_allocr; - sched->splits[cur_split].i_start = i; - sched->splits[cur_split].n_inputs = 0; - memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK - cur_allocr = node_allocr; - cur_backend_id = sched_allocr_prio(sched, cur_allocr); - } + GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now - // find inputs that are not on the same backend - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * src = node->src[j]; - if (src == NULL) { - break; + // check if we should start a new split based on the sources of the current node + bool need_new_split = false; + if (node_backend_id == cur_backend_id && split->n_inputs > 0) { + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + // check if a weight is on a different backend + // by starting a new split, the memory of the previously offloaded weights can be reused + if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + int src_backend_id = tensor_backend_id(src); + if (src_backend_id != -1 && src_backend_id != cur_backend_id) { + need_new_split = true; + break; + } + } + // check if the split has too many inputs + if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) { + const size_t id = hash_id(src); + int src_backend_id = sched->tensor_backend_id[id]; + if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL) { + //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name); + need_new_split = true; + break; + } + } + } } - ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr != node_allocr) { - int n_inputs = sched->splits[cur_split].n_inputs++; - GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS); - sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src; - - // create copies - size_t id = hash_id(src); - if (sched->node_copies[id][cur_backend_id] == NULL) { - struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); - sched->node_copies[id][cur_backend_id] = tensor_copy; - node_allocr(tensor_copy) = cur_allocr; - ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend; - ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name); + + if (node_backend_id != cur_backend_id || need_new_split) { + split->i_end = i; + i_split++; + if (i_split >= sched->splits_capacity) { + sched->splits_capacity *= 2; + sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); + GGML_ASSERT(sched->splits != NULL); } - node->src[j] = sched->node_copies[id][cur_backend_id]; + GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS); + split = &sched->splits[i_split]; + split->backend_id = node_backend_id; + split->i_start = i; + split->n_inputs = 0; + cur_backend_id = node_backend_id; } - } - } - sched->splits[cur_split].i_end = graph->n_nodes; - sched->n_splits = cur_split + 1; - //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout); + // find inputs that are not on the same backend + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } -#if 1 - // sanity check: all sources should have the same backend as the node - for (int i = 0; i < graph->n_nodes; i++) { - struct ggml_tensor * node = graph->nodes[i]; - ggml_tallocr_t node_allocr = node_allocr(node); - if (node_allocr == NULL) { - fprintf(stderr, "!!!!!!! %s has no backend\n", node->name); - } - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * src = node->src[j]; - if (src == NULL) { - break; - } - ggml_tallocr_t src_allocr = node_allocr(src); - if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now - fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n", - node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL", - j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL"); + const int src_backend_id = tensor_backend_id(src); + assert(src_backend_id != -1); // all inputs should be assigned by now + + if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) { + size_t id = hash_id(src); + if (sched->tensor_copies[id][src_backend_id][0] == NULL) { + ggml_backend_t backend = sched->backends[src_backend_id]; + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * tensor_copy; + if (c == sched->cur_copy) { + tensor_copy = src; // use the original tensor as the current copy + } else { + tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); + ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); + } + if (sched->n_copies > 1) { + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor + } + sched->tensor_copies[id][src_backend_id][c] = tensor_copy; + SET_CAUSE(tensor_copy, "4.cpy"); + } + int n_graph_inputs = sched->n_graph_inputs++; + GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS); + sched->graph_inputs[n_graph_inputs] = src; + } + } + + if (src_backend_id != node_backend_id) { + // create a copy of the input in the split's backend + const size_t id = hash_id(src); + if (sched->tensor_copies[id][cur_backend_id][0] == NULL) { + ggml_backend_t backend = sched->backends[cur_backend_id]; + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); + ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c); + if (sched->n_copies > 1) { + ggml_set_input(tensor_copy); + ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor + } + sched->tensor_copies[id][cur_backend_id][c] = tensor_copy; + SET_CAUSE(tensor_copy, "4.cpy"); + } + int n_inputs = split->n_inputs++; + GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS); + split->inputs[n_inputs] = src; + } + node->src[j] = sched->tensor_copies[id][cur_backend_id][sched->cur_copy]; + } } } + split->i_end = graph->n_nodes; + sched->n_splits = i_split + 1; } +#ifdef DEBUG_PASS4 + fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); #endif // create copies of the graph for each split - // FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way - struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false); + // TODO: avoid this copy + struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2, false); for (int i = 0; i < sched->n_splits; i++) { struct ggml_backend_sched_split * split = &sched->splits[i]; - split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end); + split->graph = ggml_graph_view(graph, split->i_start, split->i_end); // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split for (int j = 0; j < split->n_inputs; j++) { + assert(graph_copy->size > (graph_copy->n_nodes + 1)); + struct ggml_tensor * input = split->inputs[j]; - struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)]; - input_cpy->src[0] = input; + const size_t input_id = hash_id(input); + struct ggml_tensor * input_cpy = sched->tensor_copies[input_id][split->backend_id][sched->cur_copy]; + + // add a dependency to the input source so that it is not freed before the copy is done + struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input); + input_dep->src[0] = input; + sched->node_backend_ids[graph_copy->n_nodes] = sched->tensor_backend_id[input_id]; + graph_copy->nodes[graph_copy->n_nodes++] = input_dep; + + // add a dependency to the input copy so that it is allocated at the start of the split + sched->node_backend_ids[graph_copy->n_nodes] = split->backend_id; graph_copy->nodes[graph_copy->n_nodes++] = input_cpy; } for (int j = split->i_start; j < split->i_end; j++) { + assert(graph_copy->size > graph_copy->n_nodes); + sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]); graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j]; } } + + if (sched->n_copies > 1) { + // add input copies as leafs so that they are allocated first + for (int i = 0; i < sched->n_graph_inputs; i++) { + struct ggml_tensor * input = sched->graph_inputs[i]; + size_t id = hash_id(input); + int backend_id = tensor_backend_id(input); + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c]; + sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; + graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; + } + } + + for (int i = 0; i < sched->n_splits; i++) { + struct ggml_backend_sched_split * split = &sched->splits[i]; + int backend_id = split->backend_id; + for (int j = 0; j < split->n_inputs; j++) { + struct ggml_tensor * input = split->inputs[j]; + size_t id = hash_id(input); + for (int c = 0; c < sched->n_copies; c++) { + struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c]; + sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; + graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; + } + } + } + } + + // add leafs from the original graph + for (int i = 0; i < graph->n_leafs; i++) { + struct ggml_tensor * leaf = graph->leafs[i]; + sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf); + graph_copy->leafs[graph_copy->n_leafs++] = leaf; + } + sched->graph = graph_copy; } -static void sched_alloc_splits(ggml_backend_sched_t sched) { - ggml_gallocr_alloc_graph_n( - sched->galloc, - sched->graph, - sched->hash_set, - sched->node_talloc); -} +static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { + // allocate graph + if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) { + // the re-allocation may cause the split inputs to be moved to a different address + ggml_backend_sched_synchronize(sched); +#ifndef NDEBUG + fprintf(stderr, "%s: failed to allocate graph, reserving\n", __func__); +#endif + ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); + if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) { + fprintf(stderr, "%s: failed to allocate graph\n", __func__); + return false; + } + } -static void sched_compute_splits(ggml_backend_sched_t sched) { - uint64_t copy_us[GGML_MAX_BACKENDS] = {0}; - uint64_t compute_us[GGML_MAX_BACKENDS] = {0}; + return true; +} +static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { struct ggml_backend_sched_split * splits = sched->splits; for (int i = 0; i < sched->n_splits; i++) { struct ggml_backend_sched_split * split = &splits[i]; - ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend; - int split_backend_id = sched_backend_prio(sched, split_backend); + int split_backend_id = split->backend_id; + ggml_backend_t split_backend = sched->backends[split_backend_id]; // copy the input tensors to the split backend - uint64_t copy_start_us = ggml_time_us(); for (int j = 0; j < split->n_inputs; j++) { - struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)]; - if (split->inputs[j]->buffer == NULL) { - if (split->inputs[j]->view_src == NULL) { - fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name); - exit(1); + ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]); + struct ggml_tensor * input = split->inputs[j]; + struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy]; + + if (input->flags & GGML_TENSOR_FLAG_INPUT) { + // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); } - struct ggml_tensor * view = split->inputs[j]; - view->backend = view->view_src->backend; - view->buffer = view->view_src->buffer; - view->data = (char *)view->view_src->data + view->view_offs; - ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view); - } - if (input_cpy->buffer == NULL) { - fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name); - exit(1); + ggml_backend_tensor_copy(input, input_cpy); + } else { + // wait for the split backend to finish using the input before overwriting it + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); + } + ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); } - GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend); - GGML_ASSERT(input_cpy->buffer->backend == split_backend); - ggml_backend_tensor_copy(split->inputs[j], input_cpy); } - // ggml_backend_synchronize(split_backend); - int64_t copy_end_us = ggml_time_us(); - copy_us[split_backend_id] += copy_end_us - copy_start_us; -#if 0 - char split_filename[GGML_MAX_NAME]; - snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend)); - ggml_graph_dump_dot(split->graph, NULL, split_filename); -#endif + if (!sched->callback_eval) { + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); + if (ec != GGML_STATUS_SUCCESS) { + return ec; + } + } else { + // similar to ggml_backend_compare_graph_backend + for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { + struct ggml_tensor * t = split->graph.nodes[j0]; - uint64_t compute_start_us = ggml_time_us(); - ggml_backend_graph_compute(split_backend, split->graph); - // ggml_backend_synchronize(split_backend); - uint64_t compute_end_us = ggml_time_us(); - compute_us[split_backend_id] += compute_end_us - compute_start_us; - } + // check if the user needs data from this node + bool need = sched->callback_eval(t, true, sched->callback_eval_user_data); -#if 0 - // per-backend timings - fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits); - for (int i = 0; i < sched->n_backends; i++) { - if (copy_us[i] > 0 || compute_us[i] > 0) { - fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]); + int j1 = j0; + + // determine the range [j0, j1] of nodes that can be computed together + while (!need && j1 < split->graph.n_nodes - 1) { + t = split->graph.nodes[++j1]; + need = sched->callback_eval(t, true, sched->callback_eval_user_data); + } + + struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); + + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv); + if (ec != GGML_STATUS_SUCCESS) { + return ec; + } + + // TODO: pass backend to the callback, then the user can decide if they want to synchronize + ggml_backend_synchronize(split_backend); + + if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) { + break; + } + + j0 = j1; + } } - } -#endif -} -static void sched_reset(ggml_backend_sched_t sched) { - for (int i = 0; i < sched->n_backends; i++) { - ggml_tallocr_reset(sched->tallocs[i]); + // record the event of this copy + if (split->n_inputs > 0) { + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]); + } + } } + + sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies; + + return GGML_STATUS_SUCCESS; } -ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) { - GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS); +ggml_backend_sched_t ggml_backend_sched_new( + ggml_backend_t * backends, + ggml_backend_buffer_type_t * bufts, + int n_backends, + size_t graph_size, + bool parallel) { + GGML_ASSERT(n_backends > 0); + GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); + GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU + + struct ggml_backend_sched * sched = calloc(sizeof(struct ggml_backend_sched), 1); - struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched)); - memset(sched, 0, sizeof(struct ggml_backend_sched)); + // initialize hash table + sched->hash_set = ggml_hash_set_new(graph_size); + sched->tensor_backend_id = calloc(sizeof(sched->tensor_backend_id[0]), sched->hash_set.size); + sched->tensor_copies = calloc(sizeof(sched->tensor_copies[0]), sched->hash_set.size); - fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024); + const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2; + sched->node_backend_ids = calloc(sizeof(sched->node_backend_ids[0]), nodes_size); + sched->leaf_backend_ids = calloc(sizeof(sched->leaf_backend_ids[0]), nodes_size); sched->n_backends = n_backends; - for (int i = 0; i < n_backends; i++) { - sched->backends[i] = backends[i]; - } - sched->galloc = ggml_gallocr_new(); + sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; - // init measure allocs for each backend - for (int i = 0; i < n_backends; i++) { - sched->tallocs[i] = ggml_tallocr_new_measure_from_backend(backends[i]); + const int initial_splits_capacity = 16; + sched->splits = calloc(sizeof(sched->splits[0]), initial_splits_capacity); + sched->splits_capacity = initial_splits_capacity; + + for (int b = 0; b < n_backends; b++) { + sched->backends[b] = backends[b]; + sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]); + GGML_ASSERT(ggml_backend_buft_supports_backend(sched->bufts[b], backends[b])); + if (sched->n_copies > 1) { + for (int c = 0; c < sched->n_copies; c++) { + sched->events[b][c] = ggml_backend_event_new(backends[b]); + } + } } + sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); + + ggml_backend_sched_reset(sched); + return sched; } @@ -893,58 +1762,334 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { if (sched == NULL) { return; } - for (int i = 0; i < sched->n_backends; i++) { - ggml_tallocr_free(sched->tallocs[i]); + for (int b = 0; b < sched->n_backends; b++) { + for (int c = 0; c < sched->n_copies; c++) { + ggml_backend_event_free(sched->events[b][c]); + } } ggml_gallocr_free(sched->galloc); + ggml_free(sched->ctx); + free(sched->splits); free(sched->hash_set.keys); - free(sched->node_talloc); - free(sched->node_copies); + free(sched->tensor_backend_id); + free(sched->tensor_copies); + free(sched->node_backend_ids); + free(sched->leaf_backend_ids); free(sched); } -void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { - // initialize hash tables - size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS; - sched->hash_set.size = hash_size; - sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size); - sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size); - sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size); +void ggml_backend_sched_reset(ggml_backend_sched_t sched) { + // reset state for the next run + size_t hash_size = sched->hash_set.size; + memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT + memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size); + memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size); + + sched->is_reset = true; + sched->is_alloc = false; +} + +bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { + GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes); - sched_split_graph(sched, measure_graph); - sched_alloc_splits(sched); + ggml_backend_sched_split_graph(sched, measure_graph); - // allocate buffers and reset allocators - for (int i = 0; i < sched->n_backends; i++) { - size_t size = ggml_tallocr_max_size(sched->tallocs[i]); - ggml_tallocr_free(sched->tallocs[i]); - sched->tallocs[i] = ggml_tallocr_new_from_backend(sched->backends[i], size); + // TODO: extract this to a separate function + if (!ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { + return false; + } + + ggml_backend_sched_reset(sched); + ggml_backend_sched_synchronize(sched); + + return true; +} + +bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes); + + ggml_backend_sched_split_graph(sched, graph); + + if (!ggml_backend_sched_alloc_splits(sched)) { + return false; + } + + sched->is_alloc = true; + + return true; +} + +enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph); + ggml_backend_sched_synchronize(sched); + return err; +} + +enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + if (!sched->is_reset && !sched->is_alloc) { + ggml_backend_sched_reset(sched); + } + + if (!sched->is_alloc) { + if (!ggml_backend_sched_alloc_graph(sched, graph)) { + return GGML_STATUS_ALLOC_FAILED; + } } - sched_reset(sched); + return ggml_backend_sched_compute_splits(sched); +} + +void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { + for (int i = 0; i < sched->n_backends; i++) { + ggml_backend_synchronize(sched->backends[i]); + } } -void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { - GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS); +void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { + sched->callback_eval = callback; + sched->callback_eval_user_data = user_data; +} - sched_split_graph(sched, graph); - sched_alloc_splits(sched); - sched_compute_splits(sched); - sched_reset(sched); +int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { + return sched->n_splits; } -ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) { - int backend_index = sched_backend_prio(sched, backend); - return sched->tallocs[backend_index]; +int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) { + return sched->n_copies; } -ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) { - int backend_index = sched_backend_prio(sched, backend); - return ggml_tallocr_get_buffer(sched->tallocs[backend_index]); +size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + + return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } -void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { - int backend_index = sched_backend_prio(sched, backend); +void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); - node_allocr(node) = sched->tallocs[backend_index]; + tensor_backend_id(node) = backend_index; +} + +ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) { + int backend_index = tensor_backend_id(node); + if (backend_index == -1) { + return NULL; + } + return sched->backends[backend_index]; +} + +// utils + +void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->buffer == NULL); + GGML_ASSERT(tensor->view_src != NULL); + GGML_ASSERT(tensor->view_src->buffer != NULL); + GGML_ASSERT(tensor->view_src->data != NULL); + + tensor->buffer = buffer; + tensor->data = (char *)tensor->view_src->data + tensor->view_offs; + tensor->backend = tensor->view_src->backend; + ggml_backend_buffer_init_tensor(buffer, tensor); +} + +void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { + GGML_ASSERT(tensor->buffer == NULL); + GGML_ASSERT(tensor->data == NULL); + GGML_ASSERT(tensor->view_src == NULL); + GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer)); + GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= + (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); + + tensor->buffer = buffer; + tensor->data = addr; + ggml_backend_buffer_init_tensor(buffer, tensor); +} + +static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, + struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) { + + GGML_ASSERT(src != NULL); + GGML_ASSERT(src->data && "graph must be allocated"); + + size_t id = ggml_hash_insert(hash_set, src); + if (id == GGML_HASHTABLE_ALREADY_EXISTS) { + return node_copies[ggml_hash_find(hash_set, src)]; + } + + struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src); + if (src->view_src != NULL) { + dst->view_src = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src); + dst->view_offs = src->view_offs; + } + dst->op = src->op; + memcpy(dst->op_params, src->op_params, sizeof(dst->op_params)); + ggml_set_name(dst, src->name); + + // copy src + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * s = src->src[i]; + if (s == NULL) { + continue; + } + dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s); + } + + node_copies[id] = dst; + return dst; +} + +static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) { + size_t id = ggml_hash_find(hash_set, src); + if (node_init[id]) { + return; + } + node_init[id] = true; + + struct ggml_tensor * dst = node_copies[id]; + if (dst->view_src != NULL) { + graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src); + ggml_backend_view_init(dst->view_src->buffer, dst); + } + else { + ggml_backend_tensor_copy(src, dst); + } + + // init src + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * s = src->src[i]; + if (s == NULL) { + continue; + } + graph_copy_init_tensor(hash_set, node_copies, node_init, s); + } +} + +struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { + struct ggml_hash_set hash_set = { + /* .size = */ graph->visited_hash_table.size, + /* .keys = */ calloc(sizeof(hash_set.keys[0]), graph->visited_hash_table.size) // NOLINT + }; + struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]), hash_set.size); // NOLINT + bool * node_init = calloc(sizeof(node_init[0]), hash_set.size); + + struct ggml_init_params params = { + /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false), + /* .mem_buffer = */ NULL, + /* .no_alloc = */ true + }; + + struct ggml_context * ctx_allocated = ggml_init(params); + struct ggml_context * ctx_unallocated = ggml_init(params); + + if (ctx_allocated == NULL || ctx_unallocated == NULL) { + fprintf(stderr, "failed to allocate context for graph copy\n"); + free(hash_set.keys); + free(node_copies); + free(node_init); + ggml_free(ctx_allocated); + ggml_free(ctx_unallocated); + return (struct ggml_backend_graph_copy) { + /* .buffer = */ NULL, + /* .ctx_allocated = */ NULL, + /* .ctx_unallocated = */ NULL, + /* .graph = */ NULL, + }; + } + + // dup nodes + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node); + } + + // allocate nodes + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend); + if (buffer == NULL) { + fprintf(stderr, "failed to allocate buffer for graph copy\n"); + free(hash_set.keys); + free(node_copies); + free(node_init); + ggml_free(ctx_allocated); + ggml_free(ctx_unallocated); + return (struct ggml_backend_graph_copy) { + /* .buffer = */ NULL, + /* .ctx_allocated = */ NULL, + /* .ctx_unallocated = */ NULL, + /* .graph = */ NULL, + }; + } + + //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024); + + // copy data and init views + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + graph_copy_init_tensor(hash_set, node_copies, node_init, node); + } + + // build graph copy + struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false); + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)]; + graph_copy->nodes[i] = node_copy; + } + graph_copy->n_nodes = graph->n_nodes; + + free(hash_set.keys); + free(node_copies); + free(node_init); + + return (struct ggml_backend_graph_copy) { + /* .buffer = */ buffer, + /* .ctx_allocated = */ ctx_allocated, + /* .ctx_unallocated = */ ctx_unallocated, + /* .graph = */ graph_copy, + }; +} + +void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) { + ggml_backend_buffer_free(copy.buffer); + ggml_free(copy.ctx_allocated); + ggml_free(copy.ctx_unallocated); +} + +bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) { + struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); + if (copy.buffer == NULL) { + return false; + } + + struct ggml_cgraph * g1 = graph; + struct ggml_cgraph * g2 = copy.graph; + + assert(g1->n_nodes == g2->n_nodes); + + for (int i = 0; i < g1->n_nodes; i++) { + //printf("eval %d/%d\n", i, g1->n_nodes); + struct ggml_tensor * t1 = g1->nodes[i]; + struct ggml_tensor * t2 = g2->nodes[i]; + + assert(t1->op == t2->op && ggml_are_same_layout(t1, t2)); + + struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); + struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); + + ggml_backend_graph_compute(backend1, &g1v); + ggml_backend_graph_compute(backend2, &g2v); + + if (ggml_is_view_op(t1->op)) { + continue; + } + + // compare results, calculate rms etc + if (!callback(i, t1, t2, user_data)) { + break; + } + } + + ggml_backend_graph_copy_free(copy); + + return true; } diff --git a/bindings/ruby/ext/ggml-backend.h b/bindings/ruby/ext/ggml-backend.h index 793a0a9d65a..744b6a77457 100644 --- a/bindings/ruby/ext/ggml-backend.h +++ b/bindings/ruby/ext/ggml-backend.h @@ -7,69 +7,123 @@ extern "C" { #endif + typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; + typedef struct ggml_backend_buffer * ggml_backend_buffer_t; + typedef struct ggml_backend_event * ggml_backend_event_t; + typedef struct ggml_backend * ggml_backend_t; + typedef void * ggml_backend_graph_plan_t; + // // Backend buffer // - struct ggml_backend_buffer; - typedef struct ggml_backend_buffer * ggml_backend_buffer_t; - - // backend buffer functions - GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); - GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + // buffer type + GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); + GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); + GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); + GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); + GGML_API GGML_CALL size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); + GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend); + GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); + + // buffer + enum ggml_backend_buffer_usage { + GGML_BACKEND_BUFFER_USAGE_ANY = 0, + GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1, + }; + + GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); + GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); + GGML_API GGML_CALL void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); + GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); // // Backend // - struct ggml_backend; - typedef struct ggml_backend * ggml_backend_t; - typedef void * ggml_backend_graph_plan_t; - - GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor); - + GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend); GGML_API const char * ggml_backend_name(ggml_backend_t backend); GGML_API void ggml_backend_free(ggml_backend_t backend); - GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); + GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend); + GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); + GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); + GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend); - GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); + GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - - GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); GGML_API void ggml_backend_synchronize(ggml_backend_t backend); - GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); - GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); - GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan); - GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); - GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op); + GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op); + GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op); // tensor copy between different backends GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + // asynchronous copy + // the copy is performed after all the currently queued operations in backend_src + // backend_dst will wait for the copy to complete before performing other operations + // automatic fallback to sync copy if async is not supported + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); + + // events + GGML_API ggml_backend_event_t ggml_backend_event_new (ggml_backend_t backend); + GGML_API void ggml_backend_event_free (ggml_backend_event_t event); + GGML_API void ggml_backend_event_record (ggml_backend_event_t event); + GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event); + GGML_API void ggml_backend_event_wait (ggml_backend_t backend, ggml_backend_event_t event); // wait async on event + // // CPU backend // GGML_API ggml_backend_t ggml_backend_cpu_init(void); - GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend); - GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); + GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend); + GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); + GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); // Create a backend buffer from an existing pointer - GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size); + GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); + + GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); + +#ifdef GGML_USE_CPU_HBM + GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void); +#endif + + // + // Backend registry + // + + // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way + GGML_API size_t ggml_backend_reg_get_count(void); + GGML_API size_t ggml_backend_reg_find_by_name(const char * name); + GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params] + GGML_API const char * ggml_backend_reg_get_name(size_t i); + GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific + GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i); + GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size); // // Backend scheduler @@ -83,53 +137,96 @@ extern "C" { /* Example usage: - sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends); - // sched is initialized with measure allocators and cannot be used until allocated with a measure graph + // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned + // preferrably to run on the same backend as the buffer + ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - // initialize buffers from a measure graph - measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed + sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false); - // in build_graph: - build_graph(...) { - // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer) - alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu); - ggml_allocr_alloc(alloc_cpu, tensor); + // initialize buffers from a max size graph (optional) + reserve_graph = build_graph(sched, max_batch_size); - // manually assigning nodes to a backend (optional, shouldn't be needed in most cases) - struct ggml_tensor * node = ggml_mul_mat(ctx, ...); - ggml_backend_sched_set_node_backend(sched, node, backend_gpu); - } + // manually assign nodes to a backend (optional, should not be needed in most cases) + struct ggml_tensor * node = ggml_mul_mat(ctx, ...); + ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu); - // allocate backend buffers from measure graph - ggml_backend_sched_init_measure(sched, measure_graph); - - // the scheduler is now ready to compute graphs + ggml_backend_sched_reserve(sched, reserve_graph); // compute graph = build_graph(sched); ggml_backend_sched_graph_compute(sched, graph); + + // if there are graph inputs: + ggml_backend_sched_reset(sched); + ggml_backend_sched_alloc_graph(sched, graph); + ggml_backend_tensor_set(input_tensor, ...); + ggml_backend_sched_graph_compute(sched, graph); + } */ struct ggml_backend_sched; typedef struct ggml_backend_sched * ggml_backend_sched_t; - // Initialize a backend scheduler - GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends); + // when ask == true, the scheduler wants to know if the user wants to observe this node + // this allows the scheduler to batch nodes together in order to evaluate them in a single call + // + // when ask == false, the scheduler is passing the node tensor to the user for observation + // if the user returns false, the scheduler will cancel the graph compute + // + typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); - GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); + // Initialize a backend scheduler + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); + GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph - GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); + GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); + + // Get the number of splits of the last graph + GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched); + GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched); + + GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + + GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); + GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); + + // Allocate and compute graph on the backend scheduler + GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched); + + // Reset all assignments and allocators - must be called before changing the node backends + GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); + + // Set a callback to be called for each resulting node during graph compute + GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); + + // + // Utils + // + + struct ggml_backend_graph_copy { + ggml_backend_buffer_t buffer; + struct ggml_context * ctx_allocated; + struct ggml_context * ctx_unallocated; + struct ggml_cgraph * graph; + }; + + // Copy a graph to a different backend + GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph); + GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy); + + typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); - GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend); - GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend); + // Compare the output of two backends + GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data); - GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); + // Tensor initialization + GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); + GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - // Allocate a graph on the backend scheduler - GGML_API void ggml_backend_sched_graph_compute( - ggml_backend_sched_t sched, - struct ggml_cgraph * graph); #ifdef __cplusplus } diff --git a/bindings/ruby/ext/ggml-common.h b/bindings/ruby/ext/ggml-common.h new file mode 100644 index 00000000000..43c7978a098 --- /dev/null +++ b/bindings/ruby/ext/ggml-common.h @@ -0,0 +1,1853 @@ +#ifndef GGML_COMMON_DECL + +#if defined(GGML_COMMON_DECL_C) +#include + +typedef uint16_t ggml_half; +typedef uint32_t ggml_half2; + +#define GGML_COMMON_AGGR + +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_METAL) +#include + +typedef half ggml_half; +typedef half2 ggml_half2; + +#define GGML_COMMON_AGGR + +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_CUDA) +#include +#include + +typedef half ggml_half; +typedef half2 ggml_half2; + +#define GGML_COMMON_AGGR data + +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_HIP) +#include +#include + +typedef half ggml_half; +typedef half2 ggml_half2; + +#define GGML_COMMON_AGGR data + +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_SYCL) +#include +#include + +typedef sycl::half ggml_half; +typedef sycl::half2 ggml_half2; + +#define GGML_COMMON_AGGR data + +#define GGML_COMMON_DECL +#endif + +#if defined(GGML_COMMON_DECL) + +#ifndef __cplusplus +#ifndef static_assert +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) +#define static_assert(cond, msg) _Static_assert(cond, msg) +#else +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif +#endif +#endif // __cplusplus + +// QK = number of values after dequantization +// QK_K = super-block size + +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif // GGML_QKK_64 + +#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL) +// QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization + +#define QI4_0 (QK4_0 / (4 * QR4_0)) +#define QR4_0 2 + +#define QI4_1 (QK4_1 / (4 * QR4_1)) +#define QR4_1 2 + +#define QI5_0 (QK5_0 / (4 * QR5_0)) +#define QR5_0 2 + +#define QI5_1 (QK5_1 / (4 * QR5_1)) +#define QR5_1 2 + +#define QI8_0 (QK8_0 / (4 * QR8_0)) +#define QR8_0 1 + +#define QI8_1 (QK8_1 / (4 * QR8_1)) +#define QR8_1 1 + +#define QI2_K (QK_K / (4*QR2_K)) +#define QR2_K 4 + +#define QI3_K (QK_K / (4*QR3_K)) +#define QR3_K 4 + +#define QI4_K (QK_K / (4*QR4_K)) +#define QR4_K 2 + +#define QI5_K (QK_K / (4*QR5_K)) +#define QR5_K 2 + +#define QI6_K (QK_K / (4*QR6_K)) +#define QR6_K 2 + +#define QI2_XXS (QK_K / (4*QR2_XXS)) +#define QR2_XXS 8 + +#define QI2_XS (QK_K / (4*QR2_XS)) +#define QR2_XS 8 + +#define QI2_S (QK_K / (4*QR2_S)) +#define QR2_S 8 + +#define QI3_XXS (QK_K / (4*QR3_XXS)) +#define QR3_XXS 8 + +#define QI3_XS (QK_K / (4*QR3_XS)) +#define QR3_XS 8 + +#define QI1_S (QK_K / (4*QR1_S)) +#define QR1_S 8 + +#define QI4_NL (QK4_NL / (4*QR4_NL)) +#define QR4_NL 2 + +#if QK_K == 64 +#define QI4_XS QI4_NL +#define QR4_XS QR4_NL +#else +#define QI4_XS (QK_K / (4*QR4_XS)) +#define QR4_XS 8 +#endif + +#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP + +#define QK4_0 32 +typedef struct { + ggml_half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + union { + struct { + ggml_half d; // delta + ggml_half m; // min + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +typedef struct { + ggml_half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + union { + struct { + ggml_half d; // delta + ggml_half m; // min + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +typedef struct { + ggml_half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +typedef struct { + union { + struct { + ggml_half d; // delta + ggml_half s; // d * sum(qs[i]) + } GGML_COMMON_AGGR; + ggml_half2 ds; + }; + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding"); + +// +// Super-block quantization structures +// + +// 2-bit quantization +// weight is represented as x = a * q + b +// 16 blocks of 16 elements each +// Effectively 2.625 bits per weight +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + union { + struct { + ggml_half d; // super-block scale for quantized scales + ggml_half dmin; // super-block scale for quantized mins + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +// 3-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 3.4375 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[2]; + ggml_half d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); +#else +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[12]; // scales, quantized with 6 bits + ggml_half d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); +#endif + +// 4-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_half d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + union { + struct { + ggml_half d; // super-block scale for quantized scales + ggml_half dmin; // super-block scale for quantized mins + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); +#endif + +// 5-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 5.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_half d; // super-block scale + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_half) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + union { + struct { + ggml_half d; // super-block scale for quantized scales + ggml_half dmin; // super-block scale for quantized mins + } GGML_COMMON_AGGR; + ggml_half2 dm; + }; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + ggml_half d; // super-block scale +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); + +// This is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + +// (Almost) "true" 2-bit quantization. +// Due to the need to use blocks as per ggml design, it ends up using +// 2.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + ggml_half d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); + +// 2.3125 bpw quants +typedef struct { + ggml_half d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +static_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); + +// 2.5625 bpw quants +typedef struct { + ggml_half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; +} block_iq2_s; +static_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding"); + +// (Almost) "true" 3-bit quantization. +// Due to the need to use blocks as per ggml design, it ends up using +// 3.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + ggml_half d; + uint8_t qs[3*QK_K/8]; +} block_iq3_xxs; +static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); + +// 3.4375 bpw +#if QK_K == 64 +#define IQ3S_N_SCALE 2 +#else +#define IQ3S_N_SCALE QK_K/64 +#endif +typedef struct { + ggml_half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t signs[QK_K/8]; + uint8_t scales[IQ3S_N_SCALE]; +} block_iq3_s; +static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding"); + +typedef struct { + ggml_half d; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; +} block_iq1_s; +static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); + +// 1.75 bpw +typedef struct { + uint8_t qs[QK_K/8]; // grid index, low 8 bits + uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8) +#if QK_K == 64 + ggml_half d; +#endif + uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64) +} block_iq1_m; +#if QK_K == 64 +static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32 + sizeof(ggml_half), "wrong iq1_m block size/padding"); +#else +static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding"); +#endif + +// Used by IQ1_M quants +typedef union { + ggml_half f16; + uint16_t u16; +} iq1m_scale_t; + +// Non-linear quants +#define QK4_NL 32 +typedef struct { + ggml_half d; + uint8_t qs[QK4_NL/2]; +} block_iq4_nl; +static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding"); + +#if QK_K == 64 +#define block_iq4_xs block_iq4_nl +#else +typedef struct { + ggml_half d; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; +} block_iq4_xs; +static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +#endif + +#endif // GGML_COMMON_DECL +#endif // GGML_COMMON_DECL + +//////////////////////////////////////////////////////////////////////////////// + +#ifndef GGML_COMMON_IMPL + +#if defined(GGML_COMMON_IMPL_C) +#include + +#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { +#define GGML_TABLE_END() }; + +#define GGML_COMMON_IMPL +#elif defined(GGML_COMMON_IMPL_METAL) +#include + +#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = { +#define GGML_TABLE_END() }; + +#define GGML_COMMON_IMPL +#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) +#include + +#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = { +#define GGML_TABLE_END() }; + +#define GGML_COMMON_IMPL +#elif defined(GGML_COMMON_IMPL_SYCL) + +#include + +#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { +#define GGML_TABLE_END() }; + +#define GGML_COMMON_IMPL +#endif + +#if defined(GGML_COMMON_IMPL) + +GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8) + 1, 2, 4, 8, 16, 32, 64, 128 +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128) + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +GGML_TABLE_END() + +//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) + 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, + 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, + 0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff, + 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff, + 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff, + 0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff, + 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff, + 0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff, + 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff, + 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff, + 0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff, + 0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff, + 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff, + 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff, + 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff, + 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff, + 0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff, + 0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff, + 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff, + 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff, + 0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff, + 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff, + 0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff, + 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff, + 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff, + 0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff, + 0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff, + 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff, + 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff, + 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff, + 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff, + 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff, +GGML_TABLE_END() +//#endif + + +GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint64_t, iq2xs_grid, 512) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint64_t, iq2s_grid, 1024) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, + 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, + 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808, + 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908, + 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b, + 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, + 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, + 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19, + 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819, + 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919, + 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, + 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, + 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908, + 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908, + 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, + 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919, + 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b, + 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, + 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908, + 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, + 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, + 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08, + 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, + 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819, + 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, + 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, + 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b, + 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908, + 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08, + 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, + 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08, + 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819, + 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, + 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, + 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b, + 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919, + 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808, + 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, + 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, + 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919, + 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808, + 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819, + 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, + 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, + 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, + 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, + 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919, + 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, + 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, + 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b, + 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819, + 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919, + 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, + 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, + 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908, + 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b, + 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908, + 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, + 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, + 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819, + 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819, + 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808, + 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, + 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, + 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819, + 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808, + 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819, + 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, + 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, + 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19, + 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08, + 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b, + 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, + 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808, + 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819, + 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908, + 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, + 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, + 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819, + 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908, + 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08, + 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, + 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, + 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, + 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19, + 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, + 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, + 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908, + 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808, + 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808, + 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, + 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, + 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08, + 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08, + 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908, + 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, + 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, + 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819, + 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908, + 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08, + 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, + 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, + 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808, + 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819, + 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808, + 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, + 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, + 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, + 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, + 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b, + 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, + 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, + 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19, + 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819, + 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08, + 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, + 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, + 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b, + 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b, + 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919, + 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, + 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, + 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908, + 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08, + 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08, + 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, + 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919, + 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908, + 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b, + 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908, + 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, + 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, + 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08, + 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819, + 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808, + 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, + 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919, + 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808, + 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808, + 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08, + 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, + 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, + 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808, + 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819, + 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919, + 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, + 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, + 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908, + 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808, + 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908, + 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, + 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, + 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b, + 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908, + 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b, + 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, + 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, + 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908, + 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b, + 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908, + 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, + 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, + 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919, + 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808, + 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19, + 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, + 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919, + 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808, + 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819, + 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908, + 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, + 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, + 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808, + 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b, + 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919, + 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, + 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, + 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808, + 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919, + 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b, + 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, + 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, + 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808, + 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b, + 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908, + 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, + 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, + 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808, + 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908, + 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808, + 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, + 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, + 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908, + 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808, + 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808, + 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, + 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b, + 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808, + 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819, + 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, + 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, + 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08, + 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908, + 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919, + 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, + 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, + 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808, + 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819, + 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, + 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, + 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808, + 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808, + 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808, + 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, + 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, + 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908, + 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08, + 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819, + 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, + 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, + 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819, + 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908, + 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819, + 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, + 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, + 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b, + 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908, + 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808, + 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, + 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, + 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819, + 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808, + 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b, + 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, + 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, + 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, + 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b, + 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b, + 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, + 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, + 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819, + 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908, + 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808, + 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint32_t, iq3xxs_grid, 256) + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +GGML_TABLE_END() + +#define NGRID_IQ1S 2048 +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f +#if defined(GGML_COMMON_IMPL_C) +GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S) + 0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff, + 0xffffffffffff0101, 0xffffffffff00ff00, 0xffffffffff000000, 0xffffffffff01ffff, + 0xffffffffff01ff01, 0xffffffffff0101ff, 0xffffffffff010101, 0xffffffff00ff0000, + 0xffffffff0000ff00, 0xffffffff000000ff, 0xffffffff00000001, 0xffffffff00010000, + 0xffffffff01ffffff, 0xffffffff01ffff01, 0xffffffff01ff01ff, 0xffffffff01ff0101, + 0xffffffff01000000, 0xffffffff0101ffff, 0xffffffff0101ff01, 0xffffffff010101ff, + 0xffffffff01010101, 0xffffff00ffff00ff, 0xffffff00ffff0000, 0xffffff00ff00ff00, + 0xffffff00ff0000ff, 0xffffff00ff000001, 0xffffff00ff000100, 0xffffff00ff000101, + 0xffffff00ff010000, 0xffffff0000ffff00, 0xffffff0000ff0001, 0xffffff0000ff0100, + 0xffffff000000ff01, 0xffffff0000000000, 0xffffff0000000101, 0xffffff000001ff00, + 0xffffff00000100ff, 0xffffff0000010001, 0xffffff00000101ff, 0xffffff0001ff0000, + 0xffffff000100ff00, 0xffffff00010000ff, 0xffffff0001000001, 0xffffff0001010000, + 0xffffff01ffffffff, 0xffffff01ffffff01, 0xffffff01ffff01ff, 0xffffff01ffff0101, + 0xffffff01ff000000, 0xffffff01ff01ffff, 0xffffff01ff01ff01, 0xffffff01ff0101ff, + 0xffffff01ff010101, 0xffffff0100ff0000, 0xffffff010000ff00, 0xffffff0100000100, + 0xffffff01000100ff, 0xffffff0100010100, 0xffffff0101ffffff, 0xffffff0101ffff01, + 0xffffff0101ff01ff, 0xffffff0101ff0101, 0xffffff010100ff00, 0xffffff0101000000, + 0xffffff0101000100, 0xffffff010101ffff, 0xffffff010101ff01, 0xffffff01010101ff, + 0xffffff0101010101, 0xffff00ffff00ff00, 0xffff00ffff0000ff, 0xffff00ffff000001, + 0xffff00ffff010000, 0xffff00ff00ffff00, 0xffff00ff00ff0100, 0xffff00ff00000000, + 0xffff00ff00000101, 0xffff00ff000100ff, 0xffff00ff00010000, 0xffff00ff0100ff00, + 0xffff00ff01000100, 0xffff00ff01010000, 0xffff0000ffffff00, 0xffff0000ffff00ff, + 0xffff0000ffff0000, 0xffff0000ffff0001, 0xffff0000ff000000, 0xffff0000ff0001ff, + 0xffff0000ff000101, 0xffff0000ff010100, 0xffff000000ffffff, 0xffff000000ff0000, + 0xffff000000ff0101, 0xffff00000000ffff, 0xffff00000000ff00, 0xffff0000000000ff, + 0xffff000000000000, 0xffff000000000001, 0xffff000000000100, 0xffff00000001ffff, + 0xffff00000001ff01, 0xffff000000010000, 0xffff0000000101ff, 0xffff000000010101, + 0xffff000001ffff00, 0xffff00000100ff00, 0xffff000001000000, 0xffff0000010001ff, + 0xffff000001000101, 0xffff00000101ff00, 0xffff0000010100ff, 0xffff000001010000, + 0xffff000001010001, 0xffff000001010100, 0xffff0001ff0000ff, 0xffff0001ff000100, + 0xffff000100ffff00, 0xffff000100ff00ff, 0xffff00010000ffff, 0xffff00010000ff01, + 0xffff000100000000, 0xffff0001000001ff, 0xffff00010001ffff, 0xffff00010001ff00, + 0xffff000100010001, 0xffff000100010100, 0xffff000101ff0000, 0xffff00010100ff00, + 0xffff0001010000ff, 0xffff000101000100, 0xffff01ffffffffff, 0xffff01ffffffff01, + 0xffff01ffffff01ff, 0xffff01ffffff0101, 0xffff01ffff000000, 0xffff01ffff01ffff, + 0xffff01ffff01ff01, 0xffff01ffff0101ff, 0xffff01ffff010101, 0xffff01ff00ff0000, + 0xffff01ff0000ff00, 0xffff01ff00000001, 0xffff01ff00010000, 0xffff01ff01ffffff, + 0xffff01ff01ffff01, 0xffff01ff01ff01ff, 0xffff01ff01ff0101, 0xffff01ff01000000, + 0xffff01ff0101ffff, 0xffff01ff0101ff01, 0xffff01ff010101ff, 0xffff01ff01010101, + 0xffff0100ffff0000, 0xffff0100ff00ff00, 0xffff0100ff0000ff, 0xffff0100ff000100, + 0xffff0100ff0100ff, 0xffff0100ff010000, 0xffff010000ffff00, 0xffff01000000ffff, + 0xffff01000000ff00, 0xffff010000000000, 0xffff01000001ff00, 0xffff0100000100ff, + 0xffff010000010100, 0xffff01000100ff00, 0xffff0100010000ff, 0xffff010001000001, + 0xffff010001000100, 0xffff010001010000, 0xffff0101ffffffff, 0xffff0101ffffff01, + 0xffff0101ffff01ff, 0xffff0101ffff0101, 0xffff0101ff000000, 0xffff0101ff01ffff, + 0xffff0101ff01ff01, 0xffff0101ff0101ff, 0xffff0101ff010101, 0xffff010100ff0000, + 0xffff01010000ff00, 0xffff010100000100, 0xffff01010001ff00, 0xffff010100010000, + 0xffff010101ffffff, 0xffff010101ffff01, 0xffff010101ff0000, 0xffff010101ff01ff, + 0xffff010101ff0101, 0xffff010101000000, 0xffff01010101ffff, 0xffff01010101ff01, + 0xffff0101010101ff, 0xffff010101010101, 0xff00ffffff00ffff, 0xff00ffffff00ff00, + 0xff00ffffff0000ff, 0xff00ffffff000100, 0xff00ffffff0100ff, 0xff00ffffff010000, + 0xff00ffff00ffff00, 0xff00ffff00ff00ff, 0xff00ffff0000ffff, 0xff00ffff00000000, + 0xff00ffff000001ff, 0xff00ffff0001ff00, 0xff00ffff000100ff, 0xff00ffff00010000, + 0xff00ffff00010100, 0xff00ffff0100ff00, 0xff00ffff010000ff, 0xff00ffff01000001, + 0xff00ffff0101ff00, 0xff00ffff01010000, 0xff00ff00ffffff00, 0xff00ff00ffff00ff, + 0xff00ff00ffff0001, 0xff00ff00ffff0100, 0xff00ff00ff00ffff, 0xff00ff00ff00ff01, + 0xff00ff00ff000000, 0xff00ff00ff0001ff, 0xff00ff00ff01ff00, 0xff00ff00ff0100ff, + 0xff00ff00ff010100, 0xff00ff0000ff0000, 0xff00ff0000ff0101, 0xff00ff000000ffff, + 0xff00ff000000ff00, 0xff00ff000000ff01, 0xff00ff00000000ff, 0xff00ff0000000000, + 0xff00ff0000000001, 0xff00ff0000000100, 0xff00ff000001ffff, 0xff00ff0000010000, + 0xff00ff0001ff00ff, 0xff00ff000100ff01, 0xff00ff0001000000, 0xff00ff000101ff00, + 0xff00ff00010100ff, 0xff00ff01ff00ff00, 0xff00ff01ff0000ff, 0xff00ff01ff000001, + 0xff00ff01ff010000, 0xff00ff0100ffffff, 0xff00ff0100ff0001, 0xff00ff0100ff0100, + 0xff00ff010000ff01, 0xff00ff0100000000, 0xff00ff01000001ff, 0xff00ff0100000101, + 0xff00ff01000100ff, 0xff00ff0100010001, 0xff00ff0101ff0000, 0xff00ff010100ff00, + 0xff00ff01010000ff, 0xff00ff0101000001, 0xff00ff0101010000, 0xff0000ffffffff00, + 0xff0000ffffff0001, 0xff0000ffffff0100, 0xff0000ffff0000ff, 0xff0000ffff000000, + 0xff0000ffff0001ff, 0xff0000ffff000100, 0xff0000ffff01ff00, 0xff0000ffff010001, + 0xff0000ff00ffff00, 0xff0000ff00ff0000, 0xff0000ff00ff0001, 0xff0000ff00ff01ff, + 0xff0000ff00ff0101, 0xff0000ff0000ff00, 0xff0000ff000000ff, 0xff0000ff00000000, + 0xff0000ff00000001, 0xff0000ff00000100, 0xff0000ff0001ff01, 0xff0000ff00010000, + 0xff0000ff000101ff, 0xff0000ff01ff00ff, 0xff0000ff01ff0100, 0xff0000ff0100ffff, + 0xff0000ff010000ff, 0xff0000ff01000000, 0xff0000ff010001ff, 0xff0000ff01000100, + 0xff0000ff01000101, 0xff0000ff0101ff00, 0xff0000ff010100ff, 0xff0000ff01010000, + 0xff0000ff01010100, 0xff000000ffffff01, 0xff000000ffff0000, 0xff000000ffff0101, + 0xff000000ff00ff00, 0xff000000ff0000ff, 0xff000000ff000000, 0xff000000ff000001, + 0xff000000ff000100, 0xff000000ff01ffff, 0xff000000ff01ff01, 0xff000000ff010000, + 0xff000000ff0101ff, 0xff000000ff010101, 0xff00000000ffff00, 0xff00000000ff00ff, + 0xff00000000ff0000, 0xff00000000ff0001, 0xff0000000000ff00, 0xff0000000000ff01, + 0xff000000000000ff, 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, + 0xff00000000000101, 0xff0000000001ff00, 0xff000000000100ff, 0xff00000000010000, + 0xff00000000010001, 0xff00000000010100, 0xff00000001ffffff, 0xff00000001ffff01, + 0xff00000001ff00ff, 0xff00000001ff0000, 0xff00000001ff01ff, 0xff00000001ff0101, + 0xff0000000100ffff, 0xff0000000100ff00, 0xff000000010000ff, 0xff00000001000000, + 0xff00000001000001, 0xff00000001000100, 0xff00000001000101, 0xff0000000101ffff, + 0xff0000000101ff01, 0xff00000001010000, 0xff000001ffffff00, 0xff000001ffff00ff, + 0xff000001ffff0000, 0xff000001ffff0001, 0xff000001ff000000, 0xff000001ff000001, + 0xff000001ff0001ff, 0xff000001ff000101, 0xff000001ff01ff00, 0xff000001ff010001, + 0xff00000100ffffff, 0xff00000100ffff01, 0xff00000100ff00ff, 0xff00000100ff0000, + 0xff00000100ff01ff, 0xff00000100ff0101, 0xff0000010000ff00, 0xff00000100000000, + 0xff00000100000001, 0xff000001000001ff, 0xff00000100000100, 0xff0000010001ff00, + 0xff000001000100ff, 0xff00000100010000, 0xff000001000101ff, 0xff00000100010100, + 0xff00000100010101, 0xff00000101ff0001, 0xff00000101ff0101, 0xff0000010100ff01, + 0xff00000101000000, 0xff000001010100ff, 0xff00000101010100, 0xff0001ffff00ff00, + 0xff0001ffff000001, 0xff0001ffff010000, 0xff0001ff00ffff00, 0xff0001ff00ff00ff, + 0xff0001ff00ff0001, 0xff0001ff00ff0100, 0xff0001ff0000ffff, 0xff0001ff00000000, + 0xff0001ff000001ff, 0xff0001ff00000101, 0xff0001ff0001ffff, 0xff0001ff0001ff00, + 0xff0001ff000100ff, 0xff0001ff00010001, 0xff0001ff00010100, 0xff0001ff01ff0000, + 0xff0001ff0100ff00, 0xff0001ff010000ff, 0xff0001ff01010000, 0xff000100ff00ffff, + 0xff000100ff00ff01, 0xff000100ff000000, 0xff000100ff000101, 0xff000100ff01ff00, + 0xff000100ff010000, 0xff00010000ffff01, 0xff00010000ff00ff, 0xff00010000ff0000, + 0xff00010000ff01ff, 0xff0001000000ff00, 0xff000100000000ff, 0xff00010000000000, + 0xff00010000000001, 0xff00010000000100, 0xff00010000000101, 0xff0001000001ffff, + 0xff00010000010000, 0xff00010000010101, 0xff00010001ff0100, 0xff0001000100ff00, + 0xff0001000100ff01, 0xff00010001000000, 0xff000100010001ff, 0xff0001000101ff00, + 0xff00010001010001, 0xff00010001010100, 0xff000101ffff0100, 0xff000101ff000001, + 0xff000101ff0100ff, 0xff000101ff010001, 0xff00010100ff00ff, 0xff00010100ff0001, + 0xff00010100ff0100, 0xff0001010000ffff, 0xff0001010000ff01, 0xff00010100000000, + 0xff000101000001ff, 0xff0001010001ff00, 0xff00010100010001, 0xff00010100010100, + 0xff00010101ff0000, 0xff0001010100ff00, 0xff00010101000001, 0xff00010101000101, + 0xff01ffffffffffff, 0xff01ffffffffff01, 0xff01ffffffff01ff, 0xff01ffffffff0101, + 0xff01ffffff000000, 0xff01ffffff01ffff, 0xff01ffffff01ff01, 0xff01ffffff010000, + 0xff01ffffff0101ff, 0xff01ffffff010101, 0xff01ffff00ff0000, 0xff01ffff0000ff00, + 0xff01ffff00000100, 0xff01ffff0001ff00, 0xff01ffff00010000, 0xff01ffff01ffffff, + 0xff01ffff01ffff01, 0xff01ffff01ff01ff, 0xff01ffff01ff0101, 0xff01ffff01000000, + 0xff01ffff0101ffff, 0xff01ffff0101ff01, 0xff01ffff01010000, 0xff01ffff010101ff, + 0xff01ffff01010101, 0xff01ff00ffff0000, 0xff01ff00ff00ff00, 0xff01ff00ff0000ff, + 0xff01ff00ff000100, 0xff01ff00ff010000, 0xff01ff0000ffff01, 0xff01ff0000ff00ff, + 0xff01ff0000ff0100, 0xff01ff0000000000, 0xff01ff00000001ff, 0xff01ff0000000101, + 0xff01ff000001ff00, 0xff01ff00000100ff, 0xff01ff0000010000, 0xff01ff0000010001, + 0xff01ff0001ff0000, 0xff01ff000100ffff, 0xff01ff0001000001, 0xff01ff0001000100, + 0xff01ff0001010000, 0xff01ff01ffffff00, 0xff01ff01ffff01ff, 0xff01ff01ffff0101, + 0xff01ff01ff00ff00, 0xff01ff01ff000000, 0xff01ff01ff01ffff, 0xff01ff01ff01ff01, + 0xff01ff01ff0101ff, 0xff01ff01ff010101, 0xff01ff0100ff0000, 0xff01ff010000ff00, + 0xff01ff0100000001, 0xff01ff0100000100, 0xff01ff0100010000, 0xff01ff0101ffff00, + 0xff01ff0101ff01ff, 0xff01ff0101ff0101, 0xff01ff010100ff00, 0xff01ff0101000000, + 0xff01ff010101ffff, 0xff01ff010101ff01, 0xff01ff01010101ff, 0xff01ff0101010101, + 0xff0100ffffff0000, 0xff0100ffff0000ff, 0xff0100ffff000001, 0xff0100ffff000100, + 0xff0100ffff010000, 0xff0100ff00ff00ff, 0xff0100ff00ff0000, 0xff0100ff00ff0001, + 0xff0100ff00ff0100, 0xff0100ff0000ff01, 0xff0100ff00000000, 0xff0100ff000001ff, + 0xff0100ff00000101, 0xff0100ff00010001, 0xff0100ff01ff0000, 0xff0100ff0100ff00, + 0xff0100ff010000ff, 0xff0100ff01000100, 0xff0100ff0101ff00, 0xff0100ff01010000, + 0xff010000ffff0100, 0xff010000ff000000, 0xff010000ff01ff00, 0xff010000ff010100, + 0xff01000000ffffff, 0xff01000000ff0000, 0xff01000000ff01ff, 0xff0100000000ff00, + 0xff010000000000ff, 0xff01000000000000, 0xff01000000000100, 0xff0100000001ff01, + 0xff01000000010000, 0xff010000000101ff, 0xff01000001ff0100, 0xff0100000100ffff, + 0xff010000010000ff, 0xff01000001000000, 0xff010000010001ff, 0xff01000001000101, + 0xff0100000101ff00, 0xff010000010100ff, 0xff01000001010001, 0xff01000001010100, + 0xff010001ffff0000, 0xff010001ff00ffff, 0xff010001ff00ff01, 0xff010001ff000100, + 0xff010001ff010000, 0xff01000100ffff00, 0xff01000100ff0100, 0xff01000100000000, + 0xff0100010001ffff, 0xff0100010001ff00, 0xff01000100010100, 0xff01000101ff00ff, + 0xff01000101ff0001, 0xff0100010100ffff, 0xff01000101000101, 0xff0101ffffffffff, + 0xff0101ffffffff01, 0xff0101ffffff01ff, 0xff0101ffffff0101, 0xff0101ffff000000, + 0xff0101ffff01ffff, 0xff0101ffff01ff01, 0xff0101ffff0101ff, 0xff0101ffff010101, + 0xff0101ff00ff0000, 0xff0101ff0000ff00, 0xff0101ff000000ff, 0xff0101ff00010000, + 0xff0101ff01ffffff, 0xff0101ff01ffff01, 0xff0101ff01ff01ff, 0xff0101ff01ff0101, + 0xff0101ff0101ffff, 0xff0101ff0101ff01, 0xff0101ff010101ff, 0xff0101ff01010101, + 0xff010100ffff0100, 0xff010100ff00ff00, 0xff010100ff0000ff, 0xff010100ff000100, + 0xff010100ff010000, 0xff01010000ff0001, 0xff01010000ff0100, 0xff0101000000ff01, + 0xff01010000000000, 0xff0101000001ff00, 0xff010100000100ff, 0xff01010000010001, + 0xff01010000010100, 0xff01010001ff0000, 0xff0101000100ffff, 0xff01010001000001, + 0xff01010001000100, 0xff010100010100ff, 0xff01010001010000, 0xff010101ffffffff, + 0xff010101ffffff01, 0xff010101ffff01ff, 0xff010101ffff0101, 0xff010101ff01ffff, + 0xff010101ff01ff01, 0xff010101ff0101ff, 0xff010101ff010101, 0xff01010100ff0000, + 0xff0101010000ff00, 0xff01010100000001, 0xff01010100000100, 0xff01010100010000, + 0xff01010101ffffff, 0xff01010101ffff01, 0xff01010101ff01ff, 0xff01010101ff0101, + 0xff01010101000000, 0xff0101010101ffff, 0xff0101010101ff01, 0xff010101010101ff, + 0xff01010101010101, 0x00ffffffffff0000, 0x00ffffffff00ff00, 0x00ffffffff000001, + 0x00ffffffff010000, 0x00ffffff00ff0100, 0x00ffffff0000ff01, 0x00ffffff00000000, + 0x00ffffff000001ff, 0x00ffffff00000101, 0x00ffffff0001ff00, 0x00ffffff000100ff, + 0x00ffffff00010001, 0x00ffffff010000ff, 0x00ffffff01000100, 0x00ffffff0101ff00, + 0x00ffffff01010001, 0x00ffff00ffffffff, 0x00ffff00ffffff00, 0x00ffff00ffff00ff, + 0x00ffff00ffff0001, 0x00ffff00ffff0100, 0x00ffff00ff00ff01, 0x00ffff00ff000000, + 0x00ffff00ff000001, 0x00ffff00ff0001ff, 0x00ffff00ff000101, 0x00ffff00ff01ff00, + 0x00ffff00ff010001, 0x00ffff00ff010100, 0x00ffff0000ff0000, 0x00ffff0000ff01ff, + 0x00ffff0000ff0101, 0x00ffff000000ff00, 0x00ffff00000000ff, 0x00ffff0000000000, + 0x00ffff0000000001, 0x00ffff0000000100, 0x00ffff0000000101, 0x00ffff0000010000, + 0x00ffff00000101ff, 0x00ffff0000010101, 0x00ffff0001ffff00, 0x00ffff0001ff00ff, + 0x00ffff0001ff0001, 0x00ffff000100ffff, 0x00ffff000100ff01, 0x00ffff0001000000, + 0x00ffff000101ffff, 0x00ffff000101ff00, 0x00ffff000101ff01, 0x00ffff01ffff0000, + 0x00ffff01ff00ff00, 0x00ffff01ff0000ff, 0x00ffff01ff000001, 0x00ffff01ff010000, + 0x00ffff0100ffff00, 0x00ffff010000ff01, 0x00ffff0100000000, 0x00ffff0100000101, + 0x00ffff01000100ff, 0x00ffff0100010100, 0x00ffff0101ff0100, 0x00ffff01010000ff, + 0x00ffff0101010000, 0x00ff00ffffffff00, 0x00ff00ffff000000, 0x00ff00ffff000100, + 0x00ff00ffff010100, 0x00ff00ff00ff0000, 0x00ff00ff00ff01ff, 0x00ff00ff00ff0101, + 0x00ff00ff0000ff00, 0x00ff00ff000000ff, 0x00ff00ff00000000, 0x00ff00ff00000001, + 0x00ff00ff0001ff00, 0x00ff00ff0001ff01, 0x00ff00ff00010000, 0x00ff00ff000101ff, + 0x00ff00ff00010101, 0x00ff00ff01ffff00, 0x00ff00ff01ff0001, 0x00ff00ff01ff0100, + 0x00ff00ff0100ffff, 0x00ff00ff0100ff01, 0x00ff00ff01000000, 0x00ff00ff0101ffff, + 0x00ff00ff0101ff00, 0x00ff00ff01010100, 0x00ff0000ffffff00, 0x00ff0000ffffff01, + 0x00ff0000ffff0000, 0x00ff0000ffff0101, 0x00ff0000ff00ff00, 0x00ff0000ff0000ff, + 0x00ff0000ff000000, 0x00ff0000ff000001, 0x00ff0000ff000100, 0x00ff0000ff01ffff, + 0x00ff0000ff010000, 0x00ff0000ff010101, 0x00ff000000ffff00, 0x00ff000000ff00ff, + 0x00ff000000ff0000, 0x00ff000000ff0001, 0x00ff000000ff0100, 0x00ff00000000ffff, + 0x00ff00000000ff00, 0x00ff0000000000ff, 0x00ff000000000000, 0x00ff000000000001, + 0x00ff0000000001ff, 0x00ff000000000100, 0x00ff00000001ff00, 0x00ff0000000100ff, + 0x00ff000000010000, 0x00ff000000010001, 0x00ff000000010100, 0x00ff000001ffff01, + 0x00ff000001ff00ff, 0x00ff000001ff0000, 0x00ff000001ff01ff, 0x00ff00000100ff00, + 0x00ff0000010000ff, 0x00ff000001000000, 0x00ff000001000001, 0x00ff000001000100, + 0x00ff000001000101, 0x00ff000001010000, 0x00ff0000010101ff, 0x00ff000001010101, + 0x00ff0001ffffff00, 0x00ff0001ffff0000, 0x00ff0001ffff0100, 0x00ff0001ff0000ff, + 0x00ff0001ff000000, 0x00ff0001ff0001ff, 0x00ff0001ff000101, 0x00ff0001ff01ff00, + 0x00ff0001ff0100ff, 0x00ff0001ff010100, 0x00ff000100ffffff, 0x00ff000100ffff01, + 0x00ff000100ff0000, 0x00ff000100ff01ff, 0x00ff00010000ffff, 0x00ff00010000ff00, + 0x00ff00010000ff01, 0x00ff000100000000, 0x00ff000100000001, 0x00ff000100000100, + 0x00ff00010001ff01, 0x00ff000100010000, 0x00ff0001000101ff, 0x00ff000101ffff00, + 0x00ff000101ff0000, 0x00ff000101ff0101, 0x00ff0001010000ff, 0x00ff000101000000, + 0x00ff00010101ff00, 0x00ff0001010100ff, 0x00ff000101010001, 0x00ff01ffffff0000, + 0x00ff01ffff00ff00, 0x00ff01ffff000000, 0x00ff01ffff000101, 0x00ff01ffff010000, + 0x00ff01ff00ffff01, 0x00ff01ff00ff0100, 0x00ff01ff0000ffff, 0x00ff01ff00000000, + 0x00ff01ff000001ff, 0x00ff01ff0001ff00, 0x00ff01ff000100ff, 0x00ff01ff00010001, + 0x00ff01ff00010100, 0x00ff01ff01ff0000, 0x00ff01ff0100ff00, 0x00ff01ff010000ff, + 0x00ff01ff01000001, 0x00ff01ff01000100, 0x00ff01ff01010000, 0x00ff0100ffffff00, + 0x00ff0100ffff0000, 0x00ff0100ffff0001, 0x00ff0100ffff0101, 0x00ff0100ff00ffff, + 0x00ff0100ff0000ff, 0x00ff0100ff000000, 0x00ff0100ff0001ff, 0x00ff0100ff01ff00, + 0x00ff0100ff0100ff, 0x00ff0100ff010001, 0x00ff010000ffffff, 0x00ff010000ff0000, + 0x00ff010000ff0101, 0x00ff01000000ff00, 0x00ff01000000ff01, 0x00ff0100000000ff, + 0x00ff010000000000, 0x00ff010000000001, 0x00ff010000000100, 0x00ff01000001ffff, + 0x00ff01000001ff01, 0x00ff010000010000, 0x00ff010000010001, 0x00ff010000010101, + 0x00ff010001ff0001, 0x00ff010001ff0100, 0x00ff01000100ff01, 0x00ff010001000000, + 0x00ff010001000001, 0x00ff0100010001ff, 0x00ff01000101ff00, 0x00ff0100010100ff, + 0x00ff010001010001, 0x00ff010001010100, 0x00ff0101ff000001, 0x00ff010100ff00ff, + 0x00ff010100ff0001, 0x00ff010100ff0100, 0x00ff010100000000, 0x00ff0101000001ff, + 0x00ff010100000101, 0x00ff0101000100ff, 0x00ff010100010100, 0x00ff0101010000ff, + 0x00ff010101010000, 0x0000ffffffffff00, 0x0000ffffffff00ff, 0x0000ffffffff0000, + 0x0000ffffffff0001, 0x0000ffffffff0100, 0x0000ffffff00ff01, 0x0000ffffff000000, + 0x0000ffffff000101, 0x0000ffffff01ff00, 0x0000ffffff0100ff, 0x0000ffffff010100, + 0x0000ffff00ffffff, 0x0000ffff00ff0000, 0x0000ffff00ff01ff, 0x0000ffff0000ff00, + 0x0000ffff000000ff, 0x0000ffff00000000, 0x0000ffff00000001, 0x0000ffff00000100, + 0x0000ffff00010000, 0x0000ffff000101ff, 0x0000ffff01ff0001, 0x0000ffff01ff0100, + 0x0000ffff01000000, 0x0000ffff010001ff, 0x0000ffff0101ffff, 0x0000ffff0101ff00, + 0x0000ffff01010001, 0x0000ffff01010100, 0x0000ff00ffff0000, 0x0000ff00ffff01ff, + 0x0000ff00ffff0100, 0x0000ff00ffff0101, 0x0000ff00ff00ff00, 0x0000ff00ff0000ff, + 0x0000ff00ff000000, 0x0000ff00ff000001, 0x0000ff00ff0001ff, 0x0000ff00ff000100, + 0x0000ff00ff01ffff, 0x0000ff00ff010000, 0x0000ff00ff010001, 0x0000ff00ff0101ff, + 0x0000ff00ff010101, 0x0000ff0000ffff00, 0x0000ff0000ff00ff, 0x0000ff0000ff0000, + 0x0000ff0000ff0001, 0x0000ff0000ff0100, 0x0000ff000000ffff, 0x0000ff000000ff00, + 0x0000ff000000ff01, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, + 0x0000ff00000001ff, 0x0000ff0000000100, 0x0000ff0000000101, 0x0000ff000001ff00, + 0x0000ff00000100ff, 0x0000ff0000010000, 0x0000ff0000010001, 0x0000ff0000010100, + 0x0000ff0001ffff01, 0x0000ff0001ff0000, 0x0000ff000100ff00, 0x0000ff00010000ff, + 0x0000ff0001000000, 0x0000ff0001000001, 0x0000ff0001000100, 0x0000ff000101ffff, + 0x0000ff0001010000, 0x0000ff0001010101, 0x0000ff01ffffff00, 0x0000ff01ffff0001, + 0x0000ff01ff00ff01, 0x0000ff01ff000000, 0x0000ff01ff000101, 0x0000ff01ff01ff00, + 0x0000ff01ff0100ff, 0x0000ff0100ffff01, 0x0000ff0100ff0000, 0x0000ff0100ff0101, + 0x0000ff010000ff00, 0x0000ff01000000ff, 0x0000ff0100000000, 0x0000ff0100000001, + 0x0000ff0100000100, 0x0000ff010001ff01, 0x0000ff0100010000, 0x0000ff0101ff0000, + 0x0000ff010100ffff, 0x0000ff010100ff01, 0x0000ff0101000000, 0x0000ff0101000100, + 0x0000ff0101000101, 0x0000ff01010100ff, 0x000000ffffff00ff, 0x000000ffffff0000, + 0x000000ffff00ff00, 0x000000ffff0000ff, 0x000000ffff000000, 0x000000ffff000001, + 0x000000ffff0001ff, 0x000000ffff000100, 0x000000ffff01ff00, 0x000000ffff010000, + 0x000000ffff0101ff, 0x000000ffff010101, 0x000000ff00ffff00, 0x000000ff00ff00ff, + 0x000000ff00ff0000, 0x000000ff00ff0001, 0x000000ff00ff0100, 0x000000ff00ff0101, + 0x000000ff0000ffff, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, + 0x000000ff00000001, 0x000000ff000001ff, 0x000000ff00000100, 0x000000ff00000101, + 0x000000ff0001ff00, 0x000000ff0001ff01, 0x000000ff000100ff, 0x000000ff00010000, + 0x000000ff00010001, 0x000000ff00010100, 0x000000ff01ffffff, 0x000000ff01ff01ff, + 0x000000ff01ff0101, 0x000000ff0100ff00, 0x000000ff010000ff, 0x000000ff01000000, + 0x000000ff01000001, 0x000000ff01000100, 0x000000ff0101ff00, 0x000000ff010100ff, + 0x000000ff01010000, 0x000000ff01010101, 0x00000000ffffff00, 0x00000000ffffff01, + 0x00000000ffff00ff, 0x00000000ffff0000, 0x00000000ffff0001, 0x00000000ffff0100, + 0x00000000ff00ffff, 0x00000000ff00ff00, 0x00000000ff00ff01, 0x00000000ff0000ff, + 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff000101, + 0x00000000ff01ff00, 0x00000000ff0100ff, 0x00000000ff010000, 0x00000000ff010001, + 0x00000000ff010100, 0x0000000000ffffff, 0x0000000000ffff00, 0x0000000000ffff01, + 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, 0x0000000000ff01ff, + 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, + 0x00000000000000ff, 0x0000000000000000, 0x0000000000000001, 0x00000000000001ff, + 0x0000000000000100, 0x0000000000000101, 0x000000000001ffff, 0x000000000001ff00, + 0x00000000000100ff, 0x0000000000010000, 0x0000000000010001, 0x00000000000101ff, + 0x0000000000010100, 0x0000000000010101, 0x0000000001ffff00, 0x0000000001ff00ff, + 0x0000000001ff0000, 0x0000000001ff0100, 0x0000000001ff0101, 0x000000000100ffff, + 0x000000000100ff00, 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, + 0x00000000010001ff, 0x0000000001000100, 0x000000000101ff00, 0x00000000010100ff, + 0x0000000001010000, 0x0000000001010001, 0x0000000001010100, 0x00000001ffffffff, + 0x00000001ffffff00, 0x00000001ffffff01, 0x00000001ffff00ff, 0x00000001ffff0001, + 0x00000001ffff01ff, 0x00000001ffff0100, 0x00000001ff00ff00, 0x00000001ff0000ff, + 0x00000001ff000000, 0x00000001ff0001ff, 0x00000001ff000100, 0x00000001ff01ffff, + 0x00000001ff01ff00, 0x00000001ff01ff01, 0x00000001ff0100ff, 0x00000001ff010000, + 0x00000001ff010001, 0x00000001ff0101ff, 0x00000001ff010100, 0x0000000100ffff00, + 0x0000000100ff0000, 0x0000000100ff0001, 0x0000000100ff01ff, 0x0000000100ff0100, + 0x0000000100ff0101, 0x000000010000ffff, 0x000000010000ff00, 0x000000010000ff01, + 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, 0x00000001000001ff, + 0x0000000100000100, 0x0000000100000101, 0x000000010001ff00, 0x00000001000100ff, + 0x0000000100010000, 0x0000000100010100, 0x0000000101ffff01, 0x0000000101ff0000, + 0x0000000101ff0001, 0x0000000101ff01ff, 0x0000000101ff0100, 0x0000000101ff0101, + 0x000000010100ff00, 0x0000000101000000, 0x0000000101000101, 0x000000010101ff01, + 0x0000000101010000, 0x0000000101010001, 0x00000001010101ff, 0x0000000101010100, + 0x000001ffffff00ff, 0x000001ffffff0000, 0x000001ffffff0001, 0x000001ffffff0100, + 0x000001ffff00ffff, 0x000001ffff000000, 0x000001ffff0001ff, 0x000001ffff01ff00, + 0x000001ffff010101, 0x000001ff00ff0000, 0x000001ff00ff01ff, 0x000001ff00ff0101, + 0x000001ff0000ff00, 0x000001ff000000ff, 0x000001ff00000000, 0x000001ff00000001, + 0x000001ff000001ff, 0x000001ff00000100, 0x000001ff0001ffff, 0x000001ff0001ff01, + 0x000001ff000100ff, 0x000001ff00010000, 0x000001ff01ffff01, 0x000001ff01ff0100, + 0x000001ff0100ffff, 0x000001ff0100ff01, 0x000001ff01000000, 0x000001ff010001ff, + 0x000001ff0101ff00, 0x000001ff01010100, 0x00000100ffffff00, 0x00000100ffffff01, + 0x00000100ffff0000, 0x00000100ffff0101, 0x00000100ff00ff00, 0x00000100ff0000ff, + 0x00000100ff000000, 0x00000100ff000001, 0x00000100ff000100, 0x00000100ff010000, + 0x0000010000ffff00, 0x0000010000ff00ff, 0x0000010000ff0000, 0x0000010000ff0001, + 0x0000010000ff0100, 0x000001000000ffff, 0x000001000000ff00, 0x000001000000ff01, + 0x00000100000000ff, 0x0000010000000000, 0x0000010000000001, 0x00000100000001ff, + 0x0000010000000100, 0x0000010000000101, 0x000001000001ff00, 0x00000100000100ff, + 0x0000010000010000, 0x0000010000010001, 0x0000010000010100, 0x0000010001ffff00, + 0x0000010001ff0000, 0x0000010001ff0100, 0x000001000100ff00, 0x00000100010000ff, + 0x0000010001000000, 0x0000010001000001, 0x00000100010001ff, 0x0000010001000100, + 0x0000010001010000, 0x00000101ffff00ff, 0x00000101ffff01ff, 0x00000101ff000000, + 0x00000101ff000101, 0x00000101ff01ffff, 0x00000101ff010000, 0x00000101ff010001, + 0x00000101ff010100, 0x0000010100ff0000, 0x0000010100ff01ff, 0x0000010100ff0100, + 0x000001010000ff00, 0x0000010100000000, 0x0000010100000001, 0x00000101000001ff, + 0x0000010100000100, 0x000001010001ff01, 0x0000010100010000, 0x00000101000101ff, + 0x0000010100010101, 0x0000010101ffff00, 0x0000010101ff0101, 0x000001010100ff01, + 0x0000010101000000, 0x0000010101000001, 0x00000101010001ff, 0x0000010101000101, + 0x000001010101ff00, 0x0001ffffffff0000, 0x0001ffffff0000ff, 0x0001ffffff000001, + 0x0001ffffff000100, 0x0001ffffff010000, 0x0001ffff00ff00ff, 0x0001ffff0000ffff, + 0x0001ffff00000000, 0x0001ffff00000001, 0x0001ffff000001ff, 0x0001ffff00000101, + 0x0001ffff0001ff00, 0x0001ffff000100ff, 0x0001ffff00010001, 0x0001ffff00010100, + 0x0001ffff01ffff00, 0x0001ffff01000001, 0x0001ffff01010000, 0x0001ff00ffffff00, + 0x0001ff00ffff00ff, 0x0001ff00ffff0001, 0x0001ff00ffff0100, 0x0001ff00ff00ff01, + 0x0001ff00ff000000, 0x0001ff00ff01ff00, 0x0001ff00ff01ff01, 0x0001ff00ff010001, + 0x0001ff00ff010100, 0x0001ff0000ff0000, 0x0001ff0000ff0100, 0x0001ff000000ff00, + 0x0001ff0000000000, 0x0001ff0000000001, 0x0001ff0000000100, 0x0001ff0000010000, + 0x0001ff0000010001, 0x0001ff0000010101, 0x0001ff0001ff00ff, 0x0001ff0001ff0101, + 0x0001ff000100ff01, 0x0001ff0001000000, 0x0001ff000101ff00, 0x0001ff0001010001, + 0x0001ff0001010100, 0x0001ff01ff00ff00, 0x0001ff01ff000001, 0x0001ff01ff000100, + 0x0001ff0100ffffff, 0x0001ff0100ffff00, 0x0001ff0100ff0001, 0x0001ff0100000000, + 0x0001ff0100000001, 0x0001ff01000001ff, 0x0001ff010001ffff, 0x0001ff0101ff0000, + 0x0001ff010100ff00, 0x0001ff0101000001, 0x0001ff0101010000, 0x000100ffff00ff00, + 0x000100ffff00ff01, 0x000100ffff000000, 0x000100ffff000001, 0x000100ffff000101, + 0x000100ffff01ff00, 0x000100ffff010001, 0x000100ffff010100, 0x000100ff00ffffff, + 0x000100ff00ffff01, 0x000100ff00ff0000, 0x000100ff00ff01ff, 0x000100ff00ff0101, + 0x000100ff0000ff00, 0x000100ff000000ff, 0x000100ff00000000, 0x000100ff00000001, + 0x000100ff00000100, 0x000100ff00000101, 0x000100ff0001ffff, 0x000100ff0001ff01, + 0x000100ff00010000, 0x000100ff01ff00ff, 0x000100ff01ff0000, 0x000100ff01ff0100, + 0x000100ff0100ffff, 0x000100ff0100ff01, 0x000100ff010000ff, 0x000100ff01000000, + 0x000100ff01000001, 0x000100ff010001ff, 0x000100ff01000101, 0x000100ff0101ff00, + 0x000100ff010100ff, 0x000100ff01010100, 0x00010000ffff0000, 0x00010000ffff01ff, + 0x00010000ffff0101, 0x00010000ff00ff00, 0x00010000ff000000, 0x00010000ff000001, + 0x00010000ff000100, 0x0001000000ff00ff, 0x0001000000ff0000, 0x0001000000ff0001, + 0x0001000000ff0100, 0x000100000000ffff, 0x000100000000ff00, 0x00010000000000ff, + 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, 0x000100000001ff00, + 0x00010000000100ff, 0x0001000000010000, 0x0001000000010001, 0x0001000000010100, + 0x0001000001ff0001, 0x0001000001ff0100, 0x0001000001ff0101, 0x000100000100ff00, + 0x0001000001000000, 0x0001000001000001, 0x0001000001000100, 0x0001000001000101, + 0x000100000101ff01, 0x0001000001010000, 0x0001000001010001, 0x00010000010101ff, + 0x00010001ffffff01, 0x00010001ffff0100, 0x00010001ff000000, 0x00010001ff01ffff, + 0x00010001ff010001, 0x00010001ff0101ff, 0x00010001ff010100, 0x0001000100ffffff, + 0x0001000100ff0000, 0x0001000100ff01ff, 0x0001000100ff0101, 0x000100010000ff00, + 0x00010001000000ff, 0x0001000100000000, 0x0001000100000001, 0x00010001000001ff, + 0x0001000100000101, 0x000100010001ffff, 0x0001000100010000, 0x00010001000101ff, + 0x0001000101ffffff, 0x0001000101ffff01, 0x0001000101ff0000, 0x0001000101ff0101, + 0x00010001010000ff, 0x0001000101000001, 0x00010001010001ff, 0x0001000101000100, + 0x000100010101ffff, 0x00010001010100ff, 0x0001000101010001, 0x0001000101010101, + 0x000101ffff000001, 0x000101ffff000100, 0x000101ffff010000, 0x000101ff00ffff00, + 0x000101ff0000ff01, 0x000101ff00000000, 0x000101ff00000101, 0x000101ff0001ff00, + 0x000101ff00010100, 0x000101ff01ff0000, 0x000101ff0100ff00, 0x000101ff010001ff, + 0x000101ff01010001, 0x00010100ffffff00, 0x00010100ffff00ff, 0x00010100ff00ffff, + 0x00010100ff000000, 0x00010100ff01ff00, 0x00010100ff0100ff, 0x00010100ff010001, + 0x00010100ff010100, 0x0001010000ffffff, 0x0001010000ffff00, 0x0001010000ff0000, + 0x0001010000ff0001, 0x0001010000ff01ff, 0x000101000000ff00, 0x00010100000000ff, + 0x0001010000000000, 0x0001010000000001, 0x0001010000000100, 0x000101000001ffff, + 0x0001010000010000, 0x0001010000010101, 0x0001010001ffff01, 0x0001010001ff00ff, + 0x0001010001ff0101, 0x0001010001000000, 0x000101000101ff00, 0x00010100010100ff, + 0x0001010001010000, 0x0001010001010100, 0x00010101ff00ff00, 0x00010101ff000001, + 0x00010101ff0001ff, 0x0001010100ffff00, 0x0001010100ff00ff, 0x0001010100ff0100, + 0x000101010000ffff, 0x0001010100000000, 0x00010101000001ff, 0x0001010100000101, + 0x00010101000100ff, 0x0001010100010000, 0x0001010100010100, 0x0001010101ff0001, + 0x00010101010000ff, 0x00010101010001ff, 0x0001010101000101, 0x0001010101010001, + 0x01ffffffffffffff, 0x01ffffffffffff01, 0x01ffffffffff01ff, 0x01ffffffffff0101, + 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, 0x01ffffffff010101, + 0x01ffffff00ff0000, 0x01ffffff0000ffff, 0x01ffffff0000ff00, 0x01ffffff000000ff, + 0x01ffffff00000001, 0x01ffffff00000100, 0x01ffffff00010000, 0x01ffffff01ffffff, + 0x01ffffff01ffff01, 0x01ffffff01ff01ff, 0x01ffffff01ff0101, 0x01ffffff01000000, + 0x01ffffff0101ffff, 0x01ffffff0101ff01, 0x01ffffff010101ff, 0x01ffffff01010101, + 0x01ffff00ffff0000, 0x01ffff00ff00ff00, 0x01ffff00ff0000ff, 0x01ffff00ff000001, + 0x01ffff00ff000100, 0x01ffff00ff010000, 0x01ffff0000ffff00, 0x01ffff0000ff00ff, + 0x01ffff0000ff0100, 0x01ffff000000ffff, 0x01ffff000000ff01, 0x01ffff0000000000, + 0x01ffff0000000001, 0x01ffff00000001ff, 0x01ffff0000000100, 0x01ffff00000100ff, + 0x01ffff0000010001, 0x01ffff0000010100, 0x01ffff0001ff0000, 0x01ffff0001ff0100, + 0x01ffff00010000ff, 0x01ffff0001000001, 0x01ffff0001000100, 0x01ffff0001010000, + 0x01ffff01ffffffff, 0x01ffff01ffffff01, 0x01ffff01ffff01ff, 0x01ffff01ffff0101, + 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff01ff01, 0x01ffff01ff0101ff, + 0x01ffff01ff010101, 0x01ffff010000ff00, 0x01ffff01000000ff, 0x01ffff0100000100, + 0x01ffff0100010000, 0x01ffff0101ffffff, 0x01ffff0101ffff01, 0x01ffff0101ff01ff, + 0x01ffff0101ff0101, 0x01ffff0101000000, 0x01ffff010101ffff, 0x01ffff010101ff01, + 0x01ffff01010101ff, 0x01ffff0101010101, 0x01ff00ffff0000ff, 0x01ff00ffff000100, + 0x01ff00ff00ffff00, 0x01ff00ff00ff00ff, 0x01ff00ff0000ff00, 0x01ff00ff00000000, + 0x01ff00ff00000101, 0x01ff00ff0001ff00, 0x01ff00ff000100ff, 0x01ff00ff00010100, + 0x01ff00ff010000ff, 0x01ff00ff01000100, 0x01ff0000ffffff00, 0x01ff0000ffff0100, + 0x01ff0000ff00ff01, 0x01ff0000ff000000, 0x01ff0000ff000101, 0x01ff0000ff010001, + 0x01ff0000ff010100, 0x01ff000000ffffff, 0x01ff000000ffff00, 0x01ff000000ff0000, + 0x01ff000000ff01ff, 0x01ff00000000ff00, 0x01ff0000000000ff, 0x01ff000000000000, + 0x01ff000000000001, 0x01ff000000000100, 0x01ff000000000101, 0x01ff000000010000, + 0x01ff000000010001, 0x01ff0000000101ff, 0x01ff000000010101, 0x01ff000001ffff00, + 0x01ff000001ff00ff, 0x01ff000001ff0001, 0x01ff000001ff0100, 0x01ff00000100ffff, + 0x01ff00000100ff01, 0x01ff000001000000, 0x01ff0000010001ff, 0x01ff000001010001, + 0x01ff0001ff00ff00, 0x01ff0001ff000001, 0x01ff0001ff000100, 0x01ff0001ff010000, + 0x01ff000100ffff00, 0x01ff000100ff00ff, 0x01ff000100ff0100, 0x01ff000100ff0101, + 0x01ff00010000ffff, 0x01ff000100000000, 0x01ff000100000100, 0x01ff000100000101, + 0x01ff00010001ff00, 0x01ff000100010001, 0x01ff000100010101, 0x01ff000101ff0000, + 0x01ff00010100ff00, 0x01ff000101000101, 0x01ff0001010100ff, 0x01ff01ffffffffff, + 0x01ff01ffffffff01, 0x01ff01ffffff01ff, 0x01ff01ffffff0101, 0x01ff01ffff000000, + 0x01ff01ffff01ffff, 0x01ff01ffff01ff01, 0x01ff01ffff0101ff, 0x01ff01ffff010101, + 0x01ff01ff00ffff00, 0x01ff01ff00ff0000, 0x01ff01ff0000ff00, 0x01ff01ff000000ff, + 0x01ff01ff00000100, 0x01ff01ff00010000, 0x01ff01ff00010100, 0x01ff01ff01ffffff, + 0x01ff01ff01ffff01, 0x01ff01ff01ff01ff, 0x01ff01ff01ff0101, 0x01ff01ff01000000, + 0x01ff01ff0101ffff, 0x01ff01ff0101ff01, 0x01ff01ff010101ff, 0x01ff01ff01010101, + 0x01ff0100ffff0000, 0x01ff0100ffff0001, 0x01ff0100ff00ff00, 0x01ff0100ff0000ff, + 0x01ff0100ff000001, 0x01ff0100ff010000, 0x01ff010000ffff00, 0x01ff010000ff00ff, + 0x01ff010000ff0001, 0x01ff010000ff0100, 0x01ff01000000ffff, 0x01ff01000000ff01, + 0x01ff010000000000, 0x01ff010000000101, 0x01ff01000001ff00, 0x01ff0100000100ff, + 0x01ff010001ff0000, 0x01ff010001000001, 0x01ff010001000100, 0x01ff010001010000, + 0x01ff0101ffffffff, 0x01ff0101ffffff01, 0x01ff0101ffff01ff, 0x01ff0101ffff0101, + 0x01ff0101ff000000, 0x01ff0101ff01ffff, 0x01ff0101ff01ff01, 0x01ff0101ff0101ff, + 0x01ff0101ff010101, 0x01ff010100ff0000, 0x01ff01010000ff00, 0x01ff0101000000ff, + 0x01ff010100000001, 0x01ff010101ffffff, 0x01ff010101ffff01, 0x01ff010101ff01ff, + 0x01ff010101ff0101, 0x01ff010101000000, 0x01ff01010101ffff, 0x01ff01010101ff01, + 0x01ff0101010101ff, 0x01ff010101010101, 0x0100ffffffff0000, 0x0100ffffff00ff00, + 0x0100ffffff000001, 0x0100ffffff0001ff, 0x0100ffffff000100, 0x0100ffffff010000, + 0x0100ffff00ffff00, 0x0100ffff00ff0001, 0x0100ffff00ff0100, 0x0100ffff00000000, + 0x0100ffff000001ff, 0x0100ffff00000101, 0x0100ffff00010100, 0x0100ffff00010101, + 0x0100ffff01ff0000, 0x0100ffff0100ff00, 0x0100ffff010000ff, 0x0100ffff01000001, + 0x0100ffff01000100, 0x0100ffff01010000, 0x0100ff00ffffff00, 0x0100ff00ffff00ff, + 0x0100ff00ffff0001, 0x0100ff00ffff0100, 0x0100ff00ff00ffff, 0x0100ff00ff000000, + 0x0100ff00ff0001ff, 0x0100ff00ff000101, 0x0100ff00ff01ff00, 0x0100ff00ff0100ff, + 0x0100ff00ff010001, 0x0100ff00ff010100, 0x0100ff0000ffffff, 0x0100ff0000ff0000, + 0x0100ff000000ffff, 0x0100ff000000ff00, 0x0100ff00000000ff, 0x0100ff0000000000, + 0x0100ff0000000001, 0x0100ff0000000100, 0x0100ff000001ff01, 0x0100ff0000010000, + 0x0100ff0001ff00ff, 0x0100ff0001ff0001, 0x0100ff000100ff01, 0x0100ff0001000000, + 0x0100ff00010001ff, 0x0100ff000101ff00, 0x0100ff00010100ff, 0x0100ff0001010001, + 0x0100ff0001010100, 0x0100ff01ffff0000, 0x0100ff01ff00ff00, 0x0100ff01ff0000ff, + 0x0100ff01ff000100, 0x0100ff01ff010000, 0x0100ff0100ff00ff, 0x0100ff0100ff0001, + 0x0100ff0100ff0100, 0x0100ff010000ffff, 0x0100ff010000ff01, 0x0100ff0100000000, + 0x0100ff01000001ff, 0x0100ff0100010001, 0x0100ff0100010100, 0x0100ff0101ff0000, + 0x0100ff01010000ff, 0x0100ff0101000001, 0x0100ff0101010100, 0x010000ffffffff00, + 0x010000ffffff00ff, 0x010000ffffff0001, 0x010000ffff00ffff, 0x010000ffff000000, + 0x010000ffff0001ff, 0x010000ffff010001, 0x010000ff00ffffff, 0x010000ff00ff0101, + 0x010000ff0000ff00, 0x010000ff000000ff, 0x010000ff00000000, 0x010000ff00000001, + 0x010000ff000001ff, 0x010000ff00000100, 0x010000ff0001ffff, 0x010000ff0001ff00, + 0x010000ff0001ff01, 0x010000ff00010000, 0x010000ff01ff00ff, 0x010000ff01ff0001, + 0x010000ff0100ff01, 0x010000ff010000ff, 0x010000ff01000000, 0x010000ff010001ff, + 0x010000ff0101ff00, 0x010000ff01010100, 0x01000000ffffffff, 0x01000000ffff0000, + 0x01000000ffff01ff, 0x01000000ffff0101, 0x01000000ff00ffff, 0x01000000ff00ff00, + 0x01000000ff0000ff, 0x01000000ff000000, 0x01000000ff000001, 0x01000000ff000100, + 0x01000000ff01ff00, 0x01000000ff010000, 0x01000000ff010100, 0x01000000ff010101, + 0x0100000000ffff00, 0x0100000000ff00ff, 0x0100000000ff0000, 0x0100000000ff0001, + 0x0100000000ff0100, 0x010000000000ffff, 0x010000000000ff00, 0x010000000000ff01, + 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, 0x01000000000001ff, + 0x0100000000000100, 0x0100000000000101, 0x010000000001ff00, 0x01000000000100ff, + 0x0100000000010000, 0x0100000000010001, 0x0100000000010100, 0x0100000001ffff00, + 0x0100000001ff0000, 0x0100000001ff01ff, 0x010000000100ff00, 0x010000000100ff01, + 0x01000000010000ff, 0x0100000001000000, 0x0100000001000001, 0x0100000001000100, + 0x0100000001000101, 0x010000000101ffff, 0x010000000101ff01, 0x0100000001010000, + 0x01000000010101ff, 0x0100000001010101, 0x01000001ffffff00, 0x01000001ffff00ff, + 0x01000001ff00ffff, 0x01000001ff000000, 0x01000001ff000100, 0x01000001ff01ffff, + 0x01000001ff010001, 0x01000001ff010100, 0x0100000100ff0000, 0x0100000100ff01ff, + 0x0100000100ff0100, 0x010000010000ff00, 0x010000010000ff01, 0x0100000100000000, + 0x0100000100000001, 0x0100000100000100, 0x0100000100010000, 0x01000001000101ff, + 0x0100000101ffff01, 0x0100000101ff00ff, 0x0100000101ff0100, 0x0100000101ff0101, + 0x010000010100ff01, 0x01000001010000ff, 0x0100000101000000, 0x01000001010100ff, + 0x0100000101010001, 0x0100000101010100, 0x010001ffffff0000, 0x010001ffff000001, + 0x010001ffff000100, 0x010001ffff010000, 0x010001ff00ffff00, 0x010001ff00ff0001, + 0x010001ff0000ffff, 0x010001ff0000ff01, 0x010001ff00000000, 0x010001ff00000001, + 0x010001ff00000101, 0x010001ff000100ff, 0x010001ff00010000, 0x010001ff01ff0000, + 0x010001ff0100ff00, 0x010001ff01000001, 0x010001ff01000100, 0x010001ff01010000, + 0x01000100ffff00ff, 0x01000100ffff0001, 0x01000100ffff0100, 0x01000100ff00ffff, + 0x01000100ff00ff01, 0x01000100ff000000, 0x01000100ff0001ff, 0x01000100ff000101, + 0x01000100ff01ffff, 0x01000100ff01ff00, 0x01000100ff0100ff, 0x01000100ff010001, + 0x0100010000ffffff, 0x0100010000ffff01, 0x0100010000ff0000, 0x0100010000ff01ff, + 0x0100010000ff0101, 0x010001000000ff00, 0x01000100000000ff, 0x0100010000000000, + 0x0100010000000001, 0x0100010000000100, 0x010001000001ff01, 0x0100010000010000, + 0x0100010000010001, 0x0100010000010101, 0x0100010001ffff00, 0x0100010001ff00ff, + 0x010001000100ffff, 0x010001000100ff01, 0x0100010001000000, 0x0100010001000101, + 0x010001000101ff00, 0x0100010001010001, 0x01000101ffff0000, 0x01000101ff000000, + 0x01000101ff010000, 0x0100010100ff00ff, 0x0100010100ff0001, 0x0100010100ff0100, + 0x010001010000ffff, 0x0100010100000000, 0x01000101000001ff, 0x010001010001ff00, + 0x0100010101ff0000, 0x010001010100ff00, 0x01000101010000ff, 0x0100010101000000, + 0x0100010101000001, 0x0101ffffffffffff, 0x0101ffffffffff01, 0x0101ffffffff01ff, + 0x0101ffffffff0101, 0x0101ffffff000000, 0x0101ffffff01ffff, 0x0101ffffff01ff01, + 0x0101ffffff0101ff, 0x0101ffffff010101, 0x0101ffff00ff0000, 0x0101ffff0000ff00, + 0x0101ffff000000ff, 0x0101ffff00000001, 0x0101ffff00000100, 0x0101ffff01ffffff, + 0x0101ffff01ffff01, 0x0101ffff01ff01ff, 0x0101ffff01ff0101, 0x0101ffff01000000, + 0x0101ffff0101ffff, 0x0101ffff0101ff01, 0x0101ffff010101ff, 0x0101ffff01010101, + 0x0101ff00ffff0000, 0x0101ff00ffff0100, 0x0101ff00ff00ff00, 0x0101ff00ff0000ff, + 0x0101ff00ff000001, 0x0101ff00ff000100, 0x0101ff00ff000101, 0x0101ff0000ff0001, + 0x0101ff0000ff0100, 0x0101ff000000ff00, 0x0101ff0000000000, 0x0101ff00000001ff, + 0x0101ff0000000101, 0x0101ff000001ff00, 0x0101ff00000100ff, 0x0101ff0001ff0000, + 0x0101ff000100ffff, 0x0101ff000100ff01, 0x0101ff0001000001, 0x0101ff0001000100, + 0x0101ff01ffffff01, 0x0101ff01ffff01ff, 0x0101ff01ffff0101, 0x0101ff01ff00ffff, + 0x0101ff01ff000100, 0x0101ff01ff01ff01, 0x0101ff01ff0101ff, 0x0101ff01ff010101, + 0x0101ff0100ff0000, 0x0101ff010000ff00, 0x0101ff0100000001, 0x0101ff0100000100, + 0x0101ff0100010000, 0x0101ff0101ffffff, 0x0101ff0101ffff01, 0x0101ff0101ff01ff, + 0x0101ff0101ff0101, 0x0101ff0101000000, 0x0101ff010101ffff, 0x0101ff010101ff01, + 0x0101ff01010101ff, 0x0101ff0101010101, 0x010100ffff000100, 0x010100ffff010000, + 0x010100ff00ffff00, 0x010100ff00ff00ff, 0x010100ff0000ffff, 0x010100ff000000ff, + 0x010100ff00000000, 0x010100ff000001ff, 0x010100ff00000101, 0x010100ff0001ff00, + 0x010100ff00010000, 0x010100ff00010001, 0x010100ff000101ff, 0x010100ff00010100, + 0x010100ff01ff0000, 0x01010000ffff0001, 0x01010000ffff0100, 0x01010000ff00ffff, + 0x01010000ff00ff01, 0x01010000ff000000, 0x01010000ff0001ff, 0x01010000ff010001, + 0x01010000ff010100, 0x0101000000ffff01, 0x0101000000ff0000, 0x010100000000ff00, + 0x01010000000000ff, 0x0101000000000000, 0x0101000000000001, 0x0101000000000100, + 0x0101000000010000, 0x0101000000010101, 0x0101000001ffff00, 0x0101000001ff00ff, + 0x0101000001ff0000, 0x0101000001ff0001, 0x0101000001ff0100, 0x010100000100ff01, + 0x0101000001000000, 0x01010000010001ff, 0x01010001ffff0000, 0x01010001ff00ff00, + 0x01010001ff000001, 0x01010001ff000101, 0x01010001ff01ff00, 0x01010001ff010000, + 0x0101000100ff00ff, 0x0101000100ff0001, 0x0101000100ff0101, 0x010100010000ff01, + 0x0101000100000000, 0x0101000100000001, 0x01010001000001ff, 0x010100010001ffff, + 0x010100010001ff01, 0x0101000101ff0001, 0x010100010100ffff, 0x0101000101000000, + 0x0101000101000001, 0x0101000101000100, 0x010100010101ff00, 0x01010001010100ff, + 0x0101000101010001, 0x010101ffffffffff, 0x010101ffffffff01, 0x010101ffffff01ff, + 0x010101ffffff0101, 0x010101ffff01ffff, 0x010101ffff01ff01, 0x010101ffff0101ff, + 0x010101ffff010101, 0x010101ff0000ff00, 0x010101ff000000ff, 0x010101ff00000001, + 0x010101ff00000100, 0x010101ff01ffffff, 0x010101ff01ffff01, 0x010101ff01ff01ff, + 0x010101ff01ff0101, 0x010101ff01000000, 0x010101ff0101ffff, 0x010101ff0101ff01, + 0x010101ff010101ff, 0x010101ff01010101, 0x01010100ffff0000, 0x01010100ff0000ff, + 0x01010100ff000100, 0x01010100ff01ff00, 0x01010100ff010000, 0x0101010000ffff00, + 0x010101000000ffff, 0x0101010000000000, 0x0101010000000101, 0x010101000001ff00, + 0x0101010000010001, 0x0101010000010100, 0x010101000100ffff, 0x0101010001000001, + 0x01010101ffffffff, 0x01010101ffffff01, 0x01010101ffff01ff, 0x01010101ffff0101, + 0x01010101ff01ffff, 0x01010101ff01ff01, 0x01010101ff0101ff, 0x01010101ff010101, + 0x010101010000ff00, 0x01010101000000ff, 0x0101010100000001, 0x0101010101ffffff, + 0x0101010101ffff01, 0x0101010101ff01ff, 0x0101010101ff0101, 0x0101010101000000, + 0x010101010101ffff, 0x010101010101ff01, 0x01010101010101ff, 0x0101010101010101, +GGML_TABLE_END() +#else +GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S) + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, +GGML_TABLE_END() +#endif + +#endif // GGML_COMMON_IMPL +#endif // GGML_COMMON_IMPL diff --git a/bindings/ruby/ext/ggml-cuda.h b/bindings/ruby/ext/ggml-cuda.h new file mode 100644 index 00000000000..5eb4af40f4d --- /dev/null +++ b/bindings/ruby/ext/ggml-cuda.h @@ -0,0 +1,43 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_HIPBLAS +#define GGML_CUDA_NAME "ROCm" +#define GGML_CUBLAS_NAME "hipBLAS" +#else +#define GGML_CUDA_NAME "CUDA" +#define GGML_CUBLAS_NAME "cuBLAS" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_CUDA_MAX_DEVICES 16 + +// backend API +GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device); + +GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend); + +// device buffer +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); + +// split tensor buffer that splits matrices by rows across multiple devices +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split); + +// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void); + +GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void); +GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size); +GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total); + +GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size); +GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer); + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ggml-impl.h b/bindings/ruby/ext/ggml-impl.h index d88f261449f..93a4f1a2b72 100644 --- a/bindings/ruby/ext/ggml-impl.h +++ b/bindings/ruby/ext/ggml-impl.h @@ -5,6 +5,7 @@ // GGML internal header #include +#include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ #include #include #include // memcpy @@ -18,6 +19,7 @@ extern "C" { // fall back to the _Static_assert C11 keyword. // if C99 - static_assert is noop // ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef __cplusplus #ifndef static_assert #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) #define static_assert(cond, msg) _Static_assert(cond, msg) @@ -25,6 +27,7 @@ extern "C" { #define static_assert(cond, msg) struct global_scope_noop_trick #endif #endif +#endif // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) @@ -34,16 +37,17 @@ extern "C" { #ifndef __F16C__ #define __F16C__ #endif +#endif + +// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available +#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)) #ifndef __SSE3__ #define __SSE3__ #endif +#ifndef __SSSE3__ +#define __SSSE3__ +#endif #endif - -#undef MIN -#undef MAX - -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) // 16-bit float // on Arm, we use __fp16 @@ -56,14 +60,30 @@ extern "C" { // #include -#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) -#define GGML_COMPUTE_FP32_TO_FP16(x) (x) +typedef __fp16 ggml_fp16_internal_t; -#define GGML_FP16_TO_FP32(x) ((float) (x)) -#define GGML_FP32_TO_FP16(x) (x) +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + ggml_fp16_internal_t tmp; + memcpy(&tmp, &h, sizeof(ggml_fp16_t)); + return (float)tmp; +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + ggml_fp16_t res; + ggml_fp16_internal_t tmp = f; + memcpy(&res, &tmp, sizeof(ggml_fp16_t)); + return res; +} #else +typedef uint16_t ggml_fp16_internal_t; + #ifdef __wasm_simd128__ #include #else @@ -217,8 +237,7 @@ extern float ggml_table_f32_f16[1 << 16]; // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. -#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) - +#if !defined(GGML_FP16_TO_FP32) inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { uint16_t s; memcpy(&s, &f, sizeof(uint16_t)); @@ -226,19 +245,23 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { } #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) +#endif +#if !defined(GGML_FP32_TO_FP16) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) #endif #define GGML_HASHTABLE_FULL ((size_t)-1) #define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2) +struct ggml_hash_set ggml_hash_set_new(size_t size); + bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key); // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key); -// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full +// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key); // return index, asserts if table is full diff --git a/bindings/ruby/ext/ggml-kompute.h b/bindings/ruby/ext/ggml-kompute.h new file mode 100644 index 00000000000..171465456a5 --- /dev/null +++ b/bindings/ruby/ext/ggml-kompute.h @@ -0,0 +1,46 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_vk_device { + int index; + int type; // same as VkPhysicalDeviceType + size_t heapSize; + const char * name; + const char * vendor; + int subgroupSize; + uint64_t bufferAlignment; + uint64_t maxAlloc; +}; + +struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count); +bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name); +bool ggml_vk_has_vulkan(void); +bool ggml_vk_has_device(void); +struct ggml_vk_device ggml_vk_current_device(void); + +// +// backend API +// + +// forward declaration +typedef struct ggml_backend * ggml_backend_t; + +GGML_API ggml_backend_t ggml_backend_kompute_init(int device); + +GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend); + +GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device); + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ggml-metal.h b/bindings/ruby/ext/ggml-metal.h new file mode 100644 index 00000000000..a5c542189c2 --- /dev/null +++ b/bindings/ruby/ext/ggml-metal.h @@ -0,0 +1,66 @@ +// An interface allowing to compute ggml_cgraph with Metal +// +// This is a fully functional interface that extends ggml with GPU support for Apple devices. +// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.) +// +// How it works? +// +// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this +// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you +// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.) +// +// You only need to make sure that all memory buffers that you used during the graph creation +// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is +// used during the graph evaluation to determine the arguments of the compute kernels. +// +// Synchronization between device and host memory (for example for input and output tensors) +// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions. +// + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include +#include + +// max memory buffers that can be mapped to the device +#define GGML_METAL_MAX_BUFFERS 64 + +struct ggml_tensor; +struct ggml_cgraph; + +#ifdef __cplusplus +extern "C" { +#endif + +// +// backend API +// user-code should use only these functions +// + +GGML_API void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data); + +GGML_API ggml_backend_t ggml_backend_metal_init(void); + +GGML_API bool ggml_backend_is_metal(ggml_backend_t backend); + +GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size); + +GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb); + +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); + +// helper to check if the device supports a specific family +// ideally, the user code should be doing these checks +// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf +GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family); + +// capture all command buffers committed the next time `ggml_backend_graph_compute` is called +GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); + +#ifdef __cplusplus +} +#endif + diff --git a/ggml-opencl.h b/bindings/ruby/ext/ggml-opencl.h similarity index 100% rename from ggml-opencl.h rename to bindings/ruby/ext/ggml-opencl.h diff --git a/bindings/ruby/ext/ggml-quants.c b/bindings/ruby/ext/ggml-quants.c index 740be6dc5c7..32e84434a8c 100644 --- a/bindings/ruby/ext/ggml-quants.c +++ b/bindings/ruby/ext/ggml-quants.c @@ -1,10 +1,18 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + #include "ggml-quants.h" #include "ggml-impl.h" +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + #include #include #include #include +#include // for qsort +#include // for GGML_ASSERT #ifdef __ARM_NEON @@ -14,32 +22,12 @@ // #include -#if !defined(__aarch64__) -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); -} - -inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { - int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); - int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); - return vcombine_s16(a0, b0); -} - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} -#endif - #else #ifdef __wasm_simd128__ #include #else -#ifdef __POWER9_VECTOR__ +#if defined(__POWER9_VECTOR__) || defined(__powerpc64__) #include #undef bool #define bool _Bool @@ -47,13 +35,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #if defined(_MSC_VER) || defined(__MINGW32__) #include #else -#if !defined(__riscv) && !defined(__s390__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if !defined(__riscv) #include #endif #endif #endif #endif #endif +#endif #ifdef __riscv_v_intrinsic #include @@ -61,9 +51,13 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #undef MIN #undef MAX + #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define UNUSED GGML_UNUSED + +// some compilers don't provide _mm256_set_m128i, e.g. gcc 7 #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) @@ -138,7 +132,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { } static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if __AVXVNNI__ +#if defined(__AVXVNNI__) || defined(__AVX512VNNI__) const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); @@ -284,8 +278,50 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #if defined(__ARM_NEON) +#ifdef _MSC_VER + +#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } + +#else + +#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } + +#endif + #if !defined(__aarch64__) +// 64-bit compatibility + +// vaddvq_s16 +// vpaddq_s16 +// vpaddq_s32 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 +// vzip1_u8 +// vzip2_u8 + +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { + int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +} + inline static int32_t vaddvq_s32(int32x4_t v) { return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); } @@ -311,7 +347,185 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { return res; } +inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[0]; res[1] = b[0]; + res[2] = a[1]; res[3] = b[1]; + res[4] = a[2]; res[5] = b[2]; + res[6] = a[3]; res[7] = b[3]; + + return res; +} + +inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[4]; res[1] = b[4]; + res[2] = a[5]; res[3] = b[5]; + res[4] = a[6]; res[5] = b[6]; + res[6] = a[7]; res[7] = b[7]; + + return res; +} + +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct ggml_int8x16x2_t { + int8x16_t val[2]; +} ggml_int8x16x2_t; + +inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { + ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +// NOTE: not tested +inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +// NOTE: not tested +inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { + uint8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +#else + +#define ggml_int16x8x2_t int16x8x2_t +#define ggml_uint8x16x2_t uint8x16x2_t +#define ggml_uint8x16x4_t uint8x16x4_t +#define ggml_int8x16x2_t int8x16x2_t +#define ggml_int8x16x4_t int8x16x4_t + +#define ggml_vld1q_s16_x2 vld1q_s16_x2 +#define ggml_vld1q_u8_x2 vld1q_u8_x2 +#define ggml_vld1q_u8_x4 vld1q_u8_x4 +#define ggml_vld1q_s8_x2 vld1q_s8_x2 +#define ggml_vld1q_s8_x4 vld1q_s8_x4 +#define ggml_vqtbl1q_s8 vqtbl1q_s8 +#define ggml_vqtbl1q_u8 vqtbl1q_u8 + +#endif + +#if !defined(__ARM_FEATURE_DOTPROD) + +inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { + const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); + const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + + return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); +} + +#else + +#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) + #endif + #endif #if defined(__ARM_NEON) || defined(__wasm_simd128__) @@ -330,7 +544,7 @@ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 #endif // reference implementation for deterministic creation of model files -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { +void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -367,11 +581,12 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict } } -void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { +void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q4_0_reference(x, y, k); } -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { + +void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; assert(k % qk == 0); @@ -408,11 +623,11 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict } } -void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { +void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q4_1_reference(x, y, k); } -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { +void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -456,11 +671,11 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict } } -void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { +void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q5_0_reference(x, y, k); } -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { +void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; assert(k % qk == 0); @@ -504,12 +719,12 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict } } -void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { +void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q5_1_reference(x, y, k); } // reference implementation for deterministic creation of model files -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { +void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -534,7 +749,7 @@ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict } } -void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -723,7 +938,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { } // reference implementation for deterministic creation of model files -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { +void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -739,7 +954,7 @@ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); int sum = 0; @@ -754,11 +969,11 @@ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict sum += y[i].qs[QK8_1/2 + j]; } - y[i].s = sum*d; + y[i].s = GGML_FP32_TO_FP16(sum*d); } } -void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -782,7 +997,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); int32x4_t accv = vdupq_n_s32(0); @@ -798,7 +1013,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { accv = vaddq_s32(accv, vi); } - y[i].s = d * vaddvq_s32(accv); + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } #elif defined(__wasm_simd128__) for (int i = 0; i < nb; i++) { @@ -821,7 +1036,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); v128_t accv = wasm_i32x4_splat(0); @@ -837,10 +1052,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { accv = wasm_i32x4_add(accv, vi); } - y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) + - wasm_i32x4_extract_lane(accv, 1) + - wasm_i32x4_extract_lane(accv, 2) + - wasm_i32x4_extract_lane(accv, 3)); + y[i].s = GGML_FP32_TO_FP16( + d * (wasm_i32x4_extract_lane(accv, 0) + + wasm_i32x4_extract_lane(accv, 1) + + wasm_i32x4_extract_lane(accv, 2) + + wasm_i32x4_extract_lane(accv, 3))); } #elif defined(__AVX2__) || defined(__AVX__) for (int i = 0; i < nb; i++) { @@ -865,7 +1081,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { // Quantize these floats const float d = maxScalar / 127.f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -889,7 +1105,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { #if defined(__AVX2__) // Compute the sum of the quants and set y[i].s - y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -919,7 +1135,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { // Compute the sum of the quants and set y[i].s const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); - y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1)); + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1))); // Convert int32 to int16 ni0 = _mm_packs_epi32( ni0, ni1 ); @@ -950,7 +1166,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); @@ -967,7 +1183,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { // set y[i].s int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); - y[i].s = sum*d; + y[i].s = GGML_FP32_TO_FP16(sum*d); } #else GGML_UNUSED(nb); @@ -976,7 +1192,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { #endif } -void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -996,7 +1212,7 @@ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int } } -void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { +void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_1; assert(k % qk == 0); @@ -1017,7 +1233,7 @@ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int } } -void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -1043,7 +1259,7 @@ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int } } -void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { +void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int64_t k) { static const int qk = QK5_1; assert(k % qk == 0); @@ -1070,7 +1286,7 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int } } -void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK8_0; assert(k % qk == 0); @@ -1100,7 +1316,8 @@ static inline int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } -static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) { +static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type, + const float * restrict qw) { float max = 0; float amax = 0; for (int i = 0; i < n; ++i) { @@ -1126,14 +1343,18 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * rmse_type = -rmse_type; return_early = true; } - int weight_type = rmse_type%2; float sumlx = 0; float suml2 = 0; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 0; i < n; ++i) { +#else for (int i = 0; i < n; ++i) { +#endif int l = nearest_int(iscale * x[i]); l = MAX(-nmax, MIN(nmax-1, l)); L[i] = l + nmax; - float w = weight_type == 1 ? x[i] * x[i] : 1; + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); sumlx += w*x[i]*l; suml2 += w*l*l; } @@ -1149,7 +1370,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * for (int i = 0; i < n; ++i) { int l = nearest_int(iscale * x[i]); l = MAX(-nmax, MIN(nmax-1, l)); - float w = weight_type == 1 ? x[i] * x[i] : 1; + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); sumlx += w*x[i]*l; suml2 += w*l*l; } @@ -1273,7 +1494,12 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const f float max = x[0]; float sum_w = weights[0]; float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else for (int i = 1; i < n; ++i) { +#endif if (x[i] < min) min = x[i]; if (x[i] > max) max = x[i]; float w = weights[i]; @@ -1355,7 +1581,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * //========================- 2-bit (de)-quantization -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) { +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1432,7 +1658,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict } } -void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) { +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1478,64 +1704,322 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int } } -void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { quantize_row_q2_K_reference(x, vy, k); } -size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - (void)hist; // TODO: collect histograms - - for (int j = 0; j < n; j += k) { - block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K; - quantize_row_q2_K_reference(src + j, y, k); +static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, + uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights ? weights[0] : x[0]*x[0]; + float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else + for (int i = 1; i < n; ++i) { +#endif + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights ? weights[i] : x[i]*x[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) { + min = 0; + } + if (max <= min) { + memset(L, 0, n); + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff*diff; + float w = weights ? weights[i] : x[i]*x[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights ? weights[i] : x[i]*x[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff*diff; + float w = weights ? weights[i] : x[i]*x[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } } - return (n/QK_K*sizeof(block_q2_K)); + *the_min = -min; + return scale; } -//========================= 3-bit (de)-quantization +static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) { + float max = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + if (!max) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = nmax / max; + for (int i = 0; i < n; ++i) { + L[i] = nearest_int(iscale * x[i]); + } + float scale = 1/iscale; + float best_mse = 0; + for (int i = 0; i < n; ++i) { + float diff = x[i] - scale*L[i]; + float w = quant_weights[i]; + best_mse += w*diff*diff; + } + for (int is = -4; is <= 4; ++is) { + if (is == 0) continue; + float iscale_is = (0.1f*is + nmax)/max; + float scale_is = 1/iscale_is; + float mse = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale_is*x[i]); + l = MIN(nmax, l); + float diff = x[i] - scale_is*l; + float w = quant_weights[i]; + mse += w*diff*diff; + } + if (mse < best_mse) { + best_mse = mse; + iscale = iscale_is; + } + } + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MIN(nmax, l); + L[i] = l; + float w = quant_weights[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = quant_weights[i]; + float slx = sumlx - w*x[i]*L[i]; + float sl2 = suml2 - w*L[i]*L[i]; + if (slx > 0 && sl2 > 0) { + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MIN(nmax, new_l); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + return sumlx / suml2; +} -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { +static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) { + GGML_ASSERT(quant_weights); assert(k % QK_K == 0); const int nb = k / QK_K; + const bool requantize = true; - int8_t L[QK_K]; - float scales[QK_K / 16]; + uint8_t L[QK_K]; + uint8_t Laux[16]; + float mins[QK_K/16]; + float scales[QK_K/16]; + float sw[QK_K/16]; + float weight[16]; + uint8_t Ls[QK_K/16], Lm[QK_K/16]; for (int i = 0; i < nb; i++) { - - float max_scale = 0; - float amax = 0; + memset(sw, 0, QK_K/16*sizeof(float)); + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; + float sigma2 = sumx2/QK_K; for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); - float scale = fabsf(scales[j]); - if (scale > amax) { - amax = scale; max_scale = scales[j]; - } + const float * restrict qw = quant_weights + QK_K * i + 16*j; + for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); + for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l]; + scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); } -#if QK_K == 256 - memset(y[i].scales, 0, 12); + float dm, mm; +#if QK_K == 64 + float max_scale = 0, max_min = 0; + for (int j = 0; j < QK_K/16; ++j) { + max_scale = MAX(max_scale, scales[j]); + max_min = MAX(max_min, mins[j]); + } + dm = max_scale/15; + mm = max_min/15; if (max_scale) { - float iscale = -32.f/max_scale; + float id = 1/dm; for (int j = 0; j < QK_K/16; ++j) { - int8_t l = nearest_int(iscale*scales[j]); - l = MAX(-32, MIN(31, l)) + 32; - if (j < 8) { - y[i].scales[j] = l & 0xF; - } else { - y[i].scales[j-8] |= ((l & 0xF) << 4); - } - l >>= 4; - y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + int l = nearest_int(id*scales[j]); + Ls[j] = MAX(0, MIN(15, l)); } - y[i].d = GGML_FP32_TO_FP16(1/iscale); } else { - y[i].d = GGML_FP32_TO_FP16(0.f); + memset(Ls, 0, QK_K/16); } + if (max_min) { + float id = 1/mm; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(id*mins[j]); + Lm[j] = MAX(0, MIN(15, l)); + } + } else { + memset(Lm, 0, QK_K/16); + } +#else + dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); + mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); +#endif + y[i].d = GGML_FP32_TO_FP16(dm); + y[i].dmin = GGML_FP32_TO_FP16(mm); + dm = GGML_FP16_TO_FP32(y[i].d); + mm = GGML_FP16_TO_FP32(y[i].dmin); - int8_t sc; for (int j = 0; j < QK_K/16; ++j) { - sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + y[i].scales[j] = Ls[j] | (Lm[j] << 4); + } + + if (requantize) { + for (int j = 0; j < QK_K/16; ++j) { + const float d = dm * (y[i].scales[j] & 0xF); + if (!d) continue; + const float m = mm * (y[i].scales[j] >> 4); + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int((x[16*j + ii] + m)/d); + l = MAX(0, MIN(3, l)); + L[16*j + ii] = l; + } + } + } + +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif + + x += QK_K; + + } +} + +size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); + if (!quant_weights) { + quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); + float scale = fabsf(scales[j]); + if (scale > amax) { + amax = scale; max_scale = scales[j]; + } + } + +#if QK_K == 256 + memset(y[i].scales, 0, 12); + if (max_scale) { + float iscale = -32.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int8_t l = nearest_int(iscale*scales[j]); + l = MAX(-32, MIN(31, l)) + 32; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = GGML_FP32_TO_FP16(1/iscale); + } else { + y[i].d = GGML_FP32_TO_FP16(0.f); + } + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; float d = GGML_FP16_TO_FP32(y[i].d) * sc; if (!d) { @@ -1608,7 +2092,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict } #if QK_K == 256 -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1658,7 +2142,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } } #else -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); assert(QK_K == 64); const int nb = k / QK_K; @@ -1691,23 +2175,118 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } #endif -void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { quantize_row_q3_K_reference(x, vy, k); } -size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - (void)hist; // TODO: collect histograms +static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { +#if QK_K != 256 + (void)quant_weights; + quantize_row_q3_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int nb = n_per_row / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + float weight[16]; + float sw[QK_K / 16]; + int8_t Ls[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float sumx2 = 0; + for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; + float sigma2 = 2*sumx2/QK_K; + + for (int j = 0; j < QK_K/16; ++j) { + if (quant_weights) { + const float * qw = quant_weights ? quant_weights + QK_K * i + 16*j : NULL; + for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]); + } else { + for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l]; + } + float sumw = 0; + for (int l = 0; l < 16; ++l) sumw += weight[l]; + sw[j] = sumw; + + scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight); + + } + + memset(y[i].scales, 0, 12); + + float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw); + for (int j = 0; j < QK_K/16; ++j) { + int l = Ls[j]; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = GGML_FP32_TO_FP16(d_block); + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; + float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } + + memset(y[i].hmask, 0, QK_K/8); + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. + int m = 0; + uint8_t hm = 1; + for (int j = 0; j < QK_K; ++j) { + if (L[j] > 3) { + y[i].hmask[m] |= hm; + L[j] -= 4; + } + if (++m == QK_K/8) { + m = 0; hm <<= 1; + } + } + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } + + x += QK_K; + } +#endif +} - for (int j = 0; j < n; j += k) { - block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K; - quantize_row_q3_K_reference(src + j, y, k); +size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); + if (!quant_weights) { + quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); } - return (n/QK_K*sizeof(block_q3_K)); + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; } // ====================== 4-bit (de)-quantization -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1814,7 +2393,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } -void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) { +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1853,28 +2432,111 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int } } -void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q4_K * restrict y = vy; quantize_row_q4_K_reference(x, y, k); } -size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - assert(k % QK_K == 0); - (void)hist; // TODO: collect histograms +static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { +#if QK_K != 256 + (void)quant_weights; + quantize_row_q4_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int64_t nb = n_per_row / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[32]; + uint8_t Ls[QK_K/32]; + uint8_t Lm[QK_K/32]; + float weights[32]; + float sw[QK_K/32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float sum_x2 = 0; + for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l]; + float sigma2 = 2*sum_x2/QK_K; + float av_x = sqrtf(sigma2); + + for (int j = 0; j < QK_K/32; ++j) { + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 32*j; + for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]); + } else { + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + } + float sumw = 0; + for (int l = 0; l < 32; ++l) sumw += weights[l]; + sw[j] = sumw; + scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); + } + + float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw); + float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw); + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = Ls[j]; + uint8_t lm = Lm[j]; + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = GGML_FP32_TO_FP16(d_block); + y[i].dmin = GGML_FP32_TO_FP16(m_block); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; + } + + x += QK_K; + + } +#endif +} - for (int j = 0; j < n; j += k) { - block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K; - quantize_row_q4_K_reference(src + j, y, k); +size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); + if (!quant_weights) { + quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } } - return (n/QK_K*sizeof(block_q4_K)); + return nrow * row_size; } // ====================== 5-bit (de)-quantization -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) { +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 uint8_t L[QK_K]; @@ -1965,7 +2627,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict #else float max_scale = 0, amax = 0; for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); + scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1, NULL); float abs_scale = fabsf(scales[j]); if (abs_scale > amax) { amax = abs_scale; @@ -2014,9 +2676,9 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict } } -void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) { +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -2059,78 +2721,181 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int } } -void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q5_K * restrict y = vy; quantize_row_q5_K_reference(x, y, k); } -size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - assert(k % QK_K == 0); - (void)hist; // TODO: collect histograms - - for (int j = 0; j < n; j += k) { - block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; - quantize_row_q5_K_reference(src + j, y, k); - } - return (n/QK_K*sizeof(block_q5_K)); -} - -// ====================== 6-bit (de)-quantization - -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; +static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { +#if QK_K != 256 + (void)quant_weights; + quantize_row_q5_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int64_t nb = n_per_row / QK_K; - int8_t L[QK_K]; - float scales[QK_K/16]; + uint8_t L[QK_K]; + uint8_t Laux[32]; + uint8_t Ls[QK_K/32]; + uint8_t Lm[QK_K/32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + float sw[QK_K/32]; + float weights[32]; for (int i = 0; i < nb; i++) { - float max_scale = 0; - float max_abs_scale = 0; - - for (int ib = 0; ib < QK_K/16; ++ib) { - - const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1); - scales[ib] = scale; + float sum_x2 = 0; + for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l]; + float sigma2 = 2*sum_x2/QK_K; + float av_x = sqrtf(sigma2); - const float abs_scale = fabsf(scale); - if (abs_scale > max_abs_scale) { - max_abs_scale = abs_scale; - max_scale = scale; + for (int j = 0; j < QK_K/32; ++j) { + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 32*j; + for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]); + } else { + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); } + float sumw = 0; + for (int l = 0; l < 32; ++l) sumw += weights[l]; + sw[j] = sumw; + scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); } - if (!max_abs_scale) { - memset(&y[i], 0, sizeof(block_q6_K)); - y[i].d = GGML_FP32_TO_FP16(0.f); - x += QK_K; - continue; - } + float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw); + float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw); - float iscale = -128.f/max_scale; - y[i].d = GGML_FP32_TO_FP16(1/iscale); - for (int ib = 0; ib < QK_K/16; ++ib) { - y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = Ls[j]; + uint8_t lm = Lm[j]; + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } } + y[i].d = GGML_FP32_TO_FP16(d_block); + y[i].dmin = GGML_FP32_TO_FP16(m_block); - for (int j = 0; j < QK_K/16; ++j) { - float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; - if (!d) { - continue; - } - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int(x[16*j + ii]/d); - l = MAX(-32, MIN(31, l)); - L[16*j + ii] = l + 32; + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; } } - uint8_t * restrict ql = y[i].ql; uint8_t * restrict qh = y[i].qh; -#if QK_K == 256 + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + uint8_t m1 = 1, m2 = 2; + for (int n = 0; n < QK_K; n += 64) { + for (int j = 0; j < 32; ++j) { + int l1 = L[n + j]; + if (l1 > 15) { + l1 -= 16; qh[j] |= m1; + } + int l2 = L[n + j + 32]; + if (l2 > 15) { + l2 -= 16; qh[j] |= m2; + } + ql[j] = l1 | (l2 << 4); + } + m1 <<= 2; m2 <<= 2; + ql += 32; + } + + x += QK_K; + + } +#endif +} + +size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); + if (!quant_weights) { + quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + } + return nrow * row_size; +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K/16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + + const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL); + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; + } + + } + + if (!max_abs_scale) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = GGML_FP32_TO_FP16(0.f); + x += QK_K; + continue; + } + + float iscale = -128.f/max_scale; + y[i].d = GGML_FP32_TO_FP16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + } + + for (int j = 0; j < QK_K/16; ++j) { + float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; + } + } + + uint8_t * restrict ql = y[i].ql; + uint8_t * restrict qh = y[i].qh; +#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { const uint8_t q1 = L[j + l + 0] & 0xF; @@ -2160,9 +2925,9 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict } } -void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) { +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -2207,459 +2972,815 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int } } -void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q6_K * restrict y = vy; quantize_row_q6_K_reference(x, y, k); } -size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK_K == 0); - (void)hist; // TODO: collect histograms +static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { +#if QK_K != 256 + (void)quant_weights; + quantize_row_q6_K_reference(x, y, n_per_row); +#else + assert(n_per_row % QK_K == 0); + const int64_t nb = n_per_row / QK_K; - for (int j = 0; j < n; j += k) { - block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K; - quantize_row_q6_K_reference(src + j, y, k); - } - return (n/QK_K*sizeof(block_q6_K)); -} + int8_t L[QK_K]; + float scales[QK_K/16]; + //float weights[16]; -//===================================== Q8_K ============================================== + for (int i = 0; i < nb; i++) { -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; + //float sum_x2 = 0; + //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j]; + //float sigma2 = sum_x2/QK_K; - for (int i = 0; i < nb; i++) { + float max_scale = 0; + float max_abs_scale = 0; - float max = 0; - float amax = 0; - for (int j = 0; j < QK_K; ++j) { - float ax = fabsf(x[j]); - if (ax > amax) { - amax = ax; max = x[j]; + for (int ib = 0; ib < QK_K/16; ++ib) { + + float scale; + if (quant_weights) { + const float * qw = quant_weights + QK_K*i + 16*ib; + //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]); + //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights); + scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw); + } else { + scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL); + } + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; } + } - if (!amax) { - y[i].d = 0; - memset(y[i].qs, 0, QK_K); + + if (!max_abs_scale) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = GGML_FP32_TO_FP16(0.f); x += QK_K; continue; } - const float iscale = -128.f/max; - for (int j = 0; j < QK_K; ++j) { - int v = nearest_int(iscale*x[j]); - y[i].qs[j] = MIN(127, v); + + float iscale = -128.f/max_scale; + y[i].d = GGML_FP32_TO_FP16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); } + for (int j = 0; j < QK_K/16; ++j) { - int sum = 0; + float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } for (int ii = 0; ii < 16; ++ii) { - sum += y[i].qs[j*16 + ii]; + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; } - y[i].bsums[j] = sum; } - y[i].d = 1/iscale; + + uint8_t * restrict ql = y[i].ql; + uint8_t * restrict qh = y[i].qh; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[j + l + 0] & 0xF; + const uint8_t q2 = L[j + l + 32] & 0xF; + const uint8_t q3 = L[j + l + 64] & 0xF; + const uint8_t q4 = L[j + l + 96] & 0xF; + ql[l+ 0] = q1 | (q3 << 4); + ql[l+32] = q2 | (q4 << 4); + qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); + } + ql += 64; + qh += 32; + } + x += QK_K; + } +#endif } -void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < QK_K; ++j) { - *y++ = x[i].d * x[i].qs[j]; +size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); + if (!quant_weights) { + quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); + } + else { + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; } } + return nrow * row_size; } -void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) { - quantize_row_q8_K_reference(x, y, k); -} +static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK4_0 == 32, "QK4_0 must be 32"); -//===================================== Dot ptoducts ================================= + if (!quant_weights) { + quantize_row_q4_0_reference(x, y, n_per_row); + return; + } -// -// Helper functions -// -#if __AVX__ || __AVX2__ || __AVX512F__ + float weight[QK4_0]; + int8_t L[QK4_0]; -// shuffles to pick the required scales in dot products -static inline __m256i get_scale_shuffle_q3k(int i) { - static const uint8_t k_shuffle[128] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m256i get_scale_shuffle_k4(int i) { - static const uint8_t k_shuffle[256] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, - 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, - 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, - 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int64_t nb = n_per_row/QK4_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK4_0 * ib; + const float * qw = quant_weights + QK4_0 * ib; + for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight); + y[ib].d = GGML_FP32_TO_FP16(d); + for (int j = 0; j < 16; ++j) { + y[ib].qs[j] = L[j] | (L[j+16] << 4); + } + } } -static inline __m128i get_scale_shuffle(int i) { - static const uint8_t k_shuffle[128] = { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, - 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, - 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, - 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 - }; - return _mm_loadu_si128((const __m128i*)k_shuffle + i); + +size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; } -#endif -void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; +static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK4_1 == 32, "QK4_1 must be 32"); - assert(n % qk == 0); + if (!quant_weights) { + quantize_row_q4_1_reference(x, y, n_per_row); + return; + } - const block_q4_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; + float weight[QK4_1]; + uint8_t L[QK4_1], Laux[QK4_1]; -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; - assert(nb % 2 == 0); // TODO: handle odd nb + const int64_t nb = n_per_row/QK4_1; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK4_1 * ib; + const float * qw = quant_weights + QK4_1 * ib; + for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float min; + float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false); + y[ib].d = GGML_FP32_TO_FP16(d); + y[ib].m = GGML_FP32_TO_FP16(-min); + for (int j = 0; j < 16; ++j) { + y[ib].qs[j] = L[j] | (L[j+16] << 4); + } + } +} - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; +size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); +static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK5_0 == 32, "QK5_0 must be 32"); - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); + if (!quant_weights) { + quantize_row_q5_0_reference(x, y, n_per_row); + return; + } - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + float weight[QK5_0]; + int8_t L[QK5_0]; - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + const int64_t nb = n_per_row/QK5_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK5_0 * ib; + const float * qw = quant_weights + QK5_0 * ib; + for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight); + y[ib].d = GGML_FP32_TO_FP16(d); -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + uint32_t qh = 0; - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif + for (int j = 0; j < 16; ++j) { + const uint8_t xi0 = L[j]; + const uint8_t xi1 = L[j+16]; + y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + + memcpy(&y[ib].qh, &qh, sizeof(qh)); } +} - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); +size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} - // Main loop - for (int i = 0; i < nb; ++i) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); +static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK5_1 == 32, "QK5_1 must be 32"); - __m256i bx = bytes_from_nibbles_32(x[i].qs); + if (!quant_weights) { + quantize_row_q5_1_reference(x, y, n_per_row); + return; + } - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - bx = _mm256_sub_epi8( bx, off ); + float weight[QK5_1]; + uint8_t L[QK5_1], Laux[QK5_1]; - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; - const __m256 q = mul_sum_i8_pairs_float(bx, by); + const int64_t nb = n_per_row/QK5_1; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK5_1 * ib; + const float * qw = quant_weights + QK5_1 * ib; + for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float min; + float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false); + y[ib].d = GGML_FP32_TO_FP16(d); + y[ib].m = GGML_FP32_TO_FP16(-min); - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( d, q, acc ); + uint32_t qh = 0; + for (int j = 0; j < 16; ++j) { + const uint8_t xi0 = L[j]; + const uint8_t xi1 = L[j+16]; + y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + memcpy(&y[ib].qh, &qh, sizeof(qh)); } +} - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); +size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); +size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); + quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); +// ====================== "True" 2-bit (de)-quantization - const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); +void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - __m128i bx = _mm_and_si128(lowMask, tmp); - __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx, by); + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; - bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx, by); + for (int i = 0; i < nb; i++) { - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); + const float d = GGML_FP16_TO_FP32(x[i].d); - // Apply the scale, and accumulate - acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t)); + const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } } +} - *s = hsum_float_8(acc); -#elif defined(__SSSE3__) - // set constants - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - // Initialize accumulator with zeros - __m128 acc_0 = _mm_setzero_ps(); - __m128 acc_1 = _mm_setzero_ps(); - __m128 acc_2 = _mm_setzero_ps(); - __m128 acc_3 = _mm_setzero_ps(); - - // First round without accumulation - { - _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); +// ====================== 2.3125 bpw (de)-quantization - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); +void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + float db[2]; - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + for (int i = 0; i < nb; i++) { - _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); + const float d = GGML_FP16_TO_FP32(x[i].d); - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; + db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511)); + const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9]; + for (int j = 0; j < 8; ++j) { + y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + } + } +} - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); +// ====================== 2.5625 bpw (de)-quantization - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); +void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + float db[2]; - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); + for (int i = 0; i < nb; i++) { - // Apply the scale - acc_0 = _mm_mul_ps( d_0_1, p0 ); - acc_1 = _mm_mul_ps( d_0_1, p1 ); - acc_2 = _mm_mul_ps( d_2_3, p2 ); - acc_3 = _mm_mul_ps( d_2_3, p3 ); + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; + db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; + for (int l = 0; l < 4; ++l) { + const float dl = db[l/2]; + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + } + y += 8; + } + qs += 4; + signs += 4; + } } +} - assert(nb % 2 == 0); // TODO: handle odd nb - - // Main loop - for (int i = 2; i < nb; i+=2) { - _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); +// ====================== 3.0625 bpw (de)-quantization - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); +void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + uint32_t aux32; - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + for (int i = 0; i < nb; i++) { - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * scales_and_signs = qs + QK_K/4; + + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t)); + const float db = d * (0.5f + (aux32 >> 28)) * 0.5f; + for (int l = 0; l < 4; ++l) { + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]); + for (int j = 0; j < 4; ++j) { + y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qs += 8; + } + } +} - _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); +// ====================== 3.3125 bpw (de)-quantization - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); +void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); + for (int i = 0; i < nb; i++) { - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = x[i].signs; + + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qs += 8; + signs += 4; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qh += 2; + qs += 8; + signs += 4; + } + } +} - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); +// ====================== 1.5625 bpw (de)-quantization - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); +void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - // Apply the scale - __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); - __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); - __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); - __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + for (int i = 0; i < nb; i++) { - // Acummulate - acc_0 = _mm_add_ps(p0_d, acc_0); - acc_1 = _mm_add_ps(p1_d, acc_1); - acc_2 = _mm_add_ps(p2_d, acc_2); - acc_3 = _mm_add_ps(p3_d, acc_3); + const float d = GGML_FP16_TO_FP32(x[i].d); + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const float dl = d * (2*((qh[ib] >> 12) & 7) + 1); + const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + y[j] = dl * (grid[j] + delta); + } + y += 8; + } + qs += 4; + } } +} - *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); +void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; - for (int i = 0; i < nb; i++) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + float delta[4]; + uint16_t idx[4]; - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); +#if QK_K != 64 + iq1m_scale_t scale; +#endif - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + for (int i = 0; i < nb; i++) { - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + const uint16_t * sc = (const uint16_t *)x[i].scales; +#if QK_K == 64 + const float d = GGML_FP16_TO_FP32(x[i].d); +#else + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = GGML_FP16_TO_FP32(scale.f16); +#endif + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; - // subtract offset - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); + for (int ib = 0; ib < QK_K/32; ++ib) { +#if QK_K == 64 + const float dl1 = d * (2*((sc[ib/2] >> (8*(ib%2)+0)) & 0xf) + 1); + const float dl2 = d * (2*((sc[ib/2] >> (8*(ib%2)+4)) & 0xf) + 1); +#else + const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1); + const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1); +#endif + idx[0] = qs[0] | ((qh[0] << 8) & 0x700); + idx[1] = qs[1] | ((qh[0] << 4) & 0x700); + idx[2] = qs[2] | ((qh[1] << 8) & 0x700); + idx[3] = qs[3] | ((qh[1] << 4) & 0x700); + delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA; + delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA; + for (int l = 0; l < 2; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + for (int j = 0; j < 8; ++j) { + y[j] = dl1 * (grid[j] + delta[l]); + } + y += 8; + } + for (int l = 2; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]); + for (int j = 0; j < 8; ++j) { + y[j] = dl2 * (grid[j] + delta[l]); + } + y += 8; + } + qs += 4; + qh += 2; + } + } +} - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); +static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); +void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) { + assert(k % QK4_NL == 0); + const int64_t nb = k / QK4_NL; - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + for (int i = 0; i < nb; i++) { - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + const uint8_t * qs = x[i].qs; - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + const float d = GGML_FP16_TO_FP32(x[i].d); + for (int j = 0; j < QK4_NL/2; ++j) { + y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf]; + y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4]; + } + y += QK4_NL; + qs += QK4_NL/2; } +} - *s = sumf; +void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); +#if QK_K == 64 + dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k); #else - // scalar - float sumf = 0.0; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { - int sumi = 0; - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F) - 8; - const int v1 = (x[i].qs[j] >> 4) - 8; + const uint8_t * qs = x[i].qs; - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4); + const float dl = d * (ls - 32); + for (int j = 0; j < 16; ++j) { + y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; + y[j+16] = dl * kvalues_iq4nl[qs[j] >> 4]; + } + y += 32; + qs += 16; } + } +#endif +} - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); +//===================================== Q8_K ============================================== + +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + //const float iscale = -128.f/max; + // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward + const float iscale = -127.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += QK_K; } +} - *s = sumf; -#endif +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK_K; ++j) { + *y++ = x[i].d * x[i].qs[j]; + } + } } -void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; +void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q8_K_reference(x, y, k); +} + +//===================================== Dot ptoducts ================================= + +// +// Helper functions +// +#if __AVX__ || __AVX2__ || __AVX512F__ + +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} +#endif + +void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; const int nb = n / qk; assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q4_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; - // TODO: add WASM SIMD +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_0 * restrict vx0 = vx; + const block_q4_0 * restrict vx1 = vx + bx; + + const block_q8_0 * restrict vy0 = vy; + const block_q8_0 * restrict vy1 = vy + by; + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_0 * restrict b_x0 = &vx0[i]; + const block_q4_0 * restrict b_x1 = &vx1[i]; + const block_q8_0 * restrict b_y0 = &vy0[i]; + const block_q8_0 * restrict b_y1 = &vy1[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); + const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); + const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); + const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + return; + } +#endif #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - float summs = 0; - assert(nb % 2 == 0); // TODO: handle odd nb for (int i = 0; i < nb; i += 2) { - const block_q4_1 * restrict x0 = &x[i + 0]; - const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i + 0]; - const block_q8_1 * restrict y1 = &y[i + 1]; - - summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2670,393 +3791,445 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; -#elif defined(__AVX2__) || defined(__AVX__) + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); - float summs = 0; - // Main loop for (int i = 0; i < nb; ++i) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); - const float d1 = y[i].d; - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - const __m256 d0v = _mm256_set1_ps( d0 ); - const __m256 d1v = _mm256_set1_ps( d1 ); + __m256i qx = bytes_from_nibbles_32(x[i].qs); - // Compute combined scales - const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + qx = _mm256_sub_epi8( qx, off ); - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 xy = mul_sum_us8_pairs_float(bx, by); + const __m256 q = mul_sum_i8_pairs_float(qx, qy); - // Accumulate d0*d1*x*y -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d0d1, xy, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); -#endif + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); } - *s = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); - size_t vl = __riscv_vsetvl_e8m1(qk/2); + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - for (int i = 0; i < nb; i++) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + __m128i bx_0 = _mm_and_si128(lowMask, tmp); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); + by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0); - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } + // First round without accumulation + { + _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); - *s = sumf; -#else - // scalar - float sumf = 0.0; + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); - for (int i = 0; i < nb; i++) { - int sumi = 0; + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F); - const int v1 = (x[i].qs[j] >> 4); + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); - } + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } + _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); - *s = sumf; -#endif -} - -void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); - assert(n % qk == 0); - assert(qk == QK5_0); + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); - const block_q5_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - uint32_t qh0; - uint32_t qh1; + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); - uint64_t tmp0[4]; - uint64_t tmp1[4]; + // Apply the scale + acc_0 = _mm_mul_ps( d_0_1, p0 ); + acc_1 = _mm_mul_ps( d_0_1, p1 ); + acc_2 = _mm_mul_ps( d_2_3, p2 ); + acc_3 = _mm_mul_ps( d_2_3, p3 ); + } assert(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q5_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i]; - const block_q8_0 * restrict y1 = &y[i + 1]; + // Main loop + for (int i = 2; i < nb; i+=2) { + _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); - const uint8x16_t m4b = vdupq_n_u8(0x0F); + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - // extract the 5th bit via lookup table ((!b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); - tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); - // 4-bit -> 8-bit - int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); - uint32_t qh; - uint64_t tmp[4]; + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } - // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q8_0 * restrict y0 = &y[i]; + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; - const v128_t m4b = wasm_i8x16_splat(0x0F); + size_t vl = __riscv_vsetvl_e8m1(qk/2); - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + for (int i = 0; i < nb; i++) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - const v128_t v0 = wasm_v128_load(x0->qs); + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); + // subtract offset + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); - const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); } - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); + *s = sumf; +#else + // scalar + float sumf = 0.0; - // Main loop for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); - bx = _mm256_or_si256(bx, bxhi); + int sumi = 0; - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F) - 8; + const int v1 = (x[i].qs[j] >> 4) - 8; - const __m256 q = mul_sum_i8_pairs_float(bx, by); + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps(d, q, acc); + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); } - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8((char)0xF0); + *s = sumf; +#endif +} - // Main loop - for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); +void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; - __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_andnot_si128(bxhil, mask); - bxhih = _mm_andnot_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + const block_q4_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; - const __m256 q = mul_sum_i8_pairs_float(bx, by); +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_1 * restrict vx0 = vx; + const block_q4_1 * restrict vx1 = vx + bx; + const block_q8_1 * restrict vy0 = vy; + const block_q8_1 * restrict vy1 = vy + by; - /* Multiply q with scale and accumulate */ - acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); - } + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t summs0 = vdupq_n_f32(0.0f); - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; + for (int i = 0; i < nb; i++) { + const block_q4_1 * restrict b_x0 = &vx0[i]; + const block_q4_1 * restrict b_x1 = &vx1[i]; + const block_q8_1 * restrict b_y0 = &vy0[i]; + const block_q8_1 * restrict b_y1 = &vy1[i]; - uint32_t qh; + float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)}; + summs0 += summs_t; - size_t vl = __riscv_vsetvl_e8m1(qk/2); + const uint8x16_t m4b = vdupq_n_u8(0x0F); - // These tempory registers are for masking and shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); - vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); - vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + // 4-bit -> 8-bit + const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + // mmla into int32x4_t + float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d, + GGML_FP16_TO_FP32(b_x0->d)*b_y1->d, + GGML_FP16_TO_FP32(b_x1->d)*b_y0->d, + GGML_FP16_TO_FP32(b_x1->d)*b_y1->d}; - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); - vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - // load + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + sumv2 = sumv2 + summs0; + + vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + return; + } +#endif + // TODO: add WASM SIMD +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs = 0; + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q4_1 * restrict x0 = &x[i + 0]; + const block_q4_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i + 0]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // dot product into int32x4_t + const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + // Main loop + for (int i = 0; i < nb; ++i) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d1 = GGML_FP16_TO_FP32(y[i].d); + + summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i qx = bytes_from_nibbles_32(x[i].qs); + const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + + const __m256 xy = mul_sum_us8_pairs_float(qx, qy); + + // Accumulate d0*d1*x*y +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d0d1, xy, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + // load elements vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); @@ -3068,7 +4241,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); } *s = sumf; @@ -3077,45 +4250,41 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri float sumf = 0.0; for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - int sumi = 0; for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + const int v0 = (x[i].qs[j] & 0x0F); + const int v1 = (x[i].qs[j] >> 4); - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); } *s = sumf; #endif } -void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; +void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; const int nb = n / qk; assert(n % qk == 0); - assert(qk == QK5_1); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q5_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - float summs0 = 0.0f; - float summs1 = 0.0f; - uint32_t qh0; uint32_t qh1; @@ -3125,29 +4294,26 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri assert(nb % 2 == 0); // TODO: handle odd nb for (int i = 0; i < nb; i += 2) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q5_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i]; - const block_q8_1 * restrict y1 = &y[i + 1]; + const block_q5_0 * restrict x0 = &x[i]; + const block_q5_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i]; + const block_q8_0 * restrict y1 = &y[i + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); - summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s; - summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s; - - // extract the 5th bit via lookup table ((b) << 4) + // extract the 5th bit via lookup table ((!b) << 4) memcpy(&qh0, x0->qh, sizeof(qh0)); memcpy(&qh1, x1->qh, sizeof(qh1)); - tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; - tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); @@ -3158,16 +4324,16 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri const uint8x16_t v0_1 = vld1q_u8(x1->qs); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - // add high bit - const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); @@ -3175,59 +4341,35 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); -#if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d); + ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif + ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__wasm_simd128__) v128_t sumv = wasm_f32x4_splat(0.0f); - float summs = 0.0f; - uint32_t qh; uint64_t tmp[4]; // TODO: check if unrolling this is better for (int i = 0; i < nb; ++i) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q8_1 * restrict y0 = &y[i]; - - summs += GGML_FP16_TO_FP32(x0->m) * y0->s; + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; - const v128_t m4b = wasm_i8x16_splat(0x0F); + const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit memcpy(&qh, x0->qh, sizeof(qh)); - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3238,9 +4380,9 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri const v128_t v0l = wasm_v128_and (v0, m4b); const v128_t v0h = wasm_u8x16_shr(v0, 4); - // add high bit - const v128_t v0lf = wasm_v128_or(v0l, qhl); - const v128_t v0hf = wasm_v128_or(v0h, qhh); + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); + const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); // load y const v128_t v1l = wasm_v128_load(y0->qs); @@ -3258,77 +4400,71 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); // dot product - sumv = wasm_f32x4_add(sumv, - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( + wasm_i32x4_add( wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), wasm_i32x4_dot_i16x8(v0lfh, v1lh)), wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d))); + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); } *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); - float summs = 0.0f; - // Main loop for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i qx = bytes_from_nibbles_32(x[i].qs); __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); - bx = _mm256_or_si256(bx, bxhi); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + qx = _mm256_or_si256(qx, bxhi); - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_us8_pairs_float(bx, by); + const __m256 q = mul_sum_i8_pairs_float(qx, qy); - acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); } - *s = hsum_float_8(acc) + summs; + *s = hsum_float_8(acc); #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8(0x10); - - float summs = 0.0f; + __m128i mask = _mm_set1_epi8((char)0xF0); // Main loop for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); const __m256i bxhi = bytes_from_bits_32(x[i].qh); __m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_and_si128(bxhil, mask); - bxhih = _mm_and_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxhil = _mm_andnot_si128(bxhil, mask); + bxhih = _mm_andnot_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); bxl = _mm_or_si128(bxl, bxhil); bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); + bx_0 = MM256_SET_M128I(bxh, bxl); - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_us8_pairs_float(bx, by); + const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); - acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + /* Multiply q with scale and accumulate */ + acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); } - *s = hsum_float_8(acc) + summs; + *s = hsum_float_8(acc); #elif defined(__riscv_v_intrinsic) float sumf = 0.0; @@ -3336,30 +4472,30 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri size_t vl = __riscv_vsetvl_e8m1(qk/2); - // temporary registers for shift operations + // These temporary registers are for masking and shift operations vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); + + vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); + vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); for (int i = 0; i < nb; i++) { memcpy(&qh, x[i].qh, sizeof(uint32_t)); - // load qh - vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); + // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - - // ((qh >> (j + 12)) ) & 0x10; - vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); + // ((qh & (1u << (j + 16))) >> (j + 12)); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); + vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); // load @@ -3374,8 +4510,11 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); @@ -3387,7 +4526,7 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; } *s = sumf; @@ -3402,540 +4541,499 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri int sumi = 0; for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; } *s = sumf; #endif } -void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; +void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_1; const int nb = n / qk; assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q8_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + float summs0 = 0.0f; + float summs1 = 0.0f; - const int8x16_t x0_0 = vld1q_s8(x0->qs); - const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); - const int8x16_t x1_0 = vld1q_s8(x1->qs); - const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + uint32_t qh0; + uint32_t qh1; - // load y - const int8x16_t y0_0 = vld1q_s8(y0->qs); - const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); - const int8x16_t y1_0 = vld1q_s8(y1->qs); - const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + uint64_t tmp0[4]; + uint64_t tmp1[4]; -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + assert(nb % 2 == 0); // TODO: handle odd nb - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + for (int i = 0; i < nb; i += 2) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q5_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i]; + const block_q8_1 * restrict y1 = &y[i + 1]; -#else - const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); - const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); - const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); - const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); - const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); - const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); - - const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); - const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); - const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } + const uint8x16_t m4b = vdupq_n_u8(0x0F); - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); + summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); + summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + // extract the 5th bit via lookup table ((b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); - const __m256 q = mul_sum_i8_pairs_float(bx, by); + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; - // Multiply q with scale and accumulate -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d, q, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); -#endif - } + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - size_t vl = __riscv_vsetvl_e8m1(qk); + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - for (int i = 0; i < nb; i++) { - // load elements - vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); - vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + // add high bit + const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); - int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = sumf; -#else - // scalar - float sumf = 0.0; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); - for (int i = 0; i < nb; i++) { - int sumi = 0; + float summs = 0.0f; - for (int j = 0; j < qk; j++) { - sumi += x[i].qs[j]*y[i].qs[j]; - } + uint32_t qh; + uint64_t tmp[4]; - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); - } + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; - *s = sumf; -#endif -} + summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); -#if QK_K == 256 -void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const v128_t m4b = wasm_i8x16_splat(0x0F); - const block_q2_K * restrict x = vx; - const block_q8_K * restrict y = vy; + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); - const int nb = n / QK_K; + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh >> 24) ]; -#ifdef __ARM_NEON + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); - const uint8x16_t m3 = vdupq_n_u8(0x3); - const uint8x16_t m4 = vdupq_n_u8(0xF); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif + const v128_t v0 = wasm_v128_load(x0->qs); - int8x16x2_t q2bytes; - uint8_t aux[16]; + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); - float sum = 0; + // add high bit + const v128_t v0lf = wasm_v128_or(v0l, qhl); + const v128_t v0hf = wasm_v128_or(v0h, qhh); - for (int i = 0; i < nb; ++i) { + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint8_t * restrict sc = x[i].scales; + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - const uint8x16_t mins_and_scales = vld1q_u8(sc); - const uint8x16_t scales = vandq_u8(mins_and_scales, m4); - vst1q_u8(aux, scales); + // dot product + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + } - const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; - const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), - vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); - const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), - vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); - sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); - int isum = 0; - int is = 0; + float summs = 0.0f; -// We use this macro instead of a function call because for some reason -// the code runs 2-3% slower, even if the function is declared inline -#if defined(__ARM_FEATURE_DOTPROD) -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; -#else -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - {\ - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ - isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ - } -#endif + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); -#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ - q8bytes = vld1q_s8_x2(q8); q8 += 32;\ - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ - MULTIPLY_ACCUM_WITH_SCALE((index)); + summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + __m256i qx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + qx = _mm256_or_si256(qx, bxhi); - for (int j = 0; j < QK_K/128; ++j) { + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d)); + const __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); - const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; + const __m256 q = mul_sum_us8_pairs_float(qx, qy); - int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); - MULTIPLY_ACCUM_WITH_SCALE(0); + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + *s = hsum_float_8(acc) + summs; +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8(0x10); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + float summs = 0.0f; - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - is += 8; - } - sum += d * isum; + summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); - } + __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_and_si128(bxhil, mask); + bxhih = _mm_and_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx_0 = MM256_SET_M128I(bxh, bxl); - *s = sum; + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d)); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); -#elif defined __AVX2__ + const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); - const __m256i m3 = _mm256_set1_epi8(3); - const __m128i m4 = _mm_set1_epi8(0xF); + acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + } - __m256 acc = _mm256_setzero_ps(); + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; - for (int i = 0; i < nb; ++i) { + uint32_t qh; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + size_t vl = __riscv_vsetvl_e8m1(qk/2); - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + // temporary registers for shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m256i mins = _mm256_cvtepi8_epi16(mins8); - const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + // load qh + vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); - const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + // ((qh >> (j + 0)) << 4) & 0x10; + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - __m256i sumi = _mm256_setzero_si256(); + // ((qh >> (j + 12)) ) & 0x10; + vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); - for (int j = 0; j < QK_K/128; ++j) { + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - const __m256i q2_0 = _mm256_and_si256(q2bits, m3); - const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); - const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); - const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); - __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - p0 = _mm256_add_epi32(p0, p1); - p2 = _mm256_add_epi32(p2, p3); + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); - } + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - } + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - *s = hsum_float_8(acc); + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); -#elif defined __AVX__ + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); + } - const __m128i m3 = _mm_set1_epi8(0x3); - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); + *s = sumf; +#else + // scalar + float sumf = 0.0; - __m256 acc = _mm256_setzero_ps(); + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - for (int i = 0; i < nb; ++i) { + int sumi = 0; - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; - // load mins and scales from block_q2_K.scales[QK_K/16] - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); - const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } - // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 - const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); - const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); + } - // sumf += -dmin * summs in 32bits*8 - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); + *s = sumf; +#endif +} - const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); - const __m128i scales[2] = { scales_0, scales_1 }; +void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - for (int j = 0; j < QK_K/128; ++j) { + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; - // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q8_0 * restrict vx0 = vx; + const block_q8_0 * restrict vx1 = vx + bx; + const block_q8_0 * restrict vy0 = vy; + const block_q8_0 * restrict vy1 = vy + by; - // load 2bits*16*8 from block_q2_K.qs[QK_K/4] - __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_1 = _mm_and_si128(q2bits, m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + float32x4_t sumv0 = vdupq_n_f32(0.0f); - // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 - __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); - __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); - __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); - __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); - __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); - __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); - __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); - __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + for (int i = 0; i < nb; i++) { + const block_q8_0 * restrict b_x0 = &vx0[i]; + const block_q8_0 * restrict b_y0 = &vy0[i]; - // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 - __m128i shuffle = _mm_set1_epi16(0x0100); - p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); - shuffle = _mm_add_epi16(shuffle, m2); - p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); - shuffle = _mm_add_epi16(shuffle, m2); - p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); - shuffle = _mm_add_epi16(shuffle, m2); - p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); - shuffle = _mm_add_epi16(shuffle, m2); - p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); - shuffle = _mm_add_epi16(shuffle, m2); - p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); - shuffle = _mm_add_epi16(shuffle, m2); - p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); - shuffle = _mm_add_epi16(shuffle, m2); - p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + const block_q8_0 * restrict b_x1 = &vx1[i]; + const block_q8_0 * restrict b_y1 = &vy1[i]; - p0 = _mm_add_epi32(p0, p1); - p2 = _mm_add_epi32(p2, p3); - p4 = _mm_add_epi32(p4, p5); - p6 = _mm_add_epi32(p6, p7); + const int8x16_t x0_l = vld1q_s8(b_x0->qs); + const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16); + const int8x16_t x1_l = vld1q_s8(b_x1->qs); + const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16); - // isum in 32bits*4*2 - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); - } + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - // sumf += dall * isum - dmin * summs in 32bits - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); - } + float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; - *s = hsum_float_8(acc); + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); -#elif defined __riscv_v_intrinsic + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - float sumf = 0; - uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - for (int i = 0; i < nb; ++i) { + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + return; + } +#endif +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); - size_t vl = 16; + assert(nb % 2 == 0); // TODO: handle odd nb - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - - vl = 32; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - - uint8_t is=0; - int isum=0; - - for (int j = 0; j < QK_K/128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + __m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + const __m256 q = mul_sum_i8_pairs_float(qx, qy); - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + // Multiply q with scale and accumulate +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d, q, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); +#endif + } - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + size_t vl = __riscv_vsetvl_e8m1(qk); - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + for (int i = 0; i < nb; i++) { + // load elements + vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl); + vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - q2+=32; q8+=128; is=8; + vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); - } + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); - sumf += dall * isum; + int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); } *s = sumf; - #else + // scalar + float sumf = 0.0; - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + for (int i = 0; i < nb; i++) { + int sumi = 0; - int summs = 0; - for (int j = 0; j < 16; ++j) { - summs += y[i].bsums[j] * (sc[j] >> 4); + for (int j = 0; j < qk; j++) { + sumi += x[i].qs[j]*y[i].qs[j]; } - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - int isum = 0; - int is = 0; - int d; - for (int k = 0; k < QK_K/128; ++k) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - d = sc[is++] & 0xF; - int isuml = 0; - for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - d = sc[is++] & 0xF; - isuml = 0; - for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - shift += 2; - q8 += 32; - } - q2 += 32; - } - sumf += dall * isum - dmin * summs; + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); } + *s = sumf; #endif } -#else - -void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +#if QK_K == 256 +void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); const block_q2_K * restrict x = vx; const block_q8_K * restrict y = vy; @@ -3943,66 +5041,69 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; #ifdef __ARM_NEON - const uint8x16_t m3 = vdupq_n_u8(0x3); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif + const uint8x16_t m4 = vdupq_n_u8(0xF); - int8x16x4_t q2bytes; + const int32x4_t vzero = vdupq_n_s32(0); - uint32_t aux32[2]; - const uint8_t * scales = (const uint8_t *)aux32; + ggml_int8x16x2_t q2bytes; + uint8_t aux[16]; float sum = 0; for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * (float)x[i].d; - const float dmin = -y[i].d * (float)x[i].dmin; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + const uint8_t * restrict sc = x[i].scales; - aux32[0] = sc[0] & 0x0f0f0f0f; - aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); - sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); - int isum1 = 0, isum2 = 0; + int isum = 0; + int is = 0; - const uint8x16_t q2bits = vld1q_u8(q2); +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); - q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); - q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); + for (int j = 0; j < QK_K/128; ++j) { + const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; -#if defined(__ARM_FEATURE_DOTPROD) - isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; - isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; - isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; - isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; -#else - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum1 += vaddvq_s16(p1) * scales[0]; - isum2 += vaddvq_s16(p2) * scales[1]; - - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum1 += vaddvq_s16(p3) * scales[2]; - isum2 += vaddvq_s16(p4) * scales[3]; -#endif - sum += d * (isum1 + isum2); + ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + sum += d * isum; } *s = sum; @@ -4010,17 +5111,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __AVX2__ const __m256i m3 = _mm256_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); __m256 acc = _mm256_setzero_ps(); - uint32_t ud, um; - const uint8_t * restrict db = (const uint8_t *)&ud; - const uint8_t * restrict mb = (const uint8_t *)&um; - - float summs = 0; - - // TODO: optimize this - for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); @@ -4029,145 +5123,242 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - ud = (sc[0] >> 0) & 0x0f0f0f0f; - um = (sc[0] >> 4) & 0x0f0f0f0f; + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m256i mins = _mm256_cvtepi8_epi16(mins8); + const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); - int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; - summs += dmin * smin; + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); - const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); - const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3); - const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + __m256i sumi = _mm256_setzero_si256(); - const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + for (int j = 0; j < QK_K/128; ++j) { - const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); - const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); - const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); - const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); + const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); - } + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - *s = hsum_float_8(acc) + summs; + const __m256i q2_0 = _mm256_and_si256(q2bits, m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); -#elif defined __AVX__ + __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); + __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); - const __m128i m3 = _mm_set1_epi8(3); + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); - __m256 acc = _mm256_setzero_ps(); + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); - uint32_t ud, um; - const uint8_t * restrict db = (const uint8_t *)&ud; - const uint8_t * restrict mb = (const uint8_t *)&um; + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } - float summs = 0; + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - // TODO: optimize this + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(0x3); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - ud = (sc[0] >> 0) & 0x0f0f0f0f; - um = (sc[0] >> 4) & 0x0f0f0f0f; + // load mins and scales from block_q2_K.scales[QK_K/16] + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); + const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); - int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; - summs += dmin * smin; + // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 + const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); + const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); - const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + // sumf += -dmin * summs in 32bits*8 + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); + const __m128i scales[2] = { scales_0, scales_1 }; - const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); - const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); - const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); - const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); - const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); + for (int j = 0; j < QK_K/128; ++j) { - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); + // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // load 2bits*16*8 from block_q2_K.qs[QK_K/4] + __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_1 = _mm_and_si128(q2bits, m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 + __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); + __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); + __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); + __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); + __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); + __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); + __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); + __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + + // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 + __m128i shuffle = _mm_set1_epi16(0x0100); + p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); + shuffle = _mm_add_epi16(shuffle, m2); + p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); + shuffle = _mm_add_epi16(shuffle, m2); + p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); + shuffle = _mm_add_epi16(shuffle, m2); + p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); + shuffle = _mm_add_epi16(shuffle, m2); + p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); + shuffle = _mm_add_epi16(shuffle, m2); + p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); + shuffle = _mm_add_epi16(shuffle, m2); + p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); + shuffle = _mm_add_epi16(shuffle, m2); + p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + + p0 = _mm_add_epi32(p0, p1); + p2 = _mm_add_epi32(p2, p3); + p4 = _mm_add_epi32(p4, p5); + p6 = _mm_add_epi32(p6, p7); + + // isum in 32bits*4*2 + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); + } + + // sumf += dall * isum - dmin * summs in 32bits + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); } - *s = hsum_float_8(acc) + summs; + *s = hsum_float_8(acc); #elif defined __riscv_v_intrinsic - uint32_t aux32[2]; - const uint8_t * scales = (const uint8_t *)aux32; - float sumf = 0; + uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * (float)x[i].d; - const float dmin = -y[i].d * (float)x[i].dmin; + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - aux32[0] = sc[0] & 0x0f0f0f0f; - aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + size_t vl = 16; - sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - int isum1 = 0; - int isum2 = 0; + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - size_t vl = 16; + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - // load Q2 - vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl); + vl = 32; - vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl)); - vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl)); - vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl)); - vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl)); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - // load Q8, and take product with Q2 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + uint8_t is=0; + int isum=0; - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl); - vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl); - vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl); + for (int j = 0; j < QK_K/128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0]; - isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1]; - isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2]; - isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3]; + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); - sumf += d * (isum1 + isum2); + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2+=32; q8+=128; is=8; + + } + + sumf += dall * isum; } @@ -4177,8 +5368,6 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri float sumf = 0; - int isum[4]; - for (int i = 0; i < nb; ++i) { const uint8_t * q2 = x[i].qs; @@ -4186,156 +5375,95 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * sc = x[i].scales; int summs = 0; - for (int j = 0; j < QK_K/16; ++j) { + for (int j = 0; j < 16; ++j) { summs += y[i].bsums[j] * (sc[j] >> 4); } const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - isum[0] = isum[1] = isum[2] = isum[3] = 0; - for (int l = 0; l < 16; ++l) { - isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); - isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); - isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); - isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); - } - for (int l = 0; l < 4; ++l) { - isum[l] *= (sc[l] & 0xF); + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; } - sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; + sumf += dall * isum - dmin * summs; } *s = sumf; #endif } -#endif -#if QK_K == 256 -void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); +#else - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; +void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q3_K * restrict x = vx; + const block_q2_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; #ifdef __ARM_NEON + const uint8x16_t m3 = vdupq_n_u8(0x3); - uint32_t aux[3]; - uint32_t utmp[4]; - - const uint8x16_t m3b = vdupq_n_u8(0x3); -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t vzero = vdupq_n_s32(0); -#endif + const int32x4_t vzero = vdupq_n_s32(0); - const uint8x16_t m0 = vdupq_n_u8(1); - const uint8x16_t m1 = vshlq_n_u8(m0, 1); - const uint8x16_t m2 = vshlq_n_u8(m0, 2); - const uint8x16_t m3 = vshlq_n_u8(m0, 3); - const int8_t m32 = 32; + ggml_int8x16x4_t q2bytes; - int8x16x4_t q3bytes; + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; float sum = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; + const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - uint8x16x2_t qhbits = vld1q_u8_x2(qh); - - uint8x16x4_t q3h; - - int32_t isum = 0; - - // Set up scales - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= m32; - - for (int j = 0; j < QK_K/128; ++j) { - - const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; - const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; - const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; - - q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); - q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); - q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); - q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; -#else - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif - scale += 4; + sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); - q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); - q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); - q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); - q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + int isum1 = 0, isum2 = 0; - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + const uint8x16_t q2bits = vld1q_u8(q2); -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; -#else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif - scale += 4; + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); - if (j == 0) { - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); - } + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); + q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); + q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); - } - sum += d * isum; + isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; + isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; + isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; + isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; + sum += d * (isum1 + isum2); } *s = sum; @@ -4343,261 +5471,256 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __AVX2__ const __m256i m3 = _mm256_set1_epi8(3); - const __m256i mone = _mm256_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); __m256 acc = _mm256_setzero_ps(); - uint32_t aux[3]; + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - // Set up scales - memcpy(aux, x[i].scales, 12); - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; - // high bit - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; - // integer accumulator - __m256i sumi = _mm256_setzero_si256(); + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3); + const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); - int bit = 0; - int is = 0; + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits - const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - // prepare low and high bits - const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); - const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; + const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); + const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); + const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); + const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); - const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); + } - const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); - const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; + *s = hsum_float_8(acc) + summs; - const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); - const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; +#elif defined __AVX__ - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m128i m3 = _mm_set1_epi8(3); - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + __m256 acc = _mm256_setzero_ps(); - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + float summs = 0; - // multiply with scales - p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + // TODO: optimize this - // accumulate - p16_0 = _mm256_add_epi32(p16_0, p16_1); - p16_2 = _mm256_add_epi32(p16_2, p16_3); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + for (int i = 0; i < nb; ++i) { - } + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; - } + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; - *s = hsum_float_8(acc); + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; -#elif defined __AVX__ + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - const __m128i m3 = _mm_set1_epi8(3); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - __m256 acc = _mm256_setzero_ps(); + const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); - const uint32_t *aux; + const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); + const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); + const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); + const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; + + float sumf = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - // Set up scales - aux = (const uint32_t *)x[i].scales; - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); - const __m128i scales[2] = { scales_0, scales_1 }; + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; - // high bit *128*2 from block_q3_K.hmask[QK_K/8] - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); - // integer accumulator - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + int isum1 = 0; + int isum2 = 0; - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] - const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + size_t vl = 16; - // prepare low and high bits - const int bit = j << 2; + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); - const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); - const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); - const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + // load Q2 + vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl); - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); - const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl)); + vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl)); - const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); - const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); - const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + // load Q8, and take product with Q2 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); - const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); - const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl); + vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl); + vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl); - // load Q8 quants from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1]; + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3]; - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + sumf += d * (isum1 + isum2); - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + } - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); + *s = sumf; - // multiply with scales - __m128i shuffle = _mm_set1_epi16(0x0100); - p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); - shuffle = _mm_add_epi16(shuffle, m2); - p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); - shuffle = _mm_add_epi16(shuffle, m2); - p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); - shuffle = _mm_add_epi16(shuffle, m2); - p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); - shuffle = _mm_add_epi16(shuffle, m2); - p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); - shuffle = _mm_add_epi16(shuffle, m2); - p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); - shuffle = _mm_add_epi16(shuffle, m2); - p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); - shuffle = _mm_add_epi16(shuffle, m2); - p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); +#else - // accumulate - p16_0 = _mm_add_epi32(p16_0, p16_1); - p16_2 = _mm_add_epi32(p16_2, p16_3); - p16_4 = _mm_add_epi32(p16_4, p16_5); - p16_6 = _mm_add_epi32(p16_6, p16_7); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + float sumf = 0; + + int isum[QK_K/16]; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + int summs = 0; + for (int j = 0; j < QK_K/16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); } - // multiply with block scale and accumulate - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + memset(isum, 0, (QK_K/16)*sizeof(int)); + for (int l = 0; l < 16; ++l) { + isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); + isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); + isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); + isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); + } + for (int l = 0; l < QK_K/16; ++l) { + isum[l] *= (sc[l] & 0xF); + } + sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; } + *s = sumf; +#endif +} +#endif - *s = hsum_float_8(acc); +#if QK_K == 256 +void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); -#elif defined __riscv_v_intrinsic + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON uint32_t aux[3]; uint32_t utmp[4]; - float sumf = 0; + const uint8x16_t m3b = vdupq_n_u8(0x3); + const int32x4_t vzero = vdupq_n_s32(0); + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + ggml_int8x16x4_t q3bytes; + + float sum = 0; + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const uint8_t * restrict q3 = x[i].qs; const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; + const int8_t * restrict q8 = y[i].qs; + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales memcpy(aux, x[i].scales, 12); utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); @@ -4605,90 +5728,409 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; - - - size_t vl = 32; - uint8_t m = 1; + for (int j = 0; j < 16; ++j) scale[j] -= m32; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + for (int j = 0; j < QK_K/128; ++j) { - int sum_t = 0; + const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; + const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; + const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; - for (int j = 0; j < QK_K; j += 128) { + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); - vl = 32; + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + scale += 4; - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl); - m <<= 1; + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl); - m <<= 1; + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl); - m <<= 1; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl); - m <<= 1; + scale += 4; - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } - vl = 16; + } + sum += d * isum; - // retreive lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + } - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + *s = sum; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); +#elif defined __AVX2__ - q3 += 32; q8 += 128; scale += 8; + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); - } + __m256 acc = _mm256_setzero_ps(); - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + uint32_t aux[3]; - sumf += d*sum_t; + for (int i = 0; i < nb; ++i) { - } + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - *s = sumf; + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; -#else - // scalar version - // This function is written like this so the compiler can manage to vectorize most of it + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + const uint32_t *aux; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + aux = (const uint32_t *)x[i].scales; + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); + const __m128i scales[2] = { scales_0, scales_1 }; + + // high bit *128*2 from block_q3_K.hmask[QK_K/8] + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + + // integer accumulator + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] + const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + + // prepare low and high bits + const int bit = j << 2; + + const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); + const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); + const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); + const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); + const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + + const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); + const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); + const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + + const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); + const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); + const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + + // load Q8 quants from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + // multiply with scales + __m128i shuffle = _mm_set1_epi16(0x0100); + p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); + shuffle = _mm_add_epi16(shuffle, m2); + p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); + shuffle = _mm_add_epi16(shuffle, m2); + p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); + shuffle = _mm_add_epi16(shuffle, m2); + p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); + shuffle = _mm_add_epi16(shuffle, m2); + p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); + shuffle = _mm_add_epi16(shuffle, m2); + p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); + shuffle = _mm_add_epi16(shuffle, m2); + p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); + shuffle = _mm_add_epi16(shuffle, m2); + p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); + + // accumulate + p16_0 = _mm_add_epi32(p16_0, p16_1); + p16_2 = _mm_add_epi32(p16_2, p16_3); + p16_4 = _mm_add_epi32(p16_4, p16_5); + p16_6 = _mm_add_epi32(p16_6, p16_7); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + + } + + // multiply with block scale and accumulate + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint32_t aux[3]; + uint32_t utmp[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + + *s = sumf; + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the // manually vectorized version above. Every other version I tried would run at least 4 times slower. // The ideal situation would be if we could just write the code once, and the compiler would @@ -4701,202 +6143,1541 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri int32_t aux32[8]; memset(sums, 0, 8*sizeof(float)); - uint32_t auxs[4]; - const int8_t * scales = (const int8_t*)auxs; - + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +#else + +void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + const int32x4_t vzero = vdupq_n_s32(0); + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const uint8x16_t mh = vdupq_n_u8(4); + + ggml_int8x16x4_t q3bytes; + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + ggml_uint8x16x4_t q3h; + + const uint8x8_t hbits = vld1_u8(x[i].hmask); + const uint8x16_t q3bits = vld1q_u8(x[i].qs); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs); + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); + q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q3h.val[1] = vandq_u8(mh, htmp); + q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); + q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); + + q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); + q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); + q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); + q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; + + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m1 = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); + const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); + + memcpy(&aux64, x[i].hmask, 8); + + const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux); + __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); + q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); + q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits); + const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + // multiply with scales + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + p16_0 = _mm256_add_epi32(p16_0, p16_1); + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m1 = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); + const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); + const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); + const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); + + memcpy(&aux64, x[i].hmask, 8); + + __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); + __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); + __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); + q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); + q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); + q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); + q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m128i q3l_0 = _mm_and_si128(q3bits, m3); + const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_1, p16_1); + p16_2 = _mm_madd_epi16(scale_2, p16_2); + p16_3 = _mm_madd_epi16(scale_3, p16_3); + + p16_0 = _mm_add_epi32(p16_0, p16_2); + p16_1 = _mm_add_epi32(p16_1, p16_3); + __m256i p16 = MM256_SET_M128I(p16_1, p16_0); + + // multiply with block scale and accumulate + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load qh + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8); + vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + + size_t vl = 16; + + // extend and combine both qh_x1 and qh_x2 + vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl); + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl); + vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl); + vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl); + vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3); + + // load Q8 and take product with Q3 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3]; + + sumf += d * isum; + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + int32_t scales[4]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 8; ++l) { + a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); + a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); + a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); + a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); + a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); + a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); + a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); + a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); + } + + scales[0] = (x[i].scales[0] & 0xF) - 8; + scales[1] = (x[i].scales[0] >> 4) - 8; + scales[2] = (x[i].scales[1] & 0xF) - 8; + scales[3] = (x[i].scales[1] >> 4) - 8; + + memset(aux32, 0, 8*sizeof(int32_t)); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x2_t q4bytes; + ggml_int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + p16l = _mm256_madd_epi16(scale_l, p16l); + + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + const __m256i sumj = _mm256_add_epi32(p16l, p16h); + + sumi = _mm256_add_epi32(sumi, sumj); + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_0 = _mm_and_si128(q4bits, m4); + const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_1 = _mm_and_si128(q4bits, m4); + const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + + const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_0 = _mm_add_epi32(sumi_0, p16l); + const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16l = _mm_maddubs_epi16(q4l_1, q8l_1); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_1 = _mm_add_epi32(sumi_1, p16l); + + const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_0 = _mm_add_epi32(sumi_0, p16h); + const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16h = _mm_maddubs_epi16(q4h_1, q8h_1); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_1 = _mm_add_epi32(sumi_1, p16h); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + size_t vl = 8; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + vl = 32; + + int32_t sum_1 = 0; + int32_t sum_2 = 0; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; + + } + + sumf += d*(sum_1 + sum_2); + + } + + *s = sumf; + +#else + + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#else +void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + + const int32x4_t mzero = vdupq_n_s32(0); + + float sumf = 0; + + ggml_int8x16x2_t q4bytes; + ggml_int8x16x4_t q8bytes; + + float sum_mins = 0.f; + + uint16_t aux16[2]; + const uint8_t * restrict scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); + sum_mins += y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * summi; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); + + q8bytes = ggml_vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); + const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; + + sumf += d * (sumi1 + sumi2); + } + + *s = sumf - sum_mins; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; + const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + + const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); + + const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; + const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); + const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); + const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); + const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); + const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); + const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); + const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc); + + const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); + const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#elif defined __riscv_v_intrinsic + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + + size_t vl = 32; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl); + + sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1); + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl); + + sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2); + + } + + *s = sumf; + +#else + + uint8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; + for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + + for (int j = 0; j < QK_K/32; ++j) { + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + const float dl = d * scales[j]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + + sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; + } + + sumf += d * sumi - dmin * sumi_mins; + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + +#if QK_K == 256 + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; +#else + // TODO + const float d = 0, dmin = 0; +#endif + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); + __m256i hmask = mone; + + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); + __m128i hmask = mone; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int bit = 0; + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + + __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); + __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); + __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); + __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_0, p16_1); + + q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); + q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); + q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + q5_0 = _mm_add_epi8(q5l_0, q5h_0); + q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); + p16_2 = _mm_madd_epi16(scale_1, p16_2); + p16_3 = _mm_madd_epi16(scale_1, p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + float sums = 0.0; + + size_t vl; + + for (int i = 0; i < nb; ++i) { + + vl = 8; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + vl = 32; + int32_t aux32 = 0; + int is = 0; + + uint8_t m = 1; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q5 and Q8 + vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); + vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + + // compute mask for addition + vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl); + m <<= 1; + + vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl); + m <<= 1; + + vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); + vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + + vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); + vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + q5 += 32; q8 += 64; + + } + + vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); + sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + + } + + *s = sumf+sums; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; const int8_t * restrict q8 = y[i].qs; memset(aux32, 0, 8*sizeof(int32_t)); int8_t * restrict a = aux8; uint8_t m = 1; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); a += 32; m <<= 1; - q3 += 32; + q4 += 32; } - a = aux8; + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; - memcpy(auxs, x[i].scales, 12); - uint32_t tmp = auxs[2]; - auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); - auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - for (int j = 0; j < QK_K/16; ++j) { + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; - #endif - } #else -void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q3_K * restrict x = vx; + const block_q5_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; #ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mh = vdupq_n_u8(16); + const int32x4_t mzero = vdupq_n_s32(0); -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t vzero = vdupq_n_s32(0); -#endif - - const uint8x16_t m3b = vdupq_n_u8(0x3); - const uint8x16_t mh = vdupq_n_u8(4); - - int8x16x4_t q3bytes; - - uint16_t aux16[2]; - int8_t * scales = (int8_t *)aux16; + ggml_int8x16x4_t q5bytes; + ggml_uint8x16x4_t q5h; - float sum = 0; + float sumf = 0; for (int i = 0; i < nb; ++i) { - uint8x16x4_t q3h; - - const uint8x8_t hbits = vld1_u8(x[i].hmask); - const uint8x16_t q3bits = vld1q_u8(x[i].qs); - const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - for (int j = 0; j < 4; ++j) scales[j] -= 8; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const int8_t * sc = x[i].scales; - int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - const float d = y[i].d * (float)x[i].d; + const uint8x8_t qhbits = vld1_u8(qh); - const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); - q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); - q3h.val[1] = vandq_u8(mh, htmp); - q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); - q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); - q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); - q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); - q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); - q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); + const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); + q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vbicq_u8(mh, htmp); + q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; -#else - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3]; -#endif + q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); + q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); + q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); + q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); - sum += d * isum; + int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); + int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); + int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); + int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); } - *s = sum; + *s = sumf; #elif defined __AVX2__ - const __m256i m3 = _mm256_set1_epi8(3); - const __m256i m1 = _mm256_set1_epi8(1); + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i mone = _mm256_set1_epi8(1); __m256 acc = _mm256_setzero_ps(); - uint64_t aux64; - - uint16_t aux16[2]; - const int8_t * aux8 = (const int8_t *)aux16; - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); - const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); - memcpy(&aux64, x[i].hmask, 8); + const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); + const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); - const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); - __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux); - __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); - q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); - q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128); - // load low 2 bits - const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); - // prepare low and high bits - const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits); - const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - // load Q8 quants const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - - // multiply with scales - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); + const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); + const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); + const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); + const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); - p16_0 = _mm256_add_epi32(p16_0, p16_1); + const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); } @@ -4904,86 +7685,56 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __AVX__ - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m1 = _mm_set1_epi8(1); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mone = _mm_set1_epi8(1); __m256 acc = _mm256_setzero_ps(); - uint64_t aux64; - - uint16_t aux16[2]; - const int8_t * aux8 = (const int8_t *)aux16; - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); - const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); - const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); - const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); - memcpy(&aux64, x[i].hmask, 8); + const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); + const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); + const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); + const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); - __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); - __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); - __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); - __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); - q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); - q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); - q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); - q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); - // load low 2 bits - const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); + const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); + const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); + const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); - // prepare low and high bits - const __m128i q3l_0 = _mm_and_si128(q3bits, m3); - const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); + const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); + const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); + const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); + const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); - // load Q8 quants const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); - - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - - // multiply with scales - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_1, p16_1); - p16_2 = _mm_madd_epi16(scale_2, p16_2); - p16_3 = _mm_madd_epi16(scale_3, p16_3); + const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); + const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); - p16_0 = _mm_add_epi32(p16_0, p16_2); - p16_1 = _mm_add_epi32(p16_1, p16_3); - __m256i p16 = MM256_SET_M128I(p16_1, p16_0); + const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); + const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); - // multiply with block scale and accumulate - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc); } @@ -4991,72 +7742,69 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __riscv_v_intrinsic - uint16_t aux16[2]; - int8_t * scales = (int8_t *)aux16; - float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - for (int j = 0; j < 4; ++j) scales[j] -= 8; - - int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const int8_t * sc = x[i].scales; - const float d = y[i].d * (float)x[i].d; + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); // load qh - vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8); + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8); vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); size_t vl = 16; - // extend and combine both qh_x1 and qh_x2 + // combine both qh_1 and qh_2 vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); - vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); - vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl); - vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); - vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl); + vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl); + vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl); + vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); - // load Q3 - vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0); + vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1); + vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2); + vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3); - vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl); - vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl); - vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl); - vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl); + // load q5 + vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl); + vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl); - vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0); - vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1); - vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2); - vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3); + vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl)); + vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl)); + vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl)); + vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl)); - // load Q8 and take product with Q3 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl); + vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl); + vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl); + vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl); + + // load Q8 and multiply it with Q5 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); - isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3]; + int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0); + int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1); + int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2); + int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3); - sumf += d * isum; + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); } @@ -5064,371 +7812,429 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #else - int8_t aux8[QK_K]; - int16_t aux16[8]; + int8_t aux8[QK_K]; + int16_t aux16[16]; float sums [8]; - int32_t aux32[8]; - int32_t scales[4]; memset(sums, 0, 8*sizeof(float)); float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; const int8_t * restrict q8 = y[i].qs; int8_t * restrict a = aux8; - for (int l = 0; l < 8; ++l) { - a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); - a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); - a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); - a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); - a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); - a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); - a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); - a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); + for (int l = 0; l < 32; ++l) { + a[l+ 0] = q4[l] & 0xF; + a[l+32] = q4[l] >> 4; + } + for (int is = 0; is < 8; ++is) { + uint8_t m = 1 << is; + for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); } - scales[0] = (x[i].scales[0] & 0xF) - 8; - scales[1] = (x[i].scales[0] >> 4) - 8; - scales[2] = (x[i].scales[1] & 0xF) - 8; - scales[3] = (x[i].scales[1] >> 4) - 8; + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const int8_t * restrict sc = x[i].scales; - memset(aux32, 0, 8*sizeof(int32_t)); for (int j = 0; j < QK_K/16; ++j) { - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; + const float dl = d * sc[j]; + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); + q8 += 16; a += 16; } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; - #endif - } #endif + #if QK_K == 256 -void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q4_K * restrict x = vx; + const block_q6_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - #ifdef __ARM_NEON + float sum = 0; - const uint8x16_t m4b = vdupq_n_u8(0xf); -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t mzero = vdupq_n_s32(0); -#endif + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + //const int8x16_t m32s = vdupq_n_s8(32); - int8x16x2_t q4bytes; - int8x16x2_t q8bytes; + const uint8x16_t mone = vdupq_n_u8(3); - float sumf = 0; + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); + const float d_all = GGML_FP16_TO_FP32(x[i].d); - uint32x2_t mins8 = { 0 }; - mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); - mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; + const int8_t * restrict scale = x[i].scales; - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - sumf -= dmin * vaddvq_s32(prod); + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; - const uint8_t * scales = (const uint8_t *)utmp; + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + int32_t isum = 0; - int32_t sumi1 = 0; - int32_t sumi2 = 0; + for (int j = 0; j < QK_K/128; ++j) { - for (int j = 0; j < QK_K/64; ++j) { + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; + ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); -#ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); - const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + scale += 4; - const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - sumi2 += vaddvq_s32(p2) * scales[2*j+1]; -#else - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); -#endif + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; } - - sumf += d * (sumi1 + sumi2); + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); } - - *s = sumf; + *s = sum; #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); - __m256i sumi = _mm256_setzero_si256(); + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - for (int j = 0; j < QK_K/64; ++j) { + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - p16l = _mm256_madd_epi16(scale_l, p16l); + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); - const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - p16h = _mm256_madd_epi16(scale_h, p16h); - const __m256i sumj = _mm256_add_epi32(p16l, p16h); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); - sumi = _mm256_add_epi32(sumi, sumj); } - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); } - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + *s = hsum_float_8(acc); #elif defined __AVX__ const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); __m128i sumi_0 = _mm_setzero_si128(); __m128i sumi_1 = _mm_setzero_si128(); - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { + __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + for (int j = 0; j < QK_K/128; ++j) { - const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); + const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_0 = _mm_and_si128(q4bits, m4); - const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_1 = _mm_and_si128(q4bits, m4); - const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); + const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); + const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); + const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); + const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); - const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_0 = _mm_add_epi32(sumi_0, p16l); - const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16l = _mm_maddubs_epi16(q4l_1, q8l_1); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_1 = _mm_add_epi32(sumi_1, p16l); + const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_0 = _mm_add_epi32(sumi_0, p16h); - const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16h = _mm_maddubs_epi16(q4h_1, q8h_1); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_1 = _mm_add_epi32(sumi_1, p16h); + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); + const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); + const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); + const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); + const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); + + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); + p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); + p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); + p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); } - __m256 vd = _mm256_set1_ps(d); __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); } - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + *s = hsum_float_8(acc); #elif defined __riscv_v_intrinsic - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - float sumf = 0; - for (int i = 0; i < nb; ++i) { - size_t vl = 8; + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + const int8_t * restrict scale = x[i].scales; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + size_t vl; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - vl = 32; + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - int32_t sum_1 = 0; - int32_t sum_2 = 0; + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + vl = 16; - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - q4 += 32; q8 += 64; + q6 += 64; qh += 32; q8 += 128; is=8; } - sumf += d*(sum_1 + sum_2); + sumf += d * sum_t; } @@ -5436,10 +8242,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri #else - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - int8_t aux8[QK_K]; int16_t aux16[8]; float sums [8]; @@ -5448,35 +8250,26 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; memset(aux32, 0, 8*sizeof(int32_t)); int8_t * restrict a = aux8; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - a += 32; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - a += 32; q4 += 32; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; a = aux8; int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; @@ -5486,1797 +8279,4400 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; #endif } + #else -void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + +void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q4_K * restrict x = vx; + const block_q6_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; #ifdef __ARM_NEON + float sum = 0; - const uint8x16_t m4b = vdupq_n_u8(0xf); - -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t mzero = vdupq_n_s32(0); -#endif - - float sumf = 0; - - int8x16x2_t q4bytes; - int8x16x4_t q8bytes; + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int8x16_t m32s = vdupq_n_s8(32); + const int32x4_t vzero = vdupq_n_s32(0); - float sum_mins = 0.f; + const uint8x16_t mone = vdupq_n_u8(3); - uint16_t aux16[2]; - const uint8_t * restrict scales = (const uint8_t *)aux16; + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t * restrict a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); - sum_mins += y[i].d * (float)x[i].d[1] * summi; - - const float d = y[i].d * (float)x[i].d[0]; + const float d_all = GGML_FP16_TO_FP32(x[i].d); - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; -#ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x4(q8); - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int8_t * restrict scale = x[i].scales; - const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; + int32_t isum = 0; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + uint8x16_t qhbits = vld1q_u8(qh); + ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6); + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); - const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); - const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits, 2); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 4); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); -#else - q8bytes = vld1q_s8_x4(q8); - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; + q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); + q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); - int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; -#endif - sumf += d * (sumi1 + sumi2); + sum += isum * d_all * y[i].d; } - - *s = sumf - sum_mins; + *s = sum; #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); __m256 acc = _mm256_setzero_ps(); - float summs = 0; - - uint16_t aux16[2]; - const uint8_t * scales = (const uint8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; - const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; - const __m256 vd = _mm256_set1_ps(d); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); - - const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - - const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); - - const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); - - } - - *s = hsum_float_8(acc) - summs; - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - uint16_t aux16[2]; - const uint8_t * scales = (const uint8_t *)aux16; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; - const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; - const __m256 vd = _mm256_set1_ps(d); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); - const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); - const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); - const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); - const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); - const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); - - const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); - const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc); - - const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); - const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc); - - } + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); - *s = hsum_float_8(acc) - summs; + __m256i sumi = _mm256_setzero_si256(); -#elif defined __riscv_v_intrinsic + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); - uint16_t s16[2]; - const uint8_t * restrict scales = (const uint8_t *)s16; + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); - float sumf = 0; + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); - for (int i = 0; i < nb; ++i) { + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - const uint16_t * restrict b = (const uint16_t *)x[i].scales; - s16[0] = b[0] & 0x0f0f; - s16[1] = (b[0] >> 4) & 0x0f0f; + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - size_t vl = 32; + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl); + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } - sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1); + *s = hsum_float_8(acc); - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl); +#elif defined __AVX__ - sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); - } + __m256 acc = _mm256_setzero_ps(); - *s = sumf; + for (int i = 0; i < nb; ++i) { -#else + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - uint8_t aux8[QK_K]; - int16_t aux16[16]; - float sums [8]; - memset(sums, 0, 8*sizeof(float)); + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - uint16_t s16[2]; - const uint8_t * restrict scales = (const uint8_t *)s16; + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - uint8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; - for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); - const uint16_t * restrict b = (const uint16_t *)x[i].scales; - s16[0] = b[0] & 0x0f0f; - s16[1] = (b[0] >> 4) & 0x0f0f; + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); - sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); - for (int j = 0; j < QK_K/32; ++j) { - for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; - q8 += 16; a += 16; - for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; - q8 += 16; a += 16; - const float dl = d * scales[j]; - for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); - } - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} -#endif + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); -#if QK_K == 256 -void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - const block_q5_K * restrict x = vx; - const block_q8_K * restrict y = vy; + __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); - const int nb = n / QK_K; + __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); - uint32_t utmp[4]; + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); -#ifdef __ARM_NEON + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc); + } - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mone = vdupq_n_u8(1); - const uint8x16_t mtwo = vdupq_n_u8(2); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t mzero = vdupq_n_s32(0); -#endif + *s = hsum_float_8(acc); - int8x16x4_t q5bytes; +#elif defined __riscv_v_intrinsic float sumf = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + const float d_all = GGML_FP16_TO_FP32(x[i].d); - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + const int8_t * restrict scale = x[i].scales; - const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - int32_t sumi_mins = vaddvq_s32(prod); + int32_t isum = 0; - const uint8_t * scales = (const uint8_t *)utmp; + size_t vl = 16; - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - uint8x16x2_t qhbits = vld1q_u8_x2(qh); + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl); - uint8x16x4_t q5h; + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); - int32_t sumi = 0; + vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - for (int j = 0; j < QK_K/64; ++j) { + vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl); + vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl); + vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl); + vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl); - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl); + vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl); + vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl); + vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl); - q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); - q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + // load Q8 and take product + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3]; + + sumf += isum * d_all * y[i].d; + + } -#if defined(__ARM_FEATURE_DOTPROD) + *s = sumf; - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; #else - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; - - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; -#endif + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int l = 0; l < 16; ++l) { + a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; } - - sumf += d * sumi - dmin * sumi_mins; - + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; } - + for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; +#endif +} -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m256i mone = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - -#if QK_K == 256 - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; -#else - // TODO - const float d = 0, dmin = 0; #endif - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); +#if defined (__AVX2__) || defined (__ARM_NEON) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); + const block_iq2_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); - __m256i hmask = mone; + const int nb = n / QK_K; - __m256i sumi = _mm256_setzero_si256(); +#if defined(__ARM_NEON) - int bit = 0; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - for (int j = 0; j < QK_K/64; ++j) { + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; - const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.25f * sumf; - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); - hmask = _mm256_slli_epi16(hmask, 1); +#elif defined(__AVX2__) - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); - hmask = _mm256_slli_epi16(hmask, 1); + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; - __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); +#else - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); +void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - } + const block_iq2_xs * restrict x = vx; + const block_q8_K * restrict y = vy; - *s = hsum_float_8(acc) + summs; + const int nb = n / QK_K; -#elif defined __AVX__ +#if defined(__ARM_NEON) - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m2 = _mm_set1_epi8(2); + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - __m256 acc = _mm256_setzero_ps(); + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; - float summs = 0.f; + int32x4x4_t scales32; + float sumf = 0; for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8x8_t scales8 = vld1_u8(x[i].scales); + const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); + const uint8x8_t scales_h = vshr_n_u8(scales8, 4); + uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); + const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); + const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); + scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); + scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); + scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); + scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); + int32x4_t sumi = vdupq_n_s32(0); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); + const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); + sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); + q2 += 8; + } + sumf += d*vaddvq_s32(sumi); + } + *s = 0.125f * sumf; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); +#elif defined(__AVX2__) - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; + const __m256i mone = _mm256_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); + const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); + const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); +#if QK_K == 64 + static const uint8_t k_bit_helper[16] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m128i bit_helper = _mm_loadu_si128((const __m128i*)k_bit_helper); + const __m128i m511 = _mm_set1_epi16(511); + typedef union { + __m128i vec_index; + uint16_t index[8]; + } index_t; + + index_t idx; + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const __m128i q2_data = _mm_loadu_si128((const __m128i*)x[i].qs); + idx.vec_index = _mm_and_si128(q2_data, m511); - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); + const __m128i partial_sign_bits = _mm_srli_epi16(q2_data, 9); + const __m128i partial_sign_bits_upper = _mm_srli_epi16(q2_data, 13); + const __m128i partial_sign_bits_for_counting = _mm_xor_si128(partial_sign_bits, partial_sign_bits_upper); - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); - __m128i hmask = mone; + const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); + const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits); + const __m256i full_signs = MM256_SET_M128I(full_sign_bits, full_sign_bits); - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32)); - int bit = 0; + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[idx.index[3]], iq2xs_grid[idx.index[2]], + iq2xs_grid[idx.index[1]], iq2xs_grid[idx.index[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[idx.index[7]], iq2xs_grid[idx.index[6]], + iq2xs_grid[idx.index[5]], iq2xs_grid[idx.index[4]]); - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { + __m256i signs; + signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); + signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); - const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); - __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); - __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); - __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); + const __m256i sc1 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1)); + const __m256i sc2 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1)); - __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_0, p16_1); + const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2)); - q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); - q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); - q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - q5_0 = _mm_add_epi8(q5l_0, q5h_0); - q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sum), accumf); - q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); - p16_2 = _mm_madd_epi16(scale_1, p16_2); - p16_3 = _mm_madd_epi16(scale_1, p16_3); + } - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + *s = 0.125f * hsum_float_8(accumf); +#else - } + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); + const __m256i m511 = _mm256_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + uint64_t aux64; - } + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; - *s = hsum_float_8(acc) + summs; + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; -#elif defined __riscv_v_intrinsic + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - float sumf = 0; - float sums = 0.0; + const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; + aux_gindex = _mm256_and_si256(q2_data, m511); - size_t vl; + const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); + const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); + const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); - for (int i = 0; i < nb; ++i) { + const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); - vl = 8; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], + iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], + iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); + const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], + iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); + const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], + iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits); + const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1); + const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l); + const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h); - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + __m256i signs; + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); - vl = 32; - int32_t aux32 = 0; - int is = 0; + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3); + const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4); - uint8_t m = 1; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); + const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2))); + const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3))); - for (int j = 0; j < QK_K/64; ++j) { - // load Q5 and Q8 - vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); - vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4)); + } - // compute mask for addition - vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl); - m <<= 1; + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl); - m <<= 1; + } - vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); - vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + *s = 0.125f * hsum_float_8(accumf); +#endif - vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); - vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); +#else - vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); - vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} - aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); - q5 += 32; q8 += 64; +void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - } + const block_iq2_s * restrict x = vx; + const block_q8_K * restrict y = vy; - vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); - sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + const int nb = n / QK_K; - } +#if defined(__ARM_NEON) - *s = sumf+sums; + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; -#else + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; + const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint8x16_t m1 = vdupq_n_u8(1); + const int32x4_t vzero = vdupq_n_s32(0); - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); + uint8x16x2_t vs; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; float sumf = 0; for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} -#else + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; -void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); + q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); + q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); + q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); + qs += 8; - const block_q5_K * restrict x = vx; - const block_q8_K * restrict y = vy; + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); - const int nb = n / QK_K; + q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); + q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); -#ifdef __ARM_NEON + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mh = vdupq_n_u8(16); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t mzero = vdupq_n_s32(0); -#endif + signs += 4; - int8x16x4_t q5bytes; - uint8x16x4_t q5h; + q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); + q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); - float sumf = 0; + const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); - for (int i = 0; i < nb; ++i) { + sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); + sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); + sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); + sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); + } + sumf += d*(sumi1 + sumi2); + } - const float d = y[i].d * (float)x[i].d; - const int8_t * sc = x[i].scales; + *s = 0.125f * sumf; - const uint8_t * restrict q5 = x[i].qs; +#elif defined(__AVX2__) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); const int8_t * restrict q8 = y[i].qs; - const uint8x8_t qhbits = vld1_u8(qh); + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; - const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); - q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); - q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); - q5h.val[2] = vbicq_u8(mh, htmp); - q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); + __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); - q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); - q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); - q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); - q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); + aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); -#if defined(__ARM_FEATURE_DOTPROD) + signs += 4; - int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); - int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); - int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); - int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 - sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); #else - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); + *s = 0.125f * sumf; - sumf += d*sumi; #endif - } +} - *s = sumf; +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); -#elif defined __AVX2__ + const block_iq3_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i mone = _mm256_set1_epi8(1); + const int nb = n / QK_K; - __m256 acc = _mm256_setzero_ps(); +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + ggml_int8x16x4_t q3s; + ggml_int8x16x4_t q8b; + float sumf = 0; for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); + const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); + const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); + const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); + const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); + q3 += 16; + q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); + q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); + q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); + q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.5f * sumf; - const uint8_t * restrict q5 = x[i].qs; +#elif defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); +#else - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + uint32_t aux32; - const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); - const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +#endif +} - int64_t aux64; - memcpy(&aux64, x[i].qh, 8); - const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); - const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128); +void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); + const block_iq3_s * restrict x = vx; + const block_q8_K * restrict y = vy; - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const int nb = n / QK_K; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); +#if defined(__ARM_NEON) - const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); - const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); - const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); - const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); + typedef union { + uint16x8_t vec_index; + uint16_t index[8]; + } vec_index_t; - const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; - acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - } + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; - *s = hsum_float_8(acc); + const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); -#elif defined __AVX__ + const int16x8_t hshift = vld1q_s16(k_shift); + const uint16x8_t m256 = vdupq_n_u16(256); + const uint8x16_t m1 = vdupq_n_u8(1); - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mone = _mm_set1_epi8(1); + uint8x16x2_t vs; + ggml_int8x16x4_t q3s; + ggml_int8x16x4_t q8b; + vec_index_t idx; - __m256 acc = _mm256_setzero_ps(); +#if QK_K == 256 + uint32_t scales32[2]; + const uint8_t * scales8 = (const uint8_t *)scales32; +#endif + float sumf = 0; for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; +#if QK_K == 256 + memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; +#endif - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; + idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); + const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); + const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + signs += 4; + + q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); +#if QK_K == 256 + sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; + sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; +#else + sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4)); +#endif + } + sumf += d*(sumi1 + sumi2); + } + *s = sumf; - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); +#elif defined(__AVX2__) - const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); - const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); - const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); - const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; - int64_t aux64; - memcpy(&aux64, x[i].qh, 8); - const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); - const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; - const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); - const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); - const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); - const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); - const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); - const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); - const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); - const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = _mm256_set1_epi32(256); - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; - const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); - const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); - const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); - const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); - const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); - const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); - const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); - const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16; + idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]); + idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]); + idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l))); + idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = _mm256_set_epi32( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = _mm256_set_epi32( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); + + __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); + + aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint8_t * restrict signs = x[i].signs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; + const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls1; + sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls2; + } + sumf += d * bsum; + } + *s = sumf; +#endif +} - const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); - const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc); +#ifdef __AVX2__ +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = _mm256_sign_epi8(x, x); + const __m256i sy = _mm256_sign_epi8(y, x); + return _mm256_maddubs_epi16(ax, sy); +} +#endif - } +void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - *s = hsum_float_8(acc); + const block_iq1_s * restrict x = vx; + const block_q8_K * restrict y = vy; -#elif defined __riscv_v_intrinsic + const int nb = n / QK_K; - float sumf = 0; +#if defined __ARM_NEON + + ggml_int8x16x4_t q1b; + ggml_int8x16x4_t q8b; + float sumf = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * (float)x[i].d; - const int8_t * sc = x[i].scales; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + for (int ib = 0; ib < QK_K/32; ib += 2) { - // load qh - vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8); - vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); + qs += 8; - size_t vl = 16; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - // combine both qh_1 and qh_2 - vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); - vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); - vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl); - vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl); - vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + sumi1 += vaddvq_s32(p1) * ls1; + sumi2 += vaddvq_s32(p2) * ls2; + sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); - vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0); - vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1); - vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2); - vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3); + } - // load q5 - vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl); - vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl); + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); + } - vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl)); - vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl)); - vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl)); - vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl)); + *s = sumf; - vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl); - vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl); - vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl); - vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl); +#elif defined __AVX2__ - // load Q8 and multiply it with Q5 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; - int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0); - int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1); - int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2); - int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3); + __m256i sumi = _mm256_setzero_si256(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], + iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], + iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + qs += 8; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2)); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } - sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum); + accum1 += d * sumi1; } - *s = sumf; + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; #else - int8_t aux8[QK_K]; - int16_t aux16[16]; - float sums [8]; - memset(sums, 0, 8*sizeof(float)); - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - int8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) { - a[l+ 0] = q4[l] & 0xF; - a[l+32] = q4[l] >> 4; - } - for (int is = 0; is < 8; ++is) { - uint8_t m = 1 << is; - for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const int8_t * restrict sc = x[i].scales; + for (int i = 0; i < nb; i++) { - for (int j = 0; j < QK_K/16; ++j) { - const float dl = d * sc[j]; - for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); - q8 += 16; a += 16; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi = 0, sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*((qh[ib] >> 12) & 7) + 1; + const int delta = qh[ib] & 0x8000 ? -1 : 1; + int lsum = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } + q8 += 8; + } + sumi += ls * lsum; + sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); + qs += 4; } + + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); } - for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + #endif } -#endif - -#if QK_K == 256 -void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_q6_K * restrict x = vx; - const block_q8_K * restrict y = vy; + const block_iq1_m * restrict x = vx; + const block_q8_K * restrict y = vy; const int nb = n / QK_K; -#ifdef __ARM_NEON +#if QK_K != 64 + iq1m_scale_t scale; +#endif - float sum = 0; +#if defined __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xF); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); +#if QK_K == 64 + const int32x4_t mask = vdupq_n_s32(0xf); +#else + const int32x4_t mask = vdupq_n_s32(0x7); #endif - //const int8x16_t m32s = vdupq_n_s8(32); + const int32x4_t mone = vdupq_n_s32(1); + const int32x4_t mzero = vdupq_n_s32(0); - const uint8x16_t mone = vdupq_n_u8(3); + ggml_int8x16x4_t deltas; + deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1)); + deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1)); + deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1)); + deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1)); + + ggml_int8x16x4_t q1b; + ggml_int8x16x4_t q8b; - int8x16x4_t q6bytes; - uint8x16x4_t q6h; + uint32_t aux32; + const uint8_t * aux8 = (const uint8_t *)&aux32; + float sumf = 0; for (int i = 0; i < nb; ++i) { - const float d_all = GGML_FP16_TO_FP32(x[i].d); + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +#if QK_K != 64 + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); +#endif - const int8_t * restrict scale = x[i].scales; + int32x4_t sumi1 = mzero; + int32x4_t sumi2 = mzero; - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int8x16_t scales = vld1q_s8(scale); - const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + for (int ib = 0; ib < QK_K/32; ib += 2) { - const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), - vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), - vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), - vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); - int32_t isum_mins = vaddvq_s32(prod); + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700))))); - int32_t isum = 0; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - for (int j = 0; j < QK_K/128; ++j) { + const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); + const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); + const int32x4_t p12 = vpaddq_s32(p1, p2); - uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; - uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; - int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that + aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202); - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 2); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1])); + const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); + const int32x4_t p34 = vpaddq_s32(p3, p4); - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); +#if QK_K == 64 + int32x4_t scales_4 = ggml_vld1q_u32(sc[0] >> 0, sc[0] >> 4, sc[0] >> 8, sc[0] >> 12); +#else + int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); +#endif + scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); -#if defined(__ARM_FEATURE_DOTPROD) + sumi1 = vmlaq_s32(sumi1, scales_4, p12); + sumi2 = vmlaq_s32(sumi2, scales_4, p34); - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; + qs += 8; qh += 4; + + } + +#if QK_K == 64 + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); +#else + sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); +#endif + } + + *s = sumf; +#elif defined __AVX2__ + +#if QK_K == 64 + const __m256i mask = _mm256_set1_epi16(0xf); #else + const __m256i mask = _mm256_set1_epi16(0x7); +#endif + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; +#if QK_K != 64 + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); #endif - q8bytes = vld1q_s8_x4(q8); q8 += 64; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m256i q1b_1 = _mm256_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)] + ); + const __m256i q1b_2 = _mm256_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)] + ); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + + const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + + const __m256i dot3 = mul_add_epi8(delta1, q8b_1); + const __m256i dot4 = mul_add_epi8(delta2, q8b_2); +#if QK_K == 64 + __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0)); + __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8)); +#else + __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); + __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); +#endif + scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); + scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); + const __m256i p1 = _mm256_madd_epi16(dot1, scale1); + const __m256i p2 = _mm256_madd_epi16(dot2, scale2); + const __m256i p3 = _mm256_madd_epi16(dot3, scale1); + const __m256i p4 = _mm256_madd_epi16(dot4, scale2); - shifted = vshrq_n_u8(qhbits.val[0], 4); - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[0], 6); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4)); - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + qs += 8; qh += 4; + } -#if defined(__ARM_FEATURE_DOTPROD) +#if QK_K == 64 + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); +#else + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); +#endif + accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); + accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); - //for (int l = 0; l < 4; ++l) { - // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); - // isum += vaddvq_s32(p) * *scale++; - //} #else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; + + int sum1[2], sum2[2], delta[4]; + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + +#if QK_K != 64 + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); #endif + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + delta[0] = qh[0] & 0x08 ? -1 : 1; + delta[1] = qh[0] & 0x80 ? -1 : 1; + delta[2] = qh[1] & 0x08 ? -1 : 1; + delta[3] = qh[1] & 0x80 ? -1 : 1; + sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); + int lsum1 = 0, lsum2 = 0; + for (int j = 0; j < 8; ++j) { + lsum1 += q8[j] * grid[j]; + lsum2 += q8[j]; + } + q8 += 8; + sum1[l/2] += lsum1; + sum2[l/2] += lsum2*delta[l]; + } +#if QK_K == 64 + const int ls1 = 2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1; + const int ls2 = 2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1; +#else + const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; + const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; +#endif + sumi1 += sum1[0] * ls1 + sum1[1] * ls2; + sumi2 += sum2[0] * ls1 + sum2[1] * ls2; + qs += 4; + qh += 2; } - //sum += isum * d_all * y[i].d; - sum += d_all * y[i].d * (isum - 32 * isum_mins); +#if QK_K == 64 + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); +#else + sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); +#endif } - *s = sum; -#elif defined __AVX2__ + *s = sumf; - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); +#endif +} - __m256 acc = _mm256_setzero_ps(); +void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); - for (int i = 0; i < nb; ++i) { + const block_iq4_nl * restrict x = vx; + const block_q8_0 * restrict y = vy; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const int nb = n / QK4_NL; - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + uint8x16x2_t q4bits; + int8x16x4_t q4b; + int8x16x4_t q8b; + int32x4_t prod_1, prod_2; - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + float sumf = 0; - __m256i sumi = _mm256_setzero_si256(); + for (int ib = 0; ib < nb; ib += 2) { - int is = 0; + q4bits.val[0] = vld1q_u8(x[ib+0].qs); + q4bits.val[1] = vld1q_u8(x[ib+1].qs); + q8b.val[0] = vld1q_s8(y[ib+0].qs); + q8b.val[1] = vld1q_s8(y[ib+0].qs + 16); + q8b.val[2] = vld1q_s8(y[ib+1].qs); + q8b.val[3] = vld1q_s8(y[ib+1].qs + 16); - for (int j = 0; j < QK_K/128; ++j) { + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + sumf += + GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) + + GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2); + } - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + *s = sumf; + +#elif defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int ib = 0; ib < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs); + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs); + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); + const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); + accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)), + _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)), + _mm256_cvtepi32_ps(p_2), accum2); + + y += 2; + x += 2; + } + + *s = hsum_float_8(_mm256_add_ps(accum1, accum2)); + +#else + float sumf = 0; + for (int ib = 0; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +#endif +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); +#if QK_K == 64 + ggml_vec_dot_iq4_nl_q8_0(n, s, bs, vx, bx, vy, by, nrc); +#else + + const block_iq4_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + ggml_uint8x16x2_t q4bits; + ggml_int8x16x4_t q4b; + ggml_int8x16x4_t q8b; + int32x4_t prod_1, prod_2; - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + float sumf = 0; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + for (int ibl = 0; ibl < nb; ++ibl) { - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + const int8_t * q8 = y[ibl].qs; + const uint8_t * q4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/64; ++ib) { - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + sumi1 += vaddvq_s32(prod_1) * ls1; + sumi2 += vaddvq_s32(prod_2) * ls2; } - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); } - *s = hsum_float_8(acc); + *s = sumf; -#elif defined __AVX__ +#elif defined __AVX2__ - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m32s = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1)); + const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2)); + sumi1 = _mm256_add_epi32(p_1, sumi1); + sumi2 = _mm256_add_epi32(p_2, sumi2); + } + accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); - __m256 acc = _mm256_setzero_ps(); +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +#endif +} - for (int i = 0; i < nb; ++i) { +// ================================ IQ2 quantization ============================================= + +typedef struct { + uint64_t * grid; + int * map; + uint16_t * neighbours; +} iq2_entry_t; + +static iq2_entry_t iq2_data[4] = { + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, +}; + +static inline int iq2_data_index(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S); + return type == GGML_TYPE_IQ2_XXS ? 0 : + type == GGML_TYPE_IQ2_XS ? 1 : + type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 2 : 3; +} - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); +static inline int iq2_grid_size(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S); + return type == GGML_TYPE_IQ2_XXS ? 256 : + type == GGML_TYPE_IQ2_XS ? 512 : + type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? NGRID_IQ1S : 1024; +} - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +static int iq2_compare_func(const void * left, const void * right) { + const int * l = (const int *)left; + const int * r = (const int *)right; + return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; +} - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); +void iq2xs_init_impl(enum ggml_type type) { + const int gindex = iq2_data_index(type); + const int grid_size = iq2_grid_size(type); + if (iq2_data[gindex].grid) { + return; + } + static const uint16_t kgrid_2bit_256[256] = { + 0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97, + 100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642, + 1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288, + 1312, 1350, 1385, 1408, 1425, 1545, 1552, 1600, 1668, 1700, 2048, 2053, 2056, 2068, 2088, 2113, + 2116, 2128, 2130, 2184, 2308, 2368, 2562, 2580, 4097, 4100, 4112, 4129, 4160, 4192, 4228, 4240, + 4245, 4352, 4360, 4384, 4432, 4442, 4480, 4644, 4677, 5120, 5128, 5152, 5157, 5193, 5248, 5400, + 5474, 5632, 5654, 6145, 6148, 6160, 6208, 6273, 6400, 6405, 6560, 6737, 8192, 8194, 8202, 8260, + 8289, 8320, 8322, 8489, 8520, 8704, 8706, 9217, 9220, 9232, 9280, 9302, 9472, 9537, 9572, 9872, + 10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516, + 16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561, + 17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488, + 20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545, + 22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874, + 25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856, + 33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142, + 37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268, + }; + static const uint16_t kgrid_2bit_512[512] = { + 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, + 73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257, + 260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340, + 352, 360, 385, 388, 400, 512, 514, 517, 520, 529, 532, 544, 577, 580, 592, 597, + 640, 650, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1088, 1090, 1093, 1096, + 1105, 1108, 1110, 1120, 1153, 1156, 1168, 1280, 1282, 1285, 1288, 1297, 1300, 1312, 1345, 1348, + 1360, 1377, 1408, 1537, 1540, 1552, 1574, 1600, 1602, 1668, 2048, 2050, 2053, 2056, 2058, 2065, + 2068, 2080, 2085, 2113, 2116, 2128, 2136, 2176, 2208, 2218, 2305, 2308, 2320, 2368, 2433, 2441, + 2560, 2592, 2600, 2710, 2720, 4097, 4100, 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4160, + 4162, 4165, 4168, 4177, 4180, 4192, 4202, 4225, 4228, 4240, 4352, 4354, 4357, 4360, 4369, 4372, + 4384, 4417, 4420, 4432, 4480, 4500, 4502, 4609, 4612, 4614, 4624, 4672, 4704, 5120, 5122, 5125, + 5128, 5137, 5140, 5152, 5185, 5188, 5193, 5200, 5220, 5248, 5377, 5380, 5392, 5440, 5632, 5652, + 5705, 6145, 6148, 6160, 6162, 6208, 6228, 6278, 6400, 6405, 6502, 6737, 6825, 8192, 8194, 8197, + 8200, 8202, 8209, 8212, 8224, 8257, 8260, 8272, 8320, 8352, 8449, 8452, 8464, 8512, 8520, 8549, + 8704, 8738, 8832, 8872, 9217, 9220, 9232, 9257, 9280, 9472, 9537, 9554, 9625, 9729, 9754, 9894, + 10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388, + 16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480, + 16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773, + 16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473, + 17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436, + 18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497, + 20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162, + 21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528, + 22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745, + 24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234, + 32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025, + 33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810, + 33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984, + 35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462, + 37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960, + 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048, + 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690, + }; + static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = { + 0, 2, 5, 8, 10, 17, 21, 32, 34, 40, 42, 69, 81, 84, 86, 101, + 128, 130, 136, 138, 149, 160, 162, 168, 170, 260, 261, 273, 276, 278, 281, 282, + 293, 321, 326, 329, 338, 341, 346, 353, 356, 358, 360, 389, 401, 404, 406, 421, + 512, 514, 520, 522, 533, 544, 546, 552, 554, 581, 593, 601, 612, 617, 640, 642, + 648, 650, 657, 661, 665, 672, 674, 680, 682, 1041, 1044, 1046, 1061, 1089, 1097, 1109, + 1114, 1124, 1125, 1169, 1177, 1189, 1281, 1284, 1285, 1286, 1301, 1304, 1306, 1321, 1344, 1349, + 1354, 1360, 1361, 1364, 1365, 1366, 1369, 1376, 1378, 1381, 1384, 1386, 1409, 1425, 1429, 1432, + 1434, 1441, 1444, 1445, 1446, 1449, 1556, 1561, 1601, 1604, 1616, 1618, 1621, 1624, 1632, 1633, + 1638, 1641, 1669, 1681, 1684, 1689, 2048, 2050, 2056, 2058, 2069, 2080, 2082, 2088, 2090, 2117, + 2129, 2134, 2149, 2176, 2178, 2184, 2186, 2197, 2208, 2210, 2216, 2218, 2309, 2321, 2324, 2329, + 2340, 2341, 2369, 2384, 2385, 2389, 2401, 2404, 2409, 2449, 2452, 2454, 2457, 2469, 2560, 2562, + 2568, 2570, 2581, 2592, 2594, 2600, 2602, 2629, 2641, 2649, 2657, 2661, 2688, 2690, 2693, 2696, + 2698, 2709, 2720, 2722, 2728, 2730, 4112, 4113, 4116, 4121, 4132, 4133, 4161, 4164, 4176, 4181, + 4184, 4193, 4196, 4197, 4201, 4241, 4244, 4246, 4257, 4261, 4353, 4356, 4358, 4361, 4368, 4370, + 4373, 4376, 4385, 4388, 4393, 4421, 4426, 4432, 4433, 4434, 4436, 4437, 4438, 4441, 4448, 4453, + 4484, 4498, 4501, 4513, 4516, 4625, 4628, 4630, 4645, 4672, 4678, 4681, 4690, 4693, 4696, 4698, + 4708, 4710, 4741, 4753, 4756, 4758, 4773, 5121, 5126, 5129, 5140, 5141, 5144, 5145, 5153, 5158, + 5185, 5189, 5190, 5192, 5194, 5201, 5204, 5205, 5206, 5209, 5218, 5221, 5224, 5252, 5257, 5264, + 5268, 5269, 5272, 5273, 5274, 5281, 5284, 5285, 5289, 5378, 5381, 5386, 5393, 5396, 5397, 5398, + 5401, 5408, 5410, 5413, 5416, 5418, 5441, 5444, 5445, 5446, 5457, 5458, 5460, 5461, 5462, 5465, + 5466, 5473, 5476, 5477, 5478, 5481, 5504, 5506, 5508, 5509, 5512, 5514, 5520, 5521, 5524, 5525, + 5526, 5529, 5530, 5536, 5538, 5541, 5633, 5636, 5637, 5638, 5653, 5654, 5656, 5658, 5665, 5670, + 5696, 5698, 5700, 5701, 5704, 5706, 5713, 5717, 5718, 5720, 5721, 5729, 5732, 5733, 5736, 5737, + 5738, 5766, 5770, 5778, 5781, 5796, 5801, 6161, 6166, 6181, 6209, 6212, 6214, 6217, 6224, 6229, + 6232, 6234, 6240, 6241, 6244, 6246, 6249, 6277, 6289, 6292, 6309, 6416, 6418, 6421, 6426, 6433, + 6437, 6466, 6468, 6469, 6472, 6481, 6484, 6485, 6486, 6489, 6490, 6496, 6501, 6506, 6537, 6545, + 6546, 6549, 6552, 6561, 6566, 6569, 6665, 6678, 6692, 6694, 6724, 6726, 6729, 6736, 6738, 6741, + 6744, 6753, 6758, 6761, 6789, 6801, 6806, 6810, 8192, 8194, 8200, 8202, 8213, 8224, 8226, 8229, + 8232, 8234, 8261, 8273, 8281, 8289, 8293, 8320, 8322, 8328, 8330, 8341, 8352, 8354, 8357, 8360, + 8362, 8453, 8465, 8468, 8473, 8485, 8514, 8516, 8521, 8533, 8536, 8538, 8545, 8548, 8549, 8550, + 8581, 8592, 8598, 8601, 8613, 8705, 8712, 8714, 8721, 8725, 8736, 8738, 8744, 8746, 8773, 8785, + 8790, 8793, 8805, 8833, 8840, 8842, 8849, 8853, 8864, 8866, 8872, 8874, 9221, 9236, 9238, 9241, + 9253, 9284, 9285, 9286, 9289, 9298, 9301, 9304, 9306, 9318, 9349, 9361, 9364, 9369, 9377, 9381, + 9481, 9493, 9505, 9513, 9536, 9541, 9544, 9553, 9556, 9557, 9561, 9570, 9573, 9576, 9609, 9616, + 9620, 9621, 9624, 9626, 9633, 9636, 9638, 9641, 9733, 9744, 9746, 9753, 9765, 9793, 9801, 9813, + 9824, 9825, 9833, 9860, 9862, 9872, 9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282, + 10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521, + 10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752, + 10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890, + 10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484, + 16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673, + 16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772, + 16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986, + 16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494, + 17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666, + 17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744, + 17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809, + 17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953, + 17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049, + 18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517, + 18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704, + 18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784, + 18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012, + 19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501, + 20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617, + 20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761, + 20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822, + 20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896, + 20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078, + 21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526, + 21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589, + 21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653, + 21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780, + 21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832, + 21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864, + 21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924, + 21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048, + 22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098, + 22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154, + 22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561, + 22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665, + 22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821, + 22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884, + 22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061, + 23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144, + 23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656, + 24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850, + 24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970, + 24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221, + 25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674, + 25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749, + 25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926, + 25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001, + 26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176, + 26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250, + 26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721, + 26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949, + 26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044, + 27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270, + 27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852, + 32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046, + 33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161, + 33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369, + 33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877, + 33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117, + 34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192, + 34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394, + 34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858, + 34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986, + 35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172, + 35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412, + 35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901, + 36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124, + 37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205, + 37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396, + 37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889, + 37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985, + 37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161, + 38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226, + 38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290, + 38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432, + 38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538, + 38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998, + 39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194, + 39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269, + 39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497, + 39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994, + 41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130, + 41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349, + 41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561, + 41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068, + 42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278, + 42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386, + 42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592, + 42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048, + 43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284, + 43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530, + 43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690, + }; + static const uint16_t kgrid_2bit_1024[1024] = { + 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, + 73, 80, 82, 85, 88, 97, 100, 102, 105, 128, 130, 133, 136, 145, 148, 160, + 165, 170, 257, 260, 262, 265, 272, 274, 277, 280, 289, 292, 320, 322, 325, 328, + 337, 340, 342, 345, 352, 357, 360, 385, 388, 400, 402, 405, 417, 420, 512, 514, + 517, 520, 529, 532, 544, 554, 577, 580, 582, 585, 592, 597, 640, 645, 650, 660, + 674, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1062, 1065, 1088, 1090, 1093, + 1096, 1098, 1105, 1108, 1110, 1113, 1120, 1122, 1125, 1153, 1156, 1158, 1161, 1168, 1173, 1176, + 1185, 1188, 1280, 1282, 1285, 1288, 1290, 1297, 1300, 1302, 1305, 1312, 1317, 1320, 1345, 1348, + 1350, 1353, 1360, 1362, 1365, 1368, 1377, 1380, 1408, 1410, 1413, 1416, 1425, 1428, 1440, 1537, + 1540, 1542, 1545, 1552, 1557, 1600, 1605, 1608, 1617, 1620, 1632, 1665, 1668, 1680, 2048, 2050, + 2053, 2056, 2065, 2068, 2070, 2073, 2080, 2085, 2090, 2113, 2116, 2118, 2121, 2128, 2130, 2133, + 2136, 2145, 2148, 2176, 2181, 2196, 2218, 2305, 2308, 2320, 2322, 2325, 2328, 2337, 2368, 2373, + 2376, 2385, 2388, 2400, 2433, 2448, 2560, 2577, 2580, 2594, 2600, 2602, 2640, 2713, 4097, 4100, + 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4134, 4160, 4162, 4165, 4168, 4177, 4180, 4182, + 4185, 4192, 4194, 4197, 4200, 4225, 4228, 4230, 4240, 4245, 4248, 4257, 4260, 4352, 4354, 4357, + 4360, 4362, 4369, 4372, 4374, 4377, 4384, 4386, 4389, 4392, 4417, 4420, 4422, 4425, 4432, 4434, + 4437, 4440, 4449, 4452, 4480, 4482, 4485, 4488, 4497, 4500, 4609, 4612, 4617, 4624, 4629, 4641, + 4644, 4672, 4677, 4689, 4692, 4737, 4740, 4752, 5120, 5122, 5125, 5128, 5137, 5140, 5142, 5145, + 5152, 5157, 5160, 5185, 5188, 5190, 5193, 5200, 5202, 5205, 5208, 5217, 5220, 5248, 5250, 5253, + 5256, 5265, 5268, 5280, 5377, 5380, 5382, 5385, 5392, 5394, 5397, 5400, 5409, 5412, 5440, 5442, + 5445, 5448, 5457, 5460, 5472, 5505, 5508, 5520, 5632, 5637, 5640, 5649, 5652, 5664, 5697, 5700, + 5712, 5760, 5802, 6145, 6148, 6150, 6153, 6160, 6165, 6168, 6177, 6208, 6210, 6213, 6216, 6225, + 6228, 6240, 6273, 6276, 6400, 6402, 6405, 6408, 6417, 6420, 6432, 6465, 6468, 6480, 6505, 6562, + 6660, 6672, 6720, 6742, 8192, 8194, 8197, 8200, 8209, 8212, 8214, 8217, 8224, 8229, 8234, 8257, + 8260, 8272, 8274, 8277, 8292, 8320, 8330, 8340, 8362, 8449, 8452, 8464, 8466, 8469, 8481, 8512, + 8514, 8517, 8529, 8532, 8544, 8577, 8580, 8592, 8704, 8714, 8738, 8744, 8746, 8772, 8784, 8840, + 8842, 8872, 9217, 9220, 9222, 9225, 9232, 9237, 9240, 9249, 9252, 9280, 9282, 9285, 9288, 9297, + 9300, 9312, 9345, 9348, 9360, 9472, 9477, 9480, 9489, 9492, 9504, 9537, 9540, 9552, 9574, 9600, + 9729, 9732, 9744, 9792, 9817, 10240, 10245, 10257, 10260, 10305, 10308, 10320, 10378, 10410, 10497, 10500, + 10512, 10645, 10762, 10786, 10852, 10888, 10890, 16385, 16388, 16390, 16393, 16400, 16402, 16405, 16408, 16410, + 16417, 16420, 16422, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16470, 16473, 16480, 16482, 16485, 16513, + 16516, 16528, 16533, 16536, 16545, 16548, 16640, 16642, 16645, 16648, 16657, 16660, 16662, 16665, 16672, 16674, + 16677, 16705, 16708, 16710, 16713, 16720, 16722, 16725, 16728, 16737, 16740, 16768, 16770, 16773, 16776, 16785, + 16788, 16800, 16897, 16900, 16912, 16914, 16917, 16920, 16932, 16960, 16965, 16968, 16977, 16980, 16992, 17025, + 17028, 17408, 17410, 17413, 17416, 17418, 17425, 17428, 17430, 17433, 17440, 17442, 17445, 17448, 17473, 17476, + 17478, 17481, 17488, 17490, 17493, 17496, 17505, 17508, 17536, 17538, 17541, 17544, 17553, 17556, 17568, 17665, + 17668, 17670, 17673, 17680, 17682, 17685, 17688, 17697, 17700, 17728, 17730, 17733, 17736, 17745, 17748, 17760, + 17770, 17793, 17796, 17808, 17920, 17922, 17925, 17928, 17937, 17940, 17952, 17985, 17988, 18000, 18048, 18085, + 18433, 18436, 18441, 18448, 18450, 18453, 18456, 18465, 18468, 18496, 18498, 18501, 18504, 18513, 18516, 18528, + 18564, 18576, 18688, 18690, 18693, 18696, 18705, 18708, 18720, 18753, 18756, 18768, 18816, 18838, 18945, 18948, + 18960, 19008, 20480, 20482, 20485, 20488, 20497, 20500, 20502, 20505, 20512, 20514, 20517, 20520, 20545, 20548, + 20550, 20553, 20560, 20562, 20565, 20568, 20577, 20580, 20608, 20610, 20613, 20616, 20625, 20628, 20737, 20740, + 20742, 20745, 20752, 20754, 20757, 20760, 20769, 20772, 20800, 20802, 20805, 20808, 20817, 20820, 20832, 20865, + 20868, 20880, 20992, 20997, 21000, 21009, 21012, 21024, 21057, 21060, 21072, 21097, 21120, 21505, 21508, 21510, + 21513, 21520, 21522, 21525, 21528, 21537, 21540, 21568, 21570, 21573, 21576, 21585, 21588, 21600, 21633, 21636, + 21648, 21760, 21762, 21765, 21768, 21777, 21780, 21792, 21825, 21828, 21840, 21888, 22017, 22020, 22032, 22054, + 22080, 22528, 22530, 22533, 22536, 22545, 22548, 22560, 22593, 22596, 22608, 22618, 22656, 22785, 22788, 22800, + 22848, 23040, 23065, 23173, 23208, 24577, 24580, 24582, 24592, 24594, 24597, 24600, 24609, 24612, 24640, 24645, + 24648, 24657, 24660, 24672, 24708, 24720, 24832, 24834, 24837, 24840, 24849, 24852, 24864, 24897, 24900, 24912, + 24960, 24985, 25092, 25104, 25152, 25174, 25249, 25600, 25605, 25608, 25617, 25620, 25632, 25665, 25668, 25680, + 25728, 25857, 25860, 25872, 25920, 25930, 25960, 26002, 26112, 26260, 26625, 26628, 26640, 26725, 26776, 26880, + 26922, 27202, 27297, 32768, 32770, 32773, 32776, 32785, 32788, 32793, 32800, 32805, 32833, 32836, 32848, 32850, + 32853, 32856, 32865, 32896, 32901, 32913, 32916, 33025, 33028, 33033, 33040, 33042, 33045, 33048, 33057, 33060, + 33088, 33090, 33093, 33096, 33105, 33108, 33153, 33156, 33168, 33193, 33280, 33285, 33290, 33297, 33300, 33345, + 33348, 33360, 33793, 33796, 33798, 33801, 33808, 33810, 33813, 33816, 33825, 33856, 33858, 33861, 33864, 33873, + 33876, 33888, 33921, 33924, 33936, 34048, 34050, 34053, 34056, 34065, 34068, 34080, 34113, 34116, 34128, 34176, + 34186, 34305, 34308, 34320, 34345, 34368, 34816, 34821, 34833, 34836, 34881, 34884, 34896, 34978, 35073, 35076, + 35136, 35173, 35362, 35416, 35418, 35458, 35490, 36865, 36868, 36873, 36880, 36882, 36885, 36888, 36900, 36928, + 36930, 36933, 36936, 36945, 36948, 36960, 36993, 36996, 37008, 37120, 37125, 37137, 37140, 37185, 37188, 37200, + 37210, 37377, 37380, 37392, 37440, 37542, 37888, 37890, 37893, 37896, 37905, 37908, 37920, 37953, 37956, 37968, + 38016, 38038, 38145, 38148, 38160, 38208, 38296, 38305, 38400, 38470, 38500, 38913, 38916, 38928, 38950, 38976, + 39081, 39168, 39241, 39250, 39568, 40960, 40965, 40970, 40980, 40994, 41002, 41025, 41028, 41040, 41122, 41130, + 41280, 41317, 41474, 41482, 41506, 41512, 41514, 41602, 41608, 41610, 41640, 41985, 41988, 42000, 42048, 42121, + 42148, 42240, 42265, 42577, 43018, 43048, 43170, 43348, 43398, 43528, 43530, 43552, 43554, 43560, 43656, 43690, + }; - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + const int kmap_size = 43692; + //const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2; + const int nwant = type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2; + const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 : + type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : + type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? kgrid_1bit_2048 : kgrid_2bit_1024; + uint64_t * kgrid_q2xs; + int * kmap_q2xs; + uint16_t * kneighbors_q2xs; + + //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size); + uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t)); + for (int k = 0; k < grid_size; ++k) { + int8_t * pos = (int8_t *)(the_grid + k); + for (int i = 0; i < 8; ++i) { + int l = (kgrid[k] >> 2*i) & 0x3; + pos[i] = 2*l + 1; + } + } + kgrid_q2xs = the_grid; + iq2_data[gindex].grid = the_grid; + kmap_q2xs = (int *)malloc(kmap_size*sizeof(int)); + iq2_data[gindex].map = kmap_q2xs; + for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1; + uint64_t aux64; + uint8_t * aux8 = (uint8_t *)&aux64; + for (int i = 0; i < grid_size; ++i) { + aux64 = kgrid_q2xs[i]; + uint16_t index = 0; + for (int k=0; k<8; ++k) { + uint16_t q = (aux8[k] - 1)/2; + index |= (q << 2*k); + } + kmap_q2xs[index] = i; + } + int8_t pos[8]; + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + int num_neighbors = 0, num_not_in_map = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + ++num_not_in_map; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + num_neighbors += n; + } + //printf("%s: %d neighbours in total\n", __func__, num_neighbors); + kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); + iq2_data[gindex].neighbours = kneighbors_q2xs; + int counter = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + kmap_q2xs[i] = -(counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q2xs[counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q2xs[counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); +} - __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - for (int j = 0; j < QK_K/128; ++j) { +void iq2xs_free_impl(enum ggml_type type) { + GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S); + const int gindex = iq2_data_index(type); + if (iq2_data[gindex].grid) { + free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL; + free(iq2_data[gindex].map); iq2_data[gindex].map = NULL; + free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL; + } +} - const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; +static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_d2 = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 8; ++i) { + float q = pg[i]; + float diff = scale*q - xval[i]; + d2 += weight[i]*diff*diff; + } + if (d2 < best_d2) { + best_d2 = d2; grid_index = neighbours[j]; + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); - const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); - const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); - const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); - const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); +static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { - const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS); - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); - const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); - const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); - const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); - const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); + const int kMaxQ = 3; - __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + const int64_t nbl = n/QK_K; - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); + block_iq2_xxs * y = vy; - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); + float scales[QK_K/32]; + float weight[32]; + float xval[32]; + int8_t L[32]; + int8_t Laux[32]; + float waux[32]; + uint8_t block_signs[4]; + uint32_t q2[2*(QK_K/32)]; - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); - p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); - p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); - p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); - p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); + for (int ibl = 0; ibl < nbl; ++ibl) { - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(q2, 0, QK_K/4); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = sumx2/QK_K; + + for (int ib = 0; ib < QK_K/32; ++ib) { + const float * xb = xbl + 32*ib; + const float * qw = quant_weights + QK_K*ibl + 32*ib; + for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 4; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } + float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight); + float eff_max = scale*kMaxQ; + float best = 0; + for (int is = -6; is <= 6; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/eff_max; + float this_scale = 1/id; + for (int k = 0; k < 4; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + memcpy(L, Laux, 32); + } + } + if (scale > 0) { + float id = 1/scale; + for (int k = 0; k < 4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index); + for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + q2[2*ib+0] |= (grid_index << 8*k); + q2[2*ib+1] |= (block_signs[k] << 7*k); + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + if (!max_scale) { + memset(y[ibl].qs, 0, QK_K/4); + continue; } - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d); + float id = 1/d; + for (int ib = 0; ib < QK_K/32; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + q2[2*ib+1] |= ((uint32_t)l << 28); + } + memcpy(y[ibl].qs, q2, QK_K/4); } +} - *s = hsum_float_8(acc); +static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { -#elif defined __riscv_v_intrinsic + const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS); - float sumf = 0; - for (int i = 0; i < nb; ++i) { + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + const int kMaxQ = 3; - const int8_t * restrict scale = x[i].scales; + const int64_t nbl = n/QK_K; - size_t vl; + block_iq2_xs * y = vy; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + float scales[QK_K/16]; + float weight[16]; + float xval[16]; + int8_t L[16]; + int8_t Laux[16]; + float waux[16]; + bool is_on_grid[2]; + bool is_on_grid_aux[2]; + uint8_t block_signs[2]; + uint16_t q2[2*(QK_K/16)]; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(q2, 0, QK_K/4); + memset(y[ibl].scales, 0, QK_K/32); - int sum_t = 0; - int is = 0; + float max_scale = 0; - for (int j = 0; j < QK_K/128; ++j) { + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = sumx2/QK_K; - vl = 32; + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 2; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + memset(L, 0, 16); + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + is_on_grid[0] = is_on_grid[1] = true; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/max; + float this_scale = 1/id; + for (int k = 0; k < 2; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 16; ++i) L[i] = Laux[i]; + for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 2; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + L[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + scale = -scale; + for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + q2[2*ib+k] = grid_index | (block_signs[k] << 9); + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + if (!max_scale) { + memset(y[ibl].qs, 0, QK_K/4); + continue; + } - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d); + float id = 1/d; + for (int ib = 0; ib < QK_K/16; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + if (ib%2 == 0) y[ibl].scales[ib/2] = l; + else y[ibl].scales[ib/2] |= (l << 4); + } + memcpy(y[ibl].qs, q2, QK_K/4); - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + } +} - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); +size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_xxs); + } + return nrow * nblock * sizeof(block_iq2_xxs); +} - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); +size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_xs); + } + return nrow * nblock * sizeof(block_iq2_xs); +} + +// +// ============================================= 3-bit using D4 lattice +// + +typedef struct { + uint32_t * grid; + int * map; + uint16_t * neighbours; +} iq3_entry_t; + +static iq3_entry_t iq3_data[2] = { + {NULL, NULL, NULL}, + {NULL, NULL, NULL}, +}; + +static inline int iq3_data_index(int grid_size) { + (void)grid_size; + GGML_ASSERT(grid_size == 256 || grid_size == 512); + return grid_size == 256 ? 0 : 1; +} + +static int iq3_compare_func(const void * left, const void * right) { + const int * l = (const int *)left; + const int * r = (const int *)right; + return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; +} - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); +void iq3xs_init_impl(int grid_size) { + const int gindex = iq3_data_index(grid_size); + if (iq3_data[gindex].grid) { + return; + } + static const uint16_t kgrid_256[256] = { + 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74, + 81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159, + 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321, + 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531, + 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664, + 698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978, + 992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105, + 1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228, + 1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553, + 1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722, + 1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063, + 2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389, + 2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746, + 2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153, + 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610, + 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992, + }; + static const uint16_t kgrid_512[512] = { + 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34, + 37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77, + 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142, + 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210, + 217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288, + 291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393, + 395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514, + 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576, + 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653, + 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727, + 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833, + 840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977, + 989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047, + 1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103, + 1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199, + 1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296, + 1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415, + 1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561, + 1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648, + 1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761, + 1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877, + 1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068, + 2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177, + 2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269, + 2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520, + 2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634, + 2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805, + 2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083, + 3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276, + 3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591, + 3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729, + 3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032, + }; - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + const int kmap_size = 4096; + const int nwant = grid_size == 256 ? 2 : 3; + const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; + uint32_t * kgrid_q3xs; + int * kmap_q3xs; + uint16_t * kneighbors_q3xs; + + //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size); + uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t)); + for (int k = 0; k < grid_size; ++k) { + int8_t * pos = (int8_t *)(the_grid + k); + for (int i = 0; i < 4; ++i) { + int l = (kgrid[k] >> 3*i) & 0x7; + pos[i] = 2*l + 1; + } + } + kgrid_q3xs = the_grid; + iq3_data[gindex].grid = the_grid; + kmap_q3xs = (int *)malloc(kmap_size*sizeof(int)); + iq3_data[gindex].map = kmap_q3xs; + for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1; + uint32_t aux32; + uint8_t * aux8 = (uint8_t *)&aux32; + for (int i = 0; i < grid_size; ++i) { + aux32 = kgrid_q3xs[i]; + uint16_t index = 0; + for (int k=0; k<4; ++k) { + uint16_t q = (aux8[k] - 1)/2; + index |= (q << 3*k); + } + kmap_q3xs[index] = i; + } + int8_t pos[4]; + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + int num_neighbors = 0, num_not_in_map = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + num_neighbors += n; + } + //printf("%s: %d neighbours in total\n", __func__, num_neighbors); + kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); + iq3_data[gindex].neighbours = kneighbors_q3xs; + int counter = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + kmap_q3xs[i] = -(counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); +} - vl = 16; +void iq3xs_free_impl(int grid_size) { + GGML_ASSERT(grid_size == 256 || grid_size == 512); + const int gindex = iq3_data_index(grid_size); + if (iq3_data[gindex].grid) { + free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL; + free(iq3_data[gindex].map); iq3_data[gindex].map = NULL; + free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL; + } +} - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); +static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid, + const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_d2 = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 4; ++i) { + float q = pg[i]; + float diff = scale*q - xval[i]; + d2 += weight[i]*diff*diff; + } + if (d2 < best_d2) { + best_d2 = d2; grid_index = neighbours[j]; + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); +static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int64_t n, + const float * restrict quant_weights) { - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + const int gindex = iq3_data_index(grid_size); - q6 += 64; qh += 32; q8 += 128; is=8; + const uint32_t * kgrid_q3xs = iq3_data[gindex].grid; + const int * kmap_q3xs = iq3_data[gindex].map; + const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours; - } + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - sumf += d * sum_t; + const int kMaxQ = 8; + + const int64_t nbl = n/QK_K; + ggml_fp16_t * dh; + uint8_t * qs; + int block_size; + if (grid_size == 256) { + block_iq3_xxs * y = vy; + dh = &y->d; + qs = y->qs; + block_size = sizeof(block_iq3_xxs); + } else { + block_iq3_s * y = vy; + dh = &y->d; + qs = y->qs; + block_size = sizeof(block_iq3_s); } + int quant_size = block_size - sizeof(ggml_fp16_t); - *s = sumf; + float scales[QK_K/32]; + float weight[32]; + float xval[32]; + int8_t L[32]; + int8_t Laux[32]; + float waux[32]; + bool is_on_grid[8]; + bool is_on_grid_aux[8]; + uint8_t block_signs[8]; + uint8_t q3[3*(QK_K/8)+QK_K/32]; + uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4); + uint8_t * qh = q3 + 3*(QK_K/8); + + for (int ibl = 0; ibl < nbl; ++ibl) { + + dh[0] = GGML_FP32_TO_FP16(0.f); + memset(q3, 0, 3*QK_K/8+QK_K/32); -#else + float max_scale = 0; - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + for (int ib = 0; ib < QK_K/32; ++ib) { + const float * xb = xbl + 32*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + 32*ib; + for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i]; } - a += 128; - q4 += 64; - qh += 32; + for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 4; ++k) { + int nflip = 0; + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i); + } + } + if (nflip%2) { + int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin]; + for (int i = 1; i < 8; ++i) { + float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i]; + if (ax < min) { + min = ax; imin = i; + } + } + xval[8*k+imin] = -xval[8*k+imin]; + s ^= (1 << imin); + } + block_signs[k] = s & 127; + } + float max = xval[0]; + for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + for (int is = -15; is <= 15; ++is) { + float id = (2*kMaxQ-1+is*0.2f)/max; + float this_scale = 1/id; + for (int k = 0; k < 8; ++k) { + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 32; ++i) L[i] = Laux[i]; + for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 8; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 3*i); + } + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index); + for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 32; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127; + } + for (int k = 0; k < 8; ++k) { + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + if (grid_size == 256) { + q3[8*ib+k] = grid_index; + } else { + q3[8*ib+k] = grid_index & 255; + qh[ib] |= ((grid_index >> 8) << k); + } + + } + scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21); + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); } - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; + + if (!max_scale) { + memset(qs, 0, quant_size); + dh += block_size/sizeof(ggml_fp16_t); + qs += block_size; + continue; } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + + float d = max_scale/31; + dh[0] = GGML_FP32_TO_FP16(d * 1.0125f); // small improvement via this fudge factor + float id = 1/d; + for (int ib = 0; ib < QK_K/32; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + scales_and_signs[ib] |= ((uint32_t)l << 28); + } + memcpy(qs, q3, quant_size); + + dh += block_size/sizeof(ggml_fp16_t); + qs += block_size; + } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif } -#else +size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq3_xxs); + } + return nrow * nblock * sizeof(block_iq3_xxs); +} -void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); +void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_iq3_xxs * restrict y = vy; + quantize_row_iq3_xxs_reference(x, y, k); +} - const block_q6_K * restrict x = vx; - const block_q8_K * restrict y = vy; +void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_row_iq3_xxs_impl(256, x, y, k, NULL); +} - const int nb = n / QK_K; +static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, void * restrict vy, int n, + const float * restrict quant_weights, + float * scales, + float * weight, + float * xval, + int8_t * L, + int8_t * Laux, + float * waux, + bool * is_on_grid, + bool * is_on_grid_aux, + uint8_t * block_signs) { -#ifdef __ARM_NEON + const int gindex = iq3_data_index(512); - float sum = 0; + const uint32_t * kgrid_q3xs = iq3_data[gindex].grid; + const int * kmap_q3xs = iq3_data[gindex].map; + const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours; - const uint8x16_t m4b = vdupq_n_u8(0xF); - const int8x16_t m32s = vdupq_n_s8(32); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - const uint8x16_t mone = vdupq_n_u8(3); + const int kMaxQ = 8; - int8x16x4_t q6bytes; - uint8x16x4_t q6h; + const int64_t nbl = n/QK_K; - for (int i = 0; i < nb; ++i) { + block_iq3_s * y = vy; - const float d_all = (float)x[i].d; + const int bs4 = block_size/4; + const int bs8 = block_size/8; - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + for (int ibl = 0; ibl < nbl; ++ibl) { - const int8_t * restrict scale = x[i].scales; + memset(&y[ibl], 0, sizeof(block_iq3_s)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); - int32_t isum = 0; + uint8_t * qs = y[ibl].qs; + uint8_t * qh = y[ibl].qh; + uint8_t * signs = y[ibl].signs; - uint8x16_t qhbits = vld1q_u8(qh); - uint8x16x2_t q6bits = vld1q_u8_x2(q6); - int8x16x4_t q8bytes = vld1q_s8_x4(q8); + float max_scale = 0; - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits, 2); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits, 4); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits, 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; - q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); - q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < bs8; ++k) { + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; s |= (1 << i); + } + } + block_signs[k] = s; + } + float max = xval[0]; + for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + for (int k = 0; k < bs4; ++k) is_on_grid[k] = false; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.2f)/max; + float this_scale = 1/id; + for (int k = 0; k < bs4; ++k) { + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < block_size; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < block_size; ++i) L[i] = Laux[i]; + for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < bs4; ++k) { + //if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 3*i); + } + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index); + for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < block_size; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k]; + } + for (int k = 0; k < bs4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + qs[k] = grid_index & 255; + qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8)); + } + qs += bs4; + for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k]; + signs += bs8; + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } -#if defined(__ARM_FEATURE_DOTPROD) + if (!max_scale) { + continue; + } - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; -#else + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f); + float id = 1/d; + for (int ib = 0; ib < QK_K/block_size; ib += 2) { + int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); + l1 = MAX(0, MIN(15, l1)); + int l2 = nearest_int(0.5f*(id*scales[ib+1]-1)); + l2 = MAX(0, MIN(15, l2)); + y[ibl].scales[ib/2] = l1 | (l2 << 4); + } - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif + } +} - sum += isum * d_all * y[i].d; +#define IQ3S_BLOCK_SIZE 32 +size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + float scales[QK_K/IQ3S_BLOCK_SIZE]; + float weight[IQ3S_BLOCK_SIZE]; + float xval[IQ3S_BLOCK_SIZE]; + int8_t L[IQ3S_BLOCK_SIZE]; + int8_t Laux[IQ3S_BLOCK_SIZE]; + float waux[IQ3S_BLOCK_SIZE]; + bool is_on_grid[IQ3S_BLOCK_SIZE/4]; + bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4]; + uint8_t block_signs[IQ3S_BLOCK_SIZE/8]; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights, + scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs); + src += n_per_row; + qrow += nblock*sizeof(block_iq3_s); + } + return nrow * nblock * sizeof(block_iq3_s); +} - } - *s = sum; +void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_iq3_s * restrict y = vy; + quantize_row_iq3_s_reference(x, y, k); +} -#elif defined __AVX2__ +void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq3_s(x, y, 1, k, NULL); +} - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); - __m256 acc = _mm256_setzero_ps(); +// =================================== 1.5 bpw =================================================== + +static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_score = 0; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + *scale = sumqx/sumq2; best_score = *scale * sumqx; + grid_index = neighbours[j]; + } + } + if (grid_index < 0) { + for (int i = 0; i < ngrid; ++i) { + const int8_t * grid_i = (const int8_t *)(grid + i); + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < 8; ++j) { + float w = weight[j]; + float q = (grid_i[j] - 3)/2; + sumqx += w*q*xval[j]; + sumq2 += w*q*q; + } + if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + *scale = sumqx/sumq2; best_score = *scale*sumqx; + grid_index = i; + } + } + } + if (grid_index < 0) { + printf("Oops, did not find grid point\n"); + printf("Have %d neighbours\n", num_neighbors); + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2); + } + } + GGML_ASSERT(grid_index >= 0); + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + *scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result. + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} - for (int i = 0; i < nb; ++i) { +static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float scale, const float * restrict xg, int8_t * restrict L, int ngrid) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_score = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 8; ++i) { + float q = xg[(pg[i] - 1)/2]; + float w = weight[i]; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = neighbours[j]; + } + } + if (grid_index < 0) { + for (int i = 0; i < ngrid; ++i) { + const int8_t * grid_i = (const int8_t *)(grid + i); + float d2 = 0; + for (int j = 0; j < 8; ++j) { + float w = weight[j]; + float q = xg[(grid_i[j] - 1)/2]; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = i; + } + } + } + if (grid_index < 0) { + printf("Oops, did not find grid point\n"); + printf("Have %d neighbours\n", num_neighbors); + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = xg[(pg[i] - 1)/2]; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2); + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); +static int iq1_sort_helper(const void * left, const void * right) { + const float * l = left; + const float * r = right; + return *l < *r ? -1 : *l > *r ? 1 : 0; +} - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +#define IQ1S_BLOCK_SIZE 32 +#define IQ1M_BLOCK_SIZE 16 +static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, + float * scales, + float * weight, + float * sumx, + float * sumw, + float * pairs, + int8_t * L, + uint16_t * index, + int8_t * shifts) { - const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); - const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); - const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); - const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + const int gindex = iq2_data_index(GGML_TYPE_IQ1_S); - __m256i sumi = _mm256_setzero_si256(); + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); - const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + block_iq1_s * y = vy; - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); + const int64_t nbl = n/QK_K; - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); + const int block_size = IQ1S_BLOCK_SIZE; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA}; + const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA}; - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + int * idx = (int *)(pairs + 1); - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + for (int ibl = 0; ibl < nbl; ++ibl) { - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + memset(y[ibl].qs, 0, QK_K/8); + memset(y[ibl].qh, 0, QK_K/16); + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + float max = fabsf(xb[0]); + for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); + if (!max) { + scales[ib] = 0; + memset(L, 1, block_size); + continue; + } + // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. + // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two + // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights + // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and + // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale + // for each possible and score for each split. + for (int j = 0; j < block_size; ++j) { + pairs[2*j] = xb[j]; + idx[2*j] = j; + } + qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); + { + sumx[0] = sumw[0] = 0; + for (int j = 0; j < block_size; ++j) { + int i = idx[2*j]; + sumx[j+1] = sumx[j] + weight[i]*xb[i]; + sumw[j+1] = sumw[j] + weight[i]; + } + } + float best_score = 0, scale = max; + int besti1 = -1, besti2 = -1, best_shift = 0; + for (int i1 = 0; i1 <= block_size; ++i1) { + for (int i2 = i1; i2 <= block_size; ++i2) { + float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2]; + float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2]; + if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + scale = sumqx/sumq2; best_score = scale*sumqx; + besti1 = i1; besti2 = i2; best_shift = 1; + } + sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2]; + sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2]; + if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { + scale = sumqx/sumq2; best_score = scale*sumqx; + besti1 = i1; besti2 = i2; best_shift = -1; + } + } + } + GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); + for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; + for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; + for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; + if (scale < 0) { + for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; + scale = -scale; best_shift = -best_shift; + } + bool all_on_grid = true; + const float * xx = best_shift == 1 ? x_p : x_m; + for (int k = 0; k < block_size/8; ++k) { + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + index[k] = grid_index; + } + if (!all_on_grid) { + float sumqx = 0, sumq2 = 0; + for (int k = 0; k < block_size/8; ++k) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]; + sumqx += w*q*xb[8*k+j]; + sumq2 += w*q*q; + } + } + if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2; + } + uint16_t h = 0; + for (int k = 0; k < block_size/8; ++k) { + y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255; + h |= (index[k] >> 8) << 3*k; + } + y[ibl].qh[ib] = h; + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + shifts[ib] = best_shift; + max_scale = MAX(max_scale, scale); + } - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + if (!max_scale) { + continue; + } - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + float d = max_scale/15; + y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed. + float id = 1/d; + for (int ib = 0; ib < QK_K/block_size; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(7, l)); + if (shifts[ib] == -1) l |= 8; + y[ibl].qh[ib] |= (l << 12); + } } +} - *s = hsum_float_8(acc); - -#elif defined __AVX__ +size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + float scales[QK_K/IQ1S_BLOCK_SIZE]; + float weight[IQ1S_BLOCK_SIZE]; + int8_t L[IQ1S_BLOCK_SIZE]; + float sumx[IQ1S_BLOCK_SIZE+1]; + float sumw[IQ1S_BLOCK_SIZE+1]; + float pairs[2*IQ1S_BLOCK_SIZE]; + uint16_t index[IQ1S_BLOCK_SIZE/8]; + int8_t shifts[QK_K/IQ1S_BLOCK_SIZE]; + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts); + src += n_per_row; + qrow += nblock*sizeof(block_iq1_s); + } + return nrow * nblock * sizeof(block_iq1_s); +} - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(3); - const __m128i m32s = _mm_set1_epi8(32); +static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, + float * scales, + float * weight, + float * pairs, + int8_t * L, + uint16_t * index, + int8_t * shifts) { - __m256 acc = _mm256_setzero_ps(); + const int gindex = iq2_data_index(GGML_TYPE_IQ1_M); - for (int i = 0; i < nb; ++i) { + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + block_iq1_m * y = vy; - const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); - const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); - const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); - const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + const int64_t nbl = n/QK_K; - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); + const int block_size = IQ1M_BLOCK_SIZE; - const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); - const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA}; + const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA}; + const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88}; - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + int * idx = (int *)(pairs + 1); - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); + float sumqx[4], sumq2[4]; - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); + iq1m_scale_t s; + const float * xx; - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + for (int ibl = 0; ibl < nbl; ++ibl) { - __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); - __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); - __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); - __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); +#if QK_K == 64 + y[ibl].d = GGML_FP32_TO_FP16(0.f); +#endif + memset(y[ibl].qs, 0, QK_K/8); + memset(y[ibl].qh, 0, QK_K/16); + memset(y[ibl].scales, 0, QK_K/32); - __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + float max_scale = 0; - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + float max = fabsf(xb[0]); + for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); + if (!max) { + scales[ib] = 0; + memset(L, 1, block_size); + continue; + } + // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. + // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two + // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights + // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and + // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale + // for each possible and score for each split. + for (int j = 0; j < block_size; ++j) { + pairs[2*j] = xb[j]; + idx[2*j] = j; + } + qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); + float best_score = 0, scale = max; + int besti1 = -1, besti2 = -1, best_k = -1; + // 0: +, + + // 1: +, - + // 2: -, + + // 3: -, - + for (int i1 = 0; i1 <= block_size; ++i1) { + for (int i2 = i1; i2 <= block_size; ++i2) { + memset(sumqx, 0, 4*sizeof(float)); + memset(sumq2, 0, 4*sizeof(float)); + for (int j = 0; j < i1; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[0]*xb[i]; + sumqx[1] += weight[i]*x_p[0]*xb[i]; + sumqx[2] += weight[i]*x_m[0]*xb[i]; + sumqx[3] += weight[i]*x_m[0]*xb[i]; + sumq2[0] += weight[i]*x_p[0]*x_p[0]; + sumq2[1] += weight[i]*x_p[0]*x_p[0]; + sumq2[2] += weight[i]*x_m[0]*x_m[0]; + sumq2[3] += weight[i]*x_m[0]*x_m[0]; + } else { + sumqx[0] += weight[i]*x_p[0]*xb[i]; + sumqx[2] += weight[i]*x_p[0]*xb[i]; + sumqx[1] += weight[i]*x_m[0]*xb[i]; + sumqx[3] += weight[i]*x_m[0]*xb[i]; + sumq2[0] += weight[i]*x_p[0]*x_p[0]; + sumq2[2] += weight[i]*x_p[0]*x_p[0]; + sumq2[1] += weight[i]*x_m[0]*x_m[0]; + sumq2[3] += weight[i]*x_m[0]*x_m[0]; + } + } + for (int j = i1; j < i2; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[1]*xb[i]; + sumqx[1] += weight[i]*x_p[1]*xb[i]; + sumqx[2] += weight[i]*x_m[1]*xb[i]; + sumqx[3] += weight[i]*x_m[1]*xb[i]; + sumq2[0] += weight[i]*x_p[1]*x_p[1]; + sumq2[1] += weight[i]*x_p[1]*x_p[1]; + sumq2[2] += weight[i]*x_m[1]*x_m[1]; + sumq2[3] += weight[i]*x_m[1]*x_m[1]; + } else { + sumqx[0] += weight[i]*x_p[1]*xb[i]; + sumqx[2] += weight[i]*x_p[1]*xb[i]; + sumqx[1] += weight[i]*x_m[1]*xb[i]; + sumqx[3] += weight[i]*x_m[1]*xb[i]; + sumq2[0] += weight[i]*x_p[1]*x_p[1]; + sumq2[2] += weight[i]*x_p[1]*x_p[1]; + sumq2[1] += weight[i]*x_m[1]*x_m[1]; + sumq2[3] += weight[i]*x_m[1]*x_m[1]; + } + } + for (int j = i2; j < block_size; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumqx[0] += weight[i]*x_p[2]*xb[i]; + sumqx[1] += weight[i]*x_p[2]*xb[i]; + sumqx[2] += weight[i]*x_m[2]*xb[i]; + sumqx[3] += weight[i]*x_m[2]*xb[i]; + sumq2[0] += weight[i]*x_p[2]*x_p[2]; + sumq2[1] += weight[i]*x_p[2]*x_p[2]; + sumq2[2] += weight[i]*x_m[2]*x_m[2]; + sumq2[3] += weight[i]*x_m[2]*x_m[2]; + } else { + sumqx[0] += weight[i]*x_p[2]*xb[i]; + sumqx[2] += weight[i]*x_p[2]*xb[i]; + sumqx[1] += weight[i]*x_m[2]*xb[i]; + sumqx[3] += weight[i]*x_m[2]*xb[i]; + sumq2[0] += weight[i]*x_p[2]*x_p[2]; + sumq2[2] += weight[i]*x_p[2]*x_p[2]; + sumq2[1] += weight[i]*x_m[2]*x_m[2]; + sumq2[3] += weight[i]*x_m[2]*x_m[2]; + } + } + for (int k = 0; k < 4; ++k) { + if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) { + scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; + besti1 = i1; besti2 = i2; best_k = k; + } + } + } + } + GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0); + for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; + for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; + for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; + if (scale < 0) { + for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j]; + scale = -scale; + best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0; + } + bool all_on_grid = true; + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = best_k < 2 ? x_p : x_m; + else xx = best_k%2 == 0 ? x_p : x_m; + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + index[k] = grid_index; + } + if (!all_on_grid) { + float sumqx_f = 0, sumq2_f = 0; + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = best_k < 2 ? x_p : x_m; + else xx = best_k%2 == 0 ? x_p : x_m; + const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]; + sumqx_f += w*q*xb[8*k+j]; + sumq2_f += w*q*q; + } + } + if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f; + } + y[ibl].qs[2*ib + 0] = index[0] & 255; + y[ibl].qs[2*ib + 1] = index[1] & 255; + y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4); + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + shifts[ib] = best_k; + max_scale = MAX(max_scale, scale); + } - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + if (!max_scale) { + continue; + } - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc); + uint16_t * sc = (uint16_t *)y[ibl].scales; +#if QK_K == 64 + float d = max_scale/31; +#else + float d = max_scale/15; +#endif + float id = 1/d; + float sumqx_f = 0, sumq2_f = 0; + for (int ib = 0; ib < QK_K/block_size; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib+0]-1)); +#if QK_K == 64 + l = MAX(0, MIN(15, l)); + sc[ib/4] |= (l << 4*(ib%4)); +#else + l = MAX(0, MIN(7, l)); + sc[ib/4] |= (l << 3*(ib%4)); +#endif + y[ibl].qh[ib] |= masks[shifts[ib]]; + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + for (int k = 0; k < block_size/8; ++k) { + if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m; + else xx = shifts[ib]%2 == 0 ? x_p : x_m; + const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700)); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = xx[(pg[j] - 1)/2]*(2*l+1); + sumqx_f += w*q*xb[8*k+j]; + sumq2_f += w*q*q; + } + } + } + if (sumq2_f > 0) d = sumqx_f/sumq2_f; + s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed. +#if QK_K == 64 + y[ibl].d = s.f16; +#else + sc[0] |= ((s.u16 & 0x000f) << 12); + sc[1] |= ((s.u16 & 0x00f0) << 8); + sc[2] |= ((s.u16 & 0x0f00) << 4); + sc[3] |= ((s.u16 & 0xf000) << 0); +#endif } +} - *s = hsum_float_8(acc); +size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + float scales[QK_K/IQ1M_BLOCK_SIZE]; + float weight[IQ1M_BLOCK_SIZE]; + int8_t L[IQ1M_BLOCK_SIZE]; + float pairs[2*IQ1M_BLOCK_SIZE]; + uint16_t index[IQ1M_BLOCK_SIZE/8]; + int8_t shifts[QK_K/IQ1M_BLOCK_SIZE]; + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts); + src += n_per_row; + qrow += nblock*sizeof(block_iq1_m); + } + return nrow * nblock * sizeof(block_iq1_m); +} -#elif defined __riscv_v_intrinsic +// ============================ 4-bit non-linear quants - float sumf = 0; +static inline int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} - for (int i = 0; i < nb; ++i) { +static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * restrict x, + ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l, + float * scales, float * weight, uint8_t * L, + const int8_t * values, + const float * quant_weights, + const int ntry) { + + float sigma2 = 0; + for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j]; + sigma2 *= 2.f/super_block_size; + + memset(q4, 0, super_block_size/2); + dh[0] = GGML_FP32_TO_FP16(0.f); + + float max_scale = 0, amax_scale = 0; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + const float * xb = x + ib*block_size; + uint8_t * Lb = L + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + scales[ib] = 0; + continue; + } + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_int8(16, values, al); + Lb[j] = l; + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + d = sumqx/sumq2; + float best = d*sumqx; + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx = sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_int8(16, values, al); + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d * sumqx; + } + } + scales[ib] = d; + float abs_d = fabsf(d); + if (abs_d > amax_scale) { + amax_scale = abs_d; max_scale = d; + } + } + + if (super_block_size/block_size > 1) { + int nb = super_block_size/block_size; + memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t)); + float d = -max_scale/32; + dh[0] = GGML_FP32_TO_FP16(d); + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < super_block_size/block_size; ++ib) { + int l = nearest_int(id*scales[ib]); + l = MAX(-32, MIN(31, l)); + float dl = d * l; + float idl = dl ? 1/dl : 0.f; + uint8_t * Lb = L + ib*block_size; + const float * xb = x + ib*block_size; + for (int j = 0; j < block_size; ++j) { + Lb[j] = best_index_int8(16, values, idl*xb[j]); + } + l += 32; + uint8_t l_l = l & 0xf; + uint8_t l_h = l >> 4; + if (ib%2 == 0) scales_l[ib/2] = l_l; + else scales_l[ib/2] |= (l_l << 4); + scales_h[ib/8] |= (l_h << 2*(ib%8)); + } + } else { + dh[0] = GGML_FP32_TO_FP16(scales[0]); + if (ntry > 0) { + float id = scales[0] ? 1/scales[0] : 0; + for (int j = 0; j < super_block_size; ++j) { + L[j] = best_index_int8(16, values, id*x[j]); + } + } + } - const float d_all = (float)x[i].d; + for (int i = 0; i < super_block_size/32; ++i) { + for (int j = 0; j < 16; ++j) { + q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4); + } + } +} - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; +size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK4_NL == 0); + int64_t nblock = n_per_row/QK4_NL; + char * qrow = (char *)dst; + uint8_t L[QK4_NL]; + float weight[QK4_NL]; + uint16_t unused_h; + uint8_t * unused_l = NULL; + float scale; + for (int64_t row = 0; row < nrow; ++row) { + block_iq4_nl * iq4 = (block_iq4_nl *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL; + quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, + &scale, weight, L, kvalues_iq4nl, qw, 7); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_nl); + } + return nrow * nblock * sizeof(block_iq4_nl); +} - const int8_t * restrict scale = x[i].scales; +void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k) { + GGML_ASSERT(k%QK4_NL == 0); + int64_t nblock = k/QK4_NL; + uint8_t L[QK4_NL]; + float weight[QK4_NL]; + uint16_t unused_h; + uint8_t * unused_l = NULL; + float scale; + block_iq4_nl * iq4 = (block_iq4_nl *)vy; + for (int ibl = 0; ibl < nblock; ++ibl) { + quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, + &scale, weight, L, kvalues_iq4nl, NULL, -1); + } +} - int32_t isum = 0; +void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { + assert(k % QK4_NL == 0); + quantize_row_iq4_nl(x, y, k); +} - size_t vl = 16; +size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { +#if QK_K == 64 + return quantize_iq4_nl(src, dst, nrow, n_per_row, quant_weights); +#else + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + uint8_t L[QK_K]; + float weight[32]; + float scales[QK_K/32]; + for (int64_t row = 0; row < nrow; ++row) { + block_iq4_xs * iq4 = (block_iq4_xs *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL; + quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l, + scales, weight, L, kvalues_iq4nl, qw, 7); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_xs); + } + return nrow * nblock * sizeof(block_iq4_xs); +#endif +} - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); +void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_iq4_xs * restrict y = vy; + quantize_row_iq4_xs_reference(x, y, k); +} - // load Q6 - vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); - vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl); +void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_xs(x, y, 1, k, NULL); +} - // load qh - vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); +// =============================== 2.5625 bpw - vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); +static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { - vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl); - vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl); - vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl); - vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl); + const int gindex = iq2_data_index(GGML_TYPE_IQ2_S); - vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl); - vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl); - vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl); - vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl); + const uint64_t * kgrid_q2xs = iq2_data[gindex].grid; + const int * kmap_q2xs = iq2_data[gindex].map; + const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - // load Q8 and take product - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + const int kMaxQ = 3; - isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3]; + const int64_t nbl = n/QK_K; - sumf += isum * d_all * y[i].d; + block_iq2_s * y = vy; - } + float scales[QK_K/16]; + float weight[16]; + float xval[16]; + int8_t L[16]; + int8_t Laux[16]; + float waux[16]; + bool is_on_grid[2]; + bool is_on_grid_aux[2]; + uint8_t block_signs[2]; - *s = sumf; + for (int ibl = 0; ibl < nbl; ++ibl) { -#else + memset(&y[ibl], 0, sizeof(block_iq2_s)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); + float max_scale = 0; - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int l = 0; l < 16; ++l) { - a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/16; ++ib) { + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; + } + for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < 2; ++k) { + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; s |= (1 << i); + } + } + block_signs[k] = s; + } + float max = xval[0]; + for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + is_on_grid[0] = is_on_grid[1] = true; + for (int is = -9; is <= 9; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/max; + float this_scale = 1/id; + for (int k = 0; k < 2; ++k) { + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < 16; ++i) L[i] = Laux[i]; + for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < 2; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + L[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + scale = -scale; + for (int k = 0; k < 2; ++k) block_signs[k] = ~block_signs[k]; + } + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + const int i8 = 2*ib + k; + y[ibl].qs[i8] = grid_index & 255; + y[ibl].qh[i8/4] |= ((grid_index >> 8) << 2*(i8%4)); + y[ibl].qs[QK_K/8 + i8] = block_signs[k]; + } + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); } - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; + + if (!max_scale) { + continue; + } + + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d * 0.9875f); + float id = 1/d; + for (int ib = 0; ib < QK_K/16; ++ib) { + int l = nearest_int(0.5f*(id*scales[ib]-1)); + l = MAX(0, MIN(15, l)); + if (ib%2 == 0) y[ibl].scales[ib/2] = l; + else y[ibl].scales[ib/2] |= (l << 4); } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif } -#endif +size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_ASSERT(n_per_row%QK_K == 0); + int64_t nblock = n_per_row/QK_K; + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += nblock*sizeof(block_iq2_s); + } + return nrow * nblock * sizeof(block_iq2_s); +} + +void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq2_s(x, y, 1, k, NULL); +} + +void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_iq2_s * restrict y = vy; + quantize_row_iq2_s_reference(x, y, k); +} diff --git a/bindings/ruby/ext/ggml-quants.h b/bindings/ruby/ext/ggml-quants.h index 70c12c27465..4d436a8f06b 100644 --- a/bindings/ruby/ext/ggml-quants.h +++ b/bindings/ruby/ext/ggml-quants.h @@ -1,224 +1,133 @@ #pragma once -#include "ggml-impl.h" +#define GGML_COMMON_DECL_C +#include "ggml-common.h" -// GGML internal header - -#include -#include - -#define QK4_0 32 -typedef struct { - ggml_fp16_t d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); - -#define QK4_1 32 -typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; -static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); - -#define QK5_0 32 -typedef struct { - ggml_fp16_t d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; -static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); - -#define QK5_1 32 -typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; -static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); - -#define QK8_0 32 -typedef struct { - ggml_fp16_t d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); - -#define QK8_1 32 -typedef struct { - float d; // delta - float s; // d * sum(qs[i]) - int8_t qs[QK8_1]; // quants -} block_q8_1; -static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); - -// -// Super-block quantization structures -// - -// Super-block size -#ifdef GGML_QKK_64 -#define QK_K 64 -#define K_SCALE_SIZE 4 -#else -#define QK_K 256 -#define K_SCALE_SIZE 12 -#endif +#include "ggml.h" -// 2-bit quantization -// weight is represented as x = a * q + b -// 16 blocks of 16 elements each -// Effectively 2.5625 bits per weight -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins -} block_q2_K; -static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); - -// 3-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 3.4375 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[2]; - ggml_fp16_t d; // super-block scale -} block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); -#else -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[12]; // scales, quantized with 6 bits - ggml_fp16_t d; // super-block scale -} block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); -#endif - -// 4-bit quantization -// 8 blocks of 32 elements each -// weight is represented as x = a * q + b -// Effectively 4.5 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - ggml_fp16_t d[2]; // super-block scales/mins - uint8_t scales[2]; // 4-bit block scales/mins - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); -#else -typedef struct { - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); -#endif +// GGML internal header -// 5-bit quantization -// 8 blocks of 32 elements each -// weight is represented as x = a * q + b -// Effectively 5.5 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - ggml_fp16_t d; // super-block scale - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); -#else -typedef struct { - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#ifdef __cplusplus +extern "C" { #endif -// 6-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 6.5625 bits per weight -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - ggml_fp16_t d; // super-block scale -} block_q6_K; -static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); - -// This is only used for intermediate quantization and dot products -typedef struct { - float d; // delta - int8_t qs[QK_K]; // quants - int16_t bsums[QK_K/16]; // sum of quants in groups of 16 -} block_q8_K; -static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); - - // Quantization -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k); -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k); -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k); -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k); -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k); -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k); - -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k); -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k); -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k); -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k); -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); - -void quantize_row_q4_0(const float * restrict x, void * restrict y, int k); -void quantize_row_q4_1(const float * restrict x, void * restrict y, int k); -void quantize_row_q5_0(const float * restrict x, void * restrict y, int k); -void quantize_row_q5_1(const float * restrict x, void * restrict y, int k); -void quantize_row_q8_0(const float * restrict x, void * restrict y, int k); -void quantize_row_q8_1(const float * restrict x, void * restrict y, int k); - -void quantize_row_q2_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q3_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q4_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); + +void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); + +void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); + +void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization -void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k); -void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k); -void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k); -void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k); -void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k); -//void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k); - -void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k); -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k); -void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k); -void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); -void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); -void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); +void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + +void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product -void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); - -void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") +size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +void iq2xs_init_impl(enum ggml_type type); +void iq2xs_free_impl(enum ggml_type type); +void iq3xs_init_impl(int grid_size); +void iq3xs_free_impl(int grid_size); + +#ifdef __cplusplus +} +#endif + diff --git a/bindings/ruby/ext/ggml-sycl.h b/bindings/ruby/ext/ggml-sycl.h new file mode 100644 index 00000000000..a9f776fc1dd --- /dev/null +++ b/bindings/ruby/ext/ggml-sycl.h @@ -0,0 +1,49 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_SYCL_MAX_DEVICES 48 +#define GGML_SYCL_NAME "SYCL" + +// backend API +GGML_API ggml_backend_t ggml_backend_sycl_init(int device); + +// devide buffer +GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device); + +// split tensor buffer that splits matrices by rows across multiple devices +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split); + +// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU +GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); + +GGML_API void ggml_backend_sycl_print_sycl_devices(void); +GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len); +GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description, size_t description_size); +GGML_API GGML_CALL int ggml_backend_sycl_get_device_count(); +GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total); +GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id); + +// TODO: these are temporary +// ref: https://github.com/ggerganov/llama.cpp/pull/6022#issuecomment-1992615670 +GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index); +GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id); +GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode(); + +// SYCL doesn't support registering host memory, keep here for reference +// GGML_API GGML_CALL bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size); +// GGML_API GGML_CALL void ggml_backend_sycl_unregister_host_buffer(void * buffer); +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ggml-vulkan.h b/bindings/ruby/ext/ggml-vulkan.h new file mode 100644 index 00000000000..af661c2d7d5 --- /dev/null +++ b/bindings/ruby/ext/ggml-vulkan.h @@ -0,0 +1,29 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_VK_NAME "Vulkan" +#define GGML_VK_MAX_DEVICES 16 + +GGML_API void ggml_vk_instance_init(void); + +// backend API +GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num); + +GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend); +GGML_API GGML_CALL int ggml_backend_vk_get_device_count(void); +GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size); +GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total); + +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num); +// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void); + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 9066ee87f21..9d9334539b8 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -305,11 +305,11 @@ static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) { static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, token_timestamps, value) } -static VALUE ruby_whisper_params_get_speed_up(VALUE self) { - BOOL_PARAMS_GETTER(self, speed_up) +static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { + BOOL_PARAMS_GETTER(self, split_on_word) } -static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, speed_up, value) +static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, split_on_word, value) } static VALUE ruby_whisper_params_get_diarize(VALUE self) { ruby_whisper_params *rwp; @@ -400,8 +400,8 @@ void Init_whisper() { rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1); rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0); rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); - rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0); - rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1); + rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); + rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1); rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0); rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1); diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index edb65bdbaa0..3700671bce6 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -110,11 +110,11 @@ def test_token_timestamps assert !@params.token_timestamps end - def test_speed_up - @params.speed_up = true - assert @params.speed_up - @params.speed_up = false - assert !@params.speed_up + def test_split_on_word + @params.split_on_word = true + assert @params.split_on_word + @params.split_on_word = false + assert !@params.split_on_word end def test_whisper diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec new file mode 100644 index 00000000000..508a6a94052 --- /dev/null +++ b/bindings/ruby/whispercpp.gemspec @@ -0,0 +1,28 @@ +Gem::Specification.new do |s| + s.name = "whispercpp" + s.authors = ["Georgi Gerganov", "Todd A. Fisher"] + s.version = '1.3.0' + s.date = '2024-05-14' + s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} + s.email = 'todd.fisher@gmail.com' + s.extra_rdoc_files = ['LICENSE', 'README.md'] + + s.files = ["LICENSE", "README.md", "Rakefile", "ext/extconf.rb", "ext/ggml.c", "ext/ruby_whisper.cpp", "ext/whisper.cpp", "ext/dr_wav.h", "ext/ggml.h", "ext/ruby_whisper.h", "ext/whisper.h"] + + #### Load-time details + s.require_paths = ['lib','ext'] + s.summary = %q{Ruby whisper.cpp bindings} + s.test_files = ["tests/test_whisper.rb"] + + s.extensions << 'ext/extconf.rb' + + + #### Documentation and testing. + s.homepage = 'https://github.com/ggerganov/whisper.cpp' + s.rdoc_options = ['--main', '../../README.md'] + + + s.platform = Gem::Platform::RUBY + + s.licenses = ['MIT'] +end diff --git a/cmake/FindFFmpeg.cmake b/cmake/FindFFmpeg.cmake new file mode 100644 index 00000000000..19dc751605e --- /dev/null +++ b/cmake/FindFFmpeg.cmake @@ -0,0 +1,163 @@ +# From +# https://github.com/snikulov/cmake-modules/blob/master/FindFFmpeg.cmake +# +# vim: ts=2 sw=2 +# - Try to find the required ffmpeg components(default: AVFORMAT, AVUTIL, AVCODEC) +# +# Once done this will define +# FFMPEG_FOUND - System has the all required components. +# FFMPEG_INCLUDE_DIRS - Include directory necessary for using the required components headers. +# FFMPEG_LIBRARIES - Link these to use the required ffmpeg components. +# FFMPEG_DEFINITIONS - Compiler switches required for using the required ffmpeg components. +# +# For each of the components it will additionally set. +# - AVCODEC +# - AVDEVICE +# - AVFORMAT +# - AVFILTER +# - AVUTIL +# - POSTPROC +# - SWSCALE +# the following variables will be defined +# _FOUND - System has +# _INCLUDE_DIRS - Include directory necessary for using the headers +# _LIBRARIES - Link these to use +# _DEFINITIONS - Compiler switches required for using +# _VERSION - The components version +# +# Copyright (c) 2006, Matthias Kretz, +# Copyright (c) 2008, Alexander Neundorf, +# Copyright (c) 2011, Michael Jansen, +# +# Redistribution and use is allowed according to the terms of the BSD license. +# For details see the accompanying COPYING-CMAKE-SCRIPTS file. + +include(FindPackageHandleStandardArgs) + +# The default components were taken from a survey over other FindFFMPEG.cmake files +if (NOT FFmpeg_FIND_COMPONENTS) + set(FFmpeg_FIND_COMPONENTS AVFORMAT AVCODEC AVUTIL SWRESAMPLE) +endif() + +# +### Macro: set_component_found +# +# Marks the given component as found if both *_LIBRARIES AND *_INCLUDE_DIRS is present. +# +macro(set_component_found _component ) + if (${_component}_LIBRARIES AND ${_component}_INCLUDE_DIRS) + message(DEBUG " - ${_component} found.") + set(${_component}_FOUND TRUE) + else () + message(DEBUG " - ${_component} not found.") + endif () +endmacro() + +# +### Macro: find_component +# +# Checks for the given component by invoking pkgconfig and then looking up the libraries and +# include directories. +# +macro(find_component _component _pkgconfig _library _header) + + if (NOT WIN32) + # use pkg-config to get the directories and then use these values + # in the FIND_PATH() and FIND_LIBRARY() calls + find_package(PkgConfig) + if (PKG_CONFIG_FOUND) + pkg_check_modules(PC_${_component} ${_pkgconfig}) + message(STATUS "Pkgconfig found: ${PC_${_component}_INCLUDEDIR}") + message(STATUS "Pkgconfig found: ${PC_${_component}_INCLUDE_DIRS}") + message(STATUS "${PC_${_component}_CFLAGS}") + endif () + endif (NOT WIN32) + + + find_path(${_component}_INCLUDE_DIRS ${_header} + HINTS + ${PC_${_component}_INCLUDEDIR} + ${PC_${_component}_INCLUDE_DIRS} + PATH_SUFFIXES + ffmpeg + ) + + # CMake's default is to search first for shared libraries and then for static libraries. + # Todo later: add option to prefer static libs over dynamic: + find_library(${_component}_LIBRARIES NAMES ${_library} lib${_library}.a + HINTS + ${PC_${_component}_LIBDIR} + ${PC_${_component}_LIBRARY_DIRS} + ) + + set(${_component}_DEFINITIONS ${PC_${_component}_CFLAGS_OTHER} CACHE STRING "The ${_component} CFLAGS.") + set(${_component}_VERSION ${PC_${_component}_VERSION} CACHE STRING "The ${_component} version number.") + + set_component_found(${_component}) + + mark_as_advanced( + ${_component}_INCLUDE_DIRS + ${_component}_LIBRARIES + ${_component}_DEFINITIONS + ${_component}_VERSION) + +endmacro() + + +# Check for cached results. If there are skip the costly part. +if (NOT FFMPEG_LIBRARIES) + + # Check for all possible component. + find_component(AVCODEC libavcodec avcodec libavcodec/avcodec.h) + find_component(AVFORMAT libavformat avformat libavformat/avformat.h) + find_component(AVDEVICE libavdevice avdevice libavdevice/avdevice.h) + #find_component(AVRESAMPLE libavresample avresample libavresample/avresample.h) # old name for swresample + find_component(AVUTIL libavutil avutil libavutil/avutil.h) + find_component(AVFILTER libavfilter avfilter libavfilter/avfilter.h) + find_component(SWSCALE libswscale swscale libswscale/swscale.h) + find_component(POSTPROC libpostproc postproc libpostproc/postprocess.h) + find_component(SWRESAMPLE libswresample swresample libswresample/swresample.h) + + # Check if the required components were found and add their stuff to the FFMPEG_* vars. + foreach (_component ${FFmpeg_FIND_COMPONENTS}) + if (${_component}_FOUND) + # message(STATUS "Required component ${_component} present.") + set(FFMPEG_LIBRARIES ${FFMPEG_LIBRARIES} ${${_component}_LIBRARIES}) + set(FFMPEG_DEFINITIONS ${FFMPEG_DEFINITIONS} ${${_component}_DEFINITIONS}) + list(APPEND FFMPEG_INCLUDE_DIRS ${${_component}_INCLUDE_DIRS}) + else () + # message(STATUS "Required component ${_component} missing.") + endif () + endforeach () + + # Build the include path with duplicates removed. + if (FFMPEG_INCLUDE_DIRS) + list(REMOVE_DUPLICATES FFMPEG_INCLUDE_DIRS) + endif () + + # cache the vars. + set(FFMPEG_INCLUDE_DIRS ${FFMPEG_INCLUDE_DIRS} CACHE STRING "The FFmpeg include directories." FORCE) + set(FFMPEG_LIBRARIES ${FFMPEG_LIBRARIES} CACHE STRING "The FFmpeg libraries." FORCE) + set(FFMPEG_DEFINITIONS ${FFMPEG_DEFINITIONS} CACHE STRING "The FFmpeg cflags." FORCE) + + mark_as_advanced(FFMPEG_INCLUDE_DIRS + FFMPEG_LIBRARIES + FFMPEG_DEFINITIONS) + +endif () + +# Now set the noncached _FOUND vars for the components. +# whisper.cpp does not need SWSCALE +foreach (_component AVCODEC AVDEVICE AVFORMAT AVRESAMPLE AVUTIL POSTPROCESS) + set_component_found(${_component}) +endforeach () + +# Compile the list of required vars +set(_FFmpeg_REQUIRED_VARS FFMPEG_LIBRARIES FFMPEG_INCLUDE_DIRS) +foreach (_component ${FFmpeg_FIND_COMPONENTS}) + list(APPEND _FFmpeg_REQUIRED_VARS ${_component}_LIBRARIES ${_component}_INCLUDE_DIRS) +endforeach () + +# Give a nice error message if some of the required vars are missing. +find_package_handle_standard_args(FFmpeg DEFAULT_MSG ${_FFmpeg_REQUIRED_VARS}) + diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index ae99704f825..fd30ad59bd3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -22,6 +22,10 @@ endif() set(TARGET common) +if (WHISPER_FFMPEG) + set(COMMON_SOURCES_FFMPEG ffmpeg-transcode.cpp) +endif() + add_library(${TARGET} STATIC console.h console.cpp @@ -29,7 +33,9 @@ add_library(${TARGET} STATIC common.cpp common-ggml.h common-ggml.cpp + grammar-parser.h grammar-parser.cpp + ${COMMON_SOURCES_FFMPEG} ) include(DefaultTargetOptions) @@ -37,6 +43,7 @@ include(DefaultTargetOptions) target_link_libraries(${TARGET} PRIVATE whisper) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(${TARGET} PROPERTIES FOLDER "libs") if (WHISPER_SDL2) # common-sdl @@ -54,30 +61,63 @@ if (WHISPER_SDL2) target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES}) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) + set_target_properties(${TARGET} PROPERTIES FOLDER "libs") endif() +# add json lib +add_library(json_cpp INTERFACE) +target_include_directories(json_cpp INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) + # examples include_directories(${CMAKE_CURRENT_SOURCE_DIR}) if (EMSCRIPTEN) add_subdirectory(whisper.wasm) + set_target_properties(libmain PROPERTIES FOLDER "libs") add_subdirectory(stream.wasm) + set_target_properties(libstream PROPERTIES FOLDER "libs") add_subdirectory(command.wasm) + set_target_properties(libcommand PROPERTIES FOLDER "libs") add_subdirectory(talk.wasm) + set_target_properties(libtalk PROPERTIES FOLDER "libs") add_subdirectory(bench.wasm) + set_target_properties(libbench PROPERTIES FOLDER "libs") elseif(CMAKE_JS_VERSION) add_subdirectory(addon.node) + set_target_properties(addon.node PROPERTIES FOLDER "examples") else() add_subdirectory(main) + set_target_properties(main PROPERTIES FOLDER "examples") +if (WHISPER_SDL2) add_subdirectory(stream) + set_target_properties(stream PROPERTIES FOLDER "examples") +endif (WHISPER_SDL2) add_subdirectory(server) + set_target_properties(server PROPERTIES FOLDER "examples") +if (WHISPER_SDL2) add_subdirectory(command) + set_target_properties(command PROPERTIES FOLDER "examples") +endif (WHISPER_SDL2) add_subdirectory(bench) + set_target_properties(bench PROPERTIES FOLDER "examples") add_subdirectory(quantize) + set_target_properties(quantize PROPERTIES FOLDER "examples") +if (WHISPER_SDL2) add_subdirectory(talk) + set_target_properties(talk PROPERTIES FOLDER "examples") add_subdirectory(talk-llama) + set_target_properties(talk-llama PROPERTIES FOLDER "examples") add_subdirectory(lsp) + set_target_properties(lsp PROPERTIES FOLDER "examples") + if (LLAMA_SYCL) + add_subdirectory(sycl) + set_target_properties(sycl PROPERTIES FOLDER "examples") + endif() +endif (WHISPER_SDL2) endif() -add_subdirectory(wchess) +if (WHISPER_SDL2) + add_subdirectory(wchess) + set_target_properties(wchess PROPERTIES FOLDER "examples") +endif (WHISPER_SDL2) diff --git a/examples/addon.node/CMakeLists.txt b/examples/addon.node/CMakeLists.txt index aef7839eb77..29cb1a27d07 100644 --- a/examples/addon.node/CMakeLists.txt +++ b/examples/addon.node/CMakeLists.txt @@ -1,4 +1,4 @@ -set(TARGET whisper-addon) +set(TARGET addon.node) # Base settings #================================================================== diff --git a/examples/addon.node/README.md b/examples/addon.node/README.md index bdb1d256bec..16df7d95870 100644 --- a/examples/addon.node/README.md +++ b/examples/addon.node/README.md @@ -14,14 +14,14 @@ npm install Make sure it is in the project root directory and compiled with make-js. ```shell -npx cmake-js compile -T whisper-addon -B Release +npx cmake-js compile -T addon.node -B Release ``` For Electron addon and cmake-js options, you can see [cmake-js](https://github.com/cmake-js/cmake-js) and make very few configuration changes. > Such as appointing special cmake path: > ```shell -> npx cmake-js compile -c 'xxx/cmake' -T whisper-addon -B Release +> npx cmake-js compile -c 'xxx/cmake' -T addon.node -B Release > ``` ## Run diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index d102fe7624e..1ee888a1e00 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -1,7 +1,7 @@ const path = require("path"); const { whisper } = require(path.join( __dirname, - "../../../build/Release/whisper-addon" + "../../../build/Release/addon.node" )); const { promisify } = require("util"); @@ -12,6 +12,12 @@ const whisperParamsMock = { model: path.join(__dirname, "../../../models/ggml-base.en.bin"), fname_inp: path.join(__dirname, "../../../samples/jfk.wav"), use_gpu: true, + flash_attn: false, + no_prints: true, + comma_in_time: false, + translate: true, + no_timestamps: false, + audio_ctx: 0, }; describe("Run whisper.node", () => { diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 30acbc6afd8..4ada6ca5084 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -19,12 +19,12 @@ struct whisper_params { int32_t max_len = 0; int32_t best_of = 5; int32_t beam_size = -1; + int32_t audio_ctx = 0; float word_thold = 0.01f; float entropy_thold = 2.4f; float logprob_thold = -1.0f; - bool speed_up = false; bool translate = false; bool diarize = false; bool output_txt = false; @@ -36,7 +36,10 @@ struct whisper_params { bool print_colors = false; bool print_progress = false; bool no_timestamps = false; + bool no_prints = false; bool use_gpu = true; + bool flash_attn = false; + bool comma_in_time = true; std::string language = "en"; std::string prompt; @@ -44,6 +47,8 @@ struct whisper_params { std::vector fname_inp = {}; std::vector fname_out = {}; + + std::vector pcmf32 = {}; // mono-channel F32 PCM }; struct whisper_print_user_data { @@ -52,27 +57,6 @@ struct whisper_print_user_data { const std::vector> * pcmf32s; }; -// 500 -> 00:05.000 -// 6000 -> 01:00.000 -std::string to_timestamp(int64_t t, bool comma = false) { - int64_t msec = t * 10; - int64_t hr = msec / (1000 * 60 * 60); - msec = msec - hr * (1000 * 60 * 60); - int64_t min = msec / (1000 * 60); - msec = msec - min * (1000 * 60); - int64_t sec = msec / 1000; - msec = msec - sec * 1000; - - char buf[32]; - snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); - - return std::string(buf); -} - -int timestamp_to_sample(int64_t t, int n_samples) { - return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); -} - void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; @@ -104,8 +88,8 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper if (params.diarize && pcmf32s.size() == 2) { const int64_t n_samples = pcmf32s[0].size(); - const int64_t is0 = timestamp_to_sample(t0, n_samples); - const int64_t is1 = timestamp_to_sample(t1, n_samples); + const int64_t is0 = timestamp_to_sample(t0, n_samples, WHISPER_SAMPLE_RATE); + const int64_t is1 = timestamp_to_sample(t1, n_samples, WHISPER_SAMPLE_RATE); double energy0 = 0.0f; double energy1 = 0.0f; @@ -141,9 +125,15 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } } +void cb_log_disable(enum ggml_log_level, const char *, void *) {} + int run(whisper_params ¶ms, std::vector> &result) { - if (params.fname_inp.empty()) { - fprintf(stderr, "error: no input files specified\n"); + if (params.no_prints) { + whisper_log_set(cb_log_disable, NULL); + } + + if (params.fname_inp.empty() && params.pcmf32.empty()) { + fprintf(stderr, "error: no input files or audio buffer specified\n"); return 2; } @@ -154,8 +144,9 @@ int run(whisper_params ¶ms, std::vector> &result) { // whisper init - struct whisper_context_params cparams; + struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (ctx == nullptr) { @@ -163,6 +154,14 @@ int run(whisper_params ¶ms, std::vector> &result) { return 3; } + // if params.pcmf32 is provided, set params.fname_inp to "buffer" + // this is simpler than further modifications in the code + if (!params.pcmf32.empty()) { + fprintf(stderr, "info: using audio buffer as input\n"); + params.fname_inp.clear(); + params.fname_inp.emplace_back("buffer"); + } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; @@ -170,20 +169,25 @@ int run(whisper_params ¶ms, std::vector> &result) { std::vector pcmf32; // mono-channel F32 PCM std::vector> pcmf32s; // stereo-channel F32 PCM - if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { - fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); - continue; + // read the input audio file if params.pcmf32 is not provided + if (params.pcmf32.empty()) { + if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { + fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); + continue; + } + } else { + pcmf32 = params.pcmf32; } // print system information - { + if (!params.no_prints) { fprintf(stderr, "\n"); fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); } // print some info about the processing - { + if (!params.no_prints) { fprintf(stderr, "\n"); if (!whisper_is_multilingual(ctx)) { if (params.language != "en" || params.translate) { @@ -192,12 +196,13 @@ int run(whisper_params ¶ms, std::vector> &result) { fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); } } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, params.n_processors, params.language.c_str(), params.translate ? "translate" : "transcribe", - params.no_timestamps ? 0 : 1); + params.no_timestamps ? 0 : 1, + params.audio_ctx); fprintf(stderr, "\n"); } @@ -224,14 +229,15 @@ int run(whisper_params ¶ms, std::vector> &result) { wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; - - wparams.speed_up = params.speed_up; + wparams.audio_ctx = params.audio_ctx; wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; wparams.initial_prompt = params.prompt.c_str(); + wparams.no_timestamps = params.no_timestamps; + whisper_print_user_data user_data = { ¶ms, &pcmf32s }; // this callback is called on each new segment @@ -267,8 +273,8 @@ int run(whisper_params ¶ms, std::vector> &result) { const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - result[i].emplace_back(to_timestamp(t0, true)); - result[i].emplace_back(to_timestamp(t1, true)); + result[i].emplace_back(to_timestamp(t0, params.comma_in_time)); + result[i].emplace_back(to_timestamp(t1, params.comma_in_time)); result[i].emplace_back(text); } @@ -319,11 +325,33 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { std::string model = whisper_params.Get("model").As(); std::string input = whisper_params.Get("fname_inp").As(); bool use_gpu = whisper_params.Get("use_gpu").As(); + bool flash_attn = whisper_params.Get("flash_attn").As(); + bool no_prints = whisper_params.Get("no_prints").As(); + bool no_timestamps = whisper_params.Get("no_timestamps").As(); + int32_t audio_ctx = whisper_params.Get("audio_ctx").As(); + bool comma_in_time = whisper_params.Get("comma_in_time").As(); + + Napi::Value pcmf32Value = whisper_params.Get("pcmf32"); + std::vector pcmf32_vec; + if (pcmf32Value.IsTypedArray()) { + Napi::Float32Array pcmf32 = pcmf32Value.As(); + size_t length = pcmf32.ElementLength(); + pcmf32_vec.reserve(length); + for (size_t i = 0; i < length; i++) { + pcmf32_vec.push_back(pcmf32[i]); + } + } params.language = language; params.model = model; params.fname_inp.emplace_back(input); params.use_gpu = use_gpu; + params.flash_attn = flash_attn; + params.no_prints = no_prints; + params.no_timestamps = no_timestamps; + params.audio_ctx = audio_ctx; + params.pcmf32 = pcmf32_vec; + params.comma_in_time = comma_in_time; Napi::Function callback = info[1].As(); Worker* worker = new Worker(callback, params); diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 3c6429375ab..643ee756452 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -1,7 +1,7 @@ const path = require("path"); const { whisper } = require(path.join( __dirname, - "../../build/Release/whisper-addon" + "../../build/Release/addon.node" )); const { promisify } = require("util"); @@ -10,15 +10,27 @@ const whisperAsync = promisify(whisper); const whisperParams = { language: "en", model: path.join(__dirname, "../../models/ggml-base.en.bin"), - fname_inp: "../../samples/jfk.wav", + fname_inp: path.join(__dirname, "../../samples/jfk.wav"), use_gpu: true, + flash_attn: false, + no_prints: true, + comma_in_time: false, + translate: true, + no_timestamps: false, + audio_ctx: 0, }; const arguments = process.argv.slice(2); const params = Object.fromEntries( arguments.reduce((pre, item) => { if (item.startsWith("--")) { - return [...pre, item.slice(2).split("=")]; + const [key, value] = item.slice(2).split("="); + if (key === "audio_ctx") { + whisperParams[key] = parseInt(value); + } else { + whisperParams[key] = value; + } + return pre; } return pre; }, []) @@ -33,5 +45,6 @@ for (const key in params) { console.log("whisperParams =", whisperParams); whisperAsync(whisperParams).then((result) => { - console.log(`Result from whisper: ${result}`); + console.log(); + console.log(result); }); diff --git a/examples/addon.node/package.json b/examples/addon.node/package.json index bf51f0bba9f..50046bf1f56 100644 --- a/examples/addon.node/package.json +++ b/examples/addon.node/package.json @@ -1,5 +1,5 @@ { - "name": "whisper-addon", + "name": "addon.node", "version": "0.0.0", "description": "", "main": "index.js", diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 949e5737167..cac9385c82f 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -8,11 +8,12 @@ // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat + int32_t what = 0; // what to benchmark: 0 - whisper encoder, 1 - memcpy, 2 - ggml_mul_mat std::string model = "models/ggml-base.en.bin"; - bool use_gpu = true; + bool use_gpu = true; + bool flash_attn = false; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -25,10 +26,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -49,6 +51,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " %-7s 0 - whisper\n", ""); fprintf(stderr, " %-7s 1 - memcpy\n", ""); fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", ""); @@ -58,8 +61,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para int whisper_bench_full(const whisper_params & params) { // whisper init - struct whisper_context_params cparams; - cparams.use_gpu = params.use_gpu; + struct whisper_context_params cparams = whisper_context_default_params(); + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); diff --git a/examples/command/README.md b/examples/command/README.md index d291d1aa6a3..46b14e93277 100644 --- a/examples/command/README.md +++ b/examples/command/README.md @@ -37,9 +37,13 @@ https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9 The `command` tool depends on SDL2 library to capture audio from the microphone. You can build it like this: ```bash -# Install SDL2 on Linux +# Install SDL2 +# On Debian based linux distributions: sudo apt-get install libsdl2-dev +# On Fedora Linux: +sudo dnf install SDL2 SDL2-devel + # Install SDL2 on Mac OS brew install sdl2 diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 776635720cc..dd8d8974069 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -23,11 +23,6 @@ #include #include -bool file_exists(const std::string & fname) { - std::ifstream f(fname.c_str()); - return f.good(); -} - // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); @@ -44,12 +39,12 @@ struct whisper_params { grammar_parser::parse_state grammar_parsed; - bool speed_up = false; bool translate = false; bool print_special = false; bool print_energy = false; bool no_timestamps = true; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -58,6 +53,9 @@ struct whisper_params { std::string prompt; std::string context; std::string grammar; + + // A regular expression that matches tokens to suppress + std::string suppress_regex; }; void whisper_print_usage(int argc, const char ** argv, const whisper_params & params); @@ -78,11 +76,11 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } @@ -91,6 +89,7 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; } else if ( arg == "--grammar") { params.grammar = argv[++i]; } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } + else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -115,11 +114,11 @@ void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); @@ -128,6 +127,7 @@ void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); + fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); fprintf(stderr, "\n"); } @@ -163,7 +163,6 @@ std::string transcribe( wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; wparams.temperature = 0.4f; wparams.temperature_inc = 1.0f; @@ -173,6 +172,8 @@ std::string transcribe( wparams.initial_prompt = params.context.data(); + wparams.suppress_regex = params.suppress_regex.c_str(); + const auto & grammar_parsed = params.grammar_parsed; auto grammar_rules = grammar_parsed.c_rules(); @@ -367,7 +368,6 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; wparams.prompt_tokens = k_tokens.data(); wparams.prompt_n_tokens = k_tokens.size(); @@ -694,8 +694,10 @@ int run(int argc, const char ** argv) { // whisper init - struct whisper_context_params cparams; - cparams.use_gpu = params.use_gpu; + struct whisper_context_params cparams = whisper_context_default_params(); + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); @@ -737,7 +739,7 @@ int run(int argc, const char ** argv) { if (!params.grammar.empty()) { auto & grammar = params.grammar_parsed; - if (file_exists(params.grammar.c_str())) { + if (is_file_exist(params.grammar.c_str())) { // read grammar from file std::ifstream ifs(params.grammar.c_str()); const std::string txt = std::string((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index affc2827bb9..d8dbc88a01e 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -64,7 +64,14 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: case GGML_FTYPE_MOSTLY_IQ2_XXS: case GGML_FTYPE_MOSTLY_IQ2_XS: + case GGML_FTYPE_MOSTLY_IQ2_S: case GGML_FTYPE_MOSTLY_IQ3_XXS: + case GGML_FTYPE_MOSTLY_IQ3_S: + case GGML_FTYPE_MOSTLY_IQ1_S: + case GGML_FTYPE_MOSTLY_IQ4_NL: + case GGML_FTYPE_MOSTLY_IQ4_XS: + case GGML_FTYPE_MOSTLY_IQ1_M: + case GGML_FTYPE_MOSTLY_BF16: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -85,8 +92,6 @@ bool ggml_common_quantize_0( std::vector data_f16; std::vector data_f32; - std::vector hist_all(1 << 4, 0); - while (true) { int32_t n_dims; int32_t length; @@ -171,8 +176,6 @@ bool ggml_common_quantize_0( work.resize(nelements); // for quantization size_t cur_size = 0; - std::vector hist_cur(1 << 4, 0); - switch ((ggml_type) ttype) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -185,18 +188,27 @@ bool ggml_common_quantize_0( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: { - cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements/ne[0], ne[0], hist_cur.data(), nullptr); + cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements/ne[0], ne[0], nullptr); } break; case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: + case GGML_TYPE_I64: + case GGML_TYPE_F64: case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_K: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_BF16: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); @@ -207,15 +219,7 @@ bool ggml_common_quantize_0( fout.write(reinterpret_cast(work.data()), cur_size); total_size_new += cur_size; - printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0); - for (int i = 0; i < (int) hist_cur.size(); ++i) { - hist_all[i] += hist_cur[i]; - } - - for (int i = 0; i < (int) hist_cur.size(); ++i) { - printf("%5.3f ", hist_cur[i] / (float)nelements); - } - printf("\n"); + printf("size = %8.2f MB -> %8.2f MB\n", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0); } else { printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0); fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); @@ -228,18 +232,5 @@ bool ggml_common_quantize_0( printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); printf("%s: quant size = %8.2f MB | ftype = %d (%s)\n", __func__, total_size_new/1024.0/1024.0, ftype, ggml_type_name(qtype)); - { - int64_t sum_all = 0; - for (int i = 0; i < (int) hist_all.size(); ++i) { - sum_all += hist_all[i]; - } - - printf("%s: hist: ", __func__); - for (int i = 0; i < (int) hist_all.size(); ++i) { - printf("%5.3f ", hist_all[i] / (float)sum_all); - } - printf("\n"); - } - return true; } diff --git a/examples/common.cpp b/examples/common.cpp index 3dd2a248d86..7b2556eb486 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -20,6 +20,16 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +#ifdef _WIN32 +#include +#include +#endif + +#ifdef WHISPER_FFMPEG +// as implemented in ffmpeg_trancode.cpp only embedded in common lib if whisper built with ffmpeg support +extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data); +#endif + // Function to check if the next argument exists std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) { if (i + 1 < argc && argv[i + 1][0] != '-') { @@ -633,10 +643,14 @@ bool is_wav_buffer(const std::string buf) { bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { drwav wav; - std::vector wav_data; // used for pipe input from stdin + std::vector wav_data; // used for pipe input from stdin or ffmpeg decoding output if (fname == "-") { { + #ifdef _WIN32 + _setmode(_fileno(stdin), _O_BINARY); + #endif + uint8_t buf[1024]; while (true) { @@ -666,27 +680,42 @@ bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector #else else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) { #endif +#if defined(WHISPER_FFMPEG) + if (ffmpeg_decode_audio(fname, wav_data) != 0) { + fprintf(stderr, "error: failed to ffmpeg decode '%s' \n", fname.c_str()); + return false; + } + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to read wav data as wav \n"); + return false; + } +#else fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); return false; +#endif } if (wav.channels != 1 && wav.channels != 2) { fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str()); + drwav_uninit(&wav); return false; } if (stereo && wav.channels != 2) { fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str()); + drwav_uninit(&wav); return false; } if (wav.sampleRate != COMMON_SAMPLE_RATE) { fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000); + drwav_uninit(&wav); return false; } if (wav.bitsPerSample != 16) { fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str()); + drwav_uninit(&wav); return false; } @@ -841,3 +870,48 @@ void sam_print_usage(int /*argc*/, char ** argv, const sam_params & params) { fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str()); fprintf(stderr, "\n"); } + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +std::string to_timestamp(int64_t t, bool comma) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); + + return std::string(buf); +} + +int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*whisper_sample_rate)/100))); +} + +bool is_file_exist(const char *fileName) +{ + std::ifstream infile(fileName); + return infile.good(); +} + +bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id) +{ + std::ofstream speak_file(path.c_str()); + if (speak_file.fail()) { + fprintf(stderr, "%s: failed to open speak_file\n", __func__); + return false; + } else { + speak_file.write(text.c_str(), text.size()); + speak_file.close(); + int ret = system((command + " " + std::to_string(voice_id) + " " + path).c_str()); + if (ret != 0) { + fprintf(stderr, "%s: failed to speak\n", __func__); + return false; + } + } + return true; +} diff --git a/examples/common.h b/examples/common.h index 09094a1b8a1..89ab37457e2 100644 --- a/examples/common.h +++ b/examples/common.h @@ -21,7 +21,7 @@ struct gpt_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 200; // new tokens to predict int32_t n_parallel = 1; // number of parallel streams - int32_t n_batch = 8; // batch size for prompt processing + int32_t n_batch = 32; // batch size for prompt processing int32_t n_ctx = 2048; // context size (this is the KV cache max size) int32_t n_gpu_layers = 0; // number of layers to offlload to the GPU @@ -185,7 +185,7 @@ class wav_writer { // It is assumed that PCM data is normalized to a range from -1 to 1 bool write_audio(const float * data, size_t length) { for (size_t i = 0; i < length; ++i) { - const int16_t intSample = data[i] * 32767; + const int16_t intSample = int16_t(data[i] * 32767); file.write(reinterpret_cast(&intSample), sizeof(int16_t)); dataSize += sizeof(int16_t); } @@ -281,3 +281,31 @@ struct sam_params { bool sam_params_parse(int argc, char ** argv, sam_params & params); void sam_print_usage(int argc, char ** argv, const sam_params & params); + +// +// Terminal utils +// + + +// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] +// Lowest is red, middle is yellow, highest is green. +const std::vector k_colors = { + "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", + "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", +}; + +// +// Other utils +// + +// convert timestamp to string, 6000 -> 01:00.000 +std::string to_timestamp(int64_t t, bool comma = false); + +// given a timestamp get the sample +int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate); + +// check if file exists using ifstream +bool is_file_exist(const char *fileName); + +// write text to file, and call system("command voice_id file") +bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id); diff --git a/examples/ffmpeg-transcode.cpp b/examples/ffmpeg-transcode.cpp new file mode 100644 index 00000000000..910cdf5700b --- /dev/null +++ b/examples/ffmpeg-transcode.cpp @@ -0,0 +1,350 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +/* + * transcode.c - convert audio file to WAVE + * + * Copyright (C) 2019 Andrew Clayton + * Copyright (C) 2024 William Tambellini + */ + +// Just for conveninent C++ API +#include +#include + +// C +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +extern "C" { +#include +#include +#include +#include +} + +typedef uint64_t u64; +typedef int64_t s64; +typedef uint32_t u32; +typedef int32_t s32; +typedef uint16_t u16; +typedef int16_t s16; +typedef uint8_t u8; +typedef int8_t s8; + +#define WAVE_SAMPLE_RATE 16000 +#define AVIO_CTX_BUF_SZ 4096 + +static const char* ffmpegLog = getenv("FFMPEG_LOG"); +// Todo: add __FILE__ __LINE__ +#define LOG(...) \ + do { if (ffmpegLog) fprintf(stderr, __VA_ARGS__); } while(0) // C99 + +/* + * WAVE file header based on definition from + * https://gist.github.com/Jon-Schneider/8b7c53d27a7a13346a643dac9c19d34f + * + * We must ensure this structure doesn't have any holes or + * padding so we can just map it straight to the WAVE data. + */ +struct wave_hdr { + /* RIFF Header: "RIFF" */ + char riff_header[4]; + /* size of audio data + sizeof(struct wave_hdr) - 8 */ + int wav_size; + /* "WAVE" */ + char wav_header[4]; + + /* Format Header */ + /* "fmt " (includes trailing space) */ + char fmt_header[4]; + /* Should be 16 for PCM */ + int fmt_chunk_size; + /* Should be 1 for PCM. 3 for IEEE Float */ + s16 audio_format; + s16 num_channels; + int sample_rate; + /* + * Number of bytes per second + * sample_rate * num_channels * bit_depth/8 + */ + int byte_rate; + /* num_channels * bytes per sample */ + s16 sample_alignment; + /* bits per sample */ + s16 bit_depth; + + /* Data Header */ + /* "data" */ + char data_header[4]; + /* + * size of audio + * number of samples * num_channels * bit_depth/8 + */ + int data_bytes; +} __attribute__((__packed__)); + +struct audio_buffer { + u8 *ptr; + int size; /* size left in the buffer */ +}; + +static void set_wave_hdr(wave_hdr& wh, size_t size) { + memcpy(&wh.riff_header, "RIFF", 4); + wh.wav_size = size + sizeof(struct wave_hdr) - 8; + memcpy(&wh.wav_header, "WAVE", 4); + memcpy(&wh.fmt_header, "fmt ", 4); + wh.fmt_chunk_size = 16; + wh.audio_format = 1; + wh.num_channels = 1; + wh.sample_rate = WAVE_SAMPLE_RATE; + wh.sample_alignment = 2; + wh.bit_depth = 16; + wh.byte_rate = wh.sample_rate * wh.sample_alignment; + memcpy(&wh.data_header, "data", 4); + wh.data_bytes = size; +} + +static void write_wave_hdr(int fd, size_t size) { + struct wave_hdr wh; + set_wave_hdr(wh, size); + write(fd, &wh, sizeof(struct wave_hdr)); +} + +static int map_file(int fd, u8 **ptr, size_t *size) +{ + struct stat sb; + + fstat(fd, &sb); + *size = sb.st_size; + + *ptr = (u8*)mmap(NULL, *size, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); + if (*ptr == MAP_FAILED) { + perror("mmap"); + return -1; + } + + return 0; +} + +static int read_packet(void *opaque, u8 *buf, int buf_size) +{ + struct audio_buffer *audio_buf = (audio_buffer*)opaque; + + buf_size = FFMIN(buf_size, audio_buf->size); + + /* copy internal buffer data to buf */ + memcpy(buf, audio_buf->ptr, buf_size); + audio_buf->ptr += buf_size; + audio_buf->size -= buf_size; + + return buf_size; +} + +static void convert_frame(struct SwrContext *swr, AVCodecContext *codec, + AVFrame *frame, s16 **data, int *size, bool flush) +{ + int nr_samples; + s64 delay; + u8 *buffer; + + delay = swr_get_delay(swr, codec->sample_rate); + nr_samples = av_rescale_rnd(delay + frame->nb_samples, + WAVE_SAMPLE_RATE, codec->sample_rate, + AV_ROUND_UP); + av_samples_alloc(&buffer, NULL, 1, nr_samples, AV_SAMPLE_FMT_S16, 0); + + /* + * !flush is used to check if we are flushing any remaining + * conversion buffers... + */ + nr_samples = swr_convert(swr, &buffer, nr_samples, + !flush ? (const u8 **)frame->data : NULL, + !flush ? frame->nb_samples : 0); + + *data = (s16*)realloc(*data, (*size + nr_samples) * sizeof(s16)); + memcpy(*data + *size, buffer, nr_samples * sizeof(s16)); + *size += nr_samples; + av_freep(&buffer); +} + +static bool is_audio_stream(const AVStream *stream) +{ + if (stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) + return true; + + return false; +} + +// Return non zero on error, 0 on success +// audio_buffer: input memory +// data: decoded output audio data (wav file) +// size: size of output data +static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size) +{ + LOG("decode_audio: input size: %d\n", audio_buf->size); + AVFormatContext *fmt_ctx; + AVIOContext *avio_ctx; + AVStream *stream; + AVCodecContext *codec; + AVPacket packet; + AVFrame *frame; + struct SwrContext *swr; + u8 *avio_ctx_buffer; + unsigned int i; + int stream_index = -1; + int err; + const size_t errbuffsize = 1024; + char errbuff[errbuffsize]; + + av_register_all(); // from avformat. Still a must-have call for ffmpeg v3! (can be skipped for later versions) + + fmt_ctx = avformat_alloc_context(); + avio_ctx_buffer = (u8*)av_malloc(AVIO_CTX_BUF_SZ); + LOG("Creating an avio context: AVIO_CTX_BUF_SZ=%d\n", AVIO_CTX_BUF_SZ); + avio_ctx = avio_alloc_context(avio_ctx_buffer, AVIO_CTX_BUF_SZ, 0, audio_buf, &read_packet, NULL, NULL); + fmt_ctx->pb = avio_ctx; + + // open the input stream and read header + err = avformat_open_input(&fmt_ctx, NULL, NULL, NULL); + if (err) { + LOG("Could not read audio buffer: %d: %s\n", err, av_make_error_string(errbuff, errbuffsize, err)); + return err; + } + + err = avformat_find_stream_info(fmt_ctx, NULL); + if (err < 0) { + LOG("Could not retrieve stream info from audio buffer: %d\n", err); + return err; + } + + for (i = 0; i < fmt_ctx->nb_streams; i++) { + if (is_audio_stream(fmt_ctx->streams[i])) { + stream_index = i; + break; + } + } + + if (stream_index == -1) { + LOG("Could not retrieve audio stream from buffer\n"); + return -1; + } + + stream = fmt_ctx->streams[stream_index]; + codec = avcodec_alloc_context3( + avcodec_find_decoder(stream->codecpar->codec_id)); + avcodec_parameters_to_context(codec, stream->codecpar); + err = avcodec_open2(codec, avcodec_find_decoder(codec->codec_id), + NULL); + if (err) { + LOG("Failed to open decoder for stream #%d in audio buffer\n", stream_index); + return err; + } + + /* prepare resampler */ + swr = swr_alloc(); + + av_opt_set_int(swr, "in_channel_count", codec->channels, 0); + av_opt_set_int(swr, "out_channel_count", 1, 0); + av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0); + av_opt_set_int(swr, "out_channel_layout", AV_CH_LAYOUT_MONO, 0); + av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); + av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); + av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); + av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); + + swr_init(swr); + if (!swr_is_initialized(swr)) { + LOG("Resampler has not been properly initialized\n"); + return -1; + } + + av_init_packet(&packet); + frame = av_frame_alloc(); + if (!frame) { + LOG("Error allocating the frame\n"); + return -1; + } + + /* iterate through frames */ + *data = NULL; + *size = 0; + while (av_read_frame(fmt_ctx, &packet) >= 0) { + avcodec_send_packet(codec, &packet); + + err = avcodec_receive_frame(codec, frame); + if (err == AVERROR(EAGAIN)) + continue; + + convert_frame(swr, codec, frame, data, size, false); + } + /* Flush any remaining conversion buffers... */ + convert_frame(swr, codec, frame, data, size, true); + + av_frame_free(&frame); + swr_free(&swr); + //avio_context_free(); // todo? + avcodec_close(codec); + avformat_close_input(&fmt_ctx); + avformat_free_context(fmt_ctx); + + if (avio_ctx) { + av_freep(&avio_ctx->buffer); + av_freep(&avio_ctx); + } + + return 0; +} + +// in mem decoding/conversion/resampling: +// ifname: input file path +// owav_data: in mem wav file. Can be forwarded as it to whisper/drwav +// return 0 on success +int ffmpeg_decode_audio(const std::string &ifname, std::vector& owav_data) { + LOG("ffmpeg_decode_audio: %s\n", ifname.c_str()); + int ifd = open(ifname.c_str(), O_RDONLY); + if (ifd == -1) { + fprintf(stderr, "Couldn't open input file %s\n", ifname.c_str()); + return -1; + } + u8 *ibuf = NULL; + size_t ibuf_size; + int err = map_file(ifd, &ibuf, &ibuf_size); + if (err) { + LOG("Couldn't map input file %s\n", ifname.c_str()); + return err; + } + LOG("Mapped input file: %x size: %d\n", ibuf, ibuf_size); + struct audio_buffer inaudio_buf; + inaudio_buf.ptr = ibuf; + inaudio_buf.size = ibuf_size; + + s16 *odata=NULL; + int osize=0; + + err = decode_audio(&inaudio_buf, &odata, &osize); + LOG("decode_audio returned %d \n", err); + if (err != 0) { + LOG("decode_audio failed\n"); + return err; + } + LOG("decode_audio output size: %d\n", osize); + + wave_hdr wh; + const size_t outdatasize = osize * sizeof(s16); + set_wave_hdr(wh, outdatasize); + owav_data.resize(sizeof(wave_hdr) + outdatasize); + // header: + memcpy(owav_data.data(), &wh, sizeof(wave_hdr)); + // the data: + memcpy(owav_data.data() + sizeof(wave_hdr), odata, osize* sizeof(s16)); + + return 0; +} diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index 2daaaef4504..6133b8c7fc2 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -190,7 +190,7 @@ namespace grammar_parser { pos = parse_space(pos + 1, is_nested); } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); + throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos); } // apply transformation to previous symbol (last_sym_start to end) according to diff --git a/examples/helpers.js b/examples/helpers.js index 23f18ab5697..423bada60ba 100644 --- a/examples/helpers.js +++ b/examples/helpers.js @@ -34,9 +34,6 @@ async function fetchRemote(url, cbProgress, cbPrint) { url, { method: 'GET', - headers: { - 'Content-Type': 'application/octet-stream', - }, } ); diff --git a/examples/lsp/json.hpp b/examples/json.hpp similarity index 100% rename from examples/lsp/json.hpp rename to examples/json.hpp diff --git a/examples/lsp/CMakeLists.txt b/examples/lsp/CMakeLists.txt index e7ac2511836..15b5be18783 100644 --- a/examples/lsp/CMakeLists.txt +++ b/examples/lsp/CMakeLists.txt @@ -5,5 +5,5 @@ if (WHISPER_SDL2) include(DefaultTargetOptions) - target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT}) + target_link_libraries(${TARGET} PRIVATE common json_cpp common-sdl whisper ${CMAKE_THREAD_LIBS_INIT}) endif () diff --git a/examples/lsp/lsp.cpp b/examples/lsp/lsp.cpp index 8d8b6ffa238..8cca87151bf 100644 --- a/examples/lsp/lsp.cpp +++ b/examples/lsp/lsp.cpp @@ -26,11 +26,11 @@ struct whisper_params { float vad_thold = 0.6f; float freq_thold = 100.0f; - bool speed_up = false; bool translate = false; bool print_special = false; bool print_energy = false; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -69,11 +69,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else { @@ -100,11 +100,11 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, "\n"); @@ -181,7 +181,6 @@ json unguided_transcription(struct whisper_context * ctx, audio_async &audio, js wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; wparams.suppress_non_speech_tokens = true; // run the transformer and a single decoding pass if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { @@ -220,7 +219,6 @@ json guided_transcription(struct whisper_context * ctx, audio_async &audio, cons wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; // TODO: Do some time testing. Does an overly long prompt slow down processing? // Set up command sets/precompute prompts @@ -435,8 +433,11 @@ int main(int argc, char ** argv) { } // whisper init - struct whisper_context_params cparams; - cparams.use_gpu = params.use_gpu; + struct whisper_context_params cparams = whisper_context_default_params(); + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); // init audio diff --git a/examples/main/CMakeLists.txt b/examples/main/CMakeLists.txt index 1bb16f58214..1e66e4b5cc8 100644 --- a/examples/main/CMakeLists.txt +++ b/examples/main/CMakeLists.txt @@ -3,4 +3,4 @@ add_executable(${TARGET} main.cpp) include(DefaultTargetOptions) -target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common whisper ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1b7fc9c78c4..78d2e6e7ab6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -2,10 +2,12 @@ #include "console.h" #include "whisper.h" +#include "grammar-parser.h" #include #include #include +#include #include #include #include @@ -15,34 +17,6 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9] -// Lowest is red, middle is yellow, highest is green. -const std::vector k_colors = { - "\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m", - "\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m", -}; - -// 500 -> 00:05.000 -// 6000 -> 01:00.000 -std::string to_timestamp(int64_t t, bool comma = false) { - int64_t msec = t * 10; - int64_t hr = msec / (1000 * 60 * 60); - msec = msec - hr * (1000 * 60 * 60); - int64_t min = msec / (1000 * 60); - msec = msec - min * (1000 * 60); - int64_t sec = msec / 1000; - msec = msec - sec * 1000; - - char buf[32]; - snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); - - return std::string(buf); -} - -int timestamp_to_sample(int64_t t, int n_samples) { - return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); -} - // helper function to replace substrings void replace_all(std::string & s, const std::string & search, const std::string & replace) { for (size_t pos = 0; ; pos += replace.length()) { @@ -55,23 +29,27 @@ void replace_all(std::string & s, const std::string & search, const std::string // command-line parameters struct whisper_params { - int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - int32_t n_processors = 1; - int32_t offset_t_ms = 0; - int32_t offset_n = 0; - int32_t duration_ms = 0; - int32_t progress_step = 5; - int32_t max_context = -1; - int32_t max_len = 0; - int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; - int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; + int32_t progress_step = 5; + int32_t max_context = -1; + int32_t max_len = 0; + int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; + int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; + int32_t audio_ctx = 0; float word_thold = 0.01f; float entropy_thold = 2.40f; float logprob_thold = -1.00f; float no_speech_thold = 0.60f; + float grammar_penalty = 100.0f; + float temperature = 0.0f; + float temperature_inc = 0.2f; - bool speed_up = false; + bool debug_mode = false; bool translate = false; bool detect_language = false; bool diarize = false; @@ -94,23 +72,41 @@ struct whisper_params { bool heuristic = true; bool log_score = false; bool use_gpu = true; + bool flash_attn = false; std::string language = "en"; std::string prompt; std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; std::string model = "models/ggml-base.en.bin"; + std::string grammar; + std::string grammar_rule; // [TDRZ] speaker turn string std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line + // A regular expression that matches tokens to suppress + std::string suppress_regex; + std::string openvino_encode_device = "CPU"; + std::string dtw = ""; + std::vector fname_inp = {}; std::vector fname_out = {}; + + grammar_parser::parse_state grammar_parsed; }; void whisper_print_usage(int argc, const char ** argv, const whisper_params & params); +char* whisper_param_turn_lowercase(char* in){ + int string_len = strlen(in); + for(int i = 0; i < string_len; i++){ + *(in+i) = tolower((unsigned char)*(in+i)); + } + return in; +} + bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -138,13 +134,16 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-nst" || arg == "--nospeech-thold") { params.no_speech_thold = std::stof(argv[++i]); } - // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-nsnst"|| arg == "--no-suppress-nst") { params.suppress_nst = false; } else if (arg == "-nh" || arg == "--no-heuristic") { params.heuristic = false; } + else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); } + else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); } + else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } @@ -164,14 +163,20 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(argv[++i]); } else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } else if ( arg == "--prompt") { params.prompt = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } + else if ( arg == "--grammar") { params.grammar = argv[++i]; } + else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; } + else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -197,13 +202,16 @@ void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -nst N, --nospeech-thold N [%-7.2f] no-speech threshold for decoder fail\n", params.no_speech_thold); - // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -nsnst, --no-suppress-nst [%-7s] do not suppress non-speech tokens\n", params.suppress_nst ? "false" : "true"); fprintf(stderr, " -nh, --no-heuristic [%-7s] do not use heuristic while decoding\n", params.heuristic ? "false" : "true"); + fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); + fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); @@ -225,12 +233,18 @@ void whisper_print_usage(int /*argc*/, const char ** argv, const whisper_params fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); + fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); + fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); + fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); fprintf(stderr, "\n"); } @@ -245,8 +259,8 @@ std::string estimate_diarization_speaker(std::vector> pcmf32s std::string speaker = ""; const int64_t n_samples = pcmf32s[0].size(); - const int64_t is0 = timestamp_to_sample(t0, n_samples); - const int64_t is1 = timestamp_to_sample(t1, n_samples); + const int64_t is0 = timestamp_to_sample(t0, n_samples, WHISPER_SAMPLE_RATE); + const int64_t is1 = timestamp_to_sample(t1, n_samples, WHISPER_SAMPLE_RATE); double energy0 = 0.0f; double energy1 = 0.0f; @@ -492,6 +506,38 @@ char *escape_double_quotes_and_backslashes(const char *str) { return escaped; } +// double quote should be escaped by another double quote. (rfc4180) +char *escape_double_quotes_in_csv(const char *str) { + if (str == NULL) { + return NULL; + } + + size_t escaped_length = strlen(str) + 1; + + for (size_t i = 0; str[i] != '\0'; i++) { + if (str[i] == '"') { + escaped_length++; + } + } + + char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed + if (escaped == NULL) { + return NULL; + } + + size_t pos = 0; + for (size_t i = 0; str[i] != '\0'; i++) { + if (str[i] == '"') { + escaped[pos++] = '"'; + } + escaped[pos++] = str[i]; + } + + // no need to set zero due to calloc() being used prior + + return escaped; +} + bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { #if _WIN32 std::ofstream fout(console::UTF8toUTF16(fname)); @@ -517,7 +563,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_ const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - char * text_escaped = escape_double_quotes_and_backslashes(text); + char * text_escaped = escape_double_quotes_in_csv(text); //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. fout << 10 * t0 << "," << 10 * t1 << ","; @@ -704,7 +750,8 @@ bool output_json( times_o(token.t0, token.t1, false); } value_i("id", token.id, false); - value_f("p", token.p, true); + value_f("p", token.p, false); + value_f("t_dtw", token.t_dtw, true); end_obj(j == (n - 1)); } end_arr(!params.diarize && !params.tinydiarize); @@ -907,11 +954,53 @@ void cb_log_disable(enum ggml_log_level , const char * , void * ) { } int run(int argc, const char ** argv) { whisper_params params; + // If the only argument starts with "@", read arguments line-by-line + // from the given file. + std::vector vec_args; + if (argc == 2 && argv != nullptr && argv[1] != nullptr && argv[1][0] == '@') { + // Save the name of the executable. + vec_args.push_back(argv[0]); + + // Open the response file. + char const * rspfile = argv[1] + sizeof(char); + std::ifstream fin(rspfile); + if (fin.is_open() == false) { + fprintf(stderr, "error: response file '%s' not found\n", rspfile); + return 1; + } + + // Read the entire response file. + std::string line; + while (std::getline(fin, line)) { + vec_args.push_back(line); + } + + // Use the contents of the response file as the command-line arguments. + argc = static_cast(vec_args.size()); + argv = static_cast(alloca(argc * sizeof (char *))); + for (int i = 0; i < argc; ++i) { + argv[i] = const_cast(vec_args[i].c_str()); + } + } + if (whisper_params_parse(argc, argv, params) == false) { whisper_print_usage(argc, argv, params); return 1; } + // remove non-existent files + for (auto it = params.fname_inp.begin(); it != params.fname_inp.end();) { + const auto fname_inp = it->c_str(); + + if (*it != "-" && !is_file_exist(fname_inp)) { + fprintf(stderr, "error: input file not found '%s'\n", fname_inp); + it = params.fname_inp.erase(it); + continue; + } + + it++; + } + if (params.fname_inp.empty()) { fprintf(stderr, "error: no input files specified\n"); whisper_print_usage(argc, argv, params); @@ -936,8 +1025,32 @@ int run(int argc, const char ** argv) { // whisper init - struct whisper_context_params cparams; - cparams.use_gpu = params.use_gpu; + struct whisper_context_params cparams = whisper_context_default_params(); + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + + if (!params.dtw.empty()) { + cparams.dtw_token_timestamps = true; + cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; + + if (params.dtw == "tiny") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY; + if (params.dtw == "tiny.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN; + if (params.dtw == "base") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE; + if (params.dtw == "base.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN; + if (params.dtw == "small") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL; + if (params.dtw == "small.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN; + if (params.dtw == "medium") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM; + if (params.dtw == "medium.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN; + if (params.dtw == "large.v1") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1; + if (params.dtw == "large.v2") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2; + if (params.dtw == "large.v3") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3; + + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { + fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str()); + return 3; + } + } struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); @@ -949,6 +1062,29 @@ int run(int argc, const char ** argv) { // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + if (!params.grammar.empty()) { + auto & grammar = params.grammar_parsed; + if (is_file_exist(params.grammar.c_str())) { + // read grammar from file + std::ifstream ifs(params.grammar.c_str()); + const std::string txt = std::string((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + grammar = grammar_parser::parse(txt.c_str()); + } else { + // read grammar from string + grammar = grammar_parser::parse(params.grammar.c_str()); + } + + // will be empty (default) if there are parse errors + if (grammar.rules.empty()) { + fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str()); + return 4; + } else { + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, grammar); + fprintf(stderr, "\n"); + } + } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; @@ -995,7 +1131,8 @@ int run(int argc, const char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty()); + wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; wparams.print_realtime = false; wparams.print_progress = params.print_progress; @@ -1012,18 +1149,25 @@ int run(int argc, const char ** argv) { wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0; wparams.thold_pt = params.word_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; - - wparams.speed_up = params.speed_up; + wparams.heuristic = params.heuristic; + wparams.split_on_word = params.split_on_word; + wparams.audio_ctx = params.audio_ctx; + + wparams.debug_mode = params.debug_mode; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str(); + wparams.initial_prompt = params.prompt.c_str(); wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; - wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc; + wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc; + wparams.temperature = params.temperature; + wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; wparams.no_speech_thold = params.no_speech_thold; @@ -1033,6 +1177,20 @@ int run(int argc, const char ** argv) { whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; + const auto & grammar_parsed = params.grammar_parsed; + auto grammar_rules = grammar_parsed.c_rules(); + + if (use_grammar) { + if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end()) { + fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str()); + } else { + wparams.grammar_rules = grammar_rules.data(); + wparams.n_grammar_rules = grammar_rules.size(); + wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule); + wparams.grammar_penalty = params.grammar_penalty; + } + } + // this callback is called on each new segment if (!wparams.print_realtime) { wparams.new_segment_callback = whisper_print_segment_callback; @@ -1129,7 +1287,9 @@ int run(int argc, const char ** argv) { } } - whisper_print_timings(ctx); + if (!params.no_prints) { + whisper_print_timings(ctx); + } whisper_free(ctx); return 0; diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 1e8c921323f..96dd97f7654 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,9 +1,9 @@ set(TARGET server) -add_executable(${TARGET} server.cpp httplib.h json.hpp) +add_executable(${TARGET} server.cpp httplib.h) include(DefaultTargetOptions) -target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common json_cpp whisper ${CMAKE_THREAD_LIBS_INIT}) if (WIN32) target_link_libraries(${TARGET} PRIVATE ws2_32) diff --git a/examples/server/json.hpp b/examples/server/json.hpp deleted file mode 100644 index 4d1a37ad7cb..00000000000 --- a/examples/server/json.hpp +++ /dev/null @@ -1,24596 +0,0 @@ -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - -/****************************************************************************\ - * Note on documentation: The source files contain links to the online * - * documentation of the public API at https://json.nlohmann.me. This URL * - * contains the most recent documentation and should also be applicable to * - * previous versions; documentation for deprecated functions is not * - * removed, but marked deprecated. See "Generate documentation" section in * - * file docs/README.md. * -\****************************************************************************/ - -#ifndef INCLUDE_NLOHMANN_JSON_HPP_ -#define INCLUDE_NLOHMANN_JSON_HPP_ - -#include // all_of, find, for_each -#include // nullptr_t, ptrdiff_t, size_t -#include // hash, less -#include // initializer_list -#ifndef JSON_NO_IO - #include // istream, ostream -#endif // JSON_NO_IO -#include // random_access_iterator_tag -#include // unique_ptr -#include // accumulate -#include // string, stoi, to_string -#include // declval, forward, move, pair, swap -#include // vector - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -#include - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -// This file contains all macro definitions affecting or depending on the ABI - -#ifndef JSON_SKIP_LIBRARY_VERSION_CHECK - #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH) - #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 2 - #warning "Already included a different version of the library!" - #endif - #endif -#endif - -#define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum) -#define NLOHMANN_JSON_VERSION_MINOR 11 // NOLINT(modernize-macro-to-enum) -#define NLOHMANN_JSON_VERSION_PATCH 2 // NOLINT(modernize-macro-to-enum) - -#ifndef JSON_DIAGNOSTICS - #define JSON_DIAGNOSTICS 0 -#endif - -#ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON - #define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0 -#endif - -#if JSON_DIAGNOSTICS - #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag -#else - #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS -#endif - -#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON - #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp -#else - #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON -#endif - -#ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION - #define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0 -#endif - -// Construct the namespace ABI tags component -#define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi ## a ## b -#define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) \ - NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) - -#define NLOHMANN_JSON_ABI_TAGS \ - NLOHMANN_JSON_ABI_TAGS_CONCAT( \ - NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS, \ - NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON) - -// Construct the namespace version component -#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \ - _v ## major ## _ ## minor ## _ ## patch -#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \ - NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) - -#if NLOHMANN_JSON_NAMESPACE_NO_VERSION -#define NLOHMANN_JSON_NAMESPACE_VERSION -#else -#define NLOHMANN_JSON_NAMESPACE_VERSION \ - NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \ - NLOHMANN_JSON_VERSION_MINOR, \ - NLOHMANN_JSON_VERSION_PATCH) -#endif - -// Combine namespace components -#define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b -#define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \ - NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) - -#ifndef NLOHMANN_JSON_NAMESPACE -#define NLOHMANN_JSON_NAMESPACE \ - nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \ - NLOHMANN_JSON_ABI_TAGS, \ - NLOHMANN_JSON_NAMESPACE_VERSION) -#endif - -#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN -#define NLOHMANN_JSON_NAMESPACE_BEGIN \ - namespace nlohmann \ - { \ - inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \ - NLOHMANN_JSON_ABI_TAGS, \ - NLOHMANN_JSON_NAMESPACE_VERSION) \ - { -#endif - -#ifndef NLOHMANN_JSON_NAMESPACE_END -#define NLOHMANN_JSON_NAMESPACE_END \ - } /* namespace (inline namespace) NOLINT(readability/namespace) */ \ - } // namespace nlohmann -#endif - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -#include // transform -#include // array -#include // forward_list -#include // inserter, front_inserter, end -#include // map -#include // string -#include // tuple, make_tuple -#include // is_arithmetic, is_same, is_enum, underlying_type, is_convertible -#include // unordered_map -#include // pair, declval -#include // valarray - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -#include // nullptr_t -#include // exception -#include // runtime_error -#include // to_string -#include // vector - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -#include // array -#include // size_t -#include // uint8_t -#include // string - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -#include // declval, pair -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -#include - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -// #include - - -NLOHMANN_JSON_NAMESPACE_BEGIN -namespace detail -{ - -template struct make_void -{ - using type = void; -}; -template using void_t = typename make_void::type; - -} // namespace detail -NLOHMANN_JSON_NAMESPACE_END - - -NLOHMANN_JSON_NAMESPACE_BEGIN -namespace detail -{ - -// https://en.cppreference.com/w/cpp/experimental/is_detected -struct nonesuch -{ - nonesuch() = delete; - ~nonesuch() = delete; - nonesuch(nonesuch const&) = delete; - nonesuch(nonesuch const&&) = delete; - void operator=(nonesuch const&) = delete; - void operator=(nonesuch&&) = delete; -}; - -template class Op, - class... Args> -struct detector -{ - using value_t = std::false_type; - using type = Default; -}; - -template class Op, class... Args> -struct detector>, Op, Args...> -{ - using value_t = std::true_type; - using type = Op; -}; - -template class Op, class... Args> -using is_detected = typename detector::value_t; - -template class Op, class... Args> -struct is_detected_lazy : is_detected { }; - -template class Op, class... Args> -using detected_t = typename detector::type; - -template class Op, class... Args> -using detected_or = detector; - -template class Op, class... Args> -using detected_or_t = typename detected_or::type; - -template class Op, class... Args> -using is_detected_exact = std::is_same>; - -template class Op, class... Args> -using is_detected_convertible = - std::is_convertible, To>; - -} // namespace detail -NLOHMANN_JSON_NAMESPACE_END - -// #include - - -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-FileCopyrightText: 2016-2021 Evan Nemerson -// SPDX-License-Identifier: MIT - -/* Hedley - https://nemequ.github.io/hedley - * Created by Evan Nemerson - */ - -#if !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < 15) -#if defined(JSON_HEDLEY_VERSION) - #undef JSON_HEDLEY_VERSION -#endif -#define JSON_HEDLEY_VERSION 15 - -#if defined(JSON_HEDLEY_STRINGIFY_EX) - #undef JSON_HEDLEY_STRINGIFY_EX -#endif -#define JSON_HEDLEY_STRINGIFY_EX(x) #x - -#if defined(JSON_HEDLEY_STRINGIFY) - #undef JSON_HEDLEY_STRINGIFY -#endif -#define JSON_HEDLEY_STRINGIFY(x) JSON_HEDLEY_STRINGIFY_EX(x) - -#if defined(JSON_HEDLEY_CONCAT_EX) - #undef JSON_HEDLEY_CONCAT_EX -#endif -#define JSON_HEDLEY_CONCAT_EX(a,b) a##b - -#if defined(JSON_HEDLEY_CONCAT) - #undef JSON_HEDLEY_CONCAT -#endif -#define JSON_HEDLEY_CONCAT(a,b) JSON_HEDLEY_CONCAT_EX(a,b) - -#if defined(JSON_HEDLEY_CONCAT3_EX) - #undef JSON_HEDLEY_CONCAT3_EX -#endif -#define JSON_HEDLEY_CONCAT3_EX(a,b,c) a##b##c - -#if defined(JSON_HEDLEY_CONCAT3) - #undef JSON_HEDLEY_CONCAT3 -#endif -#define JSON_HEDLEY_CONCAT3(a,b,c) JSON_HEDLEY_CONCAT3_EX(a,b,c) - -#if defined(JSON_HEDLEY_VERSION_ENCODE) - #undef JSON_HEDLEY_VERSION_ENCODE -#endif -#define JSON_HEDLEY_VERSION_ENCODE(major,minor,revision) (((major) * 1000000) + ((minor) * 1000) + (revision)) - -#if defined(JSON_HEDLEY_VERSION_DECODE_MAJOR) - #undef JSON_HEDLEY_VERSION_DECODE_MAJOR -#endif -#define JSON_HEDLEY_VERSION_DECODE_MAJOR(version) ((version) / 1000000) - -#if defined(JSON_HEDLEY_VERSION_DECODE_MINOR) - #undef JSON_HEDLEY_VERSION_DECODE_MINOR -#endif -#define JSON_HEDLEY_VERSION_DECODE_MINOR(version) (((version) % 1000000) / 1000) - -#if defined(JSON_HEDLEY_VERSION_DECODE_REVISION) - #undef JSON_HEDLEY_VERSION_DECODE_REVISION -#endif -#define JSON_HEDLEY_VERSION_DECODE_REVISION(version) ((version) % 1000) - -#if defined(JSON_HEDLEY_GNUC_VERSION) - #undef JSON_HEDLEY_GNUC_VERSION -#endif -#if defined(__GNUC__) && defined(__GNUC_PATCHLEVEL__) - #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) -#elif defined(__GNUC__) - #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, 0) -#endif - -#if defined(JSON_HEDLEY_GNUC_VERSION_CHECK) - #undef JSON_HEDLEY_GNUC_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_GNUC_VERSION) - #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GNUC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_MSVC_VERSION) - #undef JSON_HEDLEY_MSVC_VERSION -#endif -#if defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 140000000) && !defined(__ICL) - #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 10000000, (_MSC_FULL_VER % 10000000) / 100000, (_MSC_FULL_VER % 100000) / 100) -#elif defined(_MSC_FULL_VER) && !defined(__ICL) - #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 1000000, (_MSC_FULL_VER % 1000000) / 10000, (_MSC_FULL_VER % 10000) / 10) -#elif defined(_MSC_VER) && !defined(__ICL) - #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_VER / 100, _MSC_VER % 100, 0) -#endif - -#if defined(JSON_HEDLEY_MSVC_VERSION_CHECK) - #undef JSON_HEDLEY_MSVC_VERSION_CHECK -#endif -#if !defined(JSON_HEDLEY_MSVC_VERSION) - #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (0) -#elif defined(_MSC_VER) && (_MSC_VER >= 1400) - #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 10000000) + (minor * 100000) + (patch))) -#elif defined(_MSC_VER) && (_MSC_VER >= 1200) - #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 1000000) + (minor * 10000) + (patch))) -#else - #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_VER >= ((major * 100) + (minor))) -#endif - -#if defined(JSON_HEDLEY_INTEL_VERSION) - #undef JSON_HEDLEY_INTEL_VERSION -#endif -#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && !defined(__ICL) - #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, __INTEL_COMPILER_UPDATE) -#elif defined(__INTEL_COMPILER) && !defined(__ICL) - #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0) -#endif - -#if defined(JSON_HEDLEY_INTEL_VERSION_CHECK) - #undef JSON_HEDLEY_INTEL_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_INTEL_VERSION) - #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_INTEL_CL_VERSION) - #undef JSON_HEDLEY_INTEL_CL_VERSION -#endif -#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && defined(__ICL) - #define JSON_HEDLEY_INTEL_CL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER, __INTEL_COMPILER_UPDATE, 0) -#endif - -#if defined(JSON_HEDLEY_INTEL_CL_VERSION_CHECK) - #undef JSON_HEDLEY_INTEL_CL_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_INTEL_CL_VERSION) - #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_CL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_PGI_VERSION) - #undef JSON_HEDLEY_PGI_VERSION -#endif -#if defined(__PGI) && defined(__PGIC__) && defined(__PGIC_MINOR__) && defined(__PGIC_PATCHLEVEL__) - #define JSON_HEDLEY_PGI_VERSION JSON_HEDLEY_VERSION_ENCODE(__PGIC__, __PGIC_MINOR__, __PGIC_PATCHLEVEL__) -#endif - -#if defined(JSON_HEDLEY_PGI_VERSION_CHECK) - #undef JSON_HEDLEY_PGI_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_PGI_VERSION) - #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PGI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_SUNPRO_VERSION) - #undef JSON_HEDLEY_SUNPRO_VERSION -#endif -#if defined(__SUNPRO_C) && (__SUNPRO_C > 0x1000) - #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_C >> 16) & 0xf) * 10) + ((__SUNPRO_C >> 12) & 0xf), (((__SUNPRO_C >> 8) & 0xf) * 10) + ((__SUNPRO_C >> 4) & 0xf), (__SUNPRO_C & 0xf) * 10) -#elif defined(__SUNPRO_C) - #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_C >> 8) & 0xf, (__SUNPRO_C >> 4) & 0xf, (__SUNPRO_C) & 0xf) -#elif defined(__SUNPRO_CC) && (__SUNPRO_CC > 0x1000) - #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_CC >> 16) & 0xf) * 10) + ((__SUNPRO_CC >> 12) & 0xf), (((__SUNPRO_CC >> 8) & 0xf) * 10) + ((__SUNPRO_CC >> 4) & 0xf), (__SUNPRO_CC & 0xf) * 10) -#elif defined(__SUNPRO_CC) - #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_CC >> 8) & 0xf, (__SUNPRO_CC >> 4) & 0xf, (__SUNPRO_CC) & 0xf) -#endif - -#if defined(JSON_HEDLEY_SUNPRO_VERSION_CHECK) - #undef JSON_HEDLEY_SUNPRO_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_SUNPRO_VERSION) - #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_SUNPRO_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) - #undef JSON_HEDLEY_EMSCRIPTEN_VERSION -#endif -#if defined(__EMSCRIPTEN__) - #define JSON_HEDLEY_EMSCRIPTEN_VERSION JSON_HEDLEY_VERSION_ENCODE(__EMSCRIPTEN_major__, __EMSCRIPTEN_minor__, __EMSCRIPTEN_tiny__) -#endif - -#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK) - #undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) - #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_EMSCRIPTEN_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_ARM_VERSION) - #undef JSON_HEDLEY_ARM_VERSION -#endif -#if defined(__CC_ARM) && defined(__ARMCOMPILER_VERSION) - #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCOMPILER_VERSION / 1000000, (__ARMCOMPILER_VERSION % 1000000) / 10000, (__ARMCOMPILER_VERSION % 10000) / 100) -#elif defined(__CC_ARM) && defined(__ARMCC_VERSION) - #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCC_VERSION / 1000000, (__ARMCC_VERSION % 1000000) / 10000, (__ARMCC_VERSION % 10000) / 100) -#endif - -#if defined(JSON_HEDLEY_ARM_VERSION_CHECK) - #undef JSON_HEDLEY_ARM_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_ARM_VERSION) - #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_ARM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_IBM_VERSION) - #undef JSON_HEDLEY_IBM_VERSION -#endif -#if defined(__ibmxl__) - #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ibmxl_version__, __ibmxl_release__, __ibmxl_modification__) -#elif defined(__xlC__) && defined(__xlC_ver__) - #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, (__xlC_ver__ >> 8) & 0xff) -#elif defined(__xlC__) - #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, 0) -#endif - -#if defined(JSON_HEDLEY_IBM_VERSION_CHECK) - #undef JSON_HEDLEY_IBM_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_IBM_VERSION) - #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IBM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_VERSION) - #undef JSON_HEDLEY_TI_VERSION -#endif -#if \ - defined(__TI_COMPILER_VERSION__) && \ - ( \ - defined(__TMS470__) || defined(__TI_ARM__) || \ - defined(__MSP430__) || \ - defined(__TMS320C2000__) \ - ) -#if (__TI_COMPILER_VERSION__ >= 16000000) - #define JSON_HEDLEY_TI_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif -#endif - -#if defined(JSON_HEDLEY_TI_VERSION_CHECK) - #undef JSON_HEDLEY_TI_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_VERSION) - #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CL2000_VERSION) - #undef JSON_HEDLEY_TI_CL2000_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C2000__) - #define JSON_HEDLEY_TI_CL2000_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CL2000_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CL2000_VERSION) - #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL2000_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CL430_VERSION) - #undef JSON_HEDLEY_TI_CL430_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__MSP430__) - #define JSON_HEDLEY_TI_CL430_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CL430_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CL430_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CL430_VERSION) - #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL430_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) - #undef JSON_HEDLEY_TI_ARMCL_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && (defined(__TMS470__) || defined(__TI_ARM__)) - #define JSON_HEDLEY_TI_ARMCL_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_ARMCL_VERSION_CHECK) - #undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) - #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_ARMCL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CL6X_VERSION) - #undef JSON_HEDLEY_TI_CL6X_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C6X__) - #define JSON_HEDLEY_TI_CL6X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CL6X_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CL6X_VERSION) - #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL6X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CL7X_VERSION) - #undef JSON_HEDLEY_TI_CL7X_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__C7000__) - #define JSON_HEDLEY_TI_CL7X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CL7X_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CL7X_VERSION) - #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL7X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) - #undef JSON_HEDLEY_TI_CLPRU_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__PRU__) - #define JSON_HEDLEY_TI_CLPRU_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CLPRU_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) - #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CLPRU_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_CRAY_VERSION) - #undef JSON_HEDLEY_CRAY_VERSION -#endif -#if defined(_CRAYC) - #if defined(_RELEASE_PATCHLEVEL) - #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, _RELEASE_PATCHLEVEL) - #else - #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, 0) - #endif -#endif - -#if defined(JSON_HEDLEY_CRAY_VERSION_CHECK) - #undef JSON_HEDLEY_CRAY_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_CRAY_VERSION) - #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_CRAY_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_IAR_VERSION) - #undef JSON_HEDLEY_IAR_VERSION -#endif -#if defined(__IAR_SYSTEMS_ICC__) - #if __VER__ > 1000 - #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE((__VER__ / 1000000), ((__VER__ / 1000) % 1000), (__VER__ % 1000)) - #else - #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE(__VER__ / 100, __VER__ % 100, 0) - #endif -#endif - -#if defined(JSON_HEDLEY_IAR_VERSION_CHECK) - #undef JSON_HEDLEY_IAR_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_IAR_VERSION) - #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IAR_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TINYC_VERSION) - #undef JSON_HEDLEY_TINYC_VERSION -#endif -#if defined(__TINYC__) - #define JSON_HEDLEY_TINYC_VERSION JSON_HEDLEY_VERSION_ENCODE(__TINYC__ / 1000, (__TINYC__ / 100) % 10, __TINYC__ % 100) -#endif - -#if defined(JSON_HEDLEY_TINYC_VERSION_CHECK) - #undef JSON_HEDLEY_TINYC_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TINYC_VERSION) - #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TINYC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_DMC_VERSION) - #undef JSON_HEDLEY_DMC_VERSION -#endif -#if defined(__DMC__) - #define JSON_HEDLEY_DMC_VERSION JSON_HEDLEY_VERSION_ENCODE(__DMC__ >> 8, (__DMC__ >> 4) & 0xf, __DMC__ & 0xf) -#endif - -#if defined(JSON_HEDLEY_DMC_VERSION_CHECK) - #undef JSON_HEDLEY_DMC_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_DMC_VERSION) - #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_DMC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_COMPCERT_VERSION) - #undef JSON_HEDLEY_COMPCERT_VERSION -#endif -#if defined(__COMPCERT_VERSION__) - #define JSON_HEDLEY_COMPCERT_VERSION JSON_HEDLEY_VERSION_ENCODE(__COMPCERT_VERSION__ / 10000, (__COMPCERT_VERSION__ / 100) % 100, __COMPCERT_VERSION__ % 100) -#endif - -#if defined(JSON_HEDLEY_COMPCERT_VERSION_CHECK) - #undef JSON_HEDLEY_COMPCERT_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_COMPCERT_VERSION) - #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_COMPCERT_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_PELLES_VERSION) - #undef JSON_HEDLEY_PELLES_VERSION -#endif -#if defined(__POCC__) - #define JSON_HEDLEY_PELLES_VERSION JSON_HEDLEY_VERSION_ENCODE(__POCC__ / 100, __POCC__ % 100, 0) -#endif - -#if defined(JSON_HEDLEY_PELLES_VERSION_CHECK) - #undef JSON_HEDLEY_PELLES_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_PELLES_VERSION) - #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PELLES_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_MCST_LCC_VERSION) - #undef JSON_HEDLEY_MCST_LCC_VERSION -#endif -#if defined(__LCC__) && defined(__LCC_MINOR__) - #define JSON_HEDLEY_MCST_LCC_VERSION JSON_HEDLEY_VERSION_ENCODE(__LCC__ / 100, __LCC__ % 100, __LCC_MINOR__) -#endif - -#if defined(JSON_HEDLEY_MCST_LCC_VERSION_CHECK) - #undef JSON_HEDLEY_MCST_LCC_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_MCST_LCC_VERSION) - #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_MCST_LCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_GCC_VERSION) - #undef JSON_HEDLEY_GCC_VERSION -#endif -#if \ - defined(JSON_HEDLEY_GNUC_VERSION) && \ - !defined(__clang__) && \ - !defined(JSON_HEDLEY_INTEL_VERSION) && \ - !defined(JSON_HEDLEY_PGI_VERSION) && \ - !defined(JSON_HEDLEY_ARM_VERSION) && \ - !defined(JSON_HEDLEY_CRAY_VERSION) && \ - !defined(JSON_HEDLEY_TI_VERSION) && \ - !defined(JSON_HEDLEY_TI_ARMCL_VERSION) && \ - !defined(JSON_HEDLEY_TI_CL430_VERSION) && \ - !defined(JSON_HEDLEY_TI_CL2000_VERSION) && \ - !defined(JSON_HEDLEY_TI_CL6X_VERSION) && \ - !defined(JSON_HEDLEY_TI_CL7X_VERSION) && \ - !defined(JSON_HEDLEY_TI_CLPRU_VERSION) && \ - !defined(__COMPCERT__) && \ - !defined(JSON_HEDLEY_MCST_LCC_VERSION) - #define JSON_HEDLEY_GCC_VERSION JSON_HEDLEY_GNUC_VERSION -#endif - -#if defined(JSON_HEDLEY_GCC_VERSION_CHECK) - #undef JSON_HEDLEY_GCC_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_GCC_VERSION) - #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_HAS_ATTRIBUTE) - #undef JSON_HEDLEY_HAS_ATTRIBUTE -#endif -#if \ - defined(__has_attribute) && \ - ( \ - (!defined(JSON_HEDLEY_IAR_VERSION) || JSON_HEDLEY_IAR_VERSION_CHECK(8,5,9)) \ - ) -# define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) __has_attribute(attribute) -#else -# define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_ATTRIBUTE) - #undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE -#endif -#if defined(__has_attribute) - #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) -#else - #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_ATTRIBUTE) - #undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE -#endif -#if defined(__has_attribute) - #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) -#else - #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE) - #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE -#endif -#if \ - defined(__has_cpp_attribute) && \ - defined(__cplusplus) && \ - (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) __has_cpp_attribute(attribute) -#else - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) (0) -#endif - -#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS) - #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS -#endif -#if !defined(__cplusplus) || !defined(__has_cpp_attribute) - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) -#elif \ - !defined(JSON_HEDLEY_PGI_VERSION) && \ - !defined(JSON_HEDLEY_IAR_VERSION) && \ - (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) && \ - (!defined(JSON_HEDLEY_MSVC_VERSION) || JSON_HEDLEY_MSVC_VERSION_CHECK(19,20,0)) - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(ns::attribute) -#else - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE) - #undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE -#endif -#if defined(__has_cpp_attribute) && defined(__cplusplus) - #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) -#else - #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE) - #undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE -#endif -#if defined(__has_cpp_attribute) && defined(__cplusplus) - #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) -#else - #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_BUILTIN) - #undef JSON_HEDLEY_HAS_BUILTIN -#endif -#if defined(__has_builtin) - #define JSON_HEDLEY_HAS_BUILTIN(builtin) __has_builtin(builtin) -#else - #define JSON_HEDLEY_HAS_BUILTIN(builtin) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_BUILTIN) - #undef JSON_HEDLEY_GNUC_HAS_BUILTIN -#endif -#if defined(__has_builtin) - #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) -#else - #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_BUILTIN) - #undef JSON_HEDLEY_GCC_HAS_BUILTIN -#endif -#if defined(__has_builtin) - #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) -#else - #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_FEATURE) - #undef JSON_HEDLEY_HAS_FEATURE -#endif -#if defined(__has_feature) - #define JSON_HEDLEY_HAS_FEATURE(feature) __has_feature(feature) -#else - #define JSON_HEDLEY_HAS_FEATURE(feature) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_FEATURE) - #undef JSON_HEDLEY_GNUC_HAS_FEATURE -#endif -#if defined(__has_feature) - #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) -#else - #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_FEATURE) - #undef JSON_HEDLEY_GCC_HAS_FEATURE -#endif -#if defined(__has_feature) - #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) -#else - #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_EXTENSION) - #undef JSON_HEDLEY_HAS_EXTENSION -#endif -#if defined(__has_extension) - #define JSON_HEDLEY_HAS_EXTENSION(extension) __has_extension(extension) -#else - #define JSON_HEDLEY_HAS_EXTENSION(extension) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_EXTENSION) - #undef JSON_HEDLEY_GNUC_HAS_EXTENSION -#endif -#if defined(__has_extension) - #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) -#else - #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_EXTENSION) - #undef JSON_HEDLEY_GCC_HAS_EXTENSION -#endif -#if defined(__has_extension) - #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) -#else - #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE) - #undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE -#endif -#if defined(__has_declspec_attribute) - #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) __has_declspec_attribute(attribute) -#else - #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE) - #undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE -#endif -#if defined(__has_declspec_attribute) - #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) -#else - #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE) - #undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE -#endif -#if defined(__has_declspec_attribute) - #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) -#else - #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_WARNING) - #undef JSON_HEDLEY_HAS_WARNING -#endif -#if defined(__has_warning) - #define JSON_HEDLEY_HAS_WARNING(warning) __has_warning(warning) -#else - #define JSON_HEDLEY_HAS_WARNING(warning) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_WARNING) - #undef JSON_HEDLEY_GNUC_HAS_WARNING -#endif -#if defined(__has_warning) - #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) -#else - #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_WARNING) - #undef JSON_HEDLEY_GCC_HAS_WARNING -#endif -#if defined(__has_warning) - #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) -#else - #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if \ - (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ - defined(__clang__) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,0,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) || \ - JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,17) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(8,0,0) || \ - (JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) && defined(__C99_PRAGMA_OPERATOR)) - #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value) -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_PRAGMA(value) __pragma(value) -#else - #define JSON_HEDLEY_PRAGMA(value) -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_PUSH) - #undef JSON_HEDLEY_DIAGNOSTIC_PUSH -#endif -#if defined(JSON_HEDLEY_DIAGNOSTIC_POP) - #undef JSON_HEDLEY_DIAGNOSTIC_POP -#endif -#if defined(__clang__) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("clang diagnostic push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("clang diagnostic pop") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push)) - #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop)) -#elif JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("pop") -#elif \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,4,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("diag_push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("diag_pop") -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") -#else - #define JSON_HEDLEY_DIAGNOSTIC_PUSH - #define JSON_HEDLEY_DIAGNOSTIC_POP -#endif - -/* JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ is for - HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ -#endif -#if defined(__cplusplus) -# if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat") -# if JSON_HEDLEY_HAS_WARNING("-Wc++17-extensions") -# if JSON_HEDLEY_HAS_WARNING("-Wc++1z-extensions") -# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ - _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ - _Pragma("clang diagnostic ignored \"-Wc++1z-extensions\"") \ - xpr \ - JSON_HEDLEY_DIAGNOSTIC_POP -# else -# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ - _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ - xpr \ - JSON_HEDLEY_DIAGNOSTIC_POP -# endif -# else -# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ - xpr \ - JSON_HEDLEY_DIAGNOSTIC_POP -# endif -# endif -#endif -#if !defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(x) x -#endif - -#if defined(JSON_HEDLEY_CONST_CAST) - #undef JSON_HEDLEY_CONST_CAST -#endif -#if defined(__cplusplus) -# define JSON_HEDLEY_CONST_CAST(T, expr) (const_cast(expr)) -#elif \ - JSON_HEDLEY_HAS_WARNING("-Wcast-qual") || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) -# define JSON_HEDLEY_CONST_CAST(T, expr) (__extension__ ({ \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \ - ((T) (expr)); \ - JSON_HEDLEY_DIAGNOSTIC_POP \ - })) -#else -# define JSON_HEDLEY_CONST_CAST(T, expr) ((T) (expr)) -#endif - -#if defined(JSON_HEDLEY_REINTERPRET_CAST) - #undef JSON_HEDLEY_REINTERPRET_CAST -#endif -#if defined(__cplusplus) - #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) (reinterpret_cast(expr)) -#else - #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) ((T) (expr)) -#endif - -#if defined(JSON_HEDLEY_STATIC_CAST) - #undef JSON_HEDLEY_STATIC_CAST -#endif -#if defined(__cplusplus) - #define JSON_HEDLEY_STATIC_CAST(T, expr) (static_cast(expr)) -#else - #define JSON_HEDLEY_STATIC_CAST(T, expr) ((T) (expr)) -#endif - -#if defined(JSON_HEDLEY_CPP_CAST) - #undef JSON_HEDLEY_CPP_CAST -#endif -#if defined(__cplusplus) -# if JSON_HEDLEY_HAS_WARNING("-Wold-style-cast") -# define JSON_HEDLEY_CPP_CAST(T, expr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wold-style-cast\"") \ - ((T) (expr)) \ - JSON_HEDLEY_DIAGNOSTIC_POP -# elif JSON_HEDLEY_IAR_VERSION_CHECK(8,3,0) -# define JSON_HEDLEY_CPP_CAST(T, expr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("diag_suppress=Pe137") \ - JSON_HEDLEY_DIAGNOSTIC_POP -# else -# define JSON_HEDLEY_CPP_CAST(T, expr) ((T) (expr)) -# endif -#else -# define JSON_HEDLEY_CPP_CAST(T, expr) (expr) -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wdeprecated-declarations") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warning(disable:1478 1786)") -#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:1478 1786)) -#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1216,1444,1445") -#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996)) -#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") -#elif \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1291,1718") -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && !defined(__cplusplus) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)") -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && defined(__cplusplus) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,symdeprecated,symdeprecated2)") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress=Pe1444,Pe1215") -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warn(disable:2241)") -#else - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("clang diagnostic ignored \"-Wunknown-pragmas\"") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("warning(disable:161)") -#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:161)) -#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 1675") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("GCC diagnostic ignored \"-Wunknown-pragmas\"") -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068)) -#elif \ - JSON_HEDLEY_TI_VERSION_CHECK(16,9,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") -#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress=Pe161") -#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 161") -#else - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wunknown-attributes") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("clang diagnostic ignored \"-Wunknown-attributes\"") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("warning(disable:1292)") -#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:1292)) -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030)) -#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097,1098") -#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("error_messages(off,attrskipunsup)") -#elif \ - JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1173") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress=Pe1097") -#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") -#else - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wcast-qual") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("clang diagnostic ignored \"-Wcast-qual\"") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("warning(disable:2203 2331)") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("GCC diagnostic ignored \"-Wcast-qual\"") -#else - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wunused-function") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("clang diagnostic ignored \"-Wunused-function\"") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("GCC diagnostic ignored \"-Wunused-function\"") -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(1,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION __pragma(warning(disable:4505)) -#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("diag_suppress 3142") -#else - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION -#endif - -#if defined(JSON_HEDLEY_DEPRECATED) - #undef JSON_HEDLEY_DEPRECATED -#endif -#if defined(JSON_HEDLEY_DEPRECATED_FOR) - #undef JSON_HEDLEY_DEPRECATED_FOR -#endif -#if \ - JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated("Since " # since)) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated("Since " #since "; use " #replacement)) -#elif \ - (JSON_HEDLEY_HAS_EXTENSION(attribute_deprecated_with_message) && !defined(JSON_HEDLEY_IAR_VERSION)) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(18,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__("Since " #since))) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__("Since " #since "; use " #replacement))) -#elif defined(__cplusplus) && (__cplusplus >= 201402L) - #define JSON_HEDLEY_DEPRECATED(since) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since)]]) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since "; use " #replacement)]]) -#elif \ - JSON_HEDLEY_HAS_ATTRIBUTE(deprecated) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ - JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) - #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__)) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__)) -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ - JSON_HEDLEY_PELLES_VERSION_CHECK(6,50,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated) -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DEPRECATED(since) _Pragma("deprecated") - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) _Pragma("deprecated") -#else - #define JSON_HEDLEY_DEPRECATED(since) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) -#endif - -#if defined(JSON_HEDLEY_UNAVAILABLE) - #undef JSON_HEDLEY_UNAVAILABLE -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(warning) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_UNAVAILABLE(available_since) __attribute__((__warning__("Not available until " #available_since))) -#else - #define JSON_HEDLEY_UNAVAILABLE(available_since) -#endif - -#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT) - #undef JSON_HEDLEY_WARN_UNUSED_RESULT -#endif -#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT_MSG) - #undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(warn_unused_result) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_WARN_UNUSED_RESULT __attribute__((__warn_unused_result__)) - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) __attribute__((__warn_unused_result__)) -#elif (JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) >= 201907L) - #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard(msg)]]) -#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) - #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) -#elif defined(_Check_return_) /* SAL */ - #define JSON_HEDLEY_WARN_UNUSED_RESULT _Check_return_ - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) _Check_return_ -#else - #define JSON_HEDLEY_WARN_UNUSED_RESULT - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) -#endif - -#if defined(JSON_HEDLEY_SENTINEL) - #undef JSON_HEDLEY_SENTINEL -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(sentinel) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_SENTINEL(position) __attribute__((__sentinel__(position))) -#else - #define JSON_HEDLEY_SENTINEL(position) -#endif - -#if defined(JSON_HEDLEY_NO_RETURN) - #undef JSON_HEDLEY_NO_RETURN -#endif -#if JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_NO_RETURN __noreturn -#elif \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) -#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L - #define JSON_HEDLEY_NO_RETURN _Noreturn -#elif defined(__cplusplus) && (__cplusplus >= 201103L) - #define JSON_HEDLEY_NO_RETURN JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[noreturn]]) -#elif \ - JSON_HEDLEY_HAS_ATTRIBUTE(noreturn) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,2,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) - #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) - #define JSON_HEDLEY_NO_RETURN _Pragma("does_not_return") -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) -#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) - #define JSON_HEDLEY_NO_RETURN _Pragma("FUNC_NEVER_RETURNS;") -#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) - #define JSON_HEDLEY_NO_RETURN __attribute((noreturn)) -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) - #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) -#else - #define JSON_HEDLEY_NO_RETURN -#endif - -#if defined(JSON_HEDLEY_NO_ESCAPE) - #undef JSON_HEDLEY_NO_ESCAPE -#endif -#if JSON_HEDLEY_HAS_ATTRIBUTE(noescape) - #define JSON_HEDLEY_NO_ESCAPE __attribute__((__noescape__)) -#else - #define JSON_HEDLEY_NO_ESCAPE -#endif - -#if defined(JSON_HEDLEY_UNREACHABLE) - #undef JSON_HEDLEY_UNREACHABLE -#endif -#if defined(JSON_HEDLEY_UNREACHABLE_RETURN) - #undef JSON_HEDLEY_UNREACHABLE_RETURN -#endif -#if defined(JSON_HEDLEY_ASSUME) - #undef JSON_HEDLEY_ASSUME -#endif -#if \ - JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_ASSUME(expr) __assume(expr) -#elif JSON_HEDLEY_HAS_BUILTIN(__builtin_assume) - #define JSON_HEDLEY_ASSUME(expr) __builtin_assume(expr) -#elif \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) - #if defined(__cplusplus) - #define JSON_HEDLEY_ASSUME(expr) std::_nassert(expr) - #else - #define JSON_HEDLEY_ASSUME(expr) _nassert(expr) - #endif -#endif -#if \ - (JSON_HEDLEY_HAS_BUILTIN(__builtin_unreachable) && (!defined(JSON_HEDLEY_ARM_VERSION))) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(18,10,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(13,1,5) || \ - JSON_HEDLEY_CRAY_VERSION_CHECK(10,0,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_UNREACHABLE() __builtin_unreachable() -#elif defined(JSON_HEDLEY_ASSUME) - #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) -#endif -#if !defined(JSON_HEDLEY_ASSUME) - #if defined(JSON_HEDLEY_UNREACHABLE) - #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, ((expr) ? 1 : (JSON_HEDLEY_UNREACHABLE(), 1))) - #else - #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, expr) - #endif -#endif -#if defined(JSON_HEDLEY_UNREACHABLE) - #if \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) - #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (JSON_HEDLEY_STATIC_CAST(void, JSON_HEDLEY_ASSUME(0)), (value)) - #else - #define JSON_HEDLEY_UNREACHABLE_RETURN(value) JSON_HEDLEY_UNREACHABLE() - #endif -#else - #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (value) -#endif -#if !defined(JSON_HEDLEY_UNREACHABLE) - #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) -#endif - -JSON_HEDLEY_DIAGNOSTIC_PUSH -#if JSON_HEDLEY_HAS_WARNING("-Wpedantic") - #pragma clang diagnostic ignored "-Wpedantic" -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat-pedantic") && defined(__cplusplus) - #pragma clang diagnostic ignored "-Wc++98-compat-pedantic" -#endif -#if JSON_HEDLEY_GCC_HAS_WARNING("-Wvariadic-macros",4,0,0) - #if defined(__clang__) - #pragma clang diagnostic ignored "-Wvariadic-macros" - #elif defined(JSON_HEDLEY_GCC_VERSION) - #pragma GCC diagnostic ignored "-Wvariadic-macros" - #endif -#endif -#if defined(JSON_HEDLEY_NON_NULL) - #undef JSON_HEDLEY_NON_NULL -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(nonnull) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) - #define JSON_HEDLEY_NON_NULL(...) __attribute__((__nonnull__(__VA_ARGS__))) -#else - #define JSON_HEDLEY_NON_NULL(...) -#endif -JSON_HEDLEY_DIAGNOSTIC_POP - -#if defined(JSON_HEDLEY_PRINTF_FORMAT) - #undef JSON_HEDLEY_PRINTF_FORMAT -#endif -#if defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && !defined(__USE_MINGW_ANSI_STDIO) - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(ms_printf, string_idx, first_to_check))) -#elif defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && defined(__USE_MINGW_ANSI_STDIO) - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(gnu_printf, string_idx, first_to_check))) -#elif \ - JSON_HEDLEY_HAS_ATTRIBUTE(format) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(__printf__, string_idx, first_to_check))) -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(6,0,0) - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __declspec(vaformat(printf,string_idx,first_to_check)) -#else - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) -#endif - -#if defined(JSON_HEDLEY_CONSTEXPR) - #undef JSON_HEDLEY_CONSTEXPR -#endif -#if defined(__cplusplus) - #if __cplusplus >= 201103L - #define JSON_HEDLEY_CONSTEXPR JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(constexpr) - #endif -#endif -#if !defined(JSON_HEDLEY_CONSTEXPR) - #define JSON_HEDLEY_CONSTEXPR -#endif - -#if defined(JSON_HEDLEY_PREDICT) - #undef JSON_HEDLEY_PREDICT -#endif -#if defined(JSON_HEDLEY_LIKELY) - #undef JSON_HEDLEY_LIKELY -#endif -#if defined(JSON_HEDLEY_UNLIKELY) - #undef JSON_HEDLEY_UNLIKELY -#endif -#if defined(JSON_HEDLEY_UNPREDICTABLE) - #undef JSON_HEDLEY_UNPREDICTABLE -#endif -#if JSON_HEDLEY_HAS_BUILTIN(__builtin_unpredictable) - #define JSON_HEDLEY_UNPREDICTABLE(expr) __builtin_unpredictable((expr)) -#endif -#if \ - (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect_with_probability) && !defined(JSON_HEDLEY_PGI_VERSION)) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(9,0,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) -# define JSON_HEDLEY_PREDICT(expr, value, probability) __builtin_expect_with_probability( (expr), (value), (probability)) -# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) __builtin_expect_with_probability(!!(expr), 1 , (probability)) -# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) __builtin_expect_with_probability(!!(expr), 0 , (probability)) -# define JSON_HEDLEY_LIKELY(expr) __builtin_expect (!!(expr), 1 ) -# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect (!!(expr), 0 ) -#elif \ - (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,27) || \ - JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) -# define JSON_HEDLEY_PREDICT(expr, expected, probability) \ - (((probability) >= 0.9) ? __builtin_expect((expr), (expected)) : (JSON_HEDLEY_STATIC_CAST(void, expected), (expr))) -# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) \ - (__extension__ ({ \ - double hedley_probability_ = (probability); \ - ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 1) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 0) : !!(expr))); \ - })) -# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) \ - (__extension__ ({ \ - double hedley_probability_ = (probability); \ - ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 0) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 1) : !!(expr))); \ - })) -# define JSON_HEDLEY_LIKELY(expr) __builtin_expect(!!(expr), 1) -# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) -#else -# define JSON_HEDLEY_PREDICT(expr, expected, probability) (JSON_HEDLEY_STATIC_CAST(void, expected), (expr)) -# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) (!!(expr)) -# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) (!!(expr)) -# define JSON_HEDLEY_LIKELY(expr) (!!(expr)) -# define JSON_HEDLEY_UNLIKELY(expr) (!!(expr)) -#endif -#if !defined(JSON_HEDLEY_UNPREDICTABLE) - #define JSON_HEDLEY_UNPREDICTABLE(expr) JSON_HEDLEY_PREDICT(expr, 1, 0.5) -#endif - -#if defined(JSON_HEDLEY_MALLOC) - #undef JSON_HEDLEY_MALLOC -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(malloc) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_MALLOC __attribute__((__malloc__)) -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) - #define JSON_HEDLEY_MALLOC _Pragma("returns_new_memory") -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_MALLOC __declspec(restrict) -#else - #define JSON_HEDLEY_MALLOC -#endif - -#if defined(JSON_HEDLEY_PURE) - #undef JSON_HEDLEY_PURE -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(pure) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(2,96,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) -# define JSON_HEDLEY_PURE __attribute__((__pure__)) -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) -# define JSON_HEDLEY_PURE _Pragma("does_not_write_global_data") -#elif defined(__cplusplus) && \ - ( \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) \ - ) -# define JSON_HEDLEY_PURE _Pragma("FUNC_IS_PURE;") -#else -# define JSON_HEDLEY_PURE -#endif - -#if defined(JSON_HEDLEY_CONST) - #undef JSON_HEDLEY_CONST -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(const) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(2,5,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_CONST __attribute__((__const__)) -#elif \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) - #define JSON_HEDLEY_CONST _Pragma("no_side_effect") -#else - #define JSON_HEDLEY_CONST JSON_HEDLEY_PURE -#endif - -#if defined(JSON_HEDLEY_RESTRICT) - #undef JSON_HEDLEY_RESTRICT -#endif -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && !defined(__cplusplus) - #define JSON_HEDLEY_RESTRICT restrict -#elif \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,4) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)) || \ - JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ - defined(__clang__) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_RESTRICT __restrict -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,3,0) && !defined(__cplusplus) - #define JSON_HEDLEY_RESTRICT _Restrict -#else - #define JSON_HEDLEY_RESTRICT -#endif - -#if defined(JSON_HEDLEY_INLINE) - #undef JSON_HEDLEY_INLINE -#endif -#if \ - (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ - (defined(__cplusplus) && (__cplusplus >= 199711L)) - #define JSON_HEDLEY_INLINE inline -#elif \ - defined(JSON_HEDLEY_GCC_VERSION) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(6,2,0) - #define JSON_HEDLEY_INLINE __inline__ -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,1,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_INLINE __inline -#else - #define JSON_HEDLEY_INLINE -#endif - -#if defined(JSON_HEDLEY_ALWAYS_INLINE) - #undef JSON_HEDLEY_ALWAYS_INLINE -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(always_inline) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ - JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) -# define JSON_HEDLEY_ALWAYS_INLINE __attribute__((__always_inline__)) JSON_HEDLEY_INLINE -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) -# define JSON_HEDLEY_ALWAYS_INLINE __forceinline -#elif defined(__cplusplus) && \ - ( \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) \ - ) -# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("FUNC_ALWAYS_INLINE;") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) -# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("inline=forced") -#else -# define JSON_HEDLEY_ALWAYS_INLINE JSON_HEDLEY_INLINE -#endif - -#if defined(JSON_HEDLEY_NEVER_INLINE) - #undef JSON_HEDLEY_NEVER_INLINE -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(noinline) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ - JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) - #define JSON_HEDLEY_NEVER_INLINE __attribute__((__noinline__)) -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) -#elif JSON_HEDLEY_PGI_VERSION_CHECK(10,2,0) - #define JSON_HEDLEY_NEVER_INLINE _Pragma("noinline") -#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) - #define JSON_HEDLEY_NEVER_INLINE _Pragma("FUNC_CANNOT_INLINE;") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_NEVER_INLINE _Pragma("inline=never") -#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) - #define JSON_HEDLEY_NEVER_INLINE __attribute((noinline)) -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) - #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) -#else - #define JSON_HEDLEY_NEVER_INLINE -#endif - -#if defined(JSON_HEDLEY_PRIVATE) - #undef JSON_HEDLEY_PRIVATE -#endif -#if defined(JSON_HEDLEY_PUBLIC) - #undef JSON_HEDLEY_PUBLIC -#endif -#if defined(JSON_HEDLEY_IMPORT) - #undef JSON_HEDLEY_IMPORT -#endif -#if defined(_WIN32) || defined(__CYGWIN__) -# define JSON_HEDLEY_PRIVATE -# define JSON_HEDLEY_PUBLIC __declspec(dllexport) -# define JSON_HEDLEY_IMPORT __declspec(dllimport) -#else -# if \ - JSON_HEDLEY_HAS_ATTRIBUTE(visibility) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ - ( \ - defined(__TI_EABI__) && \ - ( \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) \ - ) \ - ) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) -# define JSON_HEDLEY_PRIVATE __attribute__((__visibility__("hidden"))) -# define JSON_HEDLEY_PUBLIC __attribute__((__visibility__("default"))) -# else -# define JSON_HEDLEY_PRIVATE -# define JSON_HEDLEY_PUBLIC -# endif -# define JSON_HEDLEY_IMPORT extern -#endif - -#if defined(JSON_HEDLEY_NO_THROW) - #undef JSON_HEDLEY_NO_THROW -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(nothrow) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_NO_THROW __attribute__((__nothrow__)) -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(13,1,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) - #define JSON_HEDLEY_NO_THROW __declspec(nothrow) -#else - #define JSON_HEDLEY_NO_THROW -#endif - -#if defined(JSON_HEDLEY_FALL_THROUGH) - #undef JSON_HEDLEY_FALL_THROUGH -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(fallthrough) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(7,0,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_FALL_THROUGH __attribute__((__fallthrough__)) -#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(clang,fallthrough) - #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[clang::fallthrough]]) -#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(fallthrough) - #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[fallthrough]]) -#elif defined(__fallthrough) /* SAL */ - #define JSON_HEDLEY_FALL_THROUGH __fallthrough -#else - #define JSON_HEDLEY_FALL_THROUGH -#endif - -#if defined(JSON_HEDLEY_RETURNS_NON_NULL) - #undef JSON_HEDLEY_RETURNS_NON_NULL -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(returns_nonnull) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_RETURNS_NON_NULL __attribute__((__returns_nonnull__)) -#elif defined(_Ret_notnull_) /* SAL */ - #define JSON_HEDLEY_RETURNS_NON_NULL _Ret_notnull_ -#else - #define JSON_HEDLEY_RETURNS_NON_NULL -#endif - -#if defined(JSON_HEDLEY_ARRAY_PARAM) - #undef JSON_HEDLEY_ARRAY_PARAM -#endif -#if \ - defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && \ - !defined(__STDC_NO_VLA__) && \ - !defined(__cplusplus) && \ - !defined(JSON_HEDLEY_PGI_VERSION) && \ - !defined(JSON_HEDLEY_TINYC_VERSION) - #define JSON_HEDLEY_ARRAY_PARAM(name) (name) -#else - #define JSON_HEDLEY_ARRAY_PARAM(name) -#endif - -#if defined(JSON_HEDLEY_IS_CONSTANT) - #undef JSON_HEDLEY_IS_CONSTANT -#endif -#if defined(JSON_HEDLEY_REQUIRE_CONSTEXPR) - #undef JSON_HEDLEY_REQUIRE_CONSTEXPR -#endif -/* JSON_HEDLEY_IS_CONSTEXPR_ is for - HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ -#if defined(JSON_HEDLEY_IS_CONSTEXPR_) - #undef JSON_HEDLEY_IS_CONSTEXPR_ -#endif -#if \ - JSON_HEDLEY_HAS_BUILTIN(__builtin_constant_p) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,19) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ - (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) && !defined(__cplusplus)) || \ - JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ - JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) - #define JSON_HEDLEY_IS_CONSTANT(expr) __builtin_constant_p(expr) -#endif -#if !defined(__cplusplus) -# if \ - JSON_HEDLEY_HAS_BUILTIN(__builtin_types_compatible_p) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ - JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ - JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,24) -#if defined(__INTPTR_TYPE__) - #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0)), int*) -#else - #include - #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((intptr_t) ((expr) * 0)) : (int*) 0)), int*) -#endif -# elif \ - ( \ - defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) && \ - !defined(JSON_HEDLEY_SUNPRO_VERSION) && \ - !defined(JSON_HEDLEY_PGI_VERSION) && \ - !defined(JSON_HEDLEY_IAR_VERSION)) || \ - (JSON_HEDLEY_HAS_EXTENSION(c_generic_selections) && !defined(JSON_HEDLEY_IAR_VERSION)) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,3,0) -#if defined(__INTPTR_TYPE__) - #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0), int*: 1, void*: 0) -#else - #include - #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((intptr_t) * 0) : (int*) 0), int*: 1, void*: 0) -#endif -# elif \ - defined(JSON_HEDLEY_GCC_VERSION) || \ - defined(JSON_HEDLEY_INTEL_VERSION) || \ - defined(JSON_HEDLEY_TINYC_VERSION) || \ - defined(JSON_HEDLEY_TI_ARMCL_VERSION) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(18,12,0) || \ - defined(JSON_HEDLEY_TI_CL2000_VERSION) || \ - defined(JSON_HEDLEY_TI_CL6X_VERSION) || \ - defined(JSON_HEDLEY_TI_CL7X_VERSION) || \ - defined(JSON_HEDLEY_TI_CLPRU_VERSION) || \ - defined(__clang__) -# define JSON_HEDLEY_IS_CONSTEXPR_(expr) ( \ - sizeof(void) != \ - sizeof(*( \ - 1 ? \ - ((void*) ((expr) * 0L) ) : \ -((struct { char v[sizeof(void) * 2]; } *) 1) \ - ) \ - ) \ - ) -# endif -#endif -#if defined(JSON_HEDLEY_IS_CONSTEXPR_) - #if !defined(JSON_HEDLEY_IS_CONSTANT) - #define JSON_HEDLEY_IS_CONSTANT(expr) JSON_HEDLEY_IS_CONSTEXPR_(expr) - #endif - #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (JSON_HEDLEY_IS_CONSTEXPR_(expr) ? (expr) : (-1)) -#else - #if !defined(JSON_HEDLEY_IS_CONSTANT) - #define JSON_HEDLEY_IS_CONSTANT(expr) (0) - #endif - #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (expr) -#endif - -#if defined(JSON_HEDLEY_BEGIN_C_DECLS) - #undef JSON_HEDLEY_BEGIN_C_DECLS -#endif -#if defined(JSON_HEDLEY_END_C_DECLS) - #undef JSON_HEDLEY_END_C_DECLS -#endif -#if defined(JSON_HEDLEY_C_DECL) - #undef JSON_HEDLEY_C_DECL -#endif -#if defined(__cplusplus) - #define JSON_HEDLEY_BEGIN_C_DECLS extern "C" { - #define JSON_HEDLEY_END_C_DECLS } - #define JSON_HEDLEY_C_DECL extern "C" -#else - #define JSON_HEDLEY_BEGIN_C_DECLS - #define JSON_HEDLEY_END_C_DECLS - #define JSON_HEDLEY_C_DECL -#endif - -#if defined(JSON_HEDLEY_STATIC_ASSERT) - #undef JSON_HEDLEY_STATIC_ASSERT -#endif -#if \ - !defined(__cplusplus) && ( \ - (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) || \ - (JSON_HEDLEY_HAS_FEATURE(c_static_assert) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(6,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - defined(_Static_assert) \ - ) -# define JSON_HEDLEY_STATIC_ASSERT(expr, message) _Static_assert(expr, message) -#elif \ - (defined(__cplusplus) && (__cplusplus >= 201103L)) || \ - JSON_HEDLEY_MSVC_VERSION_CHECK(16,0,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) -# define JSON_HEDLEY_STATIC_ASSERT(expr, message) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(static_assert(expr, message)) -#else -# define JSON_HEDLEY_STATIC_ASSERT(expr, message) -#endif - -#if defined(JSON_HEDLEY_NULL) - #undef JSON_HEDLEY_NULL -#endif -#if defined(__cplusplus) - #if __cplusplus >= 201103L - #define JSON_HEDLEY_NULL JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(nullptr) - #elif defined(NULL) - #define JSON_HEDLEY_NULL NULL - #else - #define JSON_HEDLEY_NULL JSON_HEDLEY_STATIC_CAST(void*, 0) - #endif -#elif defined(NULL) - #define JSON_HEDLEY_NULL NULL -#else - #define JSON_HEDLEY_NULL ((void*) 0) -#endif - -#if defined(JSON_HEDLEY_MESSAGE) - #undef JSON_HEDLEY_MESSAGE -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") -# define JSON_HEDLEY_MESSAGE(msg) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ - JSON_HEDLEY_PRAGMA(message msg) \ - JSON_HEDLEY_DIAGNOSTIC_POP -#elif \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) -# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message msg) -#elif JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) -# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(_CRI message msg) -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) -# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,0,0) -# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) -#else -# define JSON_HEDLEY_MESSAGE(msg) -#endif - -#if defined(JSON_HEDLEY_WARNING) - #undef JSON_HEDLEY_WARNING -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") -# define JSON_HEDLEY_WARNING(msg) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ - JSON_HEDLEY_PRAGMA(clang warning msg) \ - JSON_HEDLEY_DIAGNOSTIC_POP -#elif \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,8,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) -# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(GCC warning msg) -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) -# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(message(msg)) -#else -# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_MESSAGE(msg) -#endif - -#if defined(JSON_HEDLEY_REQUIRE) - #undef JSON_HEDLEY_REQUIRE -#endif -#if defined(JSON_HEDLEY_REQUIRE_MSG) - #undef JSON_HEDLEY_REQUIRE_MSG -#endif -#if JSON_HEDLEY_HAS_ATTRIBUTE(diagnose_if) -# if JSON_HEDLEY_HAS_WARNING("-Wgcc-compat") -# define JSON_HEDLEY_REQUIRE(expr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ - __attribute__((diagnose_if(!(expr), #expr, "error"))) \ - JSON_HEDLEY_DIAGNOSTIC_POP -# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ - __attribute__((diagnose_if(!(expr), msg, "error"))) \ - JSON_HEDLEY_DIAGNOSTIC_POP -# else -# define JSON_HEDLEY_REQUIRE(expr) __attribute__((diagnose_if(!(expr), #expr, "error"))) -# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) __attribute__((diagnose_if(!(expr), msg, "error"))) -# endif -#else -# define JSON_HEDLEY_REQUIRE(expr) -# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) -#endif - -#if defined(JSON_HEDLEY_FLAGS) - #undef JSON_HEDLEY_FLAGS -#endif -#if JSON_HEDLEY_HAS_ATTRIBUTE(flag_enum) && (!defined(__cplusplus) || JSON_HEDLEY_HAS_WARNING("-Wbitfield-enum-conversion")) - #define JSON_HEDLEY_FLAGS __attribute__((__flag_enum__)) -#else - #define JSON_HEDLEY_FLAGS -#endif - -#if defined(JSON_HEDLEY_FLAGS_CAST) - #undef JSON_HEDLEY_FLAGS_CAST -#endif -#if JSON_HEDLEY_INTEL_VERSION_CHECK(19,0,0) -# define JSON_HEDLEY_FLAGS_CAST(T, expr) (__extension__ ({ \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("warning(disable:188)") \ - ((T) (expr)); \ - JSON_HEDLEY_DIAGNOSTIC_POP \ - })) -#else -# define JSON_HEDLEY_FLAGS_CAST(T, expr) JSON_HEDLEY_STATIC_CAST(T, expr) -#endif - -#if defined(JSON_HEDLEY_EMPTY_BASES) - #undef JSON_HEDLEY_EMPTY_BASES -#endif -#if \ - (JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,23918) && !JSON_HEDLEY_MSVC_VERSION_CHECK(20,0,0)) || \ - JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) - #define JSON_HEDLEY_EMPTY_BASES __declspec(empty_bases) -#else - #define JSON_HEDLEY_EMPTY_BASES -#endif - -/* Remaining macros are deprecated. */ - -#if defined(JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK) - #undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK -#endif -#if defined(__clang__) - #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) (0) -#else - #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_CLANG_HAS_ATTRIBUTE) - #undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE -#endif -#define JSON_HEDLEY_CLANG_HAS_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) - -#if defined(JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE) - #undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE -#endif -#define JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) - -#if defined(JSON_HEDLEY_CLANG_HAS_BUILTIN) - #undef JSON_HEDLEY_CLANG_HAS_BUILTIN -#endif -#define JSON_HEDLEY_CLANG_HAS_BUILTIN(builtin) JSON_HEDLEY_HAS_BUILTIN(builtin) - -#if defined(JSON_HEDLEY_CLANG_HAS_FEATURE) - #undef JSON_HEDLEY_CLANG_HAS_FEATURE -#endif -#define JSON_HEDLEY_CLANG_HAS_FEATURE(feature) JSON_HEDLEY_HAS_FEATURE(feature) - -#if defined(JSON_HEDLEY_CLANG_HAS_EXTENSION) - #undef JSON_HEDLEY_CLANG_HAS_EXTENSION -#endif -#define JSON_HEDLEY_CLANG_HAS_EXTENSION(extension) JSON_HEDLEY_HAS_EXTENSION(extension) - -#if defined(JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE) - #undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE -#endif -#define JSON_HEDLEY_CLANG_HAS_DECLSPEC_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) - -#if defined(JSON_HEDLEY_CLANG_HAS_WARNING) - #undef JSON_HEDLEY_CLANG_HAS_WARNING -#endif -#define JSON_HEDLEY_CLANG_HAS_WARNING(warning) JSON_HEDLEY_HAS_WARNING(warning) - -#endif /* !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < X) */ - - -// This file contains all internal macro definitions (except those affecting ABI) -// You MUST include macro_unscope.hpp at the end of json.hpp to undef all of them - -// #include - - -// exclude unsupported compilers -#if !defined(JSON_SKIP_UNSUPPORTED_COMPILER_CHECK) - #if defined(__clang__) - #if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400 - #error "unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers" - #endif - #elif defined(__GNUC__) && !(defined(__ICC) || defined(__INTEL_COMPILER)) - #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40800 - #error "unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers" - #endif - #endif -#endif - -// C++ language standard detection -// if the user manually specified the used c++ version this is skipped -#if !defined(JSON_HAS_CPP_20) && !defined(JSON_HAS_CPP_17) && !defined(JSON_HAS_CPP_14) && !defined(JSON_HAS_CPP_11) - #if (defined(__cplusplus) && __cplusplus >= 202002L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) - #define JSON_HAS_CPP_20 - #define JSON_HAS_CPP_17 - #define JSON_HAS_CPP_14 - #elif (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464 - #define JSON_HAS_CPP_17 - #define JSON_HAS_CPP_14 - #elif (defined(__cplusplus) && __cplusplus >= 201402L) || (defined(_HAS_CXX14) && _HAS_CXX14 == 1) - #define JSON_HAS_CPP_14 - #endif - // the cpp 11 flag is always specified because it is the minimal required version - #define JSON_HAS_CPP_11 -#endif - -#ifdef __has_include - #if __has_include() - #include - #endif -#endif - -#if !defined(JSON_HAS_FILESYSTEM) && !defined(JSON_HAS_EXPERIMENTAL_FILESYSTEM) - #ifdef JSON_HAS_CPP_17 - #if defined(__cpp_lib_filesystem) - #define JSON_HAS_FILESYSTEM 1 - #elif defined(__cpp_lib_experimental_filesystem) - #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1 - #elif !defined(__has_include) - #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1 - #elif __has_include() - #define JSON_HAS_FILESYSTEM 1 - #elif __has_include() - #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1 - #endif - - // std::filesystem does not work on MinGW GCC 8: https://sourceforge.net/p/mingw-w64/bugs/737/ - #if defined(__MINGW32__) && defined(__GNUC__) && __GNUC__ == 8 - #undef JSON_HAS_FILESYSTEM - #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM - #endif - - // no filesystem support before GCC 8: https://en.cppreference.com/w/cpp/compiler_support - #if defined(__GNUC__) && !defined(__clang__) && __GNUC__ < 8 - #undef JSON_HAS_FILESYSTEM - #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM - #endif - - // no filesystem support before Clang 7: https://en.cppreference.com/w/cpp/compiler_support - #if defined(__clang_major__) && __clang_major__ < 7 - #undef JSON_HAS_FILESYSTEM - #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM - #endif - - // no filesystem support before MSVC 19.14: https://en.cppreference.com/w/cpp/compiler_support - #if defined(_MSC_VER) && _MSC_VER < 1914 - #undef JSON_HAS_FILESYSTEM - #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM - #endif - - // no filesystem support before iOS 13 - #if defined(__IPHONE_OS_VERSION_MIN_REQUIRED) && __IPHONE_OS_VERSION_MIN_REQUIRED < 130000 - #undef JSON_HAS_FILESYSTEM - #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM - #endif - - // no filesystem support before macOS Catalina - #if defined(__MAC_OS_X_VERSION_MIN_REQUIRED) && __MAC_OS_X_VERSION_MIN_REQUIRED < 101500 - #undef JSON_HAS_FILESYSTEM - #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM - #endif - #endif -#endif - -#ifndef JSON_HAS_EXPERIMENTAL_FILESYSTEM - #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 0 -#endif - -#ifndef JSON_HAS_FILESYSTEM - #define JSON_HAS_FILESYSTEM 0 -#endif - -#ifndef JSON_HAS_THREE_WAY_COMPARISON - #if defined(__cpp_impl_three_way_comparison) && __cpp_impl_three_way_comparison >= 201907L \ - && defined(__cpp_lib_three_way_comparison) && __cpp_lib_three_way_comparison >= 201907L - #define JSON_HAS_THREE_WAY_COMPARISON 1 - #else - #define JSON_HAS_THREE_WAY_COMPARISON 0 - #endif -#endif - -#ifndef JSON_HAS_RANGES - // ranges header shipping in GCC 11.1.0 (released 2021-04-27) has syntax error - #if defined(__GLIBCXX__) && __GLIBCXX__ == 20210427 - #define JSON_HAS_RANGES 0 - #elif defined(__cpp_lib_ranges) - #define JSON_HAS_RANGES 1 - #else - #define JSON_HAS_RANGES 0 - #endif -#endif - -#ifdef JSON_HAS_CPP_17 - #define JSON_INLINE_VARIABLE inline -#else - #define JSON_INLINE_VARIABLE -#endif - -#if JSON_HEDLEY_HAS_ATTRIBUTE(no_unique_address) - #define JSON_NO_UNIQUE_ADDRESS [[no_unique_address]] -#else - #define JSON_NO_UNIQUE_ADDRESS -#endif - -// disable documentation warnings on clang -#if defined(__clang__) - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-Wdocumentation" - #pragma clang diagnostic ignored "-Wdocumentation-unknown-command" -#endif - -// allow disabling exceptions -#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(JSON_NOEXCEPTION) - #define JSON_THROW(exception) throw exception - #define JSON_TRY try - #define JSON_CATCH(exception) catch(exception) - #define JSON_INTERNAL_CATCH(exception) catch(exception) -#else - #include - #define JSON_THROW(exception) std::abort() - #define JSON_TRY if(true) - #define JSON_CATCH(exception) if(false) - #define JSON_INTERNAL_CATCH(exception) if(false) -#endif - -// override exception macros -#if defined(JSON_THROW_USER) - #undef JSON_THROW - #define JSON_THROW JSON_THROW_USER -#endif -#if defined(JSON_TRY_USER) - #undef JSON_TRY - #define JSON_TRY JSON_TRY_USER -#endif -#if defined(JSON_CATCH_USER) - #undef JSON_CATCH - #define JSON_CATCH JSON_CATCH_USER - #undef JSON_INTERNAL_CATCH - #define JSON_INTERNAL_CATCH JSON_CATCH_USER -#endif -#if defined(JSON_INTERNAL_CATCH_USER) - #undef JSON_INTERNAL_CATCH - #define JSON_INTERNAL_CATCH JSON_INTERNAL_CATCH_USER -#endif - -// allow overriding assert -#if !defined(JSON_ASSERT) - #include // assert - #define JSON_ASSERT(x) assert(x) -#endif - -// allow to access some private functions (needed by the test suite) -#if defined(JSON_TESTS_PRIVATE) - #define JSON_PRIVATE_UNLESS_TESTED public -#else - #define JSON_PRIVATE_UNLESS_TESTED private -#endif - -/*! -@brief macro to briefly define a mapping between an enum and JSON -@def NLOHMANN_JSON_SERIALIZE_ENUM -@since version 3.4.0 -*/ -#define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...) \ - template \ - inline void to_json(BasicJsonType& j, const ENUM_TYPE& e) \ - { \ - static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ - static const std::pair m[] = __VA_ARGS__; \ - auto it = std::find_if(std::begin(m), std::end(m), \ - [e](const std::pair& ej_pair) -> bool \ - { \ - return ej_pair.first == e; \ - }); \ - j = ((it != std::end(m)) ? it : std::begin(m))->second; \ - } \ - template \ - inline void from_json(const BasicJsonType& j, ENUM_TYPE& e) \ - { \ - static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ - static const std::pair m[] = __VA_ARGS__; \ - auto it = std::find_if(std::begin(m), std::end(m), \ - [&j](const std::pair& ej_pair) -> bool \ - { \ - return ej_pair.second == j; \ - }); \ - e = ((it != std::end(m)) ? it : std::begin(m))->first; \ - } - -// Ugly macros to avoid uglier copy-paste when specializing basic_json. They -// may be removed in the future once the class is split. - -#define NLOHMANN_BASIC_JSON_TPL_DECLARATION \ - template class ObjectType, \ - template class ArrayType, \ - class StringType, class BooleanType, class NumberIntegerType, \ - class NumberUnsignedType, class NumberFloatType, \ - template class AllocatorType, \ - template class JSONSerializer, \ - class BinaryType> - -#define NLOHMANN_BASIC_JSON_TPL \ - basic_json - -// Macros to simplify conversion from/to types - -#define NLOHMANN_JSON_EXPAND( x ) x -#define NLOHMANN_JSON_GET_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, NAME,...) NAME -#define NLOHMANN_JSON_PASTE(...) NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_GET_MACRO(__VA_ARGS__, \ - NLOHMANN_JSON_PASTE64, \ - NLOHMANN_JSON_PASTE63, \ - NLOHMANN_JSON_PASTE62, \ - NLOHMANN_JSON_PASTE61, \ - NLOHMANN_JSON_PASTE60, \ - NLOHMANN_JSON_PASTE59, \ - NLOHMANN_JSON_PASTE58, \ - NLOHMANN_JSON_PASTE57, \ - NLOHMANN_JSON_PASTE56, \ - NLOHMANN_JSON_PASTE55, \ - NLOHMANN_JSON_PASTE54, \ - NLOHMANN_JSON_PASTE53, \ - NLOHMANN_JSON_PASTE52, \ - NLOHMANN_JSON_PASTE51, \ - NLOHMANN_JSON_PASTE50, \ - NLOHMANN_JSON_PASTE49, \ - NLOHMANN_JSON_PASTE48, \ - NLOHMANN_JSON_PASTE47, \ - NLOHMANN_JSON_PASTE46, \ - NLOHMANN_JSON_PASTE45, \ - NLOHMANN_JSON_PASTE44, \ - NLOHMANN_JSON_PASTE43, \ - NLOHMANN_JSON_PASTE42, \ - NLOHMANN_JSON_PASTE41, \ - NLOHMANN_JSON_PASTE40, \ - NLOHMANN_JSON_PASTE39, \ - NLOHMANN_JSON_PASTE38, \ - NLOHMANN_JSON_PASTE37, \ - NLOHMANN_JSON_PASTE36, \ - NLOHMANN_JSON_PASTE35, \ - NLOHMANN_JSON_PASTE34, \ - NLOHMANN_JSON_PASTE33, \ - NLOHMANN_JSON_PASTE32, \ - NLOHMANN_JSON_PASTE31, \ - NLOHMANN_JSON_PASTE30, \ - NLOHMANN_JSON_PASTE29, \ - NLOHMANN_JSON_PASTE28, \ - NLOHMANN_JSON_PASTE27, \ - NLOHMANN_JSON_PASTE26, \ - NLOHMANN_JSON_PASTE25, \ - NLOHMANN_JSON_PASTE24, \ - NLOHMANN_JSON_PASTE23, \ - NLOHMANN_JSON_PASTE22, \ - NLOHMANN_JSON_PASTE21, \ - NLOHMANN_JSON_PASTE20, \ - NLOHMANN_JSON_PASTE19, \ - NLOHMANN_JSON_PASTE18, \ - NLOHMANN_JSON_PASTE17, \ - NLOHMANN_JSON_PASTE16, \ - NLOHMANN_JSON_PASTE15, \ - NLOHMANN_JSON_PASTE14, \ - NLOHMANN_JSON_PASTE13, \ - NLOHMANN_JSON_PASTE12, \ - NLOHMANN_JSON_PASTE11, \ - NLOHMANN_JSON_PASTE10, \ - NLOHMANN_JSON_PASTE9, \ - NLOHMANN_JSON_PASTE8, \ - NLOHMANN_JSON_PASTE7, \ - NLOHMANN_JSON_PASTE6, \ - NLOHMANN_JSON_PASTE5, \ - NLOHMANN_JSON_PASTE4, \ - NLOHMANN_JSON_PASTE3, \ - NLOHMANN_JSON_PASTE2, \ - NLOHMANN_JSON_PASTE1)(__VA_ARGS__)) -#define NLOHMANN_JSON_PASTE2(func, v1) func(v1) -#define NLOHMANN_JSON_PASTE3(func, v1, v2) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE2(func, v2) -#define NLOHMANN_JSON_PASTE4(func, v1, v2, v3) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE3(func, v2, v3) -#define NLOHMANN_JSON_PASTE5(func, v1, v2, v3, v4) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE4(func, v2, v3, v4) -#define NLOHMANN_JSON_PASTE6(func, v1, v2, v3, v4, v5) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE5(func, v2, v3, v4, v5) -#define NLOHMANN_JSON_PASTE7(func, v1, v2, v3, v4, v5, v6) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE6(func, v2, v3, v4, v5, v6) -#define NLOHMANN_JSON_PASTE8(func, v1, v2, v3, v4, v5, v6, v7) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE7(func, v2, v3, v4, v5, v6, v7) -#define NLOHMANN_JSON_PASTE9(func, v1, v2, v3, v4, v5, v6, v7, v8) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE8(func, v2, v3, v4, v5, v6, v7, v8) -#define NLOHMANN_JSON_PASTE10(func, v1, v2, v3, v4, v5, v6, v7, v8, v9) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE9(func, v2, v3, v4, v5, v6, v7, v8, v9) -#define NLOHMANN_JSON_PASTE11(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE10(func, v2, v3, v4, v5, v6, v7, v8, v9, v10) -#define NLOHMANN_JSON_PASTE12(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE11(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) -#define NLOHMANN_JSON_PASTE13(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE12(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) -#define NLOHMANN_JSON_PASTE14(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE13(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) -#define NLOHMANN_JSON_PASTE15(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE14(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) -#define NLOHMANN_JSON_PASTE16(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE15(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) -#define NLOHMANN_JSON_PASTE17(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE16(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) -#define NLOHMANN_JSON_PASTE18(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE17(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) -#define NLOHMANN_JSON_PASTE19(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE18(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) -#define NLOHMANN_JSON_PASTE20(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE19(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) -#define NLOHMANN_JSON_PASTE21(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE20(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) -#define NLOHMANN_JSON_PASTE22(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE21(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) -#define NLOHMANN_JSON_PASTE23(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE22(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) -#define NLOHMANN_JSON_PASTE24(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE23(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) -#define NLOHMANN_JSON_PASTE25(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE24(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) -#define NLOHMANN_JSON_PASTE26(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE25(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) -#define NLOHMANN_JSON_PASTE27(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE26(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) -#define NLOHMANN_JSON_PASTE28(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE27(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) -#define NLOHMANN_JSON_PASTE29(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE28(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) -#define NLOHMANN_JSON_PASTE30(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE29(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) -#define NLOHMANN_JSON_PASTE31(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE30(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) -#define NLOHMANN_JSON_PASTE32(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE31(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) -#define NLOHMANN_JSON_PASTE33(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE32(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) -#define NLOHMANN_JSON_PASTE34(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE33(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) -#define NLOHMANN_JSON_PASTE35(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE34(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) -#define NLOHMANN_JSON_PASTE36(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE35(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) -#define NLOHMANN_JSON_PASTE37(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE36(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) -#define NLOHMANN_JSON_PASTE38(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE37(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) -#define NLOHMANN_JSON_PASTE39(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE38(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) -#define NLOHMANN_JSON_PASTE40(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE39(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) -#define NLOHMANN_JSON_PASTE41(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE40(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) -#define NLOHMANN_JSON_PASTE42(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE41(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) -#define NLOHMANN_JSON_PASTE43(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE42(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) -#define NLOHMANN_JSON_PASTE44(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE43(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) -#define NLOHMANN_JSON_PASTE45(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE44(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) -#define NLOHMANN_JSON_PASTE46(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE45(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) -#define NLOHMANN_JSON_PASTE47(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE46(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) -#define NLOHMANN_JSON_PASTE48(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE47(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) -#define NLOHMANN_JSON_PASTE49(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE48(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) -#define NLOHMANN_JSON_PASTE50(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE49(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) -#define NLOHMANN_JSON_PASTE51(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE50(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) -#define NLOHMANN_JSON_PASTE52(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE51(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) -#define NLOHMANN_JSON_PASTE53(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE52(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) -#define NLOHMANN_JSON_PASTE54(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE53(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) -#define NLOHMANN_JSON_PASTE55(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE54(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) -#define NLOHMANN_JSON_PASTE56(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE55(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) -#define NLOHMANN_JSON_PASTE57(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE56(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) -#define NLOHMANN_JSON_PASTE58(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE57(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) -#define NLOHMANN_JSON_PASTE59(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE58(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) -#define NLOHMANN_JSON_PASTE60(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE59(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) -#define NLOHMANN_JSON_PASTE61(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE60(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) -#define NLOHMANN_JSON_PASTE62(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE61(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) -#define NLOHMANN_JSON_PASTE63(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE62(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) -#define NLOHMANN_JSON_PASTE64(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE63(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) - -#define NLOHMANN_JSON_TO(v1) nlohmann_json_j[#v1] = nlohmann_json_t.v1; -#define NLOHMANN_JSON_FROM(v1) nlohmann_json_j.at(#v1).get_to(nlohmann_json_t.v1); -#define NLOHMANN_JSON_FROM_WITH_DEFAULT(v1) nlohmann_json_t.v1 = nlohmann_json_j.value(#v1, nlohmann_json_default_obj.v1); - -/*! -@brief macro -@def NLOHMANN_DEFINE_TYPE_INTRUSIVE -@since version 3.9.0 -*/ -#define NLOHMANN_DEFINE_TYPE_INTRUSIVE(Type, ...) \ - friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ - friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } - -#define NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(Type, ...) \ - friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ - friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { Type nlohmann_json_default_obj; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } - -/*! -@brief macro -@def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE -@since version 3.9.0 -*/ -#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Type, ...) \ - inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ - inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } - -#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(Type, ...) \ - inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ - inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { Type nlohmann_json_default_obj; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } - - -// inspired from https://stackoverflow.com/a/26745591 -// allows to call any std function as if (e.g. with begin): -// using std::begin; begin(x); -// -// it allows using the detected idiom to retrieve the return type -// of such an expression -#define NLOHMANN_CAN_CALL_STD_FUNC_IMPL(std_name) \ - namespace detail { \ - using std::std_name; \ - \ - template \ - using result_of_##std_name = decltype(std_name(std::declval()...)); \ - } \ - \ - namespace detail2 { \ - struct std_name##_tag \ - { \ - }; \ - \ - template \ - std_name##_tag std_name(T&&...); \ - \ - template \ - using result_of_##std_name = decltype(std_name(std::declval()...)); \ - \ - template \ - struct would_call_std_##std_name \ - { \ - static constexpr auto const value = ::nlohmann::detail:: \ - is_detected_exact::value; \ - }; \ - } /* namespace detail2 */ \ - \ - template \ - struct would_call_std_##std_name : detail2::would_call_std_##std_name \ - { \ - } - -#ifndef JSON_USE_IMPLICIT_CONVERSIONS - #define JSON_USE_IMPLICIT_CONVERSIONS 1 -#endif - -#if JSON_USE_IMPLICIT_CONVERSIONS - #define JSON_EXPLICIT -#else - #define JSON_EXPLICIT explicit -#endif - -#ifndef JSON_DISABLE_ENUM_SERIALIZATION - #define JSON_DISABLE_ENUM_SERIALIZATION 0 -#endif - -#ifndef JSON_USE_GLOBAL_UDLS - #define JSON_USE_GLOBAL_UDLS 1 -#endif - -#if JSON_HAS_THREE_WAY_COMPARISON - #include // partial_ordering -#endif - -NLOHMANN_JSON_NAMESPACE_BEGIN -namespace detail -{ - -/////////////////////////// -// JSON type enumeration // -/////////////////////////// - -/*! -@brief the JSON type enumeration - -This enumeration collects the different JSON types. It is internally used to -distinguish the stored values, and the functions @ref basic_json::is_null(), -@ref basic_json::is_object(), @ref basic_json::is_array(), -@ref basic_json::is_string(), @ref basic_json::is_boolean(), -@ref basic_json::is_number() (with @ref basic_json::is_number_integer(), -@ref basic_json::is_number_unsigned(), and @ref basic_json::is_number_float()), -@ref basic_json::is_discarded(), @ref basic_json::is_primitive(), and -@ref basic_json::is_structured() rely on it. - -@note There are three enumeration entries (number_integer, number_unsigned, and -number_float), because the library distinguishes these three types for numbers: -@ref basic_json::number_unsigned_t is used for unsigned integers, -@ref basic_json::number_integer_t is used for signed integers, and -@ref basic_json::number_float_t is used for floating-point numbers or to -approximate integers which do not fit in the limits of their respective type. - -@sa see @ref basic_json::basic_json(const value_t value_type) -- create a JSON -value with the default value for a given type - -@since version 1.0.0 -*/ -enum class value_t : std::uint8_t -{ - null, ///< null value - object, ///< object (unordered set of name/value pairs) - array, ///< array (ordered collection of values) - string, ///< string value - boolean, ///< boolean value - number_integer, ///< number value (signed integer) - number_unsigned, ///< number value (unsigned integer) - number_float, ///< number value (floating-point) - binary, ///< binary array (ordered collection of bytes) - discarded ///< discarded by the parser callback function -}; - -/*! -@brief comparison operator for JSON types - -Returns an ordering that is similar to Python: -- order: null < boolean < number < object < array < string < binary -- furthermore, each type is not smaller than itself -- discarded values are not comparable -- binary is represented as a b"" string in python and directly comparable to a - string; however, making a binary array directly comparable with a string would - be surprising behavior in a JSON file. - -@since version 1.0.0 -*/ -#if JSON_HAS_THREE_WAY_COMPARISON - inline std::partial_ordering operator<=>(const value_t lhs, const value_t rhs) noexcept // *NOPAD* -#else - inline bool operator<(const value_t lhs, const value_t rhs) noexcept -#endif -{ - static constexpr std::array order = {{ - 0 /* null */, 3 /* object */, 4 /* array */, 5 /* string */, - 1 /* boolean */, 2 /* integer */, 2 /* unsigned */, 2 /* float */, - 6 /* binary */ - } - }; - - const auto l_index = static_cast(lhs); - const auto r_index = static_cast(rhs); -#if JSON_HAS_THREE_WAY_COMPARISON - if (l_index < order.size() && r_index < order.size()) - { - return order[l_index] <=> order[r_index]; // *NOPAD* - } - return std::partial_ordering::unordered; -#else - return l_index < order.size() && r_index < order.size() && order[l_index] < order[r_index]; -#endif -} - -// GCC selects the built-in operator< over an operator rewritten from -// a user-defined spaceship operator -// Clang, MSVC, and ICC select the rewritten candidate -// (see GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105200) -#if JSON_HAS_THREE_WAY_COMPARISON && defined(__GNUC__) -inline bool operator<(const value_t lhs, const value_t rhs) noexcept -{ - return std::is_lt(lhs <=> rhs); // *NOPAD* -} -#endif - -} // namespace detail -NLOHMANN_JSON_NAMESPACE_END - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -// #include - - -NLOHMANN_JSON_NAMESPACE_BEGIN -namespace detail -{ - -/*! -@brief replace all occurrences of a substring by another string - -@param[in,out] s the string to manipulate; changed so that all - occurrences of @a f are replaced with @a t -@param[in] f the substring to replace with @a t -@param[in] t the string to replace @a f - -@pre The search string @a f must not be empty. **This precondition is -enforced with an assertion.** - -@since version 2.0.0 -*/ -template -inline void replace_substring(StringType& s, const StringType& f, - const StringType& t) -{ - JSON_ASSERT(!f.empty()); - for (auto pos = s.find(f); // find first occurrence of f - pos != StringType::npos; // make sure f was found - s.replace(pos, f.size(), t), // replace with t, and - pos = s.find(f, pos + t.size())) // find next occurrence of f - {} -} - -/*! - * @brief string escaping as described in RFC 6901 (Sect. 4) - * @param[in] s string to escape - * @return escaped string - * - * Note the order of escaping "~" to "~0" and "/" to "~1" is important. - */ -template -inline StringType escape(StringType s) -{ - replace_substring(s, StringType{"~"}, StringType{"~0"}); - replace_substring(s, StringType{"/"}, StringType{"~1"}); - return s; -} - -/*! - * @brief string unescaping as described in RFC 6901 (Sect. 4) - * @param[in] s string to unescape - * @return unescaped string - * - * Note the order of escaping "~1" to "/" and "~0" to "~" is important. - */ -template -static void unescape(StringType& s) -{ - replace_substring(s, StringType{"~1"}, StringType{"/"}); - replace_substring(s, StringType{"~0"}, StringType{"~"}); -} - -} // namespace detail -NLOHMANN_JSON_NAMESPACE_END - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -#include // size_t - -// #include - - -NLOHMANN_JSON_NAMESPACE_BEGIN -namespace detail -{ - -/// struct to capture the start position of the current token -struct position_t -{ - /// the total number of characters read - std::size_t chars_read_total = 0; - /// the number of characters read in the current line - std::size_t chars_read_current_line = 0; - /// the number of lines read - std::size_t lines_read = 0; - - /// conversion to size_t to preserve SAX interface - constexpr operator size_t() const - { - return chars_read_total; - } -}; - -} // namespace detail -NLOHMANN_JSON_NAMESPACE_END - -// #include - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-FileCopyrightText: 2018 The Abseil Authors -// SPDX-License-Identifier: MIT - - - -#include // array -#include // size_t -#include // conditional, enable_if, false_type, integral_constant, is_constructible, is_integral, is_same, remove_cv, remove_reference, true_type -#include // index_sequence, make_index_sequence, index_sequence_for - -// #include - - -NLOHMANN_JSON_NAMESPACE_BEGIN -namespace detail -{ - -template -using uncvref_t = typename std::remove_cv::type>::type; - -#ifdef JSON_HAS_CPP_14 - -// the following utilities are natively available in C++14 -using std::enable_if_t; -using std::index_sequence; -using std::make_index_sequence; -using std::index_sequence_for; - -#else - -// alias templates to reduce boilerplate -template -using enable_if_t = typename std::enable_if::type; - -// The following code is taken from https://github.com/abseil/abseil-cpp/blob/10cb35e459f5ecca5b2ff107635da0bfa41011b4/absl/utility/utility.h -// which is part of Google Abseil (https://github.com/abseil/abseil-cpp), licensed under the Apache License 2.0. - -//// START OF CODE FROM GOOGLE ABSEIL - -// integer_sequence -// -// Class template representing a compile-time integer sequence. An instantiation -// of `integer_sequence` has a sequence of integers encoded in its -// type through its template arguments (which is a common need when -// working with C++11 variadic templates). `absl::integer_sequence` is designed -// to be a drop-in replacement for C++14's `std::integer_sequence`. -// -// Example: -// -// template< class T, T... Ints > -// void user_function(integer_sequence); -// -// int main() -// { -// // user_function's `T` will be deduced to `int` and `Ints...` -// // will be deduced to `0, 1, 2, 3, 4`. -// user_function(make_integer_sequence()); -// } -template -struct integer_sequence -{ - using value_type = T; - static constexpr std::size_t size() noexcept - { - return sizeof...(Ints); - } -}; - -// index_sequence -// -// A helper template for an `integer_sequence` of `size_t`, -// `absl::index_sequence` is designed to be a drop-in replacement for C++14's -// `std::index_sequence`. -template -using index_sequence = integer_sequence; - -namespace utility_internal -{ - -template -struct Extend; - -// Note that SeqSize == sizeof...(Ints). It's passed explicitly for efficiency. -template -struct Extend, SeqSize, 0> -{ - using type = integer_sequence < T, Ints..., (Ints + SeqSize)... >; -}; - -template -struct Extend, SeqSize, 1> -{ - using type = integer_sequence < T, Ints..., (Ints + SeqSize)..., 2 * SeqSize >; -}; - -// Recursion helper for 'make_integer_sequence'. -// 'Gen::type' is an alias for 'integer_sequence'. -template -struct Gen -{ - using type = - typename Extend < typename Gen < T, N / 2 >::type, N / 2, N % 2 >::type; -}; - -template -struct Gen -{ - using type = integer_sequence; -}; - -} // namespace utility_internal - -// Compile-time sequences of integers - -// make_integer_sequence -// -// This template alias is equivalent to -// `integer_sequence`, and is designed to be a drop-in -// replacement for C++14's `std::make_integer_sequence`. -template -using make_integer_sequence = typename utility_internal::Gen::type; - -// make_index_sequence -// -// This template alias is equivalent to `index_sequence<0, 1, ..., N-1>`, -// and is designed to be a drop-in replacement for C++14's -// `std::make_index_sequence`. -template -using make_index_sequence = make_integer_sequence; - -// index_sequence_for -// -// Converts a typename pack into an index sequence of the same length, and -// is designed to be a drop-in replacement for C++14's -// `std::index_sequence_for()` -template -using index_sequence_for = make_index_sequence; - -//// END OF CODE FROM GOOGLE ABSEIL - -#endif - -// dispatch utility (taken from ranges-v3) -template struct priority_tag : priority_tag < N - 1 > {}; -template<> struct priority_tag<0> {}; - -// taken from ranges-v3 -template -struct static_const -{ - static JSON_INLINE_VARIABLE constexpr T value{}; -}; - -#ifndef JSON_HAS_CPP_17 - template - constexpr T static_const::value; -#endif - -template -inline constexpr std::array make_array(Args&& ... args) -{ - return std::array {{static_cast(std::forward(args))...}}; -} - -} // namespace detail -NLOHMANN_JSON_NAMESPACE_END - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -#include // numeric_limits -#include // false_type, is_constructible, is_integral, is_same, true_type -#include // declval -#include // tuple - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -#include // random_access_iterator_tag - -// #include - -// #include - -// #include - - -NLOHMANN_JSON_NAMESPACE_BEGIN -namespace detail -{ - -template -struct iterator_types {}; - -template -struct iterator_types < - It, - void_t> -{ - using difference_type = typename It::difference_type; - using value_type = typename It::value_type; - using pointer = typename It::pointer; - using reference = typename It::reference; - using iterator_category = typename It::iterator_category; -}; - -// This is required as some compilers implement std::iterator_traits in a way that -// doesn't work with SFINAE. See https://github.com/nlohmann/json/issues/1341. -template -struct iterator_traits -{ -}; - -template -struct iterator_traits < T, enable_if_t < !std::is_pointer::value >> - : iterator_types -{ -}; - -template -struct iterator_traits::value>> -{ - using iterator_category = std::random_access_iterator_tag; - using value_type = T; - using difference_type = ptrdiff_t; - using pointer = T*; - using reference = T&; -}; - -} // namespace detail -NLOHMANN_JSON_NAMESPACE_END - -// #include - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -// #include - - -NLOHMANN_JSON_NAMESPACE_BEGIN - -NLOHMANN_CAN_CALL_STD_FUNC_IMPL(begin); - -NLOHMANN_JSON_NAMESPACE_END - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - - - -// #include - - -NLOHMANN_JSON_NAMESPACE_BEGIN - -NLOHMANN_CAN_CALL_STD_FUNC_IMPL(end); - -NLOHMANN_JSON_NAMESPACE_END - -// #include - -// #include - -// #include -// __ _____ _____ _____ -// __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 -// |_____|_____|_____|_|___| https://github.com/nlohmann/json -// -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann -// SPDX-License-Identifier: MIT - -#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_ - #define INCLUDE_NLOHMANN_JSON_FWD_HPP_ - - #include // int64_t, uint64_t - #include // map - #include // allocator - #include // string - #include // vector - - // #include - - - /*! - @brief namespace for Niels Lohmann - @see https://github.com/nlohmann - @since version 1.0.0 - */ - NLOHMANN_JSON_NAMESPACE_BEGIN - - /*! - @brief default JSONSerializer template argument - - This serializer ignores the template arguments and uses ADL - ([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl)) - for serialization. - */ - template - struct adl_serializer; - - /// a class to store JSON values - /// @sa https://json.nlohmann.me/api/basic_json/ - template class ObjectType = - std::map, - template class ArrayType = std::vector, - class StringType = std::string, class BooleanType = bool, - class NumberIntegerType = std::int64_t, - class NumberUnsignedType = std::uint64_t, - class NumberFloatType = double, - template class AllocatorType = std::allocator, - template class JSONSerializer = - adl_serializer, - class BinaryType = std::vector> - class basic_json; - - /// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document - /// @sa https://json.nlohmann.me/api/json_pointer/ - template - class json_pointer; - - /*! - @brief default specialization - @sa https://json.nlohmann.me/api/json/ - */ - using json = basic_json<>; - - /// @brief a minimal map-like container that preserves insertion order - /// @sa https://json.nlohmann.me/api/ordered_map/ - template - struct ordered_map; - - /// @brief specialization that maintains the insertion order of object keys - /// @sa https://json.nlohmann.me/api/ordered_json/ - using ordered_json = basic_json; - - NLOHMANN_JSON_NAMESPACE_END - -#endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_ - - -NLOHMANN_JSON_NAMESPACE_BEGIN -/*! -@brief detail namespace with internal helper functions - -This namespace collects functions that should not be exposed, -implementations of some @ref basic_json methods, and meta-programming helpers. - -@since version 2.1.0 -*/ -namespace detail -{ - -///////////// -// helpers // -///////////// - -// Note to maintainers: -// -// Every trait in this file expects a non CV-qualified type. -// The only exceptions are in the 'aliases for detected' section -// (i.e. those of the form: decltype(T::member_function(std::declval()))) -// -// In this case, T has to be properly CV-qualified to constraint the function arguments -// (e.g. to_json(BasicJsonType&, const T&)) - -template struct is_basic_json : std::false_type {}; - -NLOHMANN_BASIC_JSON_TPL_DECLARATION -struct is_basic_json : std::true_type {}; - -// used by exceptions create() member functions -// true_type for pointer to possibly cv-qualified basic_json or std::nullptr_t -// false_type otherwise -template -struct is_basic_json_context : - std::integral_constant < bool, - is_basic_json::type>::type>::value - || std::is_same::value > -{}; - -////////////////////// -// json_ref helpers // -////////////////////// - -template -class json_ref; - -template -struct is_json_ref : std::false_type {}; - -template -struct is_json_ref> : std::true_type {}; - -////////////////////////// -// aliases for detected // -////////////////////////// - -template -using mapped_type_t = typename T::mapped_type; - -template -using key_type_t = typename T::key_type; - -template -using value_type_t = typename T::value_type; - -template -using difference_type_t = typename T::difference_type; - -template -using pointer_t = typename T::pointer; - -template -using reference_t = typename T::reference; - -template -using iterator_category_t = typename T::iterator_category; - -template -using to_json_function = decltype(T::to_json(std::declval()...)); - -template -using from_json_function = decltype(T::from_json(std::declval()...)); - -template -using get_template_function = decltype(std::declval().template get()); - -// trait checking if JSONSerializer::from_json(json const&, udt&) exists -template -struct has_from_json : std::false_type {}; - -// trait checking if j.get is valid -// use this trait instead of std::is_constructible or std::is_convertible, -// both rely on, or make use of implicit conversions, and thus fail when T -// has several constructors/operator= (see https://github.com/nlohmann/json/issues/958) -template -struct is_getable -{ - static constexpr bool value = is_detected::value; -}; - -template -struct has_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> -{ - using serializer = typename BasicJsonType::template json_serializer; - - static constexpr bool value = - is_detected_exact::value; -}; - -// This trait checks if JSONSerializer::from_json(json const&) exists -// this overload is used for non-default-constructible user-defined-types -template -struct has_non_default_from_json : std::false_type {}; - -template -struct has_non_default_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> -{ - using serializer = typename BasicJsonType::template json_serializer; - - static constexpr bool value = - is_detected_exact::value; -}; - -// This trait checks if BasicJsonType::json_serializer::to_json exists -// Do not evaluate the trait when T is a basic_json type, to avoid template instantiation infinite recursion. -template -struct has_to_json : std::false_type {}; - -template -struct has_to_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> -{ - using serializer = typename BasicJsonType::template json_serializer; - - static constexpr bool value = - is_detected_exact::value; -}; - -template -using detect_key_compare = typename T::key_compare; - -template -struct has_key_compare : std::integral_constant::value> {}; - -// obtains the actual object key comparator -template -struct actual_object_comparator -{ - using object_t = typename BasicJsonType::object_t; - using object_comparator_t = typename BasicJsonType::default_object_comparator_t; - using type = typename std::conditional < has_key_compare::value, - typename object_t::key_compare, object_comparator_t>::type; -}; - -template -using actual_object_comparator_t = typename actual_object_comparator::type; - -/////////////////// -// is_ functions // -/////////////////// - -// https://en.cppreference.com/w/cpp/types/conjunction -template struct conjunction : std::true_type { }; -template struct conjunction : B { }; -template -struct conjunction -: std::conditional(B::value), conjunction, B>::type {}; - -// https://en.cppreference.com/w/cpp/types/negation -template struct negation : std::integral_constant < bool, !B::value > { }; - -// Reimplementation of is_constructible and is_default_constructible, due to them being broken for -// std::pair and std::tuple until LWG 2367 fix (see https://cplusplus.github.io/LWG/lwg-defects.html#2367). -// This causes compile errors in e.g. clang 3.5 or gcc 4.9. -template -struct is_default_constructible : std::is_default_constructible {}; - -template -struct is_default_constructible> - : conjunction, is_default_constructible> {}; - -template -struct is_default_constructible> - : conjunction, is_default_constructible> {}; - -template -struct is_default_constructible> - : conjunction...> {}; - -template -struct is_default_constructible> - : conjunction...> {}; - - -template -struct is_constructible : std::is_constructible {}; - -template -struct is_constructible> : is_default_constructible> {}; - -template -struct is_constructible> : is_default_constructible> {}; - -template -struct is_constructible> : is_default_constructible> {}; - -template -struct is_constructible> : is_default_constructible> {}; - - -template -struct is_iterator_traits : std::false_type {}; - -template -struct is_iterator_traits> -{ - private: - using traits = iterator_traits; - - public: - static constexpr auto value = - is_detected::value && - is_detected::value && - is_detected::value && - is_detected::value && - is_detected::value; -}; - -template -struct is_range -{ - private: - using t_ref = typename std::add_lvalue_reference::type; - - using iterator = detected_t; - using sentinel = detected_t; - - // to be 100% correct, it should use https://en.cppreference.com/w/cpp/iterator/input_or_output_iterator - // and https://en.cppreference.com/w/cpp/iterator/sentinel_for - // but reimplementing these would be too much work, as a lot of other concepts are used underneath - static constexpr auto is_iterator_begin = - is_iterator_traits>::value; - - public: - static constexpr bool value = !std::is_same::value && !std::is_same::value && is_iterator_begin; -}; - -template -using iterator_t = enable_if_t::value, result_of_begin())>>; - -template -using range_value_t = value_type_t>>; - -// The following implementation of is_complete_type is taken from -// https://blogs.msdn.microsoft.com/vcblog/2015/12/02/partial-support-for-expression-sfinae-in-vs-2015-update-1/ -// and is written by Xiang Fan who agreed to using it in this library. - -template -struct is_complete_type : std::false_type {}; - -template -struct is_complete_type : std::true_type {}; - -template -struct is_compatible_object_type_impl : std::false_type {}; - -template -struct is_compatible_object_type_impl < - BasicJsonType, CompatibleObjectType, - enable_if_t < is_detected::value&& - is_detected::value >> -{ - using object_t = typename BasicJsonType::object_t; - - // macOS's is_constructible does not play well with nonesuch... - static constexpr bool value = - is_constructible::value && - is_constructible::value; -}; - -template -struct is_compatible_object_type - : is_compatible_object_type_impl {}; - -template -struct is_constructible_object_type_impl : std::false_type {}; - -template -struct is_constructible_object_type_impl < - BasicJsonType, ConstructibleObjectType, - enable_if_t < is_detected::value&& - is_detected::value >> -{ - using object_t = typename BasicJsonType::object_t; - - static constexpr bool value = - (is_default_constructible::value && - (std::is_move_assignable::value || - std::is_copy_assignable::value) && - (is_constructible::value && - std::is_same < - typename object_t::mapped_type, - typename ConstructibleObjectType::mapped_type >::value)) || - (has_from_json::value || - has_non_default_from_json < - BasicJsonType, - typename ConstructibleObjectType::mapped_type >::value); -}; - -template -struct is_constructible_object_type - : is_constructible_object_type_impl {}; - -template -struct is_compatible_string_type -{ - static constexpr auto value = - is_constructible::value; -}; - -template -struct is_constructible_string_type -{ - // launder type through decltype() to fix compilation failure on ICPC -#ifdef __INTEL_COMPILER - using laundered_type = decltype(std::declval()); -#else - using laundered_type = ConstructibleStringType; -#endif - - static constexpr auto value = - conjunction < - is_constructible, - is_detected_exact>::value; -}; - -template -struct is_compatible_array_type_impl : std::false_type {}; - -template -struct is_compatible_array_type_impl < - BasicJsonType, CompatibleArrayType, - enable_if_t < - is_detected::value&& - is_iterator_traits>>::value&& -// special case for types like std::filesystem::path whose iterator's value_type are themselves -// c.f. https://github.com/nlohmann/json/pull/3073 - !std::is_same>::value >> -{ - static constexpr bool value = - is_constructible>::value; -}; - -template -struct is_compatible_array_type - : is_compatible_array_type_impl {}; - -template -struct is_constructible_array_type_impl : std::false_type {}; - -template -struct is_constructible_array_type_impl < - BasicJsonType, ConstructibleArrayType, - enable_if_t::value >> - : std::true_type {}; - -template -struct is_constructible_array_type_impl < - BasicJsonType, ConstructibleArrayType, - enable_if_t < !std::is_same::value&& - !is_compatible_string_type::value&& - is_default_constructible::value&& -(std::is_move_assignable::value || - std::is_copy_assignable::value)&& -is_detected::value&& -is_iterator_traits>>::value&& -is_detected::value&& -// special case for types like std::filesystem::path whose iterator's value_type are themselves -// c.f. https://github.com/nlohmann/json/pull/3073 -!std::is_same>::value&& - is_complete_type < - detected_t>::value >> -{ - using value_type = range_value_t; - - static constexpr bool value = - std::is_same::value || - has_from_json::value || - has_non_default_from_json < - BasicJsonType, - value_type >::value; -}; - -template -struct is_constructible_array_type - : is_constructible_array_type_impl {}; - -template -struct is_compatible_integer_type_impl : std::false_type {}; - -template -struct is_compatible_integer_type_impl < - RealIntegerType, CompatibleNumberIntegerType, - enable_if_t < std::is_integral::value&& - std::is_integral::value&& - !std::is_same::value >> -{ - // is there an assert somewhere on overflows? - using RealLimits = std::numeric_limits; - using CompatibleLimits = std::numeric_limits; - - static constexpr auto value = - is_constructible::value && - CompatibleLimits::is_integer && - RealLimits::is_signed == CompatibleLimits::is_signed; -}; - -template -struct is_compatible_integer_type - : is_compatible_integer_type_impl {}; - -template -struct is_compatible_type_impl: std::false_type {}; - -template -struct is_compatible_type_impl < - BasicJsonType, CompatibleType, - enable_if_t::value >> -{ - static constexpr bool value = - has_to_json::value; -}; - -template -struct is_compatible_type - : is_compatible_type_impl {}; - -template -struct is_constructible_tuple : std::false_type {}; - -template -struct is_constructible_tuple> : conjunction...> {}; - -template -struct is_json_iterator_of : std::false_type {}; - -template -struct is_json_iterator_of : std::true_type {}; - -template -struct is_json_iterator_of : std::true_type -{}; - -// checks if a given type T is a template specialization of Primary -template