-
-
Notifications
You must be signed in to change notification settings - Fork 291
/
Copy pathinference_dedup.py
74 lines (59 loc) · 2.57 KB
/
inference_dedup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer, Timer
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler
from util import format_prompt, get_stop_conditions
model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True)
model.load_autosplit(cache, progress = True)
print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
tokenizer = tokenizer,
)
# Load a short story and prepare some questions about it.
dirname = os.path.dirname(os.path.abspath(__file__))
filename = os.path.join(dirname, "the_black_veil.utf8")
with open(filename, "r", encoding = "utf8") as f:
input_text = f.read()
questions = [
"What are the themes of the story?",
"What is the setting for the story?",
"List the characters mentioned in the story.",
"Write a short summary of the story.",
"Does this story have a happy ending?",
"Does this story relate to any other works by the same author?",
"Is the text appropriate for all ages?",
"Provide up to five examples of outmoded language in the text."
]
# Create prompts to evaluate in parallel. The prompts will all contain the full context, but identical pages are
# deduplicated in the cache, so keys/values for the common prefix of all the prompts are only evaluated and stored
# once. Each prompt works out to about 6200 tokens including the response, but with deduplication up to 5 prompts
# can be batched together in the 8192-token cache
prompt_format = "llama3"
prompts = [
format_prompt(prompt_format,"You are a helpful AI assistant.", input_text + "\n###\n\n" + question)
for question in questions
]
# Generate
with Timer() as timer:
outputs = generator.generate(
prompt = prompts,
max_new_tokens = 400,
stop_conditions = get_stop_conditions(prompt_format, tokenizer),
completion_only = True,
encode_special_tokens = True
)
for question, output in zip(questions, outputs):
print("-----------------------------------------------------------------------------------")
print("Q: " + question)
print()
print("A: " + output)
print()
print("-----------------------------------------------------------------------------------")
print(f"Total time: {timer.interval:.2f} seconds")