-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TileLang][Dev] Enhance Layout Inference Pass to infer with complex parallel primitives #268
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Also, we introduced # If the architecture is CUDA and we have a static shape, proceed with optimization
if arch_is_cuda and is_static_shape:
sm_waste_threshold = 5e-2 # Allow at most 5% SM waste
num_sms = self.arch.compute_max_core # Get the maximum number of streaming multiprocessors
# Compute block sizes based on the configuration
block_M = hint.block[0] # Block size in the M dimension
block_N = hint.block[1] # Block size in the N dimension
block_K = hint.rstep[0] # Block size in the K dimension
# Calculate the grid dimensions in M and N directions
grid_m = M // block_M
grid_n = N // block_N
total_grids = grid_m * grid_n # Total number of grids
# Initialize the split-k factor (used to distribute K-dimension work across blocks)
split_k_factor = 1
# Optimize the split-k factor to minimize SM waste
while True:
# Total grids after applying split-k
total_grids_split_k = total_grids * split_k_factor
# Calculate the waste in SMs after split-k distribution
waste_sm_splitk = total_grids_split_k - (total_grids_split_k //
num_sms) * num_sms
waste_sm_splitk_ratio = waste_sm_splitk / total_grids_split_k
# If the SM waste ratio is within the allowed threshold, stop optimization
if waste_sm_splitk_ratio <= sm_waste_threshold:
break
# Double the split-k factor and check if the resulting K-dimension size is too large
expand_split_k = split_k_factor * 2
if expand_split_k * block_K >= K:
break
# Update the split-k factor for the next iteration
split_k_factor = expand_split_k
# Note: The optimized split_k_factor can be stored or applied to the config if needed
hint.split_k_factor = split_k_factor
# Convert the hint to a configuration object using the TLHint mapping
config = self.TLHint.from_roller_hint(hint) Analysis based on Our TileDevice Abstractons :) |
Local test has passed. |
This was referenced Dec 16, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Relevant Issues
Our
LayoutInference
pass currently fails or does not infer the optimal layout when working with complex layout indices. Examples include transformations such as a 16x16 ladder-based basic layout or dequantize layout transformations.To achieve optimal performance, BitBlas often relies on manually written, thread-level high-performance TensorIR (TIR) code. For example:
This approach is both verbose and error-prone. After this pull request, we simplify this process significantly. The above code can now be replaced with:
Simplified Code (Using
T.Parallel
)Simplified Code (Using
T.copy
)Alternatively, the entire expression can be expressed even more succinctly using
T.copy
:How Does This Change Address the Issue?
Introduction of LoopFusion for Consecutive Parallel Loops
This allows us to simplify and legalize complex cases into a single, unified paradigm. The newly fused loop structure is easier to read and manage while maintaining optimal performance.
Improved Vectorization Analysis
We now use a smarter method to analyze vectorized lengths, enabling better handling of advanced indexing and complex transformations.
By introducing these improvements, the need for manually written, complex thread-level TIR code is eliminated in most cases, reducing development effort while achieving similar or even better performance.