Skip to content

Commit

Permalink
ExllamaV2 tensor parallelism to increase multi gpu inference speeds (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomInternetPreson authored Sep 28, 2024
1 parent 3013758 commit 46996f6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
26 changes: 18 additions & 8 deletions modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_TP,
ExLlamaV2Config,
ExLlamaV2Tokenizer
)
Expand Down Expand Up @@ -54,21 +55,30 @@ def from_pretrained(self, path_to_model):

model = ExLlamaV2(config)

if not shared.args.autosplit:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]

if shared.args.enable_tp:
model.load_tp(split)
elif not shared.args.autosplit:
model.load(split)

# Determine the correct cache type
if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_8bit
elif shared.args.cache_4bit:
cache = ExLlamaV2Cache_Q4(model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_Q4
else:
cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache

if shared.args.autosplit:
# Use TP if specified
if shared.args.enable_tp:
cache = ExLlamaV2Cache_TP(model, base=cache_type)
else:
cache = cache_type(model, lazy=shared.args.autosplit)

if shared.args.autosplit and not shared.args.enable_tp:
model.load_autosplit(cache)

tokenizer = ExLlamaV2Tokenizer(config)
Expand Down
28 changes: 19 additions & 9 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_TP,
ExLlamaV2Config
)
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -42,21 +43,30 @@ def __init__(self, config: ExLlamaV2Config):

self.ex_model = ExLlamaV2(config)

if not shared.args.autosplit:
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]

self.ex_model.load(split)
if shared.args.enable_tp:
model.load_tp(split)
elif not shared.args.autosplit:
model.load(split)

# Determine the correct cache type
if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_8bit
elif shared.args.cache_4bit:
self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache_Q4
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit)
cache_type = ExLlamaV2Cache

if shared.args.autosplit:
# Use TP if specified
if shared.args.enable_tp:
self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
else:
self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)

if shared.args.autosplit and not shared.args.enable_tp:
self.ex_model.load_autosplit(self.ex_cache)

self.past_seq = None
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.')
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.')

# AutoGPTQ
group = parser.add_argument_group('AutoGPTQ')
Expand Down

0 comments on commit 46996f6

Please sign in to comment.