-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dev] Enhance Backend Abstraction for TileLang #255
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Local Test has been passed. |
Brief Introduction of the new data structure for Tile Lang, every tile lang implementation is recommend be wrapped within a @dataclass
class BaseScheduler(ABC):
_arch: TileDevice = field(default=auto_infer_current_arch(), init=False, repr=False)
_enable_simplify: bool = field(default=True, init=False, repr=False)
_dynamic_range: Dict[str, int] = field(default_factory=dict, init=False, repr=False)
@staticmethod
def Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
if isinstance(stmt, PrimFunc):
mod = Simplify()(IRModule.from_expr(stmt))
assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop()
elif isinstance(stmt, IRModule):
return Simplify()(stmt)
else:
raise ValueError(f"Unsupported type: {type(stmt)}")
def get_hardware_aware_configs(self,
arch: TileDevice = None,
topk: int = 10) -> List[BaseTLHint]:
raise NotImplementedError(
f"{self.__class__.__name__} does not support hardware-aware tuning for {arch} with topk={topk}"
)
def activate_simplify(self) -> "BaseScheduler":
self._enable_simplify = True
return self
def deactivate_simplify(self) -> "BaseScheduler":
self._enable_simplify = False
return self
def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
if self._enable_simplify:
return self.Simplify(stmt)
return stmt
def with_self_attrs(self, func: PrimFunc) -> PrimFunc:
if self._dynamic_range:
func = func.with_attr("opt_shapes", self._dynamic_range)
return func
def post_process(self, func: PrimFunc) -> PrimFunc:
func = self.with_self_attrs(func)
func = self.maybe_simplify(func)
return func
def set_dynamic_range(self, dynamic_range: Dict[str, int]) -> "BaseScheduler":
self._dynamic_range = dynamic_range
return self
def has_dynamic_range(self) -> bool:
return bool(self._dynamic_range)
def with_arch(self, arch: TileDevice) -> "BaseScheduler":
self._arch = arch
return self
def has_arch(self) -> bool:
return self._arch is not None
def is_volta_arch(self) -> bool:
return is_volta_arch(self._arch) if self._arch is not None else False
def is_ampere_arch(self) -> bool:
return is_ampere_arch(self._arch) if self._arch is not None else False
def is_cdna_arch(self) -> bool:
return is_cdna_arch(self._arch) if self._arch is not None else False
@staticmethod
def maybe_dynamic(arg: Union[int, List[int]], dynamic_symbol: str = "m") -> PrimFunc:
if isinstance(arg, int):
return arg
return te.var(dynamic_symbol)
@abstractmethod
def with_default_config(self, *args, **kwargs) -> PrimFunc:
pass
@abstractmethod
def apply_config(
self,
*args,
**kwargs,
) -> PrimFunc:
pass
def serialize_hints_to_configs(self, hints: List[Hint]) -> List[BaseTLHint]:
# Convert Roller Hints to TileLang Hints
raise NotImplementedError("Serialization of hints to configs is not implemented")
def specialize_from_dynamic_range(self,
dynamic_range: Optional[Dict[str,
int]] = None) -> "BaseScheduler":
raise NotImplementedError("Specialization from dynamic range is not implemented")
@property
def common_header(self) -> str:
# TODO(lei): For HIP Backend it should be different
common_header = "#include <tl_templates/cuda/common.h>\n"
return common_header
@property
def global_symbol(self):
# For kernel name generation
return "default"
@property
def arch(self) -> TileDevice:
return self._arch And implementation for tile lang must be put in the def apply_config(
self,
block_size_x: Optional[int] = None,
block_size_y: Optional[int] = None,
thread_row_tiles: Optional[int] = None,
thread_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
):
assert block_size_x is not None, "block_size_x must be provided"
assert block_size_y is not None, "block_size_y must be provided"
assert thread_row_tiles is not None, "thread_row_tiles must be provided"
assert thread_col_tiles is not None, "thread_col_tiles must be provided"
assert chunk is not None, "chunk must be provided"
M = self.maybe_dynamic(self.M, "m")
N, K = self.N, self.K
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"
in_dtype, out_dtype, accum_dtype = (
self.in_dtype,
self.out_dtype,
self.accum_dtype,
)
shared_scope = "shared.dyn"
block_M = block_size_x * thread_row_tiles
block_N = block_size_y * thread_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
C_shape = (M, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
threads = thread_row_tiles * thread_col_tiles
local_size_a = block_M // thread_row_tiles
local_size_b = block_N // thread_col_tiles
local_size_c = (block_M // thread_row_tiles) * (block_N // thread_col_tiles)
micro_size_k = 128 // DataType(in_dtype).bits
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
A_local = T.alloc_local((local_size_a, micro_size_k), in_dtype)
B_local = T.alloc_local((local_size_b, micro_size_k), in_dtype)
C_local = T.alloc_local((local_size_c,), accum_dtype)
thread_binding = T.thread_binding(threads, "threadIdx.x")
warp_m = thread_binding % thread_row_tiles
warp_n = thread_binding // thread_row_tiles
T.clear(C_local)
for ko in T.serial(K // block_K):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial((block_K // micro_size_k)):
for i in T.serial(local_size_a):
for mk in T.vectorized(micro_size_k):
A_local[i, mk] = A_shared[warp_m * local_size_a + i,
ki * micro_size_k + mk]
for i in T.serial(local_size_b):
for mk in T.vectorized(micro_size_k):
B_local[i, mk] = B_shared[warp_n * local_size_b + i,
ki * micro_size_k + mk]
for i, j in T.grid(local_size_a, local_size_b):
for mk in T.serial(micro_size_k // dp4a_size):
if use_dp4a:
T.dp4a(
A_local[i, mk * dp4a_size],
B_local[j, mk * dp4a_size],
C_local[i * local_size_b + j],
)
else:
for dp4a_idx in T.serial(dp4a_size):
C_local[i * local_size_b + j] += (
A_local[i, mk * dp4a_size + dp4a_idx] *
B_local[j, mk * dp4a_size + dp4a_idx])
for i, j in T.grid(local_size_a, local_size_b):
C[
by * block_M + warp_m * local_size_a + i,
bx * block_N + warp_n * local_size_b + j,
] = C_local[i * local_size_b + j]
return self.post_process(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request includes several changes to the
bitblas
module, primarily focusing on refactoring, enhancements to the scheduler, and improvements in tuning functions. The most important changes include the addition of new functions for architecture inference, refactoring the scheduler, and introducing dynamic range tuning.Refactoring and Enhancements:
bitblas/base/__init__.py
: Added imports forsimplify_prim_func
and movedfast_tune
andfast_tune_with_dynamic_range
fromutils
totuner
.bitblas/base/arch/__init__.py
: Updatedget_arch
function to accept a string ortvm.target.Target
and added new functions for architecture inference (auto_infer_current_arch
,is_ampere_arch
,is_volta_arch
,is_cdna_arch
). [1] [2]Scheduler Improvements:
bitblas/base/base_scheduler.py
: Renamed frombitblas/ops/base_scheduler.py
and added new attributes and methods toBaseScheduler
class, includingwith_self_attrs
,post_process
,set_dynamic_range
, andserialize_hints_to_configs
. [1] [2] [3]Dynamic Range Tuning:
bitblas/base/tuner.py
: Introduced new functionsfast_tune
,fast_tune_with_dynamic_range_tir
,fast_tune_with_dynamic_range_tilelang
, andfast_tune_with_dynamic_range
to support dynamic range tuning and dispatching.Miscellaneous:
3rdparty/tvm
: Updated submodule commit reference.bitblas/base/utils.py
: Removed unused imports and updated import paths.