-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathmem_spd_test.py
71 lines (65 loc) · 2.17 KB
/
mem_spd_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# LLaMA model with KIVI
import torch
import os
from models.llama_kivi import LlamaForCausalLM_KIVI
from transformers import LlamaConfig, AutoTokenizer
import time
K_BITS = 2
V_BITS = 2
GROUP_SIZE = 32
RESIDUAL_LENGTH = 128
BATCH_SIZE = 96
PATH_TO_YOUR_SAVE_DIR = './cached_models'
model_name_or_path = 'meta-llama/Llama-2-7b-hf'
config = LlamaConfig.from_pretrained(model_name_or_path)
config.k_bits = K_BITS # current support 2/4 bit for KV Cache
config.v_bits = V_BITS # current support 2/4 bit for KV Cache
config.group_size = GROUP_SIZE
config.residual_length = RESIDUAL_LENGTH # the number of recent fp16 tokens
CACHE_DIR = PATH_TO_YOUR_SAVE_DIR
if K_BITS < 16 and V_BITS < 16:
model = LlamaForCausalLM_KIVI.from_pretrained(
pretrained_model_name_or_path=model_name_or_path,
config=config,
cache_dir=CACHE_DIR,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
)
else:
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_name_or_path,
config=config,
cache_dir=CACHE_DIR,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
use_fast=False,
trust_remote_code=True,
tokenizer_type='llama')
model.cuda().eval()
context = []
batch_size = BATCH_SIZE
prompt_lenth = 160
output_length = 338
num_repeats = 3
for _ in range(batch_size):
string = 't,' * (prompt_lenth // 2)
context.append(string[:-1])
inputs = tokenizer(context, return_tensors="pt").to('cuda')
input_ids = inputs['input_ids']
print(f"bs: {batch_size}, seqlen: {input_ids.shape[1]}+{output_length}\nmodel:{model_name_or_path}")
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
torch.cuda.synchronize()
st = time.time()
for i in range(num_repeats):
outputs = model.generate(**inputs, max_new_tokens=output_length)
torch.cuda.synchronize()
print(f'used time: {(time.time() - st) / num_repeats * 1000} ms')
used_mem = torch.cuda.max_memory_allocated()
print(f'peak mem: {used_mem / 1024 ** 3} GB')