Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TileLang][Dev] Enhance Layout Inference Pass to infer with complex parallel primitives #268

Merged
merged 18 commits into from
Dec 16, 2024

Conversation

LeiWang1999
Copy link
Contributor

Relevant Issues

Our LayoutInference pass currently fails or does not infer the optimal layout when working with complex layout indices. Examples include transformations such as a 16x16 ladder-based basic layout or dequantize layout transformations.

To achieve optimal performance, BitBlas often relies on manually written, thread-level high-performance TensorIR (TIR) code. For example:

for i in T.serial(block_N * block_K // num_elems_per_byte //
                  (threads * vec_load_qb)):
    for v in T.vectorized(0, vec_load_qb):
        idx = i * threads * vec_load_qb + tx * vec_load_qb + v
        vkk = idx % (micro_size_k // num_elems_per_byte)
        vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
        vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (
            block_K // micro_size_k)
        vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y //
              (block_K // micro_size_k)) % (
                  block_N // micro_size_y)
        B_shared[vj, vk, vjj, vkk] = B[
            bx * (block_N // micro_size_y) + vj,
            ko * (block_K // micro_size_k) + vk,
            vjj,
            vkk,
        ]

This approach is both verbose and error-prone. After this pull request, we simplify this process significantly. The above code can now be replaced with:

Simplified Code (Using T.Parallel)

for j, k, jj, kk in T.Parallel(
        block_N // micro_size_y,
        block_K // micro_size_k,
        micro_size_y,
        micro_size_k // num_elems_per_byte,
    ):
    B_shared[j, k, jj, kk] = B[
        bx * (block_N // micro_size_y) + j,
        bz * (splitK // micro_size_k) + ko * (block_K // micro_size_k) + k,
        jj,
        kk,
    ]

Simplified Code (Using T.copy)

Alternatively, the entire expression can be expressed even more succinctly using T.copy:

T.copy(
    B[
        bx * (block_N // micro_size_y),
        bz * (splitK // micro_size_k) + ko * (block_K // micro_size_k),
        0,
        0,
    ], 
    B_shared
)

How Does This Change Address the Issue?

  1. Introduction of LoopFusion for Consecutive Parallel Loops
    This allows us to simplify and legalize complex cases into a single, unified paradigm. The newly fused loop structure is easier to read and manage while maintaining optimal performance.

  2. Improved Vectorization Analysis
    We now use a smarter method to analyze vectorized lengths, enabling better handling of advanced indexing and complex transformations.

By introducing these improvements, the need for manually written, complex thread-level TIR code is eliminated in most cases, reducing development effort while achieving similar or even better performance.

@LeiWang1999
Copy link
Contributor Author

Also, we introduced SplitK Infer for dequantize cases (which is necessary for the performance on dequant-gemm).

            # If the architecture is CUDA and we have a static shape, proceed with optimization
            if arch_is_cuda and is_static_shape:
                sm_waste_threshold = 5e-2  # Allow at most 5% SM waste
                num_sms = self.arch.compute_max_core  # Get the maximum number of streaming multiprocessors

                # Compute block sizes based on the configuration
                block_M = hint.block[0]  # Block size in the M dimension
                block_N = hint.block[1]  # Block size in the N dimension
                block_K = hint.rstep[0]  # Block size in the K dimension

                # Calculate the grid dimensions in M and N directions
                grid_m = M // block_M
                grid_n = N // block_N
                total_grids = grid_m * grid_n  # Total number of grids

                # Initialize the split-k factor (used to distribute K-dimension work across blocks)
                split_k_factor = 1

                # Optimize the split-k factor to minimize SM waste
                while True:
                    # Total grids after applying split-k
                    total_grids_split_k = total_grids * split_k_factor

                    # Calculate the waste in SMs after split-k distribution
                    waste_sm_splitk = total_grids_split_k - (total_grids_split_k //
                                                             num_sms) * num_sms
                    waste_sm_splitk_ratio = waste_sm_splitk / total_grids_split_k

                    # If the SM waste ratio is within the allowed threshold, stop optimization
                    if waste_sm_splitk_ratio <= sm_waste_threshold:
                        break

                    # Double the split-k factor and check if the resulting K-dimension size is too large
                    expand_split_k = split_k_factor * 2
                    if expand_split_k * block_K >= K:
                        break

                    # Update the split-k factor for the next iteration
                    split_k_factor = expand_split_k

                # Note: The optimized split_k_factor can be stored or applied to the config if needed
                hint.split_k_factor = split_k_factor

            # Convert the hint to a configuration object using the TLHint mapping
            config = self.TLHint.from_roller_hint(hint)

Analysis based on Our TileDevice Abstractons :)

@LeiWang1999
Copy link
Contributor Author

Local test has passed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant