Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pallas fused attention and a corresponding forward Laplacian #26

Merged
merged 18 commits into from
Dec 7, 2024

Conversation

ae-foster
Copy link
Collaborator

@ae-foster ae-foster commented Sep 24, 2024

This adds a basic implementation of fused attention using pallas, along with a consistent forward Laplacian algorithm. This accelerates the forward, backward and forward Laplacian computations of attention layers.

Note: We do not include all the features of Flash Attention

Credit for this code goes to @szbernat

@ae-foster
Copy link
Collaborator Author

@microsoft-github-policy-service agree company="Microsoft"

@n-gao
Copy link
Collaborator

n-gao commented Sep 24, 2024

Amazing stuff! :) I am quite busy at the moment with ICLR. I will have a look at approximately mid-October. Is there some benchmarking on the expected speedup? :)

@ae-foster ae-foster force-pushed the ae-foster/pallas-attention branch from 75c57ca to e722320 Compare September 24, 2024 13:30
@ae-foster
Copy link
Collaborator Author

Amazing stuff! :) I am quite busy at the moment with ICLR. I will have a look at approximately mid-October. Is there some benchmarking on the expected speedup? :)

No worries @n-gao , we don't need this immediately as we can add these registrations internally.

Good luck with ICLR :)

Benchmarking- we have done some internally and certainly fused attention seems to make a big improvement, although it is sensitive to the exact pallas hyperparameters you choose and the system size. To give you one number, I just ran a Psiformer on benzene (batch size 256) with and without fused attention, and attained 1.59 it/s and 1.25 it/s respectively.

One of the biggest limitations to be aware of is that, when run on the GPU, every tensor dimension must be a power of 2 and any tensor that participates in a matrix multiplication (pl_dot) must have dimension at least 16. To account for tensors with other sizes, you have to pad and mask your inputs to match the relevant size. Hence, you get the best speedup when your tensor dimensions are already powers of 2. I think you also need an A100 or later for pallas to work.

Copy link
Collaborator

@n-gao n-gao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff! Unfortunately, I haven't written a lot of pallas code yet, so my review of that might be limited.

A few things:

  1. pytest prints lots of deprecation warnings, e.g.:
DeprecationWarning: BlockSpec now expects ``block_shape`` to be passed before ``index_map``. Update your code by swapping the order of these arguments. For example, ``pl.BlockSpace(lambda i: i, (42,))`` should be written as ``pl.BlockSpec((42,), lambda i: i)``
  1. A few comments would be nice describing what's going on there.
  2. Is there a good reason not to use JAX's default multiheaded attention for fwd and bckwd? Why do we implement everything from scratch? Afaik, jax.nn.dot_product_attention defaults to the cudnn implementation.
  3. I just ran the pytest several times and sometimes they fail?

folx/experimental/pallas/custom_gradients.py Outdated Show resolved Hide resolved
folx/experimental/pallas/forward_laplacian.py Outdated Show resolved Hide resolved
folx/experimental/pallas/forward_laplacian.py Outdated Show resolved Hide resolved
folx/experimental/pallas/forward_laplacian.py Outdated Show resolved Hide resolved
folx/experimental/pallas/mha.py Outdated Show resolved Hide resolved
folx/experimental/pallas/utils.py Outdated Show resolved Hide resolved
folx/experimental/pallas/utils.py Outdated Show resolved Hide resolved
@ae-foster
Copy link
Collaborator Author

ae-foster commented Oct 28, 2024

Thanks for your comments @n-gao .

  1. pytest prints lots of deprecation warnings, e.g.:

I switched to keyword arguments, should fix

  1. A few comments would be nice describing what's going on there.

I added a longer docstring in the main mha function, which should be clearer now :)

  1. Is there a good reason not to use JAX's default multiheaded attention for fwd and bckwd?

Yeah good point. This function was actually not available on the version of jax I was writing this code in. I will try it on the latest version

I just ran the pytest several times and sometimes they fail?

Are you using JAX_PLATFORMS=cpu python -m pytest ...? The CPU implementation should be deterministic and should run in float32. The GPU implementation will by default run at a lower precision and will not be deterministic

@ae-foster
Copy link
Collaborator Author

I will investigate the cuDNN function on the latest jax version and get back to you

@ae-foster
Copy link
Collaborator Author

Alright, after a quick investigation on jax==0.4.34, I found a bit of a limitation of the CuDNN implementation. The dtype of the inputs must be either float16 or bfloat16: https://github.com/jax-ml/jax/blob/36c56fa19be6c8d6c4a19a9adaf58cbf382ad9df/jax/_src/cudnn/fused_attention_stablehlo.py#L76

For any applications that want numerical stability, this seems a bit annoying. I am sure they will add more flexibility later, but we probably want to wait for a later version of jax. The pallas version definitely supports float32.

Alternatively, jax have their own pallas kernel https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/attention.py . However, I want to add an edge bias argument (this will come in a later PR hopefully), which is not supported by jax's own pallas implementation.

Maybe we can stick with the current hand-woven kernels, but swap them out when either of the existing jax variants becomes sufficiently flexible?

Copy link
Collaborator

@n-gao n-gao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for checking out the JAX implementation! I agree that there's value in having our own implementation.
I left some comments, many are questions I would have loved to test myself, but I unfortunately recently broke my hand. From a naive perspective, there appear to be a few unnecessary operations and some caching improvements.

folx/experimental/pallas/mha.py Outdated Show resolved Hide resolved
folx/experimental/pallas/__init__.py Outdated Show resolved Hide resolved
folx/experimental/pallas/__init__.py Outdated Show resolved Hide resolved
folx/experimental/pallas/custom_gradients.py Outdated Show resolved Hide resolved
folx/experimental/pallas/custom_gradients.py Outdated Show resolved Hide resolved
folx/experimental/pallas/custom_gradients.py Outdated Show resolved Hide resolved
folx/experimental/pallas/forward_laplacian.py Outdated Show resolved Hide resolved
p = jax.nn.softmax(s, axis=-1)
o = jnp.einsum("BnhN,BNhd->Bnhd", p, v)

# Jacobian
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I am a huge fan of einsum, but here we potentially leave performance on the table. The reason is that we are at the mercy of opt_einsum yielding the same order every time. I can't tell whether that's an issue here but here's an example:
Imagine I want to compute
ABC and DAB
If I select the orders
(AB)C and (DA)B, I have to do 4 matmuls in total.
If I were to order them:
(AB)C and D(AB), I only need 3 mamuls, since I can reuse AB.

In standard JAX, if I write down the first version, I get 4 matmuls. The latter (without explicit caching of AB) would yield 3 matmuls due to common suppression elimination.

So here we implicitly hope that opt_einsum yields us with the minimal number of total matmuls (which it is not guaranteeing or considering in any way).Note that in some cases doing more matmuls may still be faster to avoid materializing large low-rank tensors (which is what opt_einsum optimizes).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example:

import jax
import jax.numpy as jnp

A, B, C, D = jax.random.normal(jax.random.PRNGKey(0), (4, 10, 10))


def f1(A, B, C, D):
    return A @ B @ C, D @ A @ B


print("f1")
lowered = jax.jit(f1).lower(A, B, C, D)
compiled = lowered.compile()
print(compiled.as_text())
print(compiled.cost_analysis()[0]["flops"])


def f2(A, B, C, D):
    return A @ B @ C, D @ (A @ B)


print("f2")
lowered = jax.jit(f2).lower(A, B, C, D)
compiled = lowered.compile()
print(compiled.as_text())
print(compiled.cost_analysis()[0]["flops"])
f1
HloModule jit_f1, is_scheduled=true, entry_computation_layout={(f32[10,10]{1,0}, f32[10,10]{1,0}, f32[10,10]{1,0}, f32[10,10]{1,0})->(f32[10,10]{1,0}, f32[10,10]{1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}

ENTRY %main.10 (Arg_0.1: f32[10,10], Arg_1.2: f32[10,10], Arg_2.3: f32[10,10], Arg_3.4: f32[10,10]) -> (f32[10,10], f32[10,10]) {
  %Arg_0.1 = f32[10,10]{1,0} parameter(0), metadata={op_name="A"}
  %Arg_1.2 = f32[10,10]{1,0} parameter(1), metadata={op_name="B"}
  %dot.5 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %Arg_0.1, f32[10,10]{1,0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f1)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=8}
  %Arg_2.3 = f32[10,10]{1,0} parameter(2), metadata={op_name="C"}
  %dot.6 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %dot.5, f32[10,10]{1,0} %Arg_2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f1)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=8}
  %Arg_3.4 = f32[10,10]{1,0} parameter(3), metadata={op_name="D"}
  %dot.7 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %Arg_3.4, f32[10,10]{1,0} %Arg_0.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f1)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=8}
  %dot.8 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %dot.7, f32[10,10]{1,0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f1)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=8}
  ROOT %tuple.9 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %dot.6, f32[10,10]{1,0} %dot.8)
}


8000.0
f2
HloModule jit_f2, is_scheduled=true, entry_computation_layout={(f32[10,10]{1,0}, f32[10,10]{1,0}, f32[10,10]{1,0}, f32[10,10]{1,0})->(f32[10,10]{1,0}, f32[10,10]{1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true}

ENTRY %main.10 (Arg_0.1: f32[10,10], Arg_1.2: f32[10,10], Arg_2.3: f32[10,10], Arg_3.4: f32[10,10]) -> (f32[10,10], f32[10,10]) {
  %Arg_0.1 = f32[10,10]{1,0} parameter(0), metadata={op_name="A"}
  %Arg_1.2 = f32[10,10]{1,0} parameter(1), metadata={op_name="B"}
  %dot.5 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %Arg_0.1, f32[10,10]{1,0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f2)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=19}
  %Arg_2.3 = f32[10,10]{1,0} parameter(2), metadata={op_name="C"}
  %dot.6 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %dot.5, f32[10,10]{1,0} %Arg_2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f2)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=19}
  %Arg_3.4 = f32[10,10]{1,0} parameter(3), metadata={op_name="D"}
  %dot.8 = f32[10,10]{1,0} dot(f32[10,10]{1,0} %Arg_3.4, f32[10,10]{1,0} %dot.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f2)/jit(main)/dot_general" source_file="/Users/gaoni/Documents/Repositories/test/hello.py" source_line=19}
  ROOT %tuple.9 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %dot.6, f32[10,10]{1,0} %dot.8)
}


6000.0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree with you in principle. However, since this is a reference implementation, it should not be used in production and so performance probably isn't a major concern here. We primarily use it to test the pallas variant. I would rather err on the side of less performant but more readable code in this case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it's not very important, but did your benchmarks compare against this unoptimized variant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't benchmarked the reference implementation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Nick
The point of this "reference" implementation was only to help in testing the correctness of the algorithm during development. It was useful to have a version where all the matrix products are written out explicitly, but which is considerably more concise than the full blown pallas kernel. It was used e.g. to implement the various terms in the pallas kernel incrementally, checking each term against this reference implementation. It was then kept as part of the unit tests, as a second, sort of independent reference to check against. However, keeping it just for the unit tests is not really necessary.

On the other hand, this reference version was never used to benchmark the speedup of the pallas kernel. Benchmarking was only ever done against the "vanilla" jax+folx version of MHA. I agree with what you brought up about einsum not optimizing the order of the contraction. In fact, I think there are even more suboptimal things in this reference implementation, like we could be reusing some intermediates, etc.

To move forward, I think we could either:

  • remove this reference implementation completely, to avoid anyone using it mistakenly
  • keep it but make it print some warnings about how it shouldn't be used in production

What do you think? :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm generally fine with both solutions but I agree that a user should not accidentally use it. :)

After resolving this one, I'd merge the PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then I added some warning messages here.

folx/experimental/pallas/forward_laplacian.py Outdated Show resolved Hide resolved
folx/experimental/pallas/utils.py Outdated Show resolved Hide resolved
@n-gao
Copy link
Collaborator

n-gao commented Oct 30, 2024

Curiously, two tests are failing for me?

JAX_PLATFORMS=cpu python -m pytest test/experimental/pallas/test_attention.py 
============================================================== test session starts ===============================================================
platform linux -- Python 3.12.5, pytest-8.3.3, pluggy-1.5.0
rootdir: /ceph/ssd/staff/gaoni/repos/folx_ms
configfile: pyproject.toml
plugins: jaxtyping-0.2.34, typeguard-2.13.3
collected 12 items                                                                                                                               

test/experimental/pallas/test_attention.py .....F..F...                                                                                    [100%]

==================================================================== FAILURES ====================================================================
________________________________________________________ test_vjp[rng2-1-16-4-32-4-False] ________________________________________________________

rng = Array([0, 5], dtype=uint32), batch_dim = 1, sequence_dim = 16, num_heads = 4, head_dim = 32, max_sequence = 4, with_vmap = False

    @pytest.mark.parametrize(
        "rng, batch_dim, sequence_dim, num_heads, head_dim, max_sequence, with_vmap",
        [
            (jax.random.PRNGKey(3), 1, 1, 1, 1, 1, False),
            (jax.random.PRNGKey(4), 1, 16, 4, 32, 16, False),
            (jax.random.PRNGKey(5), 1, 16, 4, 32, 4, False),
            (jax.random.PRNGKey(6), 1, 1, 1, 1, 1, True),
            (jax.random.PRNGKey(7), 1, 16, 4, 32, 16, True),
            (jax.random.PRNGKey(8), 1, 16, 4, 32, 4, True),
        ],
    )
    def test_vjp(rng, batch_dim, sequence_dim, num_heads, head_dim, max_sequence, with_vmap):
        input_dim = 3 * sequence_dim
        q, k, v, mask, input_mask = inputs_to_mhsa(
            rng, input_dim, batch_dim, sequence_dim, num_heads, head_dim, max_sequence, True
        )
        if with_vmap:
            q, k, v = jax.tree.map(lambda x: x[None], (q, k, v))
        o_vjp = q
    
        fn = partial(custom_vjp_mhsa, mask=mask, input_mask=input_mask, kernel="pallas", interpret=True)
        if with_vmap:
            fn = jax.vmap(fn)
        o, mhsa_vjp_fn = jax.vjp(fn, q, k, v)
        q_vjp, k_vjp, v_vjp = mhsa_vjp_fn(o_vjp)
    
        ref_fn = partial(
            custom_vjp_mhsa, mask=mask, input_mask=input_mask, kernel="reference", interpret=True
        )
        if with_vmap:
            ref_fn = jax.vmap(ref_fn)
        ref_o, ref_mhsa_vjp_fn = jax.vjp(ref_fn, q, k, v)
        ref_q_vjp, ref_k_vjp, ref_v_vjp = ref_mhsa_vjp_fn(o_vjp)
    
        jax_fn = partial(reference_mhsa_kernel, mask=mask)
        if with_vmap:
            jax_fn = jax.vmap(jax_fn)
        jax_o, jax_mhsa_vjp_fn = jax.vjp(jax_fn, q, k, v)
        jax_q_vjp, jax_k_vjp, jax_v_vjp = jax_mhsa_vjp_fn(o_vjp)
    
        print("ours", mask_array(k_vjp, mask))
        print("jax", mask_array(jax_k_vjp, mask))
        print("ref", mask_array(ref_k_vjp, mask))
        assert jnp.allclose(mask_array(o, mask), mask_array(ref_o, mask), atol=1e-6)
        assert jnp.allclose(mask_array(q_vjp, mask), mask_array(ref_q_vjp, mask), atol=1e-6)
        assert jnp.allclose(mask_array(k_vjp, mask), mask_array(ref_k_vjp, mask), atol=1e-6)
>       assert jnp.allclose(mask_array(v_vjp, mask), mask_array(ref_v_vjp, mask))
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <PjitFunction of <function allclose at 0x7f63a43cbb00>>(Array([[[[-0.01125506,  0.13717116,  0.02391454, ...,  0.1669852 ,\n          -0.07873881, -0.04022644],\n         [-0.0...       [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n           0.        ,  0.        ]]]], dtype=float32), Array([[[[-0.01125506,  0.13717116,  0.02391454, ...,  0.1669852 ,\n          -0.07873881, -0.04022644],\n         [-0.0...       [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n           0.        ,  0.        ]]]], dtype=float32))
E        +    where <PjitFunction of <function allclose at 0x7f63a43cbb00>> = jnp.allclose
E        +    and   Array([[[[-0.01125506,  0.13717116,  0.02391454, ...,  0.1669852 ,\n          -0.07873881, -0.04022644],\n         [-0.0...       [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n           0.        ,  0.        ]]]], dtype=float32) = mask_array(Array([[[[-0.01125506,  0.13717116,  0.02391454, ...,  0.1669852 ,\n          -0.07873881, -0.04022644],\n         [-0.0...       [-0.06737792, -0.01882527,  0.0371132 , ..., -0.02737187,\n          -0.01528279,  0.00219313]]]], dtype=float32), Array([[ True,  True,  True,  True, False, False, False, False, False,\n        False, False, False, False, False, False, False]], dtype=bool))
E        +    and   Array([[[[-0.01125506,  0.13717116,  0.02391454, ...,  0.1669852 ,\n          -0.07873881, -0.04022644],\n         [-0.0...       [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n           0.        ,  0.        ]]]], dtype=float32) = mask_array(Array([[[[-0.01125506,  0.13717116,  0.02391454, ...,  0.1669852 ,\n          -0.07873881, -0.04022644],\n         [-0.0...       [-0.06737792, -0.01882527,  0.0371132 , ..., -0.02737187,\n          -0.01528279,  0.00219313]]]], dtype=float32), Array([[ True,  True,  True,  True, False, False, False, False, False,\n        False, False, False, False, False, False, False]], dtype=bool))

test/experimental/pallas/test_attention.py:128: AssertionError
-------------------------------------------------------------- Captured stdout call --------------------------------------------------------------
ours [[[[-0.01265414 -0.00514394  0.01741042 ... -0.05660322  0.09775228
     0.07223597]
   [-0.01514321 -0.05074158  0.00319952 ...  0.11107461  0.00463992
    -0.051954  ]
   [-0.01784979  0.16220643  0.0139447  ... -0.1168594  -0.07455422
     0.00746383]
   [-0.01416412  0.02854457  0.03813607 ...  0.00025696  0.05752483
    -0.03615897]]

  [[ 0.04642555 -0.07991068 -0.14719492 ... -0.02718095 -0.01173019
    -0.05468796]
   [-0.06374478  0.05247931  0.04064969 ... -0.00995755  0.09136452
    -0.04732993]
   [ 0.01249861 -0.06154702  0.04532392 ...  0.09267802  0.00184924
    -0.05109798]
   [-0.08168254 -0.06194501 -0.03066411 ...  0.02067267 -0.03827596
     0.01320639]]

  [[-0.01620414  0.16409208  0.12814718 ...  0.16434848 -0.12097032
    -0.05773024]
   [ 0.17431861  0.05311927 -0.08083844 ... -0.00409125 -0.234876
     0.21461338]
   [-0.02084438  0.03327532 -0.04442266 ...  0.04805569  0.10349575
     0.01232386]
   [ 0.01905433 -0.01281284 -0.02913808 ...  0.00185004 -0.0019207
     0.03069953]]

  ...

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]]]
jax [[[[-0.01265414 -0.00514393  0.01741043 ... -0.0566032   0.09775227
     0.07223596]
   [-0.01514321 -0.05074158  0.00319952 ...  0.11107464  0.00463993
    -0.05195402]
   [-0.0178498   0.16220643  0.01394471 ... -0.11685939 -0.0745542
     0.00746382]
   [-0.01416412  0.02854457  0.03813606 ...  0.00025696  0.05752483
    -0.03615896]]

  [[ 0.04642555 -0.07991068 -0.14719492 ... -0.02718094 -0.01173019
    -0.05468797]
   [-0.06374479  0.05247931  0.04064969 ... -0.00995754  0.09136452
    -0.04732994]
   [ 0.01249861 -0.06154702  0.04532392 ...  0.09267802  0.00184925
    -0.05109798]
   [-0.08168252 -0.06194501 -0.03066411 ...  0.02067267 -0.03827596
     0.01320639]]

  [[-0.01620414  0.16409208  0.12814718 ...  0.16434848 -0.12097031
    -0.05773024]
   [ 0.17431861  0.05311926 -0.08083844 ... -0.00409125 -0.23487599
     0.2146134 ]
   [-0.02084439  0.03327533 -0.04442266 ...  0.0480557   0.10349575
     0.01232386]
   [ 0.01905434 -0.01281284 -0.02913808 ...  0.00185004 -0.0019207
     0.03069953]]

  ...

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]]]
ref [[[[-0.01265414 -0.00514393  0.01741043 ... -0.05660322  0.09775227
     0.07223596]
   [-0.01514321 -0.05074158  0.00319952 ...  0.11107461  0.00463993
    -0.05195401]
   [-0.01784979  0.16220643  0.0139447  ... -0.1168594  -0.07455422
     0.00746383]
   [-0.01416412  0.02854457  0.03813607 ...  0.00025696  0.05752483
    -0.03615897]]

  [[ 0.04642554 -0.07991068 -0.14719492 ... -0.02718095 -0.01173019
    -0.05468797]
   [-0.06374478  0.05247931  0.04064969 ... -0.00995755  0.09136452
    -0.04732993]
   [ 0.0124986  -0.06154702  0.04532391 ...  0.09267802  0.00184924
    -0.05109798]
   [-0.08168253 -0.06194501 -0.03066411 ...  0.02067267 -0.03827596
     0.01320639]]

  [[-0.01620414  0.16409208  0.12814718 ...  0.16434847 -0.12097032
    -0.05773024]
   [ 0.17431861  0.05311926 -0.08083844 ... -0.00409125 -0.234876
     0.21461338]
   [-0.02084439  0.03327532 -0.04442266 ...  0.04805569  0.10349575
     0.01232386]
   [ 0.01905433 -0.01281284 -0.02913808 ...  0.00185004 -0.00192069
     0.03069953]]

  ...

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]]]
________________________________________________________ test_vjp[rng5-1-16-4-32-4-True] _________________________________________________________

rng = Array([0, 8], dtype=uint32), batch_dim = 1, sequence_dim = 16, num_heads = 4, head_dim = 32, max_sequence = 4, with_vmap = True

    @pytest.mark.parametrize(
        "rng, batch_dim, sequence_dim, num_heads, head_dim, max_sequence, with_vmap",
        [
            (jax.random.PRNGKey(3), 1, 1, 1, 1, 1, False),
            (jax.random.PRNGKey(4), 1, 16, 4, 32, 16, False),
            (jax.random.PRNGKey(5), 1, 16, 4, 32, 4, False),
            (jax.random.PRNGKey(6), 1, 1, 1, 1, 1, True),
            (jax.random.PRNGKey(7), 1, 16, 4, 32, 16, True),
            (jax.random.PRNGKey(8), 1, 16, 4, 32, 4, True),
        ],
    )
    def test_vjp(rng, batch_dim, sequence_dim, num_heads, head_dim, max_sequence, with_vmap):
        input_dim = 3 * sequence_dim
        q, k, v, mask, input_mask = inputs_to_mhsa(
            rng, input_dim, batch_dim, sequence_dim, num_heads, head_dim, max_sequence, True
        )
        if with_vmap:
            q, k, v = jax.tree.map(lambda x: x[None], (q, k, v))
        o_vjp = q
    
        fn = partial(custom_vjp_mhsa, mask=mask, input_mask=input_mask, kernel="pallas", interpret=True)
        if with_vmap:
            fn = jax.vmap(fn)
        o, mhsa_vjp_fn = jax.vjp(fn, q, k, v)
        q_vjp, k_vjp, v_vjp = mhsa_vjp_fn(o_vjp)
    
        ref_fn = partial(
            custom_vjp_mhsa, mask=mask, input_mask=input_mask, kernel="reference", interpret=True
        )
        if with_vmap:
            ref_fn = jax.vmap(ref_fn)
        ref_o, ref_mhsa_vjp_fn = jax.vjp(ref_fn, q, k, v)
        ref_q_vjp, ref_k_vjp, ref_v_vjp = ref_mhsa_vjp_fn(o_vjp)
    
        jax_fn = partial(reference_mhsa_kernel, mask=mask)
        if with_vmap:
            jax_fn = jax.vmap(jax_fn)
        jax_o, jax_mhsa_vjp_fn = jax.vjp(jax_fn, q, k, v)
        jax_q_vjp, jax_k_vjp, jax_v_vjp = jax_mhsa_vjp_fn(o_vjp)
    
        print("ours", mask_array(k_vjp, mask))
        print("jax", mask_array(jax_k_vjp, mask))
        print("ref", mask_array(ref_k_vjp, mask))
        assert jnp.allclose(mask_array(o, mask), mask_array(ref_o, mask), atol=1e-6)
        assert jnp.allclose(mask_array(q_vjp, mask), mask_array(ref_q_vjp, mask), atol=1e-6)
        assert jnp.allclose(mask_array(k_vjp, mask), mask_array(ref_k_vjp, mask), atol=1e-6)
>       assert jnp.allclose(mask_array(v_vjp, mask), mask_array(ref_v_vjp, mask))
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = <PjitFunction of <function allclose at 0x7f63a43cbb00>>(Array([[[[[-0.02818926,  0.0783406 , -0.03920441, ..., -0.03300796,\n            0.01370931,  0.05514015],\n          [-...     [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n            0.        ,  0.        ]]]]], dtype=float32), Array([[[[[-0.02818927,  0.0783406 , -0.0392044 , ..., -0.03300797,\n            0.01370932,  0.05514015],\n          [-...     [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n            0.        ,  0.        ]]]]], dtype=float32))
E        +    where <PjitFunction of <function allclose at 0x7f63a43cbb00>> = jnp.allclose
E        +    and   Array([[[[[-0.02818926,  0.0783406 , -0.03920441, ..., -0.03300796,\n            0.01370931,  0.05514015],\n          [-...     [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n            0.        ,  0.        ]]]]], dtype=float32) = mask_array(Array([[[[[-0.02818926,  0.0783406 , -0.03920441, ..., -0.03300796,\n            0.01370931,  0.05514015],\n          [-...     [-0.02599467,  0.03981862, -0.04427316, ..., -0.06761377,\n           -0.03061614,  0.01796181]]]]], dtype=float32), Array([[ True,  True,  True,  True, False, False, False, False, False,\n        False, False, False, False, False, False, False]], dtype=bool))
E        +    and   Array([[[[[-0.02818927,  0.0783406 , -0.0392044 , ..., -0.03300797,\n            0.01370932,  0.05514015],\n          [-...     [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n            0.        ,  0.        ]]]]], dtype=float32) = mask_array(Array([[[[[-0.02818927,  0.0783406 , -0.0392044 , ..., -0.03300797,\n            0.01370932,  0.05514015],\n          [-...     [-0.02599467,  0.03981862, -0.04427316, ..., -0.06761377,\n           -0.03061614,  0.01796181]]]]], dtype=float32), Array([[ True,  True,  True,  True, False, False, False, False, False,\n        False, False, False, False, False, False, False]], dtype=bool))

test/experimental/pallas/test_attention.py:128: AssertionError
-------------------------------------------------------------- Captured stdout call --------------------------------------------------------------
ours [[[[[ 0.03918849 -0.05213178  0.0268001  ...  0.02943189  0.01782213
      0.03870804]
    [-0.03439355 -0.01018475 -0.00093702 ...  0.07854606 -0.0473554
     -0.10907508]
    [-0.01603687  0.0520269   0.09639733 ...  0.01151111  0.04496527
      0.0028101 ]
    [ 0.1090112   0.08749303 -0.14269403 ...  0.0766882   0.04787894
      0.06491714]]

   [[-0.01632031  0.05832538 -0.04605425 ...  0.00103438 -0.08423408
     -0.07199297]
    [ 0.0357766  -0.11856882  0.04105911 ...  0.03963502  0.21791284
     -0.04519068]
    [ 0.04110055 -0.09093985 -0.37334543 ...  0.01539735 -0.11850898
      0.02842933]
    [-0.16019848 -0.13069683  0.30381435 ... -0.0421362  -0.06647228
     -0.00533755]]

   [[ 0.00416413  0.03159913  0.00303013 ... -0.03852243  0.0908694
      0.01786394]
    [-0.0628048   0.09151968 -0.01255641 ...  0.03402098 -0.1385731
      0.01136297]
    [-0.0680137   0.09260172  0.23994067 ...  0.00999346  0.07607073
      0.0010826 ]
    [-0.11575231 -0.0861637   0.05980099 ... -0.04376325 -0.07687512
     -0.15914239]]

   ...

   [[ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]]

   [[ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]]

   [[ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]]]]]
jax [[[[[ 0.0391885  -0.05213178  0.02680009 ...  0.0294319   0.01782214
      0.03870805]
    [-0.03439355 -0.01018474 -0.00093703 ...  0.07854605 -0.04735539
     -0.10907509]
    [-0.01603687  0.05202691  0.09639731 ...  0.01151111  0.04496529
      0.00281009]
    [ 0.1090112   0.08749308 -0.14269403 ...  0.0766882   0.04787892
      0.06491715]]

   [[-0.01632032  0.05832538 -0.04605424 ...  0.00103438 -0.08423408
     -0.07199297]
    [ 0.0357766  -0.11856882  0.04105911 ...  0.03963502  0.21791282
     -0.04519068]
    [ 0.04110054 -0.09093985 -0.3733454  ...  0.01539735 -0.11850899
      0.02842933]
    [-0.1601985  -0.13069685  0.30381435 ... -0.04213619 -0.0664723
     -0.00533756]]

   [[ 0.00416413  0.03159913  0.00303013 ... -0.03852244  0.09086941
      0.01786395]
    [-0.0628048   0.0915197  -0.01255642 ...  0.03402099 -0.13857314
      0.01136297]
    [-0.06801369  0.09260172  0.23994066 ...  0.00999346  0.07607072
      0.00108261]
    [-0.1157523  -0.0861637   0.05980099 ... -0.04376325 -0.07687511
     -0.15914239]]

   ...

   [[ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]]

   [[ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]]

   [[ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]]]]]
ref [[[[[ 0.0391885  -0.05213178  0.0268001  ...  0.0294319   0.01782213
      0.03870804]
    [-0.03439355 -0.01018475 -0.00093702 ...  0.07854607 -0.04735541
     -0.1090751 ]
    [-0.01603688  0.0520269   0.09639733 ...  0.01151111  0.04496528
      0.00281009]
    [ 0.1090112   0.08749306 -0.14269403 ...  0.0766882   0.04787891
      0.06491713]]

   [[-0.01632032  0.05832538 -0.04605424 ...  0.00103438 -0.08423408
     -0.07199297]
    [ 0.0357766  -0.11856882  0.04105911 ...  0.03963502  0.21791282
     -0.04519068]
    [ 0.04110055 -0.09093984 -0.37334543 ...  0.01539735 -0.11850896
      0.02842932]
    [-0.16019851 -0.13069685  0.30381435 ... -0.04213618 -0.06647231
     -0.00533757]]

   [[ 0.00416412  0.03159914  0.00303013 ... -0.03852244  0.0908694
      0.01786394]
    [-0.0628048   0.09151968 -0.01255642 ...  0.03402098 -0.1385731
      0.01136296]
    [-0.06801369  0.0926017   0.23994064 ...  0.00999346  0.07607072
      0.0010826 ]
    [-0.11575231 -0.08616371  0.05980099 ... -0.04376324 -0.07687512
     -0.1591424 ]]

   ...

   [[ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]]

   [[ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]]

   [[ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]
    [ 0.          0.          0.         ...  0.          0.
      0.        ]]]]]
============================================================ short test summary info =============================================================
FAILED test/experimental/pallas/test_attention.py::test_vjp[rng2-1-16-4-32-4-False] - assert Array(False, dtype=bool)
FAILED test/experimental/pallas/test_attention.py::test_vjp[rng5-1-16-4-32-4-True] - assert Array(False, dtype=bool)
========================================================= 2 failed, 10 passed in 17.67s ==========================================================

@ae-foster
Copy link
Collaborator Author

Ok seems like tests are a bit flakey between different jax versions, will look into it

@ae-foster
Copy link
Collaborator Author

I pushed an update that sets the same tolerance level for v_vjp as we use for the other vjp tensors, since that's the one that was failing for you.

@ae-foster
Copy link
Collaborator Author

Removing the where statements is still causing me issues but I will try doing it line-by-line to see which is the most problematic

@ae-foster
Copy link
Collaborator Author

ae-foster commented Nov 4, 2024

@n-gao I have now removed the manual masking as far as I was able without causing problems in the unit tests. I would be tempted to leave it here, my hunch is that it could only be beneficial in terms of numerics at a rather minimal computational cost. The unit tests are passing for me locally and are passing on the CI, maybe you can try again on your version of jax. I sadly don't have time to do benchmarking of the reference implementation-- since this is not the implementation we would use in production anyway it's probably fine to leave as is.

@n-gao n-gao merged commit a983cdc into main Dec 7, 2024
6 checks passed
@n-gao
Copy link
Collaborator

n-gao commented Dec 7, 2024

Thank You so much and sorry for the lengthy process!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants