-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Backend] Codegen for ttg.warp_specialize
#5968
Conversation
This is in preparation for warp specialization, which will turn the number of warps into a scoped property of regions. This PR just rearranges the API for looking up the number of warps. In the next PR, the `"ttg.num-warps"` attribute will be moved to `tt.func`.
Warp specialization will cause these to become relative to the current warpgroup, so funnel all the code through a set of common APIs.
ttg.warp_specialize
ttg.warp_specialize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, few minor questions, also I wonder if we really need the llvm data layout as the rules for shared memory allocation are pretty simple in general
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you still use data layout in the code? If not, can you update your PR description?
@@ -875,4 +889,26 @@ tt.func @two_different_ws() { | |||
tt.return | |||
} | |||
|
|||
// expected-remark @below {{ptr_allocation_datalayout}} | |||
// expected-remark @below {{size = 8}} | |||
tt.func @ptr_allocation_datalayout(%arg0: !tt.ptr<i32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you still need this test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but now it's just testing that PointerType can have its size queried (just not through llvm::DataLayout
)
Ah yes, let me update that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR primarily implements
ConvertWarpSpecializeToLLVM
, a pass that runs after all other LLVM conversion that rewrites a warp-specialized function, removingttg.warp_specialize
ops. This pass generates synchronization by putting all other warp groups into a waiting loop, where each warp waits for a state ID to be populated into shared memory. The ID represents the switch case the warp should branch to.This PR also does a few other things:
ttg.total-num-warps
through the compiler so that warp specialization works correctly in the frontend