Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Bug: Reverse is orders of magnitude slower on TPU #23191

Open
bjenik opened this issue Feb 27, 2025 · 3 comments
Open

[TPU] Bug: Reverse is orders of magnitude slower on TPU #23191

bjenik opened this issue Feb 27, 2025 · 3 comments

Comments

@bjenik
Copy link

bjenik commented Feb 27, 2025

The following sample gets around 1400it/s on H100 and 11it/s on v6e. It seems there's an issue with the TPU implementation of reverse. This example is already isolated, in practice it gets generated when calling irfft.

import os 
os.environ["XLA_FLAGS"] = "--xla_dump_to=./hlo"
from tqdm import tqdm
import jax 
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec

num_devices = jax.device_count()
mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), ("batch",))

@jax.jit
def rp1(data):
    data = jax.lax.with_sharding_constraint(data, NamedSharding(mesh, PartitionSpec("batch")))
    data = data[:,:,:,:,:,::-1] + 1
    data = jax.lax.with_sharding_constraint(data, NamedSharding(mesh, PartitionSpec("batch")))
    return data

@jax.jit
def make_data():
    data = jnp.ones((num_devices * 32, 64, 16, 16, 8, 31))
    data = jax.lax.with_sharding_constraint(data, NamedSharding(mesh, PartitionSpec("batch")))
    return data

data = make_data()
with jax.profiler.trace("./tensorboard"):
    for i in tqdm(range(1000)):
        data = rp1(data)

@rdyro
Copy link

rdyro commented Feb 27, 2025

I experimented with two alternatives:

  • data = jnp.flip(data, axis=-1) + 1 - ~ 50 ms

Image

  • and a gather: data = data[..., jnp.arange(data.shape[-1])[::-1]] + 1 ~ 6 ms

Image

The gather generates a while loop which seems to be much more efficient.

One speculation is maybe the reverse operation is VMEM optimized, but your array size is large (500 MB) so needs to use HBM directly and ends up being inefficient - but I'm curious why the compiler does replace the gather, but not the reverse with a more optimized version!

@bjenik
Copy link
Author

bjenik commented Feb 27, 2025

I'd imagine this to run at HBM speeds (modulo maybe missing out on coalescing), so in the good case the 1400it/s on H100 seem about right. Given that v6e has about half the HBM bandwidth I'd still expect half that performance, meaning it shouldn't take much more than a millisecond each to achieve a target of around 700it/s. Do you have any insight what libtpu actually internally does for a reverse in terms of "algorithm" and memory access pattern?

Also curious: how would I get this (or any other) manual fix best integrated into the overall xla lowering pipeline? Issue here is that I'm in practice not calling reverse myself but xla lowers irfft into it. I was originally considering doing a matmul with a flipped identity matrix instead of the reverse (not the smartest thing but probably still better than whatever is actually happening in the reverse) but I'd be facing the same integration problem there.

@rdyro
Copy link

rdyro commented Feb 27, 2025

An important detail: I tested it on a v5e so far, and there the HBM BW suggests 2.5 ms, so I think the optimized gather is quite close to it (although it'd ideally be fused with the scalar add in this repro.

Yes, that lowering is problematic... What's the performance like when you implement irfft yourself with the gather trick and ifft? I'll work on creating an internal issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants