Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very poor gemmt performance compared to gemm and syrk #4921

Open
david-cortes opened this issue Oct 6, 2024 · 3 comments
Open

Very poor gemmt performance compared to gemm and syrk #4921

david-cortes opened this issue Oct 6, 2024 · 3 comments

Comments

@david-cortes
Copy link
Contributor

I'm running some timings on operation t(X)*X on row-major matrices having many more rows than columns.

I'm finding that for these types of inputs, function gemmt is much slower than the equivalent from syrk or gemm, with a very wide margin.

Timings in milliseconds for input size 1,000,000 x 32, intel i12700H, average of 3 runs:

  • gemmt: 216.178
  • syrk: 41.0468
  • gemm: 39.55553

Version: OpenBLAS 0.3.28, built with OpenMP, compiled from source (gcc with cmake system). Same issue happen with pthreads, and same timing difference is observed when running single-threaded.

For reference, timings for other libraries:

  • MKL gemmt: 25.66533
  • MKL syrk: 12.57197
  • MKL gemm: 15.69447
  • tabmat's "sandwich" op: 29.3

Code to reproduce:

#include <iostream>
#include <chrono>
#include <random>
#include <memory>


#include <cblas.h>
extern "C" void cblas_dgemmt(const CBLAS_LAYOUT Layout, const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc);

using std::chrono::high_resolution_clock;
using std::chrono::duration_cast;
using std::chrono::duration;
using std::chrono::milliseconds;

int main()
{
    const size_t nrows = 1'000'000;
    const size_t ncols = 32;
    const size_t tot = nrows * ncols;

    std::mt19937 rng{123};
    std::normal_distribution norm_distr{0.0, 1.0};

    std::unique_ptr<double[]> X(new double[tot]);
    std::unique_ptr<double[]> out(new double[ncols*ncols]());
    for (size_t ix = 0; ix < tot; ix++) X[ix] = norm_distr(rng);

    auto t1 = high_resolution_clock::now();
    cblas_dgemmt(
        CblasRowMajor, CblasUpper, CblasTrans, CblasNoTrans,
        ncols, nrows,
        1., X.get(), ncols,
        X.get(), ncols,
        0., out.get(), ncols
    );
    auto t2 = high_resolution_clock::now();
    duration<double, std::milli> ms_double = t2 - t1;

    double sum_res = 0.;
    for (size_t ix = 0; ix < ncols*ncols; ix++) sum_res += out[ix];

    std::cout << "time gemmt:" << ms_double.count() << std::endl;
    std::cout << "sum gemmt:" << sum_res << std::endl;

    t1 = high_resolution_clock::now();
    cblas_dsyrk(
        CblasRowMajor, CblasUpper, CblasTrans,
        ncols, nrows,
        1., X.get(), ncols,
        0., out.get(), ncols
    );
    t2 = high_resolution_clock::now();
    ms_double = t2 - t1;
    sum_res = 0.;
    for (size_t ix = 0; ix < ncols*ncols; ix++) sum_res += out[ix];
    std::cout << "time syrk:" << ms_double.count() << std::endl;
    std::cout << "sum syrk:" << sum_res << std::endl;

    t1 = high_resolution_clock::now();
    cblas_dgemm(
        CblasRowMajor, CblasTrans, CblasNoTrans,
        ncols, ncols, nrows,
        1., X.get(), ncols,
        X.get(), ncols,
        0., out.get(), ncols
    );
    t2 = high_resolution_clock::now();
    ms_double = t2 - t1;
    sum_res = 0.;
    for (size_t ix = 0; ix < ncols*ncols; ix++) sum_res += out[ix];
    std::cout << "time gemm:" << ms_double.count() << std::endl;
    std::cout << "sum gemm:" << sum_res << std::endl;

    return 0;
}
@martin-frbg
Copy link
Collaborator

The current GEMMT implementation is just a loop around GEMV, so its performance largely depends on that of the individual optimized kernels for the latter. It is provided for compatibility, but not yet optimized for speed. As the Reference BLAS looks to be adding its own Interpretation of what used to be an inofficial extension, a total rework may be necessary at some point in any case.

@martin-frbg
Copy link
Collaborator

On ARM, the performance of this somewhat naive gemmt implementation is about on par with gemm, clearly better than syrk - provided the number of threads is capped at about 30 (on 64 cores, gemmt comes out horrendously bad again, taking about ten times as long as gemm). Interestingly the obvious optimization of allocating the memory buffer only once instead of allocating and freeing it for every individual gemv step in interface/gemmt.c does not result in significant improvement.

@angsch
Copy link
Contributor

angsch commented Jan 10, 2025

If someone wants to add a generic implementation with a reasonable performance, the LAPACK implementation prior to the introduction of GEMMT (aka GEMMTR in LAPACK) as part of the BLAS may serve as inspiration. It reduces the problem essentially to GEMM by blocking into panels; only the small triangular part is computed with GEMV.

Reference-LAPACK/lapack@09cb849

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants