Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Enhance Backend Abstraction for TileLang #255

Merged
merged 55 commits into from
Dec 8, 2024

Conversation

LeiWang1999
Copy link
Contributor

This pull request includes several changes to the bitblas module, primarily focusing on refactoring, enhancements to the scheduler, and improvements in tuning functions. The most important changes include the addition of new functions for architecture inference, refactoring the scheduler, and introducing dynamic range tuning.

Refactoring and Enhancements:

  • bitblas/base/__init__.py: Added imports for simplify_prim_func and moved fast_tune and fast_tune_with_dynamic_range from utils to tuner.
  • bitblas/base/arch/__init__.py: Updated get_arch function to accept a string or tvm.target.Target and added new functions for architecture inference (auto_infer_current_arch, is_ampere_arch, is_volta_arch, is_cdna_arch). [1] [2]

Scheduler Improvements:

  • bitblas/base/base_scheduler.py: Renamed from bitblas/ops/base_scheduler.py and added new attributes and methods to BaseScheduler class, including with_self_attrs, post_process, set_dynamic_range, and serialize_hints_to_configs. [1] [2] [3]

Dynamic Range Tuning:

  • bitblas/base/tuner.py: Introduced new functions fast_tune, fast_tune_with_dynamic_range_tir, fast_tune_with_dynamic_range_tilelang, and fast_tune_with_dynamic_range to support dynamic range tuning and dispatching.

Miscellaneous:

@LeiWang1999 LeiWang1999 marked this pull request as ready for review December 3, 2024 12:50
@LeiWang1999
Copy link
Contributor Author

Local Test has been passed.

@LeiWang1999
Copy link
Contributor Author

Brief Introduction of the new data structure for Tile Lang, every tile lang implementation is recommend be wrapped within a Scheduler Class

@dataclass
class BaseScheduler(ABC):

    _arch: TileDevice = field(default=auto_infer_current_arch(), init=False, repr=False)

    _enable_simplify: bool = field(default=True, init=False, repr=False)

    _dynamic_range: Dict[str, int] = field(default_factory=dict, init=False, repr=False)

    @staticmethod
    def Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
        if isinstance(stmt, PrimFunc):
            mod = Simplify()(IRModule.from_expr(stmt))
            assert len(mod.functions) == 1, "Simplify should return a single function"
            return list(mod.functions.values()).pop()
        elif isinstance(stmt, IRModule):
            return Simplify()(stmt)
        else:
            raise ValueError(f"Unsupported type: {type(stmt)}")

    def get_hardware_aware_configs(self,
                                   arch: TileDevice = None,
                                   topk: int = 10) -> List[BaseTLHint]:
        raise NotImplementedError(
            f"{self.__class__.__name__} does not support hardware-aware tuning for {arch} with topk={topk}"
        )

    def activate_simplify(self) -> "BaseScheduler":
        self._enable_simplify = True
        return self

    def deactivate_simplify(self) -> "BaseScheduler":
        self._enable_simplify = False
        return self

    def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
        if self._enable_simplify:
            return self.Simplify(stmt)
        return stmt

    def with_self_attrs(self, func: PrimFunc) -> PrimFunc:
        if self._dynamic_range:
            func = func.with_attr("opt_shapes", self._dynamic_range)
        return func

    def post_process(self, func: PrimFunc) -> PrimFunc:
        func = self.with_self_attrs(func)
        func = self.maybe_simplify(func)
        return func

    def set_dynamic_range(self, dynamic_range: Dict[str, int]) -> "BaseScheduler":
        self._dynamic_range = dynamic_range
        return self

    def has_dynamic_range(self) -> bool:
        return bool(self._dynamic_range)

    def with_arch(self, arch: TileDevice) -> "BaseScheduler":
        self._arch = arch
        return self

    def has_arch(self) -> bool:
        return self._arch is not None

    def is_volta_arch(self) -> bool:
        return is_volta_arch(self._arch) if self._arch is not None else False

    def is_ampere_arch(self) -> bool:
        return is_ampere_arch(self._arch) if self._arch is not None else False

    def is_cdna_arch(self) -> bool:
        return is_cdna_arch(self._arch) if self._arch is not None else False

    @staticmethod
    def maybe_dynamic(arg: Union[int, List[int]], dynamic_symbol: str = "m") -> PrimFunc:
        if isinstance(arg, int):
            return arg
        return te.var(dynamic_symbol)

    @abstractmethod
    def with_default_config(self, *args, **kwargs) -> PrimFunc:
        pass

    @abstractmethod
    def apply_config(
        self,
        *args,
        **kwargs,
    ) -> PrimFunc:
        pass

    def serialize_hints_to_configs(self, hints: List[Hint]) -> List[BaseTLHint]:
        # Convert Roller Hints to TileLang Hints
        raise NotImplementedError("Serialization of hints to configs is not implemented")

    def specialize_from_dynamic_range(self,
                                      dynamic_range: Optional[Dict[str,
                                                                   int]] = None) -> "BaseScheduler":
        raise NotImplementedError("Specialization from dynamic range is not implemented")

    @property
    def common_header(self) -> str:
        # TODO(lei): For HIP Backend it should be different
        common_header = "#include <tl_templates/cuda/common.h>\n"
        return common_header

    @property
    def global_symbol(self):
        # For kernel name generation
        return "default"

    @property
    def arch(self) -> TileDevice:
        return self._arch

And implementation for tile lang must be put in the apply_config class, for example, the matmul with simt:

    def apply_config(
        self,
        block_size_x: Optional[int] = None,
        block_size_y: Optional[int] = None,
        thread_row_tiles: Optional[int] = None,
        thread_col_tiles: Optional[int] = None,
        chunk: Optional[int] = None,
    ):
        assert block_size_x is not None, "block_size_x must be provided"
        assert block_size_y is not None, "block_size_y must be provided"
        assert thread_row_tiles is not None, "thread_row_tiles must be provided"
        assert thread_col_tiles is not None, "thread_col_tiles must be provided"
        assert chunk is not None, "chunk must be provided"

        M = self.maybe_dynamic(self.M, "m")
        N, K = self.N, self.K
        assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"

        in_dtype, out_dtype, accum_dtype = (
            self.in_dtype,
            self.out_dtype,
            self.accum_dtype,
        )

        shared_scope = "shared.dyn"

        block_M = block_size_x * thread_row_tiles
        block_N = block_size_y * thread_col_tiles
        block_K = chunk

        A_shape = (M, K)
        B_shape = (N, K)
        C_shape = (M, N)
        A_shared_shape = (block_M, block_K)
        B_shared_shape = (block_N, block_K)

        threads = thread_row_tiles * thread_col_tiles
        local_size_a = block_M // thread_row_tiles
        local_size_b = block_N // thread_col_tiles
        local_size_c = (block_M // thread_row_tiles) * (block_N // thread_col_tiles)

        micro_size_k = 128 // DataType(in_dtype).bits

        dp4a_size = 4
        use_dp4a = in_dtype == "int8" and accum_dtype == "int32"

        @T.prim_func
        def main(
                A: T.Buffer(A_shape, in_dtype),
                B: T.Buffer(B_shape, in_dtype),
                C: T.Buffer(C_shape, out_dtype),
        ):
            with T.Kernel(
                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

                A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
                B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)

                A_local = T.alloc_local((local_size_a, micro_size_k), in_dtype)
                B_local = T.alloc_local((local_size_b, micro_size_k), in_dtype)
                C_local = T.alloc_local((local_size_c,), accum_dtype)

                thread_binding = T.thread_binding(threads, "threadIdx.x")

                warp_m = thread_binding % thread_row_tiles
                warp_n = thread_binding // thread_row_tiles

                T.clear(C_local)

                for ko in T.serial(K // block_K):

                    # Load A into shared memory
                    for i, k in T.Parallel(block_M, block_K):
                        A_shared[i, k] = A[by * block_M + i, ko * block_K + k]

                    # Load B into shared memory
                    for j, k in T.Parallel(block_N, block_K):
                        B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]

                    for ki in T.serial((block_K // micro_size_k)):
                        for i in T.serial(local_size_a):
                            for mk in T.vectorized(micro_size_k):
                                A_local[i, mk] = A_shared[warp_m * local_size_a + i,
                                                          ki * micro_size_k + mk]

                        for i in T.serial(local_size_b):
                            for mk in T.vectorized(micro_size_k):
                                B_local[i, mk] = B_shared[warp_n * local_size_b + i,
                                                          ki * micro_size_k + mk]

                        for i, j in T.grid(local_size_a, local_size_b):
                            for mk in T.serial(micro_size_k // dp4a_size):
                                if use_dp4a:
                                    T.dp4a(
                                        A_local[i, mk * dp4a_size],
                                        B_local[j, mk * dp4a_size],
                                        C_local[i * local_size_b + j],
                                    )
                                else:
                                    for dp4a_idx in T.serial(dp4a_size):
                                        C_local[i * local_size_b + j] += (
                                            A_local[i, mk * dp4a_size + dp4a_idx] *
                                            B_local[j, mk * dp4a_size + dp4a_idx])

                for i, j in T.grid(local_size_a, local_size_b):
                    C[
                        by * block_M + warp_m * local_size_a + i,
                        bx * block_N + warp_n * local_size_b + j,
                    ] = C_local[i * local_size_b + j]

        return self.post_process(main)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant