Skip to content

Commit

Permalink
add script for decode
Browse files Browse the repository at this point in the history
  • Loading branch information
carlushuang committed Aug 11, 2024
1 parent 8148b54 commit b9f05ff
Showing 1 changed file with 255 additions and 0 deletions.
255 changes: 255 additions & 0 deletions benchmarks/kernels/benchmark_mixtral_moe_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import argparse
import json
import os
import sys

import torch
import torch.nn.functional as F
import triton
from tqdm import tqdm
from vllm import envs
from torch import nn
from vllm.model_executor.layers.fused_moe import (fused_moe,
get_config_file_name)


def main(model, tp_size, gpu, dtype: str):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
method = fused_moe
# for bs in [
# 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
# 2048, 3072, 4096
# ]:
for bs in [8, 16, 32, 64, 96, 112, 120, 128]:
run_grid(bs,
model=model,
method=method,
gpu=gpu,
tp_size=tp_size,
dtype=dtype)


def run_grid(bs, model, method, gpu, tp_size, dtype: str):
if model == '8x7B':
d_model = 4096
model_intermediate_size = 14336
num_layers = 32
elif model == '8x22B':
d_model = 6144
model_intermediate_size = 16384
num_layers = 56
else:
raise ValueError(f'Unsupported Mixtral model {model}')
num_total_experts = 8
top_k = 2
# tp_size = 2
num_calls = 100

num_warmup_trials = 1
num_trials = 1

configs = []

for block_size_n in [32, 64, 128, 256]:
for block_size_m in [16, 32, 64, 128, 256]:
for block_size_k in [64, 128, 256]:
for group_size_m in [1, 16, 32, 64]:
for num_warps in [4, 8]:
for num_stages in [2, 3, 4, 5]:
configs.append({
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
})

best_config = None
best_time_us = 1e20

print(f'{tp_size=} {bs=}')

# for config in tqdm(configs):
if 1:
# warmup
try:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=None,
dtype=dtype,
)
except triton.runtime.autotuner.OutOfResources:
#continue
pass

# trial
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=None,
dtype=dtype,
)

kernel_dur_us = 1000 * kernel_dur_ms
model_dur_ms = kernel_dur_ms * num_layers

if kernel_dur_us < best_time_us:
# best_config = config
best_time_us = kernel_dur_us
tqdm.write(
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
f'{d_model=} {model_intermediate_size=} {num_layers=}')

print("best_time_us", best_time_us)
print("best_config", best_config)

# holds Dict[str, Dict[str, int]]
# filename = get_config_file_name(num_total_experts,
# model_intermediate_size // tp_size,
# "float8" if dtype == "float8" else None)
# print(f"writing config to file {filename}")
# existing_content = {}
# if os.path.exists(filename):
# with open(filename, "r") as f:
# existing_content = json.load(f)
# existing_content[str(bs)] = best_config
# with open(filename, "w") as f:
# json.dump(existing_content, f, indent=4)
# f.write("\n")


def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method,
config, dtype: str) -> float:
shard_intermediate_size = model_intermediate_size // tp_size

hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=torch.float16,
)

w1 = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

w2 = torch.rand(
(num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

if envs.VLLM_MOE_PADDING:
w1 = nn.Parameter(F.pad(w1.data,
(0, 128), "constant", 0),
requires_grad=False)
torch.cuda.empty_cache()
w2 = nn.Parameter(F.pad(w2.data,
(0, 128), "constant", 0),
requires_grad=False)
torch.cuda.empty_cache()

w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None

if dtype == "float8":
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
w2_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
a1_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)
a2_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)

gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
gating_output=gating_output[i],
topk=2,
renormalize=True,
inplace=True,
override_config=config,
use_fp8=dtype == "float8",
)
end_event.record()
end_event.synchronize()


# torch_output = torch_moe(a, w1, w2, score, topk)

dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='benchmark_mixtral_moe',
description='Benchmark and tune the fused_moe kernel',
)
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['float8', 'float16'],
help='Data type used for fused_moe kernel computations',
)
parser.add_argument('--model',
type=str,
default='8x7B',
choices=['8x7B', '8x22B'],
help='The Mixtral model to benchmark')
parser.add_argument('--tp-size',
type=int,
default=2,
help='Tensor paralleli size')
parser.add_argument('--gpu',
type=int,
default=0,
help="GPU ID for benchmarking")
args = parser.parse_args()
sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))

0 comments on commit b9f05ff

Please sign in to comment.