Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG][QST] Hopper Grouped GEMM Fails When Workspace not aligned at 64, but MinWorkspaceAlignment =16 #2042

Open
ankutalev opened this issue Jan 16, 2025 · 3 comments
Labels
? - Needs Triage bug Something isn't working

Comments

@ankutalev
Copy link

Describe the bug
See title - I expected GroupedGemm works, when workspace pointer 16-bits aligned, but it fails with Got bad cuda status: misaligned address at line: 596 for 16 and 32 alignments.

Steps/Code to reproduce bug
You can apply following patch to example for reproducing

diff --git a/hopper_grouped.cu b/orig_hopper_groped.cu
index a927a2b..f578b85 100644
--- a/hopper_grouped.cu
+++ b/orig_hopper_groped.cu
@@ -664,13 +664,13 @@ int run(Options &options, bool host_problem_shapes_available = true)
   size_t workspace_size = GemmT::get_workspace_size(arguments);
 
   // Allocate workspace memory
-  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size * 2);
+  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
 
   // Check if the problem size is supported or not
   CUTLASS_CHECK(gemm.can_implement(arguments));
 
   // Initialize CUTLASS kernel with arguments and workspace pointer
-  CUTLASS_CHECK(gemm.initialize(arguments, workspace.get() + 32));
+  CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
 
   // Correctness / Warmup iteration
   CUTLASS_CHECK(gemm.run());

You can change 32 to 16 and still see fail. 64/128/256 alignments works fine. Note that workspace_size * 2 here is only to not get out of bounds with additional offset.

Expected behavior
GEMM must work, because minimumWorkspaceAlignment set to 16. Link

Environment details (please complete the following information):

  • Bare metal, CUDA Version is 12.6

Thanks!

@ankutalev ankutalev added ? - Needs Triage bug Something isn't working labels Jan 16, 2025
@ankutalev ankutalev changed the title [BUG][QST] Hopper Grouped GEMM Fails When Workspace not aligned for 64, but MinWorkspaceAlignment =16 [BUG][QST] Hopper Grouped GEMM Fails When Workspace not aligned at 64, but MinWorkspaceAlignment =16 Jan 16, 2025
@ankutalev
Copy link
Author

cuda-gdb complains for this example built with -g -G:

Thread 1 "hopper_grouped" received signal CUDA_EXCEPTION_14, Warp Illegal Address.
[Switching focus to CUDA kernel 0, grid 4, block (4,0,0), thread (0,0,0), device 0, sm 0, warp 3, lane 0]
0x00007ffdbd309940 in cute::tma_descriptor_replace_dims_strides_in_shared_mem (smem_desc=..., prob_shape=..., prob_stride=...)
    at cutlass/include/cute/arch/copy_sm90_desc.hpp:325
325         :: "l"(smem_int64_desc), "r"(prob_shape[0]));

...
at cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp:719
    cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A,
                                                            prob_shape_A,
                                                            prob_stride_A);

even without original changes, i.e. example is broken

@ankutalev
Copy link
Author

@thakkarV can you take a look?
Thanks!

@thakkarV
Copy link
Collaborator

@ANIKET-SHIVAM CC

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants