Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Enhance Backend Abstraction for TileLang #255

Merged
merged 55 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
2a0f59c
relax transform update
Nov 10, 2024
b475407
End2end Fix
Nov 11, 2024
b0738ba
Merge branch 'main' of https://github.com/microsoft/BitBLAS into relax
Nov 11, 2024
f23a2ec
lint fix
Nov 11, 2024
79826b6
Merge branch 'main' of https://github.com/microsoft/BitBLAS into relax
LeiWang1999 Nov 27, 2024
1961bc4
bf16 test fix
LeiWang1999 Nov 27, 2024
3aa5d82
format fix
LeiWang1999 Nov 27, 2024
353e279
lint fix
LeiWang1999 Nov 27, 2024
7eb315f
test fix
LeiWang1999 Nov 27, 2024
c1b452f
test fix
LeiWang1999 Nov 27, 2024
fe93429
update commits
LeiWang1999 Nov 27, 2024
ccac456
test fix
LeiWang1999 Nov 27, 2024
ddaeba2
Merge branch 'main' of https://github.com/microsoft/BitBLAS into bf16…
LeiWang1999 Nov 28, 2024
4b6fddb
submodule update
LeiWang1999 Nov 28, 2024
a8ccb17
Implement FP4
LeiWang1999 Nov 29, 2024
e2632e6
lint fix
LeiWang1999 Nov 29, 2024
47abe0a
lint fix
LeiWang1999 Nov 29, 2024
1b5a336
testfix
LeiWang1999 Nov 29, 2024
02c09eb
test fix
LeiWang1999 Nov 29, 2024
ec0e00c
lint fix
LeiWang1999 Nov 29, 2024
667b36c
lint fix
LeiWang1999 Nov 29, 2024
2193164
bugfix
LeiWang1999 Nov 29, 2024
478a0c7
support dp4a and fix test
LeiWang1999 Nov 29, 2024
c323c79
format fix
LeiWang1999 Nov 29, 2024
a9559a2
implement simt
LeiWang1999 Nov 29, 2024
32e8141
submodule update
LeiWang1999 Nov 29, 2024
017b0a7
lint fix
LeiWang1999 Nov 29, 2024
5eb8c16
Code refactorization
LeiWang1999 Dec 1, 2024
7e2b3a9
BUG Fix
LeiWang1999 Dec 1, 2024
a4a741d
optimize import
LeiWang1999 Dec 1, 2024
347dc31
optimize import
LeiWang1999 Dec 1, 2024
e3c371e
submodule update
LeiWang1999 Dec 1, 2024
6e2e595
test case fix
LeiWang1999 Dec 1, 2024
ccf66a8
Enhance top warp hint
LeiWang1999 Dec 1, 2024
c70c6c0
typo fix
LeiWang1999 Dec 1, 2024
de63591
optimize code
LeiWang1999 Dec 2, 2024
4663996
Support TL Wrapper with Dynamic Shape
LeiWang1999 Dec 3, 2024
801e675
Code Reformat
LeiWang1999 Dec 3, 2024
0852f86
Enhance Layout Inference Pass
LeiWang1999 Dec 3, 2024
5cd120c
Implement tuning with dynamic shape
LeiWang1999 Dec 3, 2024
5b689bc
Merge branch 'main' of https://github.com/microsoft/BitBLAS into dyna…
LeiWang1999 Dec 3, 2024
d4dd664
optimize dequantize code structure
LeiWang1999 Dec 3, 2024
b231f84
Support WMMA
LeiWang1999 Dec 4, 2024
658a7f4
Smart Rewrite Support
LeiWang1999 Dec 4, 2024
70fba29
support simt
LeiWang1999 Dec 4, 2024
58f470c
typofix
LeiWang1999 Dec 4, 2024
e3b42b6
implement dequant test
LeiWang1999 Dec 5, 2024
b89d71e
test fix
LeiWang1999 Dec 5, 2024
1f28c19
test fix
LeiWang1999 Dec 5, 2024
ff8966b
test fix
LeiWang1999 Dec 5, 2024
142771d
lint fix
LeiWang1999 Dec 6, 2024
a06f773
fix for rescale zeros
LeiWang1999 Dec 6, 2024
f697321
Support A100
LeiWang1999 Dec 6, 2024
a8a0af5
update
LeiWang1999 Dec 8, 2024
00f170d
format
LeiWang1999 Dec 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 321f41 to 8e2f4b
2 changes: 0 additions & 2 deletions bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def remove_tvm_path(path):
logger.warning(CUTLASS_NOT_FOUND_MESSAGE)

import tvm as tvm # noqa: E402
from . import gpu # noqa: F401
from .base import (
TileDevice, # noqa: F401
fast_tune, # noqa: F401
Expand All @@ -148,7 +147,6 @@ def remove_tvm_path(path):
ApplyDefaultSchedule, # noqa: F401
ApplyFastTuning, # noqa: F401
)
from . import testing # noqa: F401
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
Expand Down
4 changes: 3 additions & 1 deletion bitblas/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
normalize_prim_func, # noqa: F401
) # noqa: F401
from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial # noqa: F401
from .base_scheduler import simplify_prim_func # noqa: F401
from .schedule_rule import ScheduleRule # noqa: F401
from .utils import fast_tune, fast_tune_with_dynamic_range # noqa: F401
from .tuner import fast_tune, fast_tune_with_dynamic_range # noqa: F401
from .roller import *
from .arch import CUDA, CDNA # noqa: F401
from .operator_common import TransformKind, OptimizeStrategy, BackendKind # noqa: F401
59 changes: 55 additions & 4 deletions bitblas/base/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from .cuda import *
from .cpu import *
from .cdna import *
from typing import Union


def get_arch(target: tvm.target.Target) -> TileDevice:
def get_arch(target: Union[str, tvm.target.Target] = "cuda") -> TileDevice:
if isinstance(target, str):
target = tvm.target.Target(target)

if target.kind.name == "cuda":
return CUDA(target)
elif target.kind.name == "llvm":
Expand All @@ -17,16 +21,63 @@ def get_arch(target: tvm.target.Target) -> TileDevice:
raise ValueError(f"Unsupported target: {target.kind.name}")


def auto_infer_current_arch() -> TileDevice:
# TODO(lei): This is a temporary solution to infer the current architecture
# Can be replaced by a more sophisticated method in the future
return get_arch("cuda")


def is_cpu_arch(arch: TileDevice) -> bool:
return isinstance(arch, CPU)


def is_cuda_arch(arch: TileDevice) -> bool:
return isinstance(arch, CUDA)


def is_ampere_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(isinstance(arch, CUDA))
conditions.append(arch.sm_version >= 80)
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80 and arch.sm_version < 90)
return all(conditions)


def is_volta_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(isinstance(arch, CUDA))
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 70)
conditions.append(arch.sm_version < 80)
return all(conditions)


def is_cdna_arch(arch: TileDevice) -> bool:
return isinstance(arch, CDNA)


def has_mma_support(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80)
return all(conditions)


def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool:
volta_tensorcore_supported = [
("float16", "float32"),
("float16", "float16"),
]
ampere_tensorcore_supported = [
("float16", "float32"),
("float16", "float16"),
("int8", "int32"),
("int4", "int32"),
("int2", "int32"),
("int1", "int32"),
]

if is_volta_arch(arch):
return (in_dtype, accum_dtype) in volta_tensorcore_supported
elif is_ampere_arch(arch):
return (in_dtype, accum_dtype) in ampere_tensorcore_supported
else:
raise ValueError(f"Unsupported architecture: {arch}")
148 changes: 148 additions & 0 deletions bitblas/base/base_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from tvm import te
from tvm import IRModule
from tvm.tir import PrimFunc
from typing import Optional, Union, Callable, List, Dict
from dataclasses import dataclass, field
from tvm.tl.transform import Simplify
from abc import ABC, abstractmethod
from bitblas.base.arch import TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch
from bitblas.base.roller.hint import Hint
from bitblas.tl.base_hint import BaseTLHint


# Decorator to simplify the output of a function
def maybe_simplify(self, func: Callable) -> Callable:

def wrapper(*args, **kwargs):
stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs)
if self._enable_simplify:
return self.Simplify(stmt)
return stmt

return wrapper


@dataclass
class BaseScheduler(ABC):

_arch: TileDevice = field(default=auto_infer_current_arch(), init=False, repr=False)

_enable_simplify: bool = field(default=True, init=False, repr=False)

_dynamic_range: Dict[str, int] = field(default_factory=dict, init=False, repr=False)

@staticmethod
def Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
if isinstance(stmt, PrimFunc):
mod = Simplify()(IRModule.from_expr(stmt))
assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop()
elif isinstance(stmt, IRModule):
return Simplify()(stmt)
else:
raise ValueError(f"Unsupported type: {type(stmt)}")

def get_hardware_aware_configs(self,
arch: TileDevice = None,
topk: int = 10) -> List[BaseTLHint]:
raise NotImplementedError(
f"{self.__class__.__name__} does not support hardware-aware tuning for {arch} with topk={topk}"
)

def activate_simplify(self) -> "BaseScheduler":
self._enable_simplify = True
return self

def deactivate_simplify(self) -> "BaseScheduler":
self._enable_simplify = False
return self

def maybe_simplify(self, stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
if self._enable_simplify:
return self.Simplify(stmt)
return stmt

def with_self_attrs(self, func: PrimFunc) -> PrimFunc:
if self._dynamic_range:
func = func.with_attr("opt_shapes", self._dynamic_range)
return func

def post_process(self, func: PrimFunc) -> PrimFunc:
func = self.with_self_attrs(func)
func = self.maybe_simplify(func)
return func

def set_dynamic_range(self, dynamic_range: Dict[str, int]) -> "BaseScheduler":
self._dynamic_range = dynamic_range
return self

def has_dynamic_range(self) -> bool:
return bool(self._dynamic_range)

def with_arch(self, arch: TileDevice) -> "BaseScheduler":
self._arch = arch
return self

def has_arch(self) -> bool:
return self._arch is not None

def is_volta_arch(self) -> bool:
return is_volta_arch(self._arch) if self._arch is not None else False

def is_ampere_arch(self) -> bool:
return is_ampere_arch(self._arch) if self._arch is not None else False

def is_cdna_arch(self) -> bool:
return is_cdna_arch(self._arch) if self._arch is not None else False

@staticmethod
def maybe_dynamic(arg: Union[int, List[int]], dynamic_symbol: str = "m") -> PrimFunc:
if isinstance(arg, int):
return arg
return te.var(dynamic_symbol)

@abstractmethod
def with_default_config(self, *args, **kwargs) -> PrimFunc:
pass

@abstractmethod
def apply_config(
self,
*args,
**kwargs,
) -> PrimFunc:
pass

def serialize_hints_to_configs(self, hints: List[Hint]) -> List[BaseTLHint]:
# Convert Roller Hints to TileLang Hints
raise NotImplementedError("Serialization of hints to configs is not implemented")

def specialize_from_dynamic_range(self,
dynamic_range: Optional[Dict[str,
int]] = None) -> "BaseScheduler":
raise NotImplementedError("Specialization from dynamic range is not implemented")

@property
def common_header(self) -> str:
# TODO(lei): For HIP Backend it should be different
common_header = "#include <tl_templates/cuda/common.h>\n"
return common_header

@property
def global_symbol(self):
# For kernel name generation
return "default"

@property
def arch(self) -> TileDevice:
return self._arch


# Decorator to simplify the output of a function
def simplify_prim_func(func: Callable) -> Callable:

def wrapper(*args, **kwargs):
stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs)
return BaseScheduler.Simplify(stmt)

return wrapper
102 changes: 102 additions & 0 deletions bitblas/base/operator_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from enum import IntEnum


class OptimizeStrategy(IntEnum):
SingleBatchDecodeOnly = 0
ContigousBatching = 1

def is_single_batch_decode_only(self):
return self == OptimizeStrategy.SingleBatchDecodeOnly

def is_contigous_batching(self):
return self == OptimizeStrategy.ContigousBatching


class TransformKind(IntEnum):
NonTransform = 0
InterWarpTransform = 1
IntraWarpTransform = 2
LDMatrixTransform = 3

def is_non_transform(self):
return self == TransformKind.NonTransform

def is_inter_warp_transform(self):
return self == TransformKind.InterWarpTransform

def is_intra_warp_transform(self):
return self == TransformKind.IntraWarpTransform

def is_ld_matrix_transform(self):
return self == TransformKind.LDMatrixTransform


class BackendKind(IntEnum):
TIR = 0
TileLang = 1

def is_tir_backend(self):
return self == BackendKind.TIR

def is_tilelang_backend(self):
return self == BackendKind.TileLang


class QuantizationMemoryStage(IntEnum):
# Represents in which stage the dequantize operation is performed
#
# 1. For devices without async copy, we can use a simple dequantize schedule
# without shared memory prefetch.
# quantized weight
# |
# V
# dequantized in register
# |
# V
# save into shared memory
# |
# V
# compute
#
# 2. For A100 Like devices, the shared memory prefetch(async) is required
# to achieve optimal performance.
# quantized weight
# |
# V
# shared memory prefetch (with async copy)
# |
# V
# dequantized into shared memory
# |
# V
# compute
# 3. For A100 Like devices, the shared memory prefetch(async) is required
# to achieve optimal performance.
# quantized weight
# |
# V
# shared memory prefetch (with async copy)
# |
# V
# LDMatrix into warp memory
# |
# V
# Dequantize
# |
# V
# Compute
Local = 0
Shared = 1
Global = 2

def is_quant_memory_in_local(self):
return self == QuantizationMemoryStage.Local

def is_quant_memory_in_shared(self):
return self == QuantizationMemoryStage.Shared

def is_quant_memory_in_global(self):
return self == QuantizationMemoryStage.Global
Loading
Loading