Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Inefficient posterior evaluation of SaasFullyBayesianSingleTaskGP when q=1 #2310

Open
slishak-PX opened this issue Apr 26, 2024 · 9 comments
Labels
bug Something isn't working

Comments

@slishak-PX
Copy link
Contributor

slishak-PX commented Apr 26, 2024

🐛 Bug

Evaluating an acquisition function with q=1 with SaasFullyBayesianSingleTaskGP requires an unnecessarily large amount of memory, due to an inefficient broadcasted matmul operation.

In the example below, the following line multiplies a tensor of size [256, 16, 1, 2048] with a tensor of size [16, 2048, 2048] which requires the allocation of 128GB of memory:
https://github.com/cornellius-gp/gpytorch/blob/9551eba889adf835b69cfd86e9a5d584fb61cdcc/gpytorch/models/exact_prediction_strategies.py#L118

To reproduce

** Code snippet to reproduce **

import torch
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.transforms import Standardize
from botorch import fit_fully_bayesian_model_nuts
from botorch.acquisition import UpperConfidenceBound

n_train = 2048
n_test = 256
d = 256

tkwargs = {
    "device": torch.device("cuda:3" if torch.cuda.is_available() else "cpu"),
    "dtype": torch.double,
}

train_X = torch.rand(n_train, d, **tkwargs)
test_X = torch.rand(n_test, d, **tkwargs)
train_Y = torch.sin(train_X[:, :1])
test_Y = torch.sin(test_X[:, :1])

gp = SaasFullyBayesianSingleTaskGP(
    train_X=train_X, 
    train_Y=train_Y, 
    outcome_transform=Standardize(m=1),
)
fit_fully_bayesian_model_nuts(
    gp,
    warmup_steps=4,
    num_samples=16,
    thinning=1,
)

ucb = UpperConfidenceBound(gp, beta=2.5)
acq_values = ucb(test_X[:, None, :])

** Stack trace/error message **

Traceback (most recent call last):
  File "/tmp/ipykernel_3377365/3398296989.py", line 3, in <module>
    acq_values = ucb(test_X[:, None, :])
                 ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/utils/transforms.py", line 259, in decorated
    output = method(acqf, X, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 786, in forward
    mean, sigma = self._mean_and_sigma(X)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 106, in _mean_and_sigma
    posterior = self.model.posterior(
                ^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/models/fully_bayesian.py", line 536, in posterior
    posterior = super().posterior(
                ^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/models/gpytorch.py", line 383, in posterior
    mvn = self(X)
          ^^^^^^^
...
    return test_train_covar.matmul(precomputed_cache)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 GiB. GPU 3 has a total capacity of 79.15 GiB of which 44.45 GiB is free. Including non-PyTorch memory, this process has 34.69 GiB memory in use. Of the allocated memory 23.74 GiB is allocated by PyTorch, and 10.42 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Expected Behavior

The memory usage for this operation is very high because torch.matmul is inefficient for such batched matrix-vector multiplications. If the same operation is written as an einsum, or transposing such that it's a matrix-matrix multiplication, the memory usage and computation time are substantially reduced.

For example, below is a demonstration of two alternative operations which reduce the memory and computation time by orders of magnitude:

import torch
device = "cuda:3"

# Matrices to multiply
torch.manual_seed(50)
a = torch.randn((256, 16, 1, 1024), device=device)
b = torch.randn((16, 1024, 1024), device=device)

def profile(func):
    torch.cuda.reset_peak_memory_stats(device=device)
    m0 = torch.cuda.max_memory_allocated(device=device)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    out = func()
    end.record()

    torch.cuda.synchronize()
    t = start.elapsed_time(end)

    m1 = torch.cuda.max_memory_allocated(device=device)

    print(f"Memory used: {(m1 - m0) / 1024**3:.2f}GB")
    print(f"Time: {1e3 * t:.6f} ms")

    return out

with torch.no_grad():
    print("matmul")
    c = profile(lambda: torch.matmul(a, b))

    print("\neinsum")
    c_einsum = profile(lambda: torch.einsum("...ij,...jk", a, b))
    print(f"Max error: {(c_einsum - c).abs().max().cpu().item():.7f}")

    print("\ntransposed matmul")
    c_transpose = profile(lambda: torch.matmul(a.transpose(0, 2), b).transpose(0, 2))
    print(f"Max error: {(c_transpose - c).abs().max().cpu().item():.7f}")
matmul
Memory used: 16.02GB
Time: 261.343986 ms

einsum
Memory used: 0.02GB
Time: 160.416007 ms
Max error: 0.0002327

transposed matmul
Memory used: 0.02GB
Time: 118.303999 ms
Max error: 0.0002327

System information

Please complete the following information:

  • BoTorch Version 1.11
  • GPyTorch Version 0.9.5
  • PyTorch Version 2.2.0+cu121
  • Computer OS: Rocky Linux release 8.9
  • GPU: NVIDIA A100 80GB PCIe
@slishak-PX slishak-PX added the bug Something isn't working label Apr 26, 2024
@Balandat
Copy link
Contributor

Thanks for raising this, this is a great catch. Since this call happens in gpytorch we'll have to make a change there and ensure that it is compatible with all kinds of other (non-fully-Bayesian) scenarios (not sure what kinds of shape exactly are encountered in this call), but we will definitely fix this.

cc @dme65, @esantorella

@Balandat Balandat self-assigned this Apr 26, 2024
@slishak-PX
Copy link
Contributor Author

Thank you! I did initially try and come up with something to contribute to GPyTorch and/or linear_operator, but it was harder than anticipated to make it compatible and not introduce slowdowns in other situations, so I thought I'd report it here for now. For example, einsum is faster than matmul in this specific situation but has the potential to be much slower in other situations.

(n.b. I've just corrected a small mistake in the profiling code in the issue above)

@Balandat
Copy link
Contributor

I did initially try and come up with something to contribute to GPyTorch and/or linear_operator, but it was harder than anticipated to make it compatible and not introduce slowdowns in other situations, so I thought I'd report it here for now.

Interesting. Do you happen to have have some artifacts of those attempts that you could share? That would be very helpful.

@slishak-PX
Copy link
Contributor Author

What I've done at the moment is just replace the aforementioned matmul with the equivalent einsum to unblock my work, but that of course only works when neither of the tensors are a LinearOperator (as einsum is not implemented).

Below is a script that benchmarks matmul against einsum and shows that matmul is generally faster except in some specific situations (I had to restrict the tensor sizes to run on a laptop, as running on a shared server was interfering with the timing results too much). It does not report memory usage.

Those specific situations when matmul is inefficient appear to be:

  • a is a (batched) row vector and one of the batch dimensions of b is broadcasted
  • b is a (batched) column vector and one of the batch dimensions of a is broadcasted

In both situations, I'd guess that the appropriate set of matrix transposes and matmul would be more efficient than einsum, but I haven't tested this.

Code
import torch
import torch.utils.benchmark as benchmark
from tqdm import tqdm

device = "cuda:0"
f_out = "einsum.txt"


def wrap(f):
    def wrapped(a_sz, b_sz, device):
        a = torch.randn(a_sz, device=device)
        b = torch.randn(b_sz, device=device)
        return f(a, b)

    return wrapped


matmul = wrap(torch.matmul)


@wrap
def einsum(a, b):
    return torch.einsum("...ik,...kj", a, b)


if __name__ == "__main__":
    sizes = []

    batch_funcs = [
        ("Full", lambda i_batch, j_batch: ((i_batch, j_batch), (i_batch, j_batch))),
        ("No a0", lambda i_batch, j_batch: ((j_batch,), (i_batch, j_batch))),
        ("No b0", lambda i_batch, j_batch: ((i_batch, j_batch), (j_batch,))),
    ]

    for batch_type, batch_func in batch_funcs:
        for i_batch in (256, 32, 1):
            for j_batch in (32, 16, 1):
                a_batch, b_batch = batch_func(i_batch, j_batch)
                if None in (a_batch, b_batch):
                    continue
                for i_size in (128, 64, 4):
                    sizes.append(
                        (
                            batch_type,
                            "Matrix-matrix product",
                            a_batch + (i_size, i_size),
                            b_batch + (i_size, i_size),
                        )
                    )
                    sizes.append(
                        (
                            batch_type,
                            "Matrix-vector product",
                            a_batch + (i_size, i_size),
                            b_batch + (i_size, 1),
                        )
                    )
                    sizes.append(
                        (
                            batch_type,
                            "Transposed MVP",
                            a_batch + (1, i_size),
                            b_batch + (i_size, i_size),
                        )
                    )
                    sizes.append(
                        (
                            batch_type,
                            "Vector outer product",
                            a_batch + (i_size, 1),
                            b_batch + (1, i_size),
                        )
                    )
                    sizes.append(
                        (
                            batch_type,
                            "Vector inner product",
                            a_batch + (1, i_size),
                            b_batch + (i_size, 1),
                        )
                    )

    results = []
    with torch.no_grad():
        pbar = tqdm(sizes)
        for env, label, a_sz, b_sz in pbar:
            sub_label = f"{a_sz}x{b_sz}"
            pbar.set_description(sub_label)

            timers = [
                benchmark.Timer(
                    stmt="matmul(a_sz, b_sz, device)",
                    globals={"device": device, "a_sz": a_sz, "b_sz": b_sz},
                    description="matmul",
                    setup="from __main__ import matmul",
                    label=label,
                    sub_label=sub_label,
                    env=env,
                ),
                benchmark.Timer(
                    stmt="einsum(a_sz, b_sz, device)",
                    globals={"device": device, "a_sz": a_sz, "b_sz": b_sz},
                    description="einsum",
                    setup="from __main__ import einsum",
                    label=label,
                    sub_label=sub_label,
                    env=env,
                ),
            ]

            for timer in timers:
                result = timer.adaptive_autorange(min_run_time=1)
                results.append(result)
                pbar.write(str(result))
                compare = benchmark.Compare(results)
                compare.colorize(rowwise=True)
                with open(f_out, "wt") as f:
                    f.write(str(compare))
Timing results
[------------------------ Matrix-matrix product -------------------------]                                                                                                                                                                                                                                    
                                                    |   matmul  |   einsum
1 threads: ---------------------------------------------------------------
  (Full)   (256, 32, 128, 128)x(256, 32, 128, 128)  |  39434.7  |  39591.2
           (256, 32, 64, 64)x(256, 32, 64, 64)      |   5629.9  |   5642.9
           (256, 32, 4, 4)x(256, 32, 4, 4)          |    170.2  |    170.6
           (256, 16, 128, 128)x(256, 16, 128, 128)  |  20098.4  |  20091.8
           (256, 16, 64, 64)x(256, 16, 64, 64)      |   2838.9  |   2837.1
           (256, 16, 4, 4)x(256, 16, 4, 4)          |     82.0  |     82.4
           (256, 1, 128, 128)x(256, 1, 128, 128)    |   1260.6  |   1262.0
           (256, 1, 64, 64)x(256, 1, 64, 64)        |    196.1  |    196.2
           (256, 1, 4, 4)x(256, 1, 4, 4)            |     59.6  |     81.5
           (32, 32, 128, 128)x(32, 32, 128, 128)    |   4865.4  |   5037.6
           (32, 32, 64, 64)x(32, 32, 64, 64)        |    724.0  |    724.2
           (32, 32, 4, 4)x(32, 32, 4, 4)            |     52.7  |     92.4
           (32, 16, 128, 128)x(32, 16, 128, 128)    |   2511.0  |   2523.9
           (32, 16, 64, 64)x(32, 16, 64, 64)        |    371.7  |    371.6
           (32, 16, 4, 4)x(32, 16, 4, 4)            |     64.1  |     88.7
           (32, 1, 128, 128)x(32, 1, 128, 128)      |    170.6  |    172.2
           (32, 1, 64, 64)x(32, 1, 64, 64)          |     62.5  |     93.1
           (32, 1, 4, 4)x(32, 1, 4, 4)              |     57.9  |     80.1
           (1, 32, 128, 128)x(1, 32, 128, 128)      |    170.2  |    172.1
           (1, 32, 64, 64)x(1, 32, 64, 64)          |     51.7  |     81.9
           (1, 32, 4, 4)x(1, 32, 4, 4)              |     55.1  |     79.1
           (1, 16, 128, 128)x(1, 16, 128, 128)      |     75.1  |     91.9
           (1, 16, 64, 64)x(1, 16, 64, 64)          |     64.0  |     79.5
           (1, 16, 4, 4)x(1, 16, 4, 4)              |     56.4  |     79.8
           (1, 1, 128, 128)x(1, 1, 128, 128)        |     66.8  |     82.9
           (1, 1, 64, 64)x(1, 1, 64, 64)            |     68.6  |     83.6
           (1, 1, 4, 4)x(1, 1, 4, 4)                |     64.2  |     82.8
  (No a0)  (32, 128, 128)x(256, 32, 128, 128)       |  41409.0  |  41664.8
           (32, 64, 64)x(256, 32, 64, 64)           |   5943.2  |   5717.3
           (32, 4, 4)x(256, 32, 4, 4)               |    170.0  |    108.7
           (16, 128, 128)x(256, 16, 128, 128)       |  20606.6  |  20651.3
           (16, 64, 64)x(256, 16, 64, 64)           |   2697.3  |   2869.0
           (16, 4, 4)x(256, 16, 4, 4)               |     80.8  |     99.6
           (1, 128, 128)x(256, 1, 128, 128)         |   1102.7  |    955.0
           (1, 64, 64)x(256, 1, 64, 64)             |    136.9  |    188.1
           (1, 4, 4)x(256, 1, 4, 4)                 |     65.9  |    104.4
           (32, 128, 128)x(32, 32, 128, 128)        |   5099.2  |   5208.4
           (32, 64, 64)x(32, 32, 64, 64)            |    771.4  |    742.5
           (32, 4, 4)x(32, 32, 4, 4)                |     86.0  |    102.9
           (16, 128, 128)x(32, 16, 128, 128)        |   2587.0  |   2606.2
           (16, 64, 64)x(32, 16, 64, 64)            |    359.4  |    378.0
           (16, 4, 4)x(32, 16, 4, 4)                |     92.2  |    117.4
           (1, 128, 128)x(32, 1, 128, 128)          |    151.2  |    137.0
           (1, 64, 64)x(32, 1, 64, 64)              |     65.3  |    109.1
           (1, 4, 4)x(32, 1, 4, 4)                  |     73.6  |    115.5
           (32, 128, 128)x(1, 32, 128, 128)         |    170.8  |    172.3
           (32, 64, 64)x(1, 32, 64, 64)             |     55.8  |     83.5
           (32, 4, 4)x(1, 32, 4, 4)                 |     60.1  |     81.1
           (16, 128, 128)x(1, 16, 128, 128)         |     75.4  |     88.6
           (16, 64, 64)x(1, 16, 64, 64)             |     57.9  |     81.4
           (16, 4, 4)x(1, 16, 4, 4)                 |     58.3  |     80.7
           (1, 128, 128)x(1, 1, 128, 128)           |     68.5  |     81.7
           (1, 64, 64)x(1, 1, 64, 64)               |     69.5  |     87.2
           (1, 4, 4)x(1, 1, 4, 4)                   |     73.2  |     97.5
  (No b0)  (256, 32, 128, 128)x(32, 128, 128)       |  41350.2  |  41628.3
           (256, 32, 64, 64)x(32, 64, 64)           |   5935.2  |   6984.1
           (256, 32, 4, 4)x(32, 4, 4)               |    169.9  |    118.6
           (256, 16, 128, 128)x(16, 128, 128)       |  20609.6  |  20646.5
           (256, 16, 64, 64)x(16, 64, 64)           |   2692.2  |   3489.1
           (256, 16, 4, 4)x(16, 4, 4)               |     80.9  |    112.8
           (256, 1, 128, 128)x(1, 128, 128)         |   1101.0  |    753.1
           (256, 1, 64, 64)x(1, 64, 64)             |    139.5  |    178.4
           (256, 1, 4, 4)x(1, 4, 4)                 |     66.5  |     81.2
           (32, 32, 128, 128)x(32, 128, 128)        |   5085.8  |   5191.4
           (32, 32, 64, 64)x(32, 64, 64)            |    766.8  |    894.6
           (32, 32, 4, 4)x(32, 4, 4)                |     96.6  |    163.8
           (32, 16, 128, 128)x(16, 128, 128)        |   2607.6  |   2612.5
           (32, 16, 64, 64)x(16, 64, 64)            |    358.3  |    455.5
           (32, 16, 4, 4)x(16, 4, 4)                |     70.5  |    105.0
           (32, 1, 128, 128)x(1, 128, 128)          |    150.8  |    109.2
           (32, 1, 64, 64)x(1, 64, 64)              |     77.3  |     94.8
           (32, 1, 4, 4)x(1, 4, 4)                  |     65.8  |     89.3
           (1, 32, 128, 128)x(32, 128, 128)         |    170.2  |    171.9
           (1, 32, 64, 64)x(32, 64, 64)             |     51.6  |     83.7
           (1, 32, 4, 4)x(32, 4, 4)                 |     68.6  |     98.6
           (1, 16, 128, 128)x(16, 128, 128)         |     74.6  |     86.9
           (1, 16, 64, 64)x(16, 64, 64)             |     73.1  |    100.9
           (1, 16, 4, 4)x(16, 4, 4)                 |     71.5  |     98.6
           (1, 1, 128, 128)x(1, 128, 128)           |     78.0  |     99.1
           (1, 1, 64, 64)x(1, 64, 64)               |     68.6  |    100.3
           (1, 1, 4, 4)x(1, 4, 4)                   |     73.6  |     95.4

Times are in microseconds (us).

[----------------------- Matrix-vector product ------------------------]
                                                  |   matmul  |   einsum
1 threads: -------------------------------------------------------------
  (Full)   (256, 32, 128, 128)x(256, 32, 128, 1)  |   8758.8  |   8771.6
           (256, 32, 64, 64)x(256, 32, 64, 1)     |   2265.5  |   2257.6
           (256, 32, 4, 4)x(256, 32, 4, 1)        |     64.4  |     84.1
           (256, 16, 128, 128)x(256, 16, 128, 1)  |   4432.1  |   4430.9
           (256, 16, 64, 64)x(256, 16, 64, 1)     |   1132.3  |   1131.9
           (256, 16, 4, 4)x(256, 16, 4, 1)        |     57.5  |     80.1
           (256, 1, 128, 128)x(256, 1, 128, 1)    |    287.1  |    287.1
           (256, 1, 64, 64)x(256, 1, 64, 1)       |     82.3  |     85.1
           (256, 1, 4, 4)x(256, 1, 4, 1)          |     57.9  |     81.1
           (32, 32, 128, 128)x(32, 32, 128, 1)    |   1098.5  |   1108.3
           (32, 32, 64, 64)x(32, 32, 64, 1)       |    295.7  |    295.8
           (32, 32, 4, 4)x(32, 32, 4, 1)          |     64.2  |     82.5
           (32, 16, 128, 128)x(32, 16, 128, 1)    |    558.9  |    558.7
           (32, 16, 64, 64)x(32, 16, 64, 1)       |    152.5  |    152.5
           (32, 16, 4, 4)x(32, 16, 4, 1)          |     64.8  |     80.1
           (32, 1, 128, 128)x(32, 1, 128, 1)      |     61.9  |     92.6
           (32, 1, 64, 64)x(32, 1, 64, 1)         |     64.6  |     87.5
           (32, 1, 4, 4)x(32, 1, 4, 1)            |     60.7  |     80.5
           (1, 32, 128, 128)x(1, 32, 128, 1)      |     50.8  |     87.1
           (1, 32, 64, 64)x(1, 32, 64, 1)         |     55.7  |     79.5
           (1, 32, 4, 4)x(1, 32, 4, 1)            |     61.6  |     80.8
           (1, 16, 128, 128)x(1, 16, 128, 1)      |     55.0  |     80.2
           (1, 16, 64, 64)x(1, 16, 64, 1)         |     55.8  |     80.5
           (1, 16, 4, 4)x(1, 16, 4, 1)            |     58.6  |     80.1
           (1, 1, 128, 128)x(1, 1, 128, 1)        |     55.8  |     81.2
           (1, 1, 64, 64)x(1, 1, 64, 1)           |     58.5  |     81.3
           (1, 1, 4, 4)x(1, 1, 4, 1)              |     71.3  |    103.5
  (No a0)  (32, 128, 128)x(256, 32, 128, 1)       |  11368.5  |    317.5
           (32, 64, 64)x(256, 32, 64, 1)          |   2704.4  |     89.7
           (32, 4, 4)x(256, 32, 4, 1)             |     78.8  |     83.6
           (16, 128, 128)x(256, 16, 128, 1)       |   5688.9  |    167.6
           (16, 64, 64)x(256, 16, 64, 1)          |   1029.9  |     84.1
           (16, 4, 4)x(256, 16, 4, 1)             |     79.0  |     89.9
           (1, 128, 128)x(256, 1, 128, 1)         |    126.6  |     89.2
           (1, 64, 64)x(256, 1, 64, 1)            |     80.8  |    106.6
           (1, 4, 4)x(256, 1, 4, 1)               |     66.1  |     89.6
           (32, 128, 128)x(32, 32, 128, 1)        |   1448.9  |     85.4
           (32, 64, 64)x(32, 32, 64, 1)           |    354.7  |     84.9
           (32, 4, 4)x(32, 32, 4, 1)              |     72.8  |     84.1
           (16, 128, 128)x(32, 16, 128, 1)        |    733.7  |     83.7
           (16, 64, 64)x(32, 16, 64, 1)           |    141.3  |    100.6
           (16, 4, 4)x(32, 16, 4, 1)              |     92.3  |    104.4
           (1, 128, 128)x(32, 1, 128, 1)          |     84.9  |     94.6
           (1, 64, 64)x(32, 1, 64, 1)             |     73.6  |    105.4
           (1, 4, 4)x(32, 1, 4, 1)                |     75.4  |    101.3
           (32, 128, 128)x(1, 32, 128, 1)         |     67.1  |     91.3
           (32, 64, 64)x(1, 32, 64, 1)            |     59.2  |     82.1
           (32, 4, 4)x(1, 32, 4, 1)               |     63.9  |     81.2
           (16, 128, 128)x(1, 16, 128, 1)         |     57.5  |     83.0
           (16, 64, 64)x(1, 16, 64, 1)            |     57.7  |     83.1
           (16, 4, 4)x(1, 16, 4, 1)               |     65.2  |     81.3
           (1, 128, 128)x(1, 1, 128, 1)           |     57.7  |     81.2
           (1, 64, 64)x(1, 1, 64, 1)              |     58.4  |     80.3
           (1, 4, 4)x(1, 1, 4, 1)                 |     80.6  |    107.0
  (No b0)  (256, 32, 128, 128)x(32, 128, 1)       |   8751.5  |  15982.3
           (256, 32, 64, 64)x(32, 64, 1)          |   2280.4  |   4043.9
           (256, 32, 4, 4)x(32, 4, 1)             |     84.3  |    124.5
           (256, 16, 128, 128)x(16, 128, 1)       |   4419.8  |   8036.1
           (256, 16, 64, 64)x(16, 64, 1)          |   1131.8  |   2028.7
           (256, 16, 4, 4)x(16, 4, 1)             |     72.7  |    107.3
           (256, 1, 128, 128)x(1, 128, 1)         |    283.3  |    271.1
           (256, 1, 64, 64)x(1, 64, 1)            |     79.8  |     81.9
           (256, 1, 4, 4)x(1, 4, 1)               |     67.4  |     96.3
           (32, 32, 128, 128)x(32, 128, 1)        |   1105.2  |   1999.8
           (32, 32, 64, 64)x(32, 64, 1)           |    292.8  |    518.5
           (32, 32, 4, 4)x(32, 4, 1)              |    100.1  |    134.8
           (32, 16, 128, 128)x(16, 128, 1)        |    559.8  |   1013.7
           (32, 16, 64, 64)x(16, 64, 1)           |    152.6  |    265.0
           (32, 16, 4, 4)x(16, 4, 1)              |     82.8  |    111.6
           (32, 1, 128, 128)x(1, 128, 1)          |     57.8  |     82.0
           (32, 1, 64, 64)x(1, 64, 1)             |     69.1  |     82.4
           (32, 1, 4, 4)x(1, 4, 1)                |     72.8  |     95.6
           (1, 32, 128, 128)x(32, 128, 1)         |     53.0  |     82.7
           (1, 32, 64, 64)x(32, 64, 1)            |     56.4  |     98.6
           (1, 32, 4, 4)x(32, 4, 1)               |     76.3  |     97.1
           (1, 16, 128, 128)x(16, 128, 1)         |     55.6  |     98.5
           (1, 16, 64, 64)x(16, 64, 1)            |     67.6  |     95.9
           (1, 16, 4, 4)x(16, 4, 1)               |     76.1  |     97.6
           (1, 1, 128, 128)x(1, 128, 1)           |     65.4  |     96.5
           (1, 1, 64, 64)x(1, 64, 1)              |     69.7  |     97.9
           (1, 1, 4, 4)x(1, 4, 1)                 |     80.8  |    109.2

Times are in microseconds (us).

[--------------------------- Transposed MVP ---------------------------]
                                                  |   matmul  |   einsum
1 threads: -------------------------------------------------------------
  (Full)   (256, 32, 1, 128)x(256, 32, 128, 128)  |   8156.9  |   8164.4
           (256, 32, 1, 64)x(256, 32, 64, 64)     |   2079.4  |   2079.0
           (256, 32, 1, 4)x(256, 32, 4, 4)        |     69.6  |     88.3
           (256, 16, 1, 128)x(256, 16, 128, 128)  |   4097.1  |   4095.3
           (256, 16, 1, 64)x(256, 16, 64, 64)     |   1049.0  |   1048.5
           (256, 16, 1, 4)x(256, 16, 4, 4)        |     54.5  |     78.8
           (256, 1, 1, 128)x(256, 1, 128, 128)    |    270.3  |    270.5
           (256, 1, 1, 64)x(256, 1, 64, 64)       |     74.5  |     75.1
           (256, 1, 1, 4)x(256, 1, 4, 4)          |     57.6  |     80.5
           (32, 32, 1, 128)x(32, 32, 128, 128)    |   1036.2  |   1036.7
           (32, 32, 1, 64)x(32, 32, 64, 64)       |    274.2  |    274.4
           (32, 32, 1, 4)x(32, 32, 4, 4)          |     54.2  |     81.5
           (32, 16, 1, 128)x(32, 16, 128, 128)    |    530.1  |    529.6
           (32, 16, 1, 64)x(32, 16, 64, 64)       |    141.4  |    141.7
           (32, 16, 1, 4)x(32, 16, 4, 4)          |     65.6  |     85.6
           (32, 1, 1, 128)x(32, 1, 128, 128)      |     69.3  |    121.2
           (32, 1, 1, 64)x(32, 1, 64, 64)         |     62.4  |     80.2
           (32, 1, 1, 4)x(32, 1, 4, 4)            |     58.5  |     80.4
           (1, 32, 1, 128)x(1, 32, 128, 128)      |     54.5  |     90.1
           (1, 32, 1, 64)x(1, 32, 64, 64)         |     58.0  |     81.3
           (1, 32, 1, 4)x(1, 32, 4, 4)            |     57.5  |     79.8
           (1, 16, 1, 128)x(1, 16, 128, 128)      |     65.3  |     81.9
           (1, 16, 1, 64)x(1, 16, 64, 64)         |     57.5  |     90.7
           (1, 16, 1, 4)x(1, 16, 4, 4)            |     59.5  |     79.7
           (1, 1, 1, 128)x(1, 1, 128, 128)        |     69.3  |     79.4
           (1, 1, 1, 64)x(1, 1, 64, 64)           |     68.1  |     82.0
           (1, 1, 1, 4)x(1, 1, 4, 4)              |     73.4  |     82.1
  (No a0)  (32, 1, 128)x(256, 32, 128, 128)       |   8155.8  |  15325.0
           (32, 1, 64)x(256, 32, 64, 64)          |   2071.2  |   3938.2
           (32, 1, 4)x(256, 32, 4, 4)             |     71.9  |    103.3
           (16, 1, 128)x(256, 16, 128, 128)       |   4085.8  |   7668.6
           (16, 1, 64)x(256, 16, 64, 64)          |   1045.5  |   1971.8
           (16, 1, 4)x(256, 16, 4, 4)             |     79.8  |    102.1
           (1, 1, 128)x(256, 1, 128, 128)         |    264.5  |    496.9
           (1, 1, 64)x(256, 1, 64, 64)            |     72.9  |    143.9
           (1, 1, 4)x(256, 1, 4, 4)               |     68.6  |    110.5
           (32, 1, 128)x(32, 32, 128, 128)        |   1032.5  |   1931.6
           (32, 1, 64)x(32, 32, 64, 64)           |    271.5  |    509.4
           (32, 1, 4)x(32, 32, 4, 4)              |     80.5  |    120.2
           (16, 1, 128)x(32, 16, 128, 128)        |    524.2  |    979.9
           (16, 1, 64)x(32, 16, 64, 64)           |    142.5  |    256.6
           (16, 1, 4)x(32, 16, 4, 4)              |    103.2  |    119.8
           (1, 1, 128)x(32, 1, 128, 128)          |     61.2  |    102.4
           (1, 1, 64)x(32, 1, 64, 64)             |     69.2  |    121.2
           (1, 1, 4)x(32, 1, 4, 4)                |     75.3  |    119.2
           (32, 1, 128)x(1, 32, 128, 128)         |     70.1  |     92.3
           (32, 1, 64)x(1, 32, 64, 64)            |     66.5  |     81.0
           (32, 1, 4)x(1, 32, 4, 4)               |     66.7  |     81.0
           (16, 1, 128)x(1, 16, 128, 128)         |     66.1  |     85.0
           (16, 1, 64)x(1, 16, 64, 64)            |     58.3  |     83.0
           (16, 1, 4)x(1, 16, 4, 4)               |     61.5  |     82.2
           (1, 1, 128)x(1, 1, 128, 128)           |     69.4  |     79.5
           (1, 1, 64)x(1, 1, 64, 64)              |     67.1  |     79.2
           (1, 1, 4)x(1, 1, 4, 4)                 |     76.8  |     97.8
  (No b0)  (256, 32, 1, 128)x(32, 128, 128)       |  11310.7  |    307.8
           (256, 32, 1, 64)x(32, 64, 64)          |   2654.2  |    103.4
           (256, 32, 1, 4)x(32, 4, 4)             |     88.4  |    101.2
           (256, 16, 1, 128)x(16, 128, 128)       |   5672.2  |    161.3
           (256, 16, 1, 64)x(16, 64, 64)          |    971.0  |     97.6
           (256, 16, 1, 4)x(16, 4, 4)             |     71.9  |     93.0
           (256, 1, 1, 128)x(1, 128, 128)         |     63.8  |    104.2
           (256, 1, 1, 64)x(1, 64, 64)            |     72.3  |     84.2
           (256, 1, 1, 4)x(1, 4, 4)               |     58.9  |     98.4
           (32, 32, 1, 128)x(32, 128, 128)        |   1444.0  |     97.6
           (32, 32, 1, 64)x(32, 64, 64)           |    346.9  |    119.0
           (32, 32, 1, 4)x(32, 4, 4)              |    101.3  |    119.3
           (32, 16, 1, 128)x(16, 128, 128)        |    731.4  |    104.3
           (32, 16, 1, 64)x(16, 64, 64)           |    134.4  |     95.6
           (32, 16, 1, 4)x(16, 4, 4)              |     79.5  |     96.8
           (32, 1, 1, 128)x(1, 128, 128)          |     70.0  |     83.4
           (32, 1, 1, 64)x(1, 64, 64)             |     67.0  |     98.1
           (32, 1, 1, 4)x(1, 4, 4)                |     74.5  |     94.4
           (1, 32, 1, 128)x(32, 128, 128)         |     54.4  |     90.8
           (1, 32, 1, 64)x(32, 64, 64)            |     64.4  |     92.0
           (1, 32, 1, 4)x(32, 4, 4)               |     71.5  |    102.9
           (1, 16, 1, 128)x(16, 128, 128)         |     64.8  |    100.4
           (1, 16, 1, 64)x(16, 64, 64)            |     69.4  |     97.8
           (1, 16, 1, 4)x(16, 4, 4)               |     75.3  |     98.7
           (1, 1, 1, 128)x(1, 128, 128)           |     69.8  |     99.9
           (1, 1, 1, 64)x(1, 64, 64)              |     76.9  |     95.5
           (1, 1, 1, 4)x(1, 4, 4)                 |     76.7  |    108.9

Times are in microseconds (us).

[---------------------- Vector outer product ----------------------]
                                                |  matmul  |  einsum
1 threads: ---------------------------------------------------------
  (Full)   (256, 32, 128, 1)x(256, 32, 1, 128)  |  3857.8  |  5171.4
           (256, 32, 64, 1)x(256, 32, 1, 64)    |  1042.8  |  1337.9
           (256, 32, 4, 1)x(256, 32, 1, 4)      |    93.5  |    62.8
           (256, 16, 128, 1)x(256, 16, 1, 128)  |  1938.2  |  2597.6
           (256, 16, 64, 1)x(256, 16, 1, 64)    |   528.5  |   670.7
           (256, 16, 4, 1)x(256, 16, 1, 4)      |    52.2  |    63.0
           (256, 1, 128, 1)x(256, 1, 1, 128)    |   131.4  |   168.9
           (256, 1, 64, 1)x(256, 1, 1, 64)      |    56.0  |    66.5
           (256, 1, 4, 1)x(256, 1, 1, 4)        |    64.9  |    67.7
           (32, 32, 128, 1)x(32, 32, 1, 128)    |   497.3  |   661.1
           (32, 32, 64, 1)x(32, 32, 1, 64)      |   141.3  |   174.8
           (32, 32, 4, 1)x(32, 32, 1, 4)        |    54.7  |    62.3
           (32, 16, 128, 1)x(32, 16, 1, 128)    |   254.0  |   334.6
           (32, 16, 64, 1)x(32, 16, 1, 64)      |    75.3  |    90.6
           (32, 16, 4, 1)x(32, 16, 1, 4)        |    57.8  |    63.6
           (32, 1, 128, 1)x(32, 1, 1, 128)      |    78.6  |    78.3
           (32, 1, 64, 1)x(32, 1, 1, 64)        |    55.1  |    62.6
           (32, 1, 4, 1)x(32, 1, 1, 4)          |    58.0  |    66.9
           (1, 32, 128, 1)x(1, 32, 1, 128)      |    54.7  |    63.0
           (1, 32, 64, 1)x(1, 32, 1, 64)        |    64.1  |    62.1
           (1, 32, 4, 1)x(1, 32, 1, 4)          |    58.2  |    63.1
           (1, 16, 128, 1)x(1, 16, 1, 128)      |    59.4  |    64.1
           (1, 16, 64, 1)x(1, 16, 1, 64)        |    59.2  |    71.7
           (1, 16, 4, 1)x(1, 16, 1, 4)          |    60.0  |    64.8
           (1, 1, 128, 1)x(1, 1, 1, 128)        |    62.0  |    62.8
           (1, 1, 64, 1)x(1, 1, 1, 64)          |    62.2  |    66.4
           (1, 1, 4, 1)x(1, 1, 1, 4)            |    61.1  |    65.9
  (No a0)  (32, 128, 1)x(256, 32, 1, 128)       |  3853.3  |  5370.5
           (32, 64, 1)x(256, 32, 1, 64)         |  1041.9  |  1391.3
           (32, 4, 1)x(256, 32, 1, 4)           |    94.2  |    63.8
           (16, 128, 1)x(256, 16, 1, 128)       |  1937.8  |  2583.4
           (16, 64, 1)x(256, 16, 1, 64)         |   529.2  |   700.7
           (16, 4, 1)x(256, 16, 1, 4)           |    80.1  |    63.6
           (1, 128, 1)x(256, 1, 1, 128)         |   123.9  |   158.6
           (1, 64, 1)x(256, 1, 1, 64)           |    61.5  |    63.0
           (1, 4, 1)x(256, 1, 1, 4)             |    62.0  |    70.0
           (32, 128, 1)x(32, 32, 1, 128)        |   497.1  |   681.7
           (32, 64, 1)x(32, 32, 1, 64)          |   142.2  |   181.4
           (32, 4, 1)x(32, 32, 1, 4)            |    88.5  |    76.3
           (16, 128, 1)x(32, 16, 1, 128)        |   255.0  |   330.6
           (16, 64, 1)x(32, 16, 1, 64)          |    82.6  |    94.2
           (16, 4, 1)x(32, 16, 1, 4)            |    94.0  |    77.1
           (1, 128, 1)x(32, 1, 1, 128)          |    61.8  |    64.7
           (1, 64, 1)x(32, 1, 1, 64)            |    80.0  |    83.9
           (1, 4, 1)x(32, 1, 1, 4)              |    74.9  |    79.5
           (32, 128, 1)x(1, 32, 1, 128)         |    66.2  |    74.2
           (32, 64, 1)x(1, 32, 1, 64)           |    59.5  |    64.3
           (32, 4, 1)x(1, 32, 1, 4)             |    63.0  |    67.1
           (16, 128, 1)x(1, 16, 1, 128)         |    61.8  |    63.3
           (16, 64, 1)x(1, 16, 1, 64)           |    61.1  |    69.9
           (16, 4, 1)x(1, 16, 1, 4)             |    64.5  |    67.6
           (1, 128, 1)x(1, 1, 1, 128)           |    65.6  |    65.8
           (1, 64, 1)x(1, 1, 1, 64)             |    60.8  |    66.4
           (1, 4, 1)x(1, 1, 1, 4)               |    76.1  |    78.9
  (No b0)  (256, 32, 128, 1)x(32, 1, 128)       |  3853.5  |  5463.5
           (256, 32, 64, 1)x(32, 1, 64)         |  1041.5  |  1371.9
           (256, 32, 4, 1)x(32, 1, 4)           |    94.3  |    76.2
           (256, 16, 128, 1)x(16, 1, 128)       |  1937.0  |  2714.4
           (256, 16, 64, 1)x(16, 1, 64)         |   520.2  |   691.9
           (256, 16, 4, 1)x(16, 1, 4)           |    81.0  |    73.1
           (256, 1, 128, 1)x(1, 1, 128)         |   124.1  |   152.4
           (256, 1, 64, 1)x(1, 1, 64)           |    58.2  |    62.4
           (256, 1, 4, 1)x(1, 1, 4)             |    76.5  |    82.5
           (32, 32, 128, 1)x(32, 1, 128)        |   495.6  |   687.8
           (32, 32, 64, 1)x(32, 1, 64)          |   141.8  |   179.0
           (32, 32, 4, 1)x(32, 1, 4)            |    77.4  |    66.9
           (32, 16, 128, 1)x(16, 1, 128)        |   254.5  |   345.6
           (32, 16, 64, 1)x(16, 1, 64)          |    79.6  |    92.6
           (32, 16, 4, 1)x(16, 1, 4)            |    81.8  |    82.6
           (32, 1, 128, 1)x(1, 1, 128)          |    59.4  |    77.1
           (32, 1, 64, 1)x(1, 1, 64)            |    59.5  |    70.9
           (32, 1, 4, 1)x(1, 1, 4)              |    69.8  |    74.9
           (1, 32, 128, 1)x(32, 1, 128)         |    54.6  |    63.7
           (1, 32, 64, 1)x(32, 1, 64)           |    71.6  |    82.0
           (1, 32, 4, 1)x(32, 1, 4)             |    72.1  |    79.2
           (1, 16, 128, 1)x(16, 1, 128)         |    78.2  |    87.9
           (1, 16, 64, 1)x(16, 1, 64)           |    68.8  |    78.0
           (1, 16, 4, 1)x(16, 1, 4)             |    77.0  |    79.3
           (1, 1, 128, 1)x(1, 1, 128)           |    75.9  |    78.9
           (1, 1, 64, 1)x(1, 1, 64)             |    75.2  |    76.7
           (1, 1, 4, 1)x(1, 1, 4)               |    76.1  |    78.1

Times are in microseconds (us).

[---------------------- Vector inner product ----------------------]
                                                |  matmul  |  einsum
1 threads: ---------------------------------------------------------
  (Full)   (256, 32, 1, 128)x(256, 32, 128, 1)  |  256.9   |  257.1 
           (256, 32, 1, 64)x(256, 32, 64, 1)    |  141.1   |  141.7 
           (256, 32, 1, 4)x(256, 32, 4, 1)      |   57.9   |   86.4
           (256, 16, 1, 128)x(256, 16, 128, 1)  |  137.0   |  137.0 
           (256, 16, 1, 64)x(256, 16, 64, 1)    |   74.9   |   85.8
           (256, 16, 1, 4)x(256, 16, 4, 1)      |   54.5   |   78.6
           (256, 1, 1, 128)x(256, 1, 128, 1)    |   65.3   |   90.8
           (256, 1, 1, 64)x(256, 1, 64, 1)      |   59.7   |   83.2
           (256, 1, 1, 4)x(256, 1, 4, 1)        |   55.1   |   80.6
           (32, 32, 1, 128)x(32, 32, 128, 1)    |   52.2   |   90.6
           (32, 32, 1, 64)x(32, 32, 64, 1)      |   65.9   |   80.0
           (32, 32, 1, 4)x(32, 32, 4, 1)        |   55.5   |   82.7
           (32, 16, 1, 128)x(32, 16, 128, 1)    |   67.0   |   80.0
           (32, 16, 1, 64)x(32, 16, 64, 1)      |   66.2   |   97.3
           (32, 16, 1, 4)x(32, 16, 4, 1)        |   68.4   |   85.7
           (32, 1, 1, 128)x(32, 1, 128, 1)      |   70.1   |   95.6
           (32, 1, 1, 64)x(32, 1, 64, 1)        |   67.7   |   79.5
           (32, 1, 1, 4)x(32, 1, 4, 1)          |   61.2   |   79.5
           (1, 32, 1, 128)x(1, 32, 128, 1)      |   56.5   |   82.0
           (1, 32, 1, 64)x(1, 32, 64, 1)        |   62.8   |   79.3
           (1, 32, 1, 4)x(1, 32, 4, 1)          |   61.6   |   79.3
           (1, 16, 1, 128)x(1, 16, 128, 1)      |   57.8   |   80.2
           (1, 16, 1, 64)x(1, 16, 64, 1)        |   64.5   |   79.7
           (1, 16, 1, 4)x(1, 16, 4, 1)          |   57.7   |   79.9
           (1, 1, 1, 128)x(1, 1, 128, 1)        |   61.4   |   88.4
           (1, 1, 1, 64)x(1, 1, 64, 1)          |   67.6   |   91.3
           (1, 1, 1, 4)x(1, 1, 4, 1)            |   63.5   |   81.8
  (No a0)  (32, 1, 128)x(256, 32, 128, 1)       |  251.3   |   79.1 
           (32, 1, 64)x(256, 32, 64, 1)         |  136.7   |   90.3 
           (32, 1, 4)x(256, 32, 4, 1)           |   73.0   |   86.9
           (16, 1, 128)x(256, 16, 128, 1)       |  133.2   |  102.1 
           (16, 1, 64)x(256, 16, 64, 1)         |   76.2   |   83.9
           (16, 1, 4)x(256, 16, 4, 1)           |   80.1   |   87.5 
           (1, 1, 128)x(256, 1, 128, 1)         |   89.5   |   96.2 
           (1, 1, 64)x(256, 1, 64, 1)           |   62.3   |   90.7
           (1, 1, 4)x(256, 1, 4, 1)             |   62.9   |   88.8
           (32, 1, 128)x(32, 32, 128, 1)        |   77.6   |   89.8
           (32, 1, 64)x(32, 32, 64, 1)          |   87.9   |   89.7 
           (32, 1, 4)x(32, 32, 4, 1)            |   84.5   |  104.8
           (16, 1, 128)x(32, 16, 128, 1)        |   90.0   |  111.5
           (16, 1, 64)x(32, 16, 64, 1)          |  106.2   |  123.3
           (16, 1, 4)x(32, 16, 4, 1)            |  104.2   |  103.5 
           (1, 1, 128)x(32, 1, 128, 1)          |   71.5   |   86.5
           (1, 1, 64)x(32, 1, 64, 1)            |   77.2   |   97.9
           (1, 1, 4)x(32, 1, 4, 1)              |   74.8   |  102.6
           (32, 1, 128)x(1, 32, 128, 1)         |   65.6   |   82.8
           (32, 1, 64)x(1, 32, 64, 1)           |   66.9   |   82.2
           (32, 1, 4)x(1, 32, 4, 1)             |   64.8   |   80.6
           (16, 1, 128)x(1, 16, 128, 1)         |   60.4   |   90.8
           (16, 1, 64)x(1, 16, 64, 1)           |   61.3   |   82.1
           (16, 1, 4)x(1, 16, 4, 1)             |   65.0   |   81.9
           (1, 1, 128)x(1, 1, 128, 1)           |   63.0   |   87.5
           (1, 1, 64)x(1, 1, 64, 1)             |   62.5   |   88.9
           (1, 1, 4)x(1, 1, 4, 1)               |   76.9   |   98.4
  (No b0)  (256, 32, 1, 128)x(32, 128, 1)       |  251.1   |   98.1 
           (256, 32, 1, 64)x(32, 64, 1)         |  137.6   |  102.3 
           (256, 32, 1, 4)x(32, 4, 1)           |   83.4   |  103.7
           (256, 16, 1, 128)x(16, 128, 1)       |  133.1   |   89.2 
           (256, 16, 1, 64)x(16, 64, 1)         |   76.2   |   86.6
           (256, 16, 1, 4)x(16, 4, 1)           |   88.4   |  103.9
           (256, 1, 1, 128)x(1, 128, 1)         |   77.3   |   95.9
           (256, 1, 1, 64)x(1, 64, 1)           |   58.9   |   82.1
           (256, 1, 1, 4)x(1, 4, 1)             |   78.2   |   96.4
           (32, 32, 1, 128)x(32, 128, 1)        |   94.8   |  105.0
           (32, 32, 1, 64)x(32, 64, 1)          |  104.9   |  111.4 
           (32, 32, 1, 4)x(32, 4, 1)            |   88.8   |  104.2
           (32, 16, 1, 128)x(16, 128, 1)        |   82.6   |   89.2 
           (32, 16, 1, 64)x(16, 64, 1)          |   81.1   |   92.9
           (32, 16, 1, 4)x(16, 4, 1)            |   94.1   |  100.6 
           (32, 1, 1, 128)x(1, 128, 1)          |   70.2   |  100.2
           (32, 1, 1, 64)x(1, 64, 1)            |   71.3   |   83.7
           (32, 1, 1, 4)x(1, 4, 1)              |   70.2   |   92.5
           (1, 32, 1, 128)x(32, 128, 1)         |   57.1   |   83.7
           (1, 32, 1, 64)x(32, 64, 1)           |   71.2   |   98.5
           (1, 32, 1, 4)x(32, 4, 1)             |   71.6   |   90.1
           (1, 16, 1, 128)x(16, 128, 1)         |   67.3   |   90.7
           (1, 16, 1, 64)x(16, 64, 1)           |   73.2   |   98.5
           (1, 16, 1, 4)x(16, 4, 1)             |   74.8   |   98.3
           (1, 1, 1, 128)x(1, 128, 1)           |   76.7   |  105.1
           (1, 1, 1, 64)x(1, 64, 1)             |   81.4   |  105.8
           (1, 1, 1, 4)x(1, 4, 1)               |   75.7   |  107.8

Times are in microseconds (us).

@slishak-PX
Copy link
Contributor Author

slishak-PX commented Apr 29, 2024

Along the same lines as above - I've found another (less significant) memory saving. If I attempt to optimise an acquisition function in the original issue with the setup below, after replacing the matmul with einsum, I still run out of memory on an 80GB A100 while computing the distance matrix.

candidates, acq_values = optimize_acqf(
    ucb,
    bounds=torch.cat((torch.zeros(1, d), torch.ones(1, d))).to(**tkwargs),
    q=1,
    num_restarts=10,
    raw_samples=1024,
)

Traceback:

Traceback (most recent call last):
  File "/tmp/ipykernel_2227345/3762668224.py", line 5, in <module>
    candidates, acq_values = optimize_acqf(
                             ^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/optim/optimize.py", line 563, in optimize_acqf
    return _optimize_acqf(opt_acqf_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/optim/optimize.py", line 584, in _optimize_acqf
    return _optimize_acqf_batch(opt_inputs=opt_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/optim/optimize.py", line 274, in _optimize_acqf_batch
    batch_initial_conditions = opt_inputs.get_ic_generator()(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/optim/initializers.py", line 417, in gen_batch_initial_conditions
    Y_rnd_curr = acq_function(
                 ^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/utils/transforms.py", line 259, in decorated
    output = method(acqf, X, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 786, in forward
    mean, sigma = self._mean_and_sigma(X)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/acquisition/analytic.py", line 106, in _mean_and_sigma
    posterior = self.model.posterior(
                ^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/models/fully_bayesian.py", line 536, in posterior
    posterior = super().posterior(
                ^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/botorch/models/gpytorch.py", line 383, in posterior
    mvn = self(X)
          ^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/models/exact_gp.py", line 333, in __call__
    ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py", line 286, in exact_prediction
    test_covar = joint_covar[..., self.num_train :, :].to_dense()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 410, in to_dense
    return self.evaluate_kernel().to_dense()
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 25, in wrapped
    output = method(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
    res = self.kernel(
          ^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/kernel.py", line 532, in __call__
    super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
  File "/home/.../lib/python3.11/site-packages/gpytorch/module.py", line 31, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/scale_kernel.py", line 109, in forward
    orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/matern_kernel.py", line 99, in forward
    distance = self.covar_dist(x1_, x2_, diag=diag, **params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/kernel.py", line 359, in covar_dist
    return dist_func(x1, x2, x1_eq_x2)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/gpytorch/kernels/kernel.py", line 59, in dist
    res = torch.cdist(x1, x2)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../lib/python3.11/site-packages/torch/functional.py", line 1330, in cdist
    return _VF.cdist(x1, x2, p, None)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 64.03 GiB. GPU 3 has a total capacity of 79.15 GiB of which 9.08 GiB is free. Process 3788106 has 706.00 MiB memory in use. Process 3860132 has 706.00 MiB memory in use. Including non-PyTorch memory, this process has 68.67 GiB memory in use. Of the allocated memory 68.09 GiB is allocated by PyTorch, and 51.37 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

If I modify the call to torch.cdist to the following, the memory usage is reduced just enough for it to work:

res = torch.cdist(x1, x2, compute_mode="donot_use_mm_for_euclid_dist")

For ref: at this point, x1 and x2 have shapes [1024, 16, 1, 256] and [1024, 16, 2049, 256] respectively.

The PyTorch documentation on compute_mode is quite vague, but pytorch/pytorch#42479 suggests that this option is slower but slightly more accurate.

(This issue is also more of a GPyTorch issue than a BoTorch issue, but it's closely related to this so I'm just adding it here for now; let me know if you want me to create a new issue).

@slishak-PX
Copy link
Contributor Author

slishak-PX commented Jul 31, 2024

There is a very similar issue with KroneckerMultiTaskGP. Below is a minimum reproducible example.

import torch
from botorch.models import KroneckerMultiTaskGP

n_inputs = 10
n_tasks = 4
n_train = 2048
n_test = 1
device = torch.device("cuda:0")

train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device)

test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device)

gp = KroneckerMultiTaskGP(train_x, train_y)

posterior = gp.posterior(test_x)
posterior.rsample(torch.Size([256, 1]))
Stack trace
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[7], line 18
     15 gp = KroneckerMultiTaskGP(train_x, train_y)
     17 posterior = gp.posterior(test_x)
---> 18 posterior.rsample(torch.Size([256, 1]))

File ~/miniconda3/envs/al/lib/python3.10/site-packages/botorch/posteriors/multitask.py:269, in MultitaskGPPosterior.rsample(self, sample_shape)
    267 if sample_shape is None:
    268     sample_shape = torch.Size([1])
--> 269 return self.rsample_from_base_samples(
    270     sample_shape=sample_shape, base_samples=None
    271 )

File ~/miniconda3/envs/al/lib/python3.10/site-packages/botorch/posteriors/multitask.py:229, in MultitaskGPPosterior.rsample_from_base_samples(self, sample_shape, base_samples, train_diff)
    225 obs_minus_samples = (
    226     train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples
    227 )
    228 train_covar_plus_noise = self.train_train_covar + self.train_noise
--> 229 obs_solve = train_covar_plus_noise.solve(obs_minus_samples.unsqueeze(-1))
    231 # and multiply the test-observed matrix against the result of the solve
    232 updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1)

File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/operators/_linear_operator.py:2334, in LinearOperator.solve(self, right_tensor, left_tensor)
   2332 func = Solve
   2333 if left_tensor is None:
-> 2334     return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
   2335 else:
   2336     return func.apply(
   2337         self.representation_tree(),
   2338         True,
   (...)
   2341         *self.representation(),
   2342     )

File ~/miniconda3/envs/al/lib/python3.10/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
    503 if not torch._C._are_functorch_transforms_active():
    504     # See NOTE: [functorch vjp and autograd interaction]
    505     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506     return super().apply(*args, **kwargs)  # type: ignore[misc]
    508 if cls.setup_context == _SingleLevelFunction.setup_context:
    509     raise RuntimeError(
    510         'In order to use an autograd.Function with functorch transforms '
    511         '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
    512         'staticmethod. For more details, please see '
    513         'https://pytorch.org/docs/master/notes/extending.func.html')

File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/functions/_solve.py:53, in Solve.forward(ctx, representation_tree, has_left, *args)
     51     res = left_tensor @ res
     52 else:
---> 53     solves = _solve(linear_op, right_tensor)
     54     res = solves
     56 if ctx.is_vector:

File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/functions/_solve.py:17, in _solve(linear_op, rhs)
     15     return linear_op.solve(rhs)
     16 if settings.fast_computations.solves.off() or linear_op.size(-1) <= settings.max_cholesky_size.value():
---> 17     return linear_op.cholesky()._cholesky_solve(rhs)
     18 else:
     19     with torch.no_grad():

File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/operators/triangular_linear_operator.py:78, in TriangularLinearOperator._cholesky_solve(self, rhs, upper)
     71 def _cholesky_solve(
     72     self: Float[LinearOperator, "*batch N N"],
     73     rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]],
     74     upper: Optional[bool] = False,
     75 ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]:
     76     # use custom method if implemented
     77     try:
---> 78         res = self._tensor._cholesky_solve(rhs=rhs, upper=upper)
     79     except NotImplementedError:
     80         if upper:
     81             # res = (U.T @ U)^-1 @ v = U^-1 @ U^-T @ v

File ~/miniconda3/envs/al/lib/python3.10/site-packages/linear_operator/operators/dense_linear_operator.py:38, in DenseLinearOperator._cholesky_solve(self, rhs, upper)
     33 def _cholesky_solve(
     34     self: Float[LinearOperator, "*batch N N"],
     35     rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]],
     36     upper: Optional[bool] = False,
     37 ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]:
---> 38     return torch.cholesky_solve(rhs, self.to_dense(), upper=upper)

OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 GiB (GPU 0; 79.19 GiB total capacity; 4.73 GiB already allocated; 71.26 GiB free; 5.26 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The final line requires allocation of 128GB of GPU memory, because of the call to torch.cholesky_solve with B shaped (256, 1, 8192, 1) and L shaped (8192, 8192).

The following line appears to be the root cause. Unsqueezing the last dimension sets up a batched matrix-vector solve, but if we instead transpose one of the batch dimensions to the end, we do a more efficient matrix-matrix solve.

obs_solve = train_covar_plus_noise.solve(obs_minus_samples.unsqueeze(-1))

For example, the following code requires less than 4GB of GPU memory, and I believe it is equivalent. By moving the first batch dimension to the final position, we at least stand the chance of having a more efficient operation if the first batch dimension is greater than 1. It would probably be even better to find the largest batch dimension and move that one to the end, or even flatten them all.

perm = list(range(1, obs_minus_samples.ndim)) + [0]
inverse_perm = torch.argsort(torch.tensor(perm))

obs_minus_samples_p = obs_minus_samples.permute(*perm)
obs_solve = train_covar_plus_noise.solve(obs_minus_samples_p)

# and multiply the test-observed matrix against the result of the solve
updated_samples = self.test_train_covar.matmul(obs_solve).permute(*inverse_perm)

@Balandat @esantorella: should I submit a PR for this, or is there an obvious reason that it wouldn't work in other use cases?

@Balandat
Copy link
Contributor

Balandat commented Aug 1, 2024

should I submit a PR for this, or is there an obvious reason that it wouldn't work in other use cases?

A PR for this would be great - I don't see any obvious reason why this wouldn't work in other cases. Ideally, we could even do this at the level of LinearOperator.solve() so that other cases can also benefit from this. Even more ideal, this could be done at the level of pytorch's torch.cholesky_solve() under the hood, but that would be a larger lift - though it's probably worth raising this with the pytorch folks to understand whether there are any plans in that direction. cc @gpleiss, @jacobrgardner

It would probably be even better to find the largest batch dimension and move that one to the end, or even flatten them all.

Yes, that makes sense. There are probably some nontrivial tradeoffs between flattening them all and keeping them around depending on how exactly the underlying cuda kernel parallelizes the evaluation in each case.

@Balandat
Copy link
Contributor

Balandat commented Aug 1, 2024

cc @sdaulton, @jandylin, @SebastianAment re excessive memory usage in Kronecker MTGPs

@Balandat
Copy link
Contributor

cc @JonathanWenger

facebook-github-bot pushed a commit that referenced this issue Oct 1, 2024
Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

See #2310 (comment)

```python
import torch
from botorch.models import KroneckerMultiTaskGP

n_inputs = 10
n_tasks = 4
n_train = 2048
n_test = 1
device = torch.device("cuda:0")

train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device)

test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device)

gp = KroneckerMultiTaskGP(train_x, train_y)

posterior = gp.posterior(test_x)
posterior.rsample(torch.Size([256, 1]))
```

The final line requires allocation of 128GB of GPU memory, because of the call to `torch.cholesky_solve` with B shaped `(256, 1, 8192, 1)` and L shaped `(8192, 8192)`.

By moving the largest batch dimension to the final position, we should achieve a more efficient operation.

Also fix docstring for `MultitaskGPPosterior`.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #2460

Test Plan:
Passes unit tests (specifically `test_multitask.py`).

Benchmarking results:
![image](https://github.com/user-attachments/assets/1eca54be-1ed4-43c9-bb50-a18cf24d00f5)
![image](https://github.com/user-attachments/assets/016322f6-992a-45bf-b175-e76208c11b12)

## Related PRs

N/A

Reviewed By: saitcakmak

Differential Revision: D63678866

Pulled By: Balandat

fbshipit-source-id: 6675c66dadd62934f95fabafe7b3f0155a1c0c6f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants