-
Notifications
You must be signed in to change notification settings - Fork 297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tritonbench] Fix colfax_cutlass flash_attention operator #2401
Conversation
@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@@ -128,6 +128,7 @@ class Operator(BenchmarkOperator): | |||
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): | |||
super().__init__(tb_args, extra_args) | |||
args = parse_op_args(self.extra_args) | |||
self.use_cuda_graphs = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why we need to turn off cuda_graphs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not about colfax_cutlass, but with ThunderKittens (#2370) mine fails with error:
Caught exception, terminating early with partial results
Traceback (most recent call last):
File "/home/xz/git/benchmark/torchbenchmark/util/triton_op.py", line 558, in run
y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
^^^^^^^^^^^^^^^^^
File "/home/xz/git/benchmark/torchbenchmark/util/triton_op.py", line 546, in _reduce_benchmarks
acc[bm_name] = self._do_bench(
^^^^^^^^^^^^^^^
File "/home/xz/git/benchmark/torchbenchmark/util/triton_op.py", line 753, in _do_bench
metrics.latency = triton.testing.do_bench_cudagraph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xz/miniconda3/lib/python3.11/site-packages/triton/testing.py", line 46, in do_bench_cudagraph
with torch.cuda.graph(g):
File "/home/xz/miniconda3/lib/python3.11/site-packages/torch/cuda/graphs.py", line 186, in __exit__
self.cuda_graph.capture_end()
File "/home/xz/miniconda3/lib/python3.11/site-packages/torch/cuda/graphs.py", line 84, in capture_end
super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
So I am thinking cudagraph might not work with ThunderKittens?
Since it is working with colfax_cutlass, I am reverting this line for this PR.
@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
b5f4ef3
to
b7c6ad7
Compare
@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
colfax_cutlass kernels will fail because of C++ template instantiation.
We need to explicitly include the header file to instantiate all template parameters.
Test plan:
Install the colfax_cutlass operators:
Run the flash_attention operator from colfax_cutlass