From 46996f6519d483841a49c8a857104a6bf6e2ac7a Mon Sep 17 00:00:00 2001 From: RandoInternetPreson Date: Fri, 27 Sep 2024 23:26:03 -0400 Subject: [PATCH] ExllamaV2 tensor parallelism to increase multi gpu inference speeds (#6356) --- modules/exllamav2.py | 26 ++++++++++++++++++-------- modules/exllamav2_hf.py | 28 +++++++++++++++++++--------- modules/shared.py | 1 + 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index a770e34257..42b9ade145 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -7,6 +7,7 @@ ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, + ExLlamaV2Cache_TP, ExLlamaV2Config, ExLlamaV2Tokenizer ) @@ -54,21 +55,30 @@ def from_pretrained(self, path_to_model): model = ExLlamaV2(config) - if not shared.args.autosplit: - split = None - if shared.args.gpu_split: - split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + if shared.args.enable_tp: + model.load_tp(split) + elif not shared.args.autosplit: model.load(split) + # Determine the correct cache type if shared.args.cache_8bit: - cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_8bit elif shared.args.cache_4bit: - cache = ExLlamaV2Cache_Q4(model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_Q4 else: - cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache - if shared.args.autosplit: + # Use TP if specified + if shared.args.enable_tp: + cache = ExLlamaV2Cache_TP(model, base=cache_type) + else: + cache = cache_type(model, lazy=shared.args.autosplit) + + if shared.args.autosplit and not shared.args.enable_tp: model.load_autosplit(cache) tokenizer = ExLlamaV2Tokenizer(config) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 53143d9a92..febb2c6495 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -9,6 +9,7 @@ ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, + ExLlamaV2Cache_TP, ExLlamaV2Config ) from torch.nn import CrossEntropyLoss @@ -42,21 +43,30 @@ def __init__(self, config: ExLlamaV2Config): self.ex_model = ExLlamaV2(config) - if not shared.args.autosplit: - split = None - if shared.args.gpu_split: - split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] - self.ex_model.load(split) + if shared.args.enable_tp: + model.load_tp(split) + elif not shared.args.autosplit: + model.load(split) + # Determine the correct cache type if shared.args.cache_8bit: - self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_8bit elif shared.args.cache_4bit: - self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_Q4 else: - self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache - if shared.args.autosplit: + # Use TP if specified + if shared.args.enable_tp: + self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type) + else: + self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit) + + if shared.args.autosplit and not shared.args.enable_tp: self.ex_model.load_autosplit(self.ex_cache) self.past_seq = None diff --git a/modules/shared.py b/modules/shared.py index 43533a1480..894ed6fe56 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -146,6 +146,7 @@ group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.') group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') +group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.') # AutoGPTQ group = parser.add_argument_group('AutoGPTQ')