Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tritonbench] Fix colfax_cutlass flash_attention operator #2401

Closed
wants to merge 4 commits into from

Conversation

xuzhao9
Copy link
Contributor

@xuzhao9 xuzhao9 commented Jul 31, 2024

colfax_cutlass kernels will fail because of C++ template instantiation.
We need to explicitly include the header file to instantiate all template parameters.

Test plan:

Install the colfax_cutlass operators:

python install.py --userbenchmark triton --cutlass
/home/xz/git/benchmark/submodules/cutlass-kernels/src/fmha/fmha_forward.cu(826): warning #117-D: non-void function "main" should return a value
      return;
            ^

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

/home/xz/git/benchmark/submodules/cutlass-kernels/src/fmha/fmha_forward.cu(826): warning #117-D: non-void function "main" should return a value
      return;
            ^

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

Run the flash_attention operator from colfax_cutlass

python run_benchmark.py triton --op flash_attention --only colfax_cutlass --num-inputs 1

  (Batch, Heads, SeqLen, Dhead)    colfax_cutlass-latency
-------------------------------  ------------------------
              (32, 32, 512, 64)                  0.001024

@facebook-github-bot
Copy link
Contributor

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -128,6 +128,7 @@ class Operator(BenchmarkOperator):
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
super().__init__(tb_args, extra_args)
args = parse_op_args(self.extra_args)
self.use_cuda_graphs = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why we need to turn off cuda_graphs.

Copy link
Contributor Author

@xuzhao9 xuzhao9 Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not about colfax_cutlass, but with ThunderKittens (#2370) mine fails with error:

Caught exception, terminating early with partial results
Traceback (most recent call last):
  File "/home/xz/git/benchmark/torchbenchmark/util/triton_op.py", line 558, in run
    y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
                                                  ^^^^^^^^^^^^^^^^^
  File "/home/xz/git/benchmark/torchbenchmark/util/triton_op.py", line 546, in _reduce_benchmarks
    acc[bm_name] = self._do_bench(
                   ^^^^^^^^^^^^^^^
  File "/home/xz/git/benchmark/torchbenchmark/util/triton_op.py", line 753, in _do_bench
    metrics.latency = triton.testing.do_bench_cudagraph(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xz/miniconda3/lib/python3.11/site-packages/triton/testing.py", line 46, in do_bench_cudagraph
    with torch.cuda.graph(g):
  File "/home/xz/miniconda3/lib/python3.11/site-packages/torch/cuda/graphs.py", line 186, in __exit__
    self.cuda_graph.capture_end()
  File "/home/xz/miniconda3/lib/python3.11/site-packages/torch/cuda/graphs.py", line 84, in capture_end
    super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

So I am thinking cudagraph might not work with ThunderKittens?

Since it is working with colfax_cutlass, I am reverting this line for this PR.

@facebook-github-bot
Copy link
Contributor

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@xuzhao9 xuzhao9 temporarily deployed to docker-s3-upload August 1, 2024 03:26 — with GitHub Actions Inactive
@xuzhao9 xuzhao9 temporarily deployed to docker-s3-upload August 1, 2024 03:26 — with GitHub Actions Inactive
@facebook-github-bot
Copy link
Contributor

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@xuzhao9 merged this pull request in 0a2ff22.

@xuzhao9 xuzhao9 deleted the xz9/fix-cutlass branch August 1, 2024 14:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants