Skip to content

Commit

Permalink
Migrate multi_head_jagged_flash_attention SLL ops to OSS
Browse files Browse the repository at this point in the history
Summary: - Migrate `multi_head_jagged_flash_attention` SLL ops to OSS

Differential Revision: D66972360
  • Loading branch information
Benson Ma authored and facebook-github-bot committed Dec 10, 2024
1 parent 5dbeebf commit 46e0c8e
Show file tree
Hide file tree
Showing 4 changed files with 962 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ __configure_fbgemm_gpu_test_cpu () {
./sll/jagged_flash_attention_basic_test.py
./sll/jagged_jagged_bmm_jagged_out_test.py
./sll/jagged_dense_flash_attention_test.py
./sll/multi_head_jagged_flash_attention_test.py
)
}

Expand Down
22 changes: 22 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
jagged_jagged_bmm,
jagged_jagged_bmm_jagged_out,
jagged_softmax,
multi_head_jagged_flash_attention,
triton_jagged_self_substraction_jagged_out,
)

Expand Down Expand Up @@ -263,6 +264,19 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None:
"""
)

if "fbgemm::sll_multi_head_jagged_flash_attention" not in torch.library._defs:
lib.define(
"""sll_multi_head_jagged_flash_attention(
Tensor q_weights,
Tensor k_weights,
Tensor v_weights,
Tensor offsets,
int max_seq_len,
bool allow_tf32=True
) -> Tensor
"""
)

# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same function
# however, this is not ideal because in the inference case, we don't need the autograd forward
# to save the context because we don't need to do backward.
Expand Down Expand Up @@ -396,3 +410,11 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None:
"AutogradCPU": cpu_jagged_dense_flash_attention,
},
)

register_sll_op(
"sll_multi_head_jagged_flash_attention",
{
"CUDA": multi_head_jagged_flash_attention,
"AutogradCUDA": multi_head_jagged_flash_attention,
},
)
Loading

0 comments on commit 46e0c8e

Please sign in to comment.