forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8148b54
commit b9f05ff
Showing
1 changed file
with
255 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
import argparse | ||
import json | ||
import os | ||
import sys | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import triton | ||
from tqdm import tqdm | ||
from vllm import envs | ||
from torch import nn | ||
from vllm.model_executor.layers.fused_moe import (fused_moe, | ||
get_config_file_name) | ||
|
||
|
||
def main(model, tp_size, gpu, dtype: str): | ||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) | ||
method = fused_moe | ||
# for bs in [ | ||
# 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, | ||
# 2048, 3072, 4096 | ||
# ]: | ||
for bs in [8, 16, 32, 64, 96, 112, 120, 128]: | ||
run_grid(bs, | ||
model=model, | ||
method=method, | ||
gpu=gpu, | ||
tp_size=tp_size, | ||
dtype=dtype) | ||
|
||
|
||
def run_grid(bs, model, method, gpu, tp_size, dtype: str): | ||
if model == '8x7B': | ||
d_model = 4096 | ||
model_intermediate_size = 14336 | ||
num_layers = 32 | ||
elif model == '8x22B': | ||
d_model = 6144 | ||
model_intermediate_size = 16384 | ||
num_layers = 56 | ||
else: | ||
raise ValueError(f'Unsupported Mixtral model {model}') | ||
num_total_experts = 8 | ||
top_k = 2 | ||
# tp_size = 2 | ||
num_calls = 100 | ||
|
||
num_warmup_trials = 1 | ||
num_trials = 1 | ||
|
||
configs = [] | ||
|
||
for block_size_n in [32, 64, 128, 256]: | ||
for block_size_m in [16, 32, 64, 128, 256]: | ||
for block_size_k in [64, 128, 256]: | ||
for group_size_m in [1, 16, 32, 64]: | ||
for num_warps in [4, 8]: | ||
for num_stages in [2, 3, 4, 5]: | ||
configs.append({ | ||
"BLOCK_SIZE_M": block_size_m, | ||
"BLOCK_SIZE_N": block_size_n, | ||
"BLOCK_SIZE_K": block_size_k, | ||
"GROUP_SIZE_M": group_size_m, | ||
"num_warps": num_warps, | ||
"num_stages": num_stages, | ||
}) | ||
|
||
best_config = None | ||
best_time_us = 1e20 | ||
|
||
print(f'{tp_size=} {bs=}') | ||
|
||
# for config in tqdm(configs): | ||
if 1: | ||
# warmup | ||
try: | ||
for _ in range(num_warmup_trials): | ||
run_timing( | ||
num_calls=num_calls, | ||
bs=bs, | ||
d_model=d_model, | ||
num_total_experts=num_total_experts, | ||
top_k=top_k, | ||
tp_size=tp_size, | ||
model_intermediate_size=model_intermediate_size, | ||
method=method, | ||
config=None, | ||
dtype=dtype, | ||
) | ||
except triton.runtime.autotuner.OutOfResources: | ||
#continue | ||
pass | ||
|
||
# trial | ||
for _ in range(num_trials): | ||
kernel_dur_ms = run_timing( | ||
num_calls=num_calls, | ||
bs=bs, | ||
d_model=d_model, | ||
num_total_experts=num_total_experts, | ||
top_k=top_k, | ||
tp_size=tp_size, | ||
model_intermediate_size=model_intermediate_size, | ||
method=method, | ||
config=None, | ||
dtype=dtype, | ||
) | ||
|
||
kernel_dur_us = 1000 * kernel_dur_ms | ||
model_dur_ms = kernel_dur_ms * num_layers | ||
|
||
if kernel_dur_us < best_time_us: | ||
# best_config = config | ||
best_time_us = kernel_dur_us | ||
tqdm.write( | ||
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' | ||
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' | ||
f'{d_model=} {model_intermediate_size=} {num_layers=}') | ||
|
||
print("best_time_us", best_time_us) | ||
print("best_config", best_config) | ||
|
||
# holds Dict[str, Dict[str, int]] | ||
# filename = get_config_file_name(num_total_experts, | ||
# model_intermediate_size // tp_size, | ||
# "float8" if dtype == "float8" else None) | ||
# print(f"writing config to file {filename}") | ||
# existing_content = {} | ||
# if os.path.exists(filename): | ||
# with open(filename, "r") as f: | ||
# existing_content = json.load(f) | ||
# existing_content[str(bs)] = best_config | ||
# with open(filename, "w") as f: | ||
# json.dump(existing_content, f, indent=4) | ||
# f.write("\n") | ||
|
||
|
||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, | ||
top_k: int, tp_size: int, model_intermediate_size: int, method, | ||
config, dtype: str) -> float: | ||
shard_intermediate_size = model_intermediate_size // tp_size | ||
|
||
hidden_states = torch.rand( | ||
(bs, d_model), | ||
device="cuda:0", | ||
dtype=torch.float16, | ||
) | ||
|
||
w1 = torch.rand( | ||
(num_total_experts, 2 * shard_intermediate_size, d_model), | ||
device=hidden_states.device, | ||
dtype=hidden_states.dtype, | ||
) | ||
|
||
w2 = torch.rand( | ||
(num_total_experts, d_model, shard_intermediate_size), | ||
device=hidden_states.device, | ||
dtype=hidden_states.dtype, | ||
) | ||
|
||
if envs.VLLM_MOE_PADDING: | ||
w1 = nn.Parameter(F.pad(w1.data, | ||
(0, 128), "constant", 0), | ||
requires_grad=False) | ||
torch.cuda.empty_cache() | ||
w2 = nn.Parameter(F.pad(w2.data, | ||
(0, 128), "constant", 0), | ||
requires_grad=False) | ||
torch.cuda.empty_cache() | ||
|
||
w1_scale = None | ||
w2_scale = None | ||
a1_scale = None | ||
a2_scale = None | ||
|
||
if dtype == "float8": | ||
w1 = w1.to(torch.float8_e4m3fn) | ||
w2 = w2.to(torch.float8_e4m3fn) | ||
w1_scale = torch.ones(num_total_experts, | ||
device=hidden_states.device, | ||
dtype=torch.float32) | ||
w2_scale = torch.ones(num_total_experts, | ||
device=hidden_states.device, | ||
dtype=torch.float32) | ||
a1_scale = torch.ones(1, | ||
device=hidden_states.device, | ||
dtype=torch.float32) | ||
a2_scale = torch.ones(1, | ||
device=hidden_states.device, | ||
dtype=torch.float32) | ||
|
||
gating_output = F.softmax(torch.rand( | ||
(num_calls, bs, num_total_experts), | ||
device=hidden_states.device, | ||
dtype=torch.float32, | ||
), | ||
dim=-1) | ||
|
||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
|
||
start_event.record() | ||
for i in range(num_calls): | ||
hidden_states = method( | ||
hidden_states=hidden_states, | ||
w1=w1, | ||
w2=w2, | ||
w1_scale=w1_scale, | ||
w2_scale=w2_scale, | ||
a1_scale=a1_scale, | ||
a2_scale=a2_scale, | ||
gating_output=gating_output[i], | ||
topk=2, | ||
renormalize=True, | ||
inplace=True, | ||
override_config=config, | ||
use_fp8=dtype == "float8", | ||
) | ||
end_event.record() | ||
end_event.synchronize() | ||
|
||
|
||
# torch_output = torch_moe(a, w1, w2, score, topk) | ||
|
||
dur_ms = start_event.elapsed_time(end_event) / num_calls | ||
return dur_ms | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
prog='benchmark_mixtral_moe', | ||
description='Benchmark and tune the fused_moe kernel', | ||
) | ||
parser.add_argument( | ||
'--dtype', | ||
type=str, | ||
default='auto', | ||
choices=['float8', 'float16'], | ||
help='Data type used for fused_moe kernel computations', | ||
) | ||
parser.add_argument('--model', | ||
type=str, | ||
default='8x7B', | ||
choices=['8x7B', '8x22B'], | ||
help='The Mixtral model to benchmark') | ||
parser.add_argument('--tp-size', | ||
type=int, | ||
default=2, | ||
help='Tensor paralleli size') | ||
parser.add_argument('--gpu', | ||
type=int, | ||
default=0, | ||
help="GPU ID for benchmarking") | ||
args = parser.parse_args() | ||
sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype)) |