-
Notifications
You must be signed in to change notification settings - Fork 512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] Bug: Reverse is orders of magnitude slower on TPU #23191
Comments
I experimented with two alternatives:
The gather generates a while loop which seems to be much more efficient. One speculation is maybe the reverse operation is VMEM optimized, but your array size is large (500 MB) so needs to use HBM directly and ends up being inefficient - but I'm curious why the compiler does replace the gather, but not the reverse with a more optimized version! |
I'd imagine this to run at HBM speeds (modulo maybe missing out on coalescing), so in the good case the 1400it/s on H100 seem about right. Given that v6e has about half the HBM bandwidth I'd still expect half that performance, meaning it shouldn't take much more than a millisecond each to achieve a target of around 700it/s. Do you have any insight what libtpu actually internally does for a reverse in terms of "algorithm" and memory access pattern? Also curious: how would I get this (or any other) manual fix best integrated into the overall xla lowering pipeline? Issue here is that I'm in practice not calling reverse myself but xla lowers irfft into it. I was originally considering doing a matmul with a flipped identity matrix instead of the reverse (not the smartest thing but probably still better than whatever is actually happening in the reverse) but I'd be facing the same integration problem there. |
An important detail: I tested it on a v5e so far, and there the HBM BW suggests 2.5 ms, so I think the optimized gather is quite close to it (although it'd ideally be fused with the scalar add in this repro. Yes, that lowering is problematic... What's the performance like when you implement irfft yourself with the gather trick and ifft? I'll work on creating an internal issue. |
The following sample gets around 1400it/s on H100 and 11it/s on v6e. It seems there's an issue with the TPU implementation of reverse. This example is already isolated, in practice it gets generated when calling irfft.
The text was updated successfully, but these errors were encountered: