-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pallas
fused attention and a corresponding forward Laplacian
#26
Conversation
@microsoft-github-policy-service agree company="Microsoft" |
Amazing stuff! :) I am quite busy at the moment with ICLR. I will have a look at approximately mid-October. Is there some benchmarking on the expected speedup? :) |
75c57ca
to
e722320
Compare
No worries @n-gao , we don't need this immediately as we can add these registrations internally. Good luck with ICLR :) Benchmarking- we have done some internally and certainly fused attention seems to make a big improvement, although it is sensitive to the exact pallas hyperparameters you choose and the system size. To give you one number, I just ran a Psiformer on benzene (batch size 256) with and without fused attention, and attained 1.59 it/s and 1.25 it/s respectively. One of the biggest limitations to be aware of is that, when run on the GPU, every tensor dimension must be a power of 2 and any tensor that participates in a matrix multiplication ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff! Unfortunately, I haven't written a lot of pallas code yet, so my review of that might be limited.
A few things:
- pytest prints lots of deprecation warnings, e.g.:
DeprecationWarning: BlockSpec now expects ``block_shape`` to be passed before ``index_map``. Update your code by swapping the order of these arguments. For example, ``pl.BlockSpace(lambda i: i, (42,))`` should be written as ``pl.BlockSpec((42,), lambda i: i)``
- A few comments would be nice describing what's going on there.
- Is there a good reason not to use JAX's default multiheaded attention for fwd and bckwd? Why do we implement everything from scratch? Afaik,
jax.nn.dot_product_attention
defaults to the cudnn implementation. - I just ran the pytest several times and sometimes they fail?
Thanks for your comments @n-gao .
I switched to keyword arguments, should fix
I added a longer docstring in the main
Yeah good point. This function was actually not available on the version of
Are you using |
I will investigate the cuDNN function on the latest jax version and get back to you |
Alright, after a quick investigation on For any applications that want numerical stability, this seems a bit annoying. I am sure they will add more flexibility later, but we probably want to wait for a later version of Alternatively, Maybe we can stick with the current hand-woven kernels, but swap them out when either of the existing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for checking out the JAX implementation! I agree that there's value in having our own implementation.
I left some comments, many are questions I would have loved to test myself, but I unfortunately recently broke my hand. From a naive perspective, there appear to be a few unnecessary operations and some caching improvements.
p = jax.nn.softmax(s, axis=-1) | ||
o = jnp.einsum("BnhN,BNhd->Bnhd", p, v) | ||
|
||
# Jacobian |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I am a huge fan of einsum, but here we potentially leave performance on the table. The reason is that we are at the mercy of opt_einsum yielding the same order every time. I can't tell whether that's an issue here but here's an example:
Imagine I want to compute
ABC and DAB
If I select the orders
(AB)C and (DA)B, I have to do 4 matmuls in total.
If I were to order them:
(AB)C and D(AB), I only need 3 mamuls, since I can reuse AB.
In standard JAX, if I write down the first version, I get 4 matmuls. The latter (without explicit caching of AB) would yield 3 matmuls due to common suppression elimination.
So here we implicitly hope that opt_einsum yields us with the minimal number of total matmuls (which it is not guaranteeing or considering in any way).Note that in some cases doing more matmuls may still be faster to avoid materializing large low-rank tensors (which is what opt_einsum optimizes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An example:
import jax
import jax.numpy as jnp
A, B, C, D = jax.random.normal(jax.random.PRNGKey(0), (4, 10, 10))
def f1(A, B, C, D):
return A @ B @ C, D @ A @ B
print("f1")
lowered = jax.jit(f1).lower(A, B, C, D)
compiled = lowered.compile()
print(compiled.as_text())
print(compiled.cost_analysis()[0]["flops"])
def f2(A, B, C, D):
return A @ B @ C, D @ (A @ B)
print("f2")
lowered = jax.jit(f2).lower(A, B, C, D)
compiled = lowered.compile()
print(compiled.as_text())
print(compiled.cost_analysis()[0]["flops"])
f1
HloModule jit_f1, is_scheduled=true, entry_computation_layout={(f32[10,10]{1,0}, f32[10,10]{1,0}, f32[10,10]{1,0}, f32[10,10]{1,0})->(f32[10,10]{1,0}, f32[10,10]{1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}
ENTRY %main.10 (Arg_0.1: f32[10,10], Arg_1.2: f32[10,10], Arg_2.3: f32[10,10], Arg_3.4: f32[10,10]) -> (f32[10,10], f32[10,10]) {
%Arg_0.1 = f32[10,10]{1,0} parameter(0), metadata={op_name="A"}
%Arg_1.2 = f32[10,10]{1,0} parameter(1), metadata={op_name="B"}
%dot.5 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %Arg_0.1, f32[10,10]{1,0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f1)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=8}
%Arg_2.3 = f32[10,10]{1,0} parameter(2), metadata={op_name="C"}
%dot.6 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %dot.5, f32[10,10]{1,0} %Arg_2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f1)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=8}
%Arg_3.4 = f32[10,10]{1,0} parameter(3), metadata={op_name="D"}
%dot.7 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %Arg_3.4, f32[10,10]{1,0} %Arg_0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f1)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=8}
%dot.8 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %dot.7, f32[10,10]{1,0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f1)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=8}
ROOT %tuple.9 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %dot.6, f32[10,10]{1,0} %dot.8)
}
8000.0
f2
HloModule jit_f2, is_scheduled=true, entry_computation_layout={(f32[10,10]{1,0}, f32[10,10]{1,0}, f32[10,10]{1,0}, f32[10,10]{1,0})->(f32[10,10]{1,0}, f32[10,10]{1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}
ENTRY %main.10 (Arg_0.1: f32[10,10], Arg_1.2: f32[10,10], Arg_2.3: f32[10,10], Arg_3.4: f32[10,10]) -> (f32[10,10], f32[10,10]) {
%Arg_0.1 = f32[10,10]{1,0} parameter(0), metadata={op_name="A"}
%Arg_1.2 = f32[10,10]{1,0} parameter(1), metadata={op_name="B"}
%dot.5 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %Arg_0.1, f32[10,10]{1,0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f2)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=19}
%Arg_2.3 = f32[10,10]{1,0} parameter(2), metadata={op_name="C"}
%dot.6 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %dot.5, f32[10,10]{1,0} %Arg_2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f2)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=19}
%Arg_3.4 = f32[10,10]{1,0} parameter(3), metadata={op_name="D"}
%dot.8 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %Arg_3.4, f32[10,10]{1,0} %dot.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f2)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=19}
ROOT %tuple.9 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %dot.6, f32[10,10]{1,0} %dot.8)
}
6000.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do agree with you in principle. However, since this is a reference implementation, it should not be used in production and so performance probably isn't a major concern here. We primarily use it to test the pallas
variant. I would rather err on the side of less performant but more readable code in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it's not very important, but did your benchmarks compare against this unoptimized variant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't benchmarked the reference implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Nick
The point of this "reference" implementation was only to help in testing the correctness of the algorithm during development. It was useful to have a version where all the matrix products are written out explicitly, but which is considerably more concise than the full blown pallas
kernel. It was used e.g. to implement the various terms in the pallas
kernel incrementally, checking each term against this reference implementation. It was then kept as part of the unit tests, as a second, sort of independent reference to check against. However, keeping it just for the unit tests is not really necessary.
On the other hand, this reference version was never used to benchmark the speedup of the pallas
kernel. Benchmarking was only ever done against the "vanilla" jax+folx version of MHA. I agree with what you brought up about einsum
not optimizing the order of the contraction. In fact, I think there are even more suboptimal things in this reference implementation, like we could be reusing some intermediates, etc.
To move forward, I think we could either:
- remove this reference implementation completely, to avoid anyone using it mistakenly
- keep it but make it print some warnings about how it shouldn't be used in production
What do you think? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm generally fine with both solutions but I agree that a user should not accidentally use it. :)
After resolving this one, I'd merge the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, then I added some warning messages here.
Curiously, two tests are failing for me?
|
Ok seems like tests are a bit flakey between different |
I pushed an update that sets the same tolerance level for |
Removing the |
@n-gao I have now removed the manual masking as far as I was able without causing problems in the unit tests. I would be tempted to leave it here, my hunch is that it could only be beneficial in terms of numerics at a rather minimal computational cost. The unit tests are passing for me locally and are passing on the CI, maybe you can try again on your version of |
Thank You so much and sorry for the lengthy process! |
This adds a basic implementation of fused attention using
pallas
, along with a consistent forward Laplacian algorithm. This accelerates the forward, backward and forward Laplacian computations of attention layers.Note: We do not include all the features of Flash Attention
Credit for this code goes to @szbernat