Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add q-cache 6 and 8 support for Exllamav2 #6280

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Config,
ExLlamaV2Tokenizer
)
Expand Down Expand Up @@ -63,8 +65,12 @@ def from_pretrained(self, path_to_model):

if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit)
elif shared.args.cache_4bit:
elif shared.args.cache_q4:
cache = ExLlamaV2Cache_Q4(model, lazy=shared.args.autosplit)
elif shared.args.cache_q6:
cache = ExLlamaV2Cache_Q6(model, lazy=shared.args.autosplit)
elif shared.args.cache_q8:
cache = ExLlamaV2Cache_Q8(model, lazy=shared.args.autosplit)
else:
cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit)

Expand Down
14 changes: 12 additions & 2 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Config
)
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -51,8 +53,12 @@ def __init__(self, config: ExLlamaV2Config):

if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
elif shared.args.cache_4bit:
elif shared.args.cache_q4:
self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit)
elif shared.args.cache_q6:
self.ex_cache = ExLlamaV2Cache_Q6(self.ex_model, lazy=shared.args.autosplit)
elif shared.args.cache_q8:
self.ex_cache = ExLlamaV2Cache_Q8(self.ex_model, lazy=shared.args.autosplit)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)

Expand All @@ -63,8 +69,12 @@ def __init__(self, config: ExLlamaV2Config):
if shared.args.cfg_cache:
if shared.args.cache_8bit:
self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
elif shared.args.cache_4bit:
elif shared.args.cache_q4:
self.ex_cache_negative = ExLlamaV2Cache_Q4(self.ex_model)
elif shared.args.cache_q6:
self.ex_cache_negative = ExLlamaV2Cache_Q6(self.ex_model)
elif shared.args.cache_q8:
self.ex_cache_negative = ExLlamaV2Cache_Q8(self.ex_model)
else:
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)

Expand Down
8 changes: 6 additions & 2 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@
'no_sdpa',
'num_experts_per_token',
'cache_8bit',
'cache_4bit',
'cache_q4',
'cache_q6',
'cache_q8',
'autosplit',
'alpha_value',
'compress_pos_emb',
Expand All @@ -103,7 +105,9 @@
'no_sdpa',
'num_experts_per_token',
'cache_8bit',
'cache_4bit',
'cache_q4',
'cache_q6',
'cache_q8',
'autosplit',
'alpha_value',
'compress_pos_emb',
Expand Down
7 changes: 5 additions & 2 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,11 @@
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')
group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.')
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.')
group.add_argument('--cache_4bit', action='store_true', help='Use 4-bit cache to save VRAM (llama.cpp).')
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit (FP8) cache to save VRAM.')
group.add_argument('--cache_q4', action='store_true', help='Use Q4 cache to save VRAM.')
group.add_argument('--cache_q6', action='store_true', help='Use Q6 cache to save VRAM.')
group.add_argument('--cache_q8', action='store_true', help='Use Q8 cache to save VRAM.')
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')

# AutoGPTQ
Expand Down
3 changes: 3 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def list_model_elements():
'num_experts_per_token',
'cache_8bit',
'cache_4bit',
'cache_q4',
'cache_q6',
'cache_q8',
'autosplit',
'threads',
'threads_batch',
Expand Down
7 changes: 5 additions & 2 deletions modules/ui_model_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ def create_ui():
shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.')
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This may increase performance on newer cards.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
shared.gradio['cache_4bit'] = gr.Checkbox(label="cache_4bit", value=shared.args.cache_4bit, info='Use Q4 cache to save VRAM.')
shared.gradio['cache_4bit'] = gr.Checkbox(label="cache_4bit", value=shared.args.cache_8bit, info='Use 4-bit (FP4) cache to save VRAM.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit (FP8) cache to save VRAM.')
shared.gradio['cache_q4'] = gr.Checkbox(label="cache_q4", value=shared.args.cache_q4, info='Use Q4 cache to save VRAM.')
shared.gradio['cache_q6'] = gr.Checkbox(label="cache_q6", value=shared.args.cache_q6, info='Use Q6 cache to save VRAM.')
shared.gradio['cache_q8'] = gr.Checkbox(label="cache_q8", value=shared.args.cache_q8, info='Use Q8 cache to save VRAM.')
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.')
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')
Expand Down