You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
See title - I expected GroupedGemm works, when workspace pointer 16-bits aligned, but it fails with Got bad cuda status: misaligned address at line: 596 for 16 and 32 alignments.
Steps/Code to reproduce bug
You can apply following patch to example for reproducing
diff --git a/hopper_grouped.cu b/orig_hopper_groped.cu
index a927a2b..f578b85 100644
--- a/hopper_grouped.cu
+++ b/orig_hopper_groped.cu
@@ -664,13 +664,13 @@ int run(Options &options, bool host_problem_shapes_available = true)
size_t workspace_size = GemmT::get_workspace_size(arguments);
// Allocate workspace memory
- cutlass::device_memory::allocation<uint8_t> workspace(workspace_size * 2);
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
CUTLASS_CHECK(gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
- CUTLASS_CHECK(gemm.initialize(arguments, workspace.get() + 32));
+ CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(gemm.run());
You can change 32 to 16 and still see fail. 64/128/256 alignments works fine. Note that workspace_size * 2 here is only to not get out of bounds with additional offset.
Expected behavior
GEMM must work, because minimumWorkspaceAlignment set to 16. Link
Environment details (please complete the following information):
Bare metal, CUDA Version is 12.6
Thanks!
The text was updated successfully, but these errors were encountered:
ankutalev
changed the title
[BUG][QST] Hopper Grouped GEMM Fails When Workspace not aligned for 64, but MinWorkspaceAlignment =16
[BUG][QST] Hopper Grouped GEMM Fails When Workspace not aligned at 64, but MinWorkspaceAlignment =16
Jan 16, 2025
Describe the bug
See title - I expected GroupedGemm works, when workspace pointer 16-bits aligned, but it fails with
Got bad cuda status: misaligned address at line: 596
for 16 and 32 alignments.Steps/Code to reproduce bug
You can apply following patch to example for reproducing
You can change 32 to 16 and still see fail. 64/128/256 alignments works fine. Note that
workspace_size * 2
here is only to not get out of bounds with additional offset.Expected behavior
GEMM must work, because minimumWorkspaceAlignment set to 16. Link
Environment details (please complete the following information):
Thanks!
The text was updated successfully, but these errors were encountered: