From 7e15947aa1412b21c52077c1feb7c73122242ce7 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:51:46 +0200 Subject: [PATCH 01/10] Fix possible sync issues with fasttensors --- exllamav2/fasttensors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/exllamav2/fasttensors.py b/exllamav2/fasttensors.py index 3823c25c..ef49ca04 100644 --- a/exllamav2/fasttensors.py +++ b/exllamav2/fasttensors.py @@ -189,6 +189,8 @@ def get_tensor(self, out_dtype = None) -> torch.Tensor: global global_tensorcache + torch.cuda.synchronize() + if self.tensor_remap and (not_fast or not self.fast): key = self.tensor_remap[key] @@ -236,4 +238,6 @@ def get_tensor(self, global_tensorcache = global_tensorcache[1:] global_tensorcache.append((cachekey, tensor)) + torch.cuda.synchronize() + return tensor From d3fe9f25d2f1955a19b5f3015d50e5872a55cf1e Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 25 Aug 2024 21:15:28 +0200 Subject: [PATCH 02/10] Unmap tensors on CPU to reduce temp VRAM overhead while loading --- exllamav2/exllamav2_ext/cpp/safetensors.cpp | 63 ++++++++++++++++++++- exllamav2/exllamav2_ext/cpp/safetensors.h | 13 +++++ exllamav2/exllamav2_ext/ext_bindings.cpp | 2 + exllamav2/ext.py | 4 +- exllamav2/linear.py | 11 ++-- exllamav2/module.py | 16 +++--- 6 files changed, 95 insertions(+), 14 deletions(-) diff --git a/exllamav2/exllamav2_ext/cpp/safetensors.cpp b/exllamav2/exllamav2_ext/cpp/safetensors.cpp index b15cd26c..712fe4be 100644 --- a/exllamav2/exllamav2_ext/cpp/safetensors.cpp +++ b/exllamav2/exllamav2_ext/cpp/safetensors.cpp @@ -453,4 +453,65 @@ void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tenso remaining -= chunk; } } -} \ No newline at end of file +} + +void tensor_remap +( + torch::Tensor tensor, + torch::Tensor index +) +{ + TORCH_CHECK_SHAPES(tensor, 1, index, 0, 1); + TORCH_CHECK_DTYPE(tensor, kInt); + TORCH_CHECK_DTYPE(index, kInt); + + int rows = tensor.size(0); + int cols = tensor.size(1); + uint32_t* temp = (uint32_t*) calloc(cols, sizeof(int)); + uint32_t* a = (uint32_t*) tensor.data_ptr(); + uint32_t* idx = (uint32_t*) index.data_ptr(); + + for (int r = 0; r < rows; ++r) + { + memcpy(temp, a, sizeof(uint32_t) * cols); + for (int c = 0; c < cols; ++c) + { + *a++ = temp[idx[c]]; + } + } + free(temp); +} + +void tensor_remap_4bit +( + torch::Tensor tensor, + torch::Tensor index +) +{ + TORCH_CHECK_SHAPES(index, 0, tensor, 1, 8); + TORCH_CHECK_DTYPE(tensor, kInt); + TORCH_CHECK_DTYPE(index, kInt); + + int rows = tensor.size(0); + int cols = index.size(0); + uint32_t* temp = (uint32_t*) calloc(cols / 8, sizeof(int)); + uint32_t* a = (uint32_t*) tensor.data_ptr(); + uint32_t* idx = (uint32_t*) index.data_ptr(); + + for (int r = 0; r < rows; ++r) + { + memcpy(temp, a, sizeof(uint32_t) * cols / 8); + for (int c = 0; c < cols;) + { + uint32_t rv = 0; + for (int b = 0; b < 8; ++b, ++c) + { + uint32_t i = idx[c]; + uint32_t v = (temp[i / 8] >> ((i & 7) * 4) & 0x0f); + rv |= v << (b * 4); + } + *a++ = rv; + } + } + free(temp); +} diff --git a/exllamav2/exllamav2_ext/cpp/safetensors.h b/exllamav2/exllamav2_ext/cpp/safetensors.h index f4c6d284..4ad054bb 100644 --- a/exllamav2/exllamav2_ext/cpp/safetensors.h +++ b/exllamav2/exllamav2_ext/cpp/safetensors.h @@ -47,4 +47,17 @@ uintptr_t safetensors_open_fb(const char* filename); void safetensors_close_fb(uintptr_t handle); void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tensor target); +void tensor_remap +( + torch::Tensor tensor, + torch::Tensor index +); + +void tensor_remap_4bit +( + torch::Tensor tensor, + torch::Tensor index +); + + #endif \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_bindings.cpp b/exllamav2/exllamav2_ext/ext_bindings.cpp index 387c2aef..da8d2fd5 100644 --- a/exllamav2/exllamav2_ext/ext_bindings.cpp +++ b/exllamav2/exllamav2_ext/ext_bindings.cpp @@ -55,6 +55,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("safetensors_pinned_buffer", &safetensors_pinned_buffer, "safetensors_pinned_buffer"); m.def("safetensors_free_pinned_buffer", &safetensors_free_pinned_buffer, "safetensors_free_pinned_buffer"); m.def("safetensors_read_fb", &safetensors_read_fb, "safetensors_read_fb"); + m.def("tensor_remap", &tensor_remap, "tensor_remap"); + m.def("tensor_remap_4bit", &tensor_remap_4bit, "tensor_remap_4bit"); // qmatrix diff --git a/exllamav2/ext.py b/exllamav2/ext.py index f23daa6d..a828c7b4 100644 --- a/exllamav2/ext.py +++ b/exllamav2/ext.py @@ -173,9 +173,9 @@ def find_msvc(): # gcc / cl.exe flags if windows: - extra_cflags = ["/Ox", "/openmp"] + extra_cflags = ["/Ox"] else: - extra_cflags = ["-Ofast", "-fopenmp"] + extra_cflags = ["-Ofast"] if ext_debug: extra_cflags += ["-ftime-report", "-DTORCH_USE_CUDA_DSA"] diff --git a/exllamav2/linear.py b/exllamav2/linear.py index 1c0e264d..7eb34ab0 100644 --- a/exllamav2/linear.py +++ b/exllamav2/linear.py @@ -8,6 +8,7 @@ from exllamav2.compat import safe_move_tensor from exllamav2.tensor_p import BROADCAST_VC from exllamav2.util import unpack_4bit, pack_4bit +import gc from typing import TYPE_CHECKING @@ -118,7 +119,7 @@ def load(self, cfg = self.model.config if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv) - if w is None: w = self.load_weight() + if w is None: w = self.load_weight(cpu = output_map is not None) # Load quantized linear layer from dictionary @@ -137,7 +138,7 @@ def load(self, self.q_tensors = w if unmap and "q_perm" in w: - perm = w["q_perm"] + perm = w["q_perm"].cpu() del w["q_perm"] del w["q_invperm"] # w["q_perm"] = torch.arange(0, w["q_perm"].shape[-1], dtype = w["q_perm"].dtype, device = w["q_perm"].device) @@ -146,8 +147,10 @@ def load(self, perm = None if output_map is not None: - w["q_weight"] = w["q_weight"][:, output_map] - w["q_scale"] = pack_4bit(unpack_4bit(w["q_scale"])[:, output_map]) + ext_c.tensor_remap(w["q_weight"], output_map) + ext_c.tensor_remap_4bit(w["q_scale"], output_map) + for k in w.keys(): + w[k] = safe_move_tensor(w[k], self.device()) self.q_handle = ext.make_q_matrix(w, self.temp_dq, diff --git a/exllamav2/module.py b/exllamav2/module.py index bb1d0ce4..4657c911 100644 --- a/exllamav2/module.py +++ b/exllamav2/module.py @@ -60,7 +60,8 @@ def device(self) -> str: def load_multi(self, key: str, keys: list[str], - measure: bool = False) -> int | dict[str: torch.Tensor]: + measure: bool = False, + cpu: bool = False) -> int | dict[str: torch.Tensor]: tensors = {} submap = {} @@ -85,13 +86,14 @@ def load_multi(self, if measure: size += stfile.measure(key + "." + k) else: - tensors[k] = stfile.get_tensor(key + "." + k, device = self.device()) + tensors[k] = stfile.get_tensor(key + "." + k, device = self.device() if not cpu else "cpu") return size if measure else tensors def load_weight(self, - override_key: str | None = None): + override_key: str | None = None, + cpu: bool = False): if override_key is not None: keys = [override_key] @@ -105,14 +107,14 @@ def load_weight(self, # EXL2 if key + ".q_weight" in self.model.config.tensor_file_map: - qtensors = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"]) + qtensors = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"], cpu = cpu) qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int) return qtensors # GPTQ if key + ".qweight" in self.model.config.tensor_file_map: - qtensors = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx", "bias"]) + qtensors = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx", "bias"], cpu = cpu) if "bias" in qtensors and torch.all(qtensors["bias"].eq(0)): del qtensors["bias"] qtensors["scales"] = qtensors["scales"].half() @@ -122,14 +124,14 @@ def load_weight(self, if key + ".weight" in self.model.config.tensor_file_map: if key + ".bias" in self.model.config.tensor_file_map: - tensors = self.load_multi(key, ["weight", "bias"]) + tensors = self.load_multi(key, ["weight", "bias"], cpu = cpu) tensor = tensors["weight"].half() bias = tensors["bias"].half() if self.model.config.arch.orig_weights_transposed and len(tensor.shape) == 2: tensor = tensor.T return nn.Parameter(tensor, requires_grad = False), nn.Parameter(bias, requires_grad = False) else: - tensors = self.load_multi(key, ["weight"]) + tensors = self.load_multi(key, ["weight"], cpu = cpu) tensor = tensors["weight"].half() # if self.model.config.arch.orig_weights_transposed: # tensor = tensor.T From e539f7cc286f185603c8cb6414a7db2668982211 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 25 Aug 2024 21:56:02 +0200 Subject: [PATCH 03/10] Fix another possible sync issues with fasttensors --- exllamav2/fasttensors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exllamav2/fasttensors.py b/exllamav2/fasttensors.py index ef49ca04..585c5608 100644 --- a/exllamav2/fasttensors.py +++ b/exllamav2/fasttensors.py @@ -226,7 +226,8 @@ def get_tensor(self, offset = data_offsets[0] + self.header_size length = data_offsets[1] - data_offsets[0] assert np.prod(sh) * dts == length, f"Tensor shape doesn't match storage size: {key}" - + if device != "cpu": + torch.cuda.set_stream(torch.cuda.default_stream(device)) tensor = torch.empty(sh, device = device, dtype = dtt) ext_c.safetensors_load(self.handle, tensor, offset, length) From 69291d13335dd553ac475927866f58606d99c342 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 25 Aug 2024 22:01:53 +0200 Subject: [PATCH 04/10] Fix another possible sync issues with fasttensors (for Windows) --- exllamav2/fasttensors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exllamav2/fasttensors.py b/exllamav2/fasttensors.py index 585c5608..87d36ce8 100644 --- a/exllamav2/fasttensors.py +++ b/exllamav2/fasttensors.py @@ -213,6 +213,8 @@ def get_tensor(self, size = end - beg numel = size // esize shape = h["shape"] + if device != "cpu": + torch.cuda.set_stream(torch.cuda.default_stream(device)) tensor = torch.zeros(shape, dtype = dtype, device = device) assert tensor.is_contiguous, "Non-contiguous tensor" ext_c.safetensors_read_fb(self.handle_fb, beg + self.header_size, size, tensor) From 4230dab3c12fa1ad7745afa1eb8c53bae5cf7198 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 27 Aug 2024 19:16:23 +0200 Subject: [PATCH 05/10] Fix model_init feedback --- exllamav2/model_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exllamav2/model_init.py b/exllamav2/model_init.py index 1f91adec..f03c3a3d 100644 --- a/exllamav2/model_init.py +++ b/exllamav2/model_init.py @@ -34,7 +34,7 @@ def print_options(args): print_opts = [] if args.gpu_split is not None: print_opts += [f"gpu_split: {args.gpu_split}"] - if args.tensor_parallel is not None: print_opts += ["tensor_parallel"] + if args.tensor_parallel: print_opts += ["tensor_parallel"] if args.length is not None: print_opts += [f"length: {args.length}"] if args.rope_scale is not None: print_opts += [f"rope_scale: {args.rope_scale}"] if args.rope_alpha is not None: print_opts += [f"rope_alpha: {args.rope_alpha}"] From 7319b6ea31234eff132e1c91c4869a315a2e1563 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 27 Aug 2024 20:07:21 +0200 Subject: [PATCH 06/10] Fix graph update for MLP with post layernorm --- exllamav2/exllamav2_ext/cuda/graph.cu | 21 ++++++++++++--------- exllamav2/exllamav2_ext/cuda/graph.cuh | 6 +++--- exllamav2/exllamav2_ext/cuda/q_mlp.cu | 4 ++-- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/exllamav2/exllamav2_ext/cuda/graph.cu b/exllamav2/exllamav2_ext/cuda/graph.cu index aefb4850..877b5d15 100644 --- a/exllamav2/exllamav2_ext/cuda/graph.cu +++ b/exllamav2/exllamav2_ext/cuda/graph.cu @@ -133,7 +133,7 @@ void Graph::attach_label(cudaStream_t stream, int label, int sublabel) } template -void Graph::update_param(int label, int sublabel, int param, T value) +void Graph::update_param(int label, int sublabel, int param, T value, bool debug) { for (int i = 0; i < node_labels.size(); ++i) { @@ -145,19 +145,22 @@ void Graph::update_param(int label, int sublabel, int param, T value) node_needs_update[i] = true; -// printf("-----------------------------------------------------\n"); -// printf("UPDATED:\n"); -// DBGI(i); -// inspect_graph(); + if (debug) + { + printf("-----------------------------------------------------\n"); + printf("UPDATED: "); + DBGI(i); + inspect_graph(); + } } } -void Graph::update_param_ptr(int label, int sublabel, int param, void* value) +void Graph::update_param_ptr(int label, int sublabel, int param, void* value, bool debug) { - update_param(label, sublabel, param, value); + update_param(label, sublabel, param, value, debug); } -void Graph::update_param_int(int label, int sublabel, int param, int value) +void Graph::update_param_int(int label, int sublabel, int param, int value, bool debug) { - update_param(label, sublabel, param, value); + update_param(label, sublabel, param, value, debug); } diff --git a/exllamav2/exllamav2_ext/cuda/graph.cuh b/exllamav2/exllamav2_ext/cuda/graph.cuh index 084a60f3..2d261b8f 100644 --- a/exllamav2/exllamav2_ext/cuda/graph.cuh +++ b/exllamav2/exllamav2_ext/cuda/graph.cuh @@ -46,10 +46,10 @@ public: void attach_label(cudaStream_t stream, int label, int sublabel); template - void update_param(int label, int sublabel, int param, T value); + void update_param(int label, int sublabel, int param, T value, bool debug); - void update_param_ptr(int label, int sublabel, int param, void* value); - void update_param_int(int label, int sublabel, int param, int value); + void update_param_ptr(int label, int sublabel, int param, void* value, bool debug = false); + void update_param_int(int label, int sublabel, int param, int value, bool debug = false); }; diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cu b/exllamav2/exllamav2_ext/cuda/q_mlp.cu index 53b5ce57..8e9cda2e 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cu +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cu @@ -109,7 +109,7 @@ void QMLP::forward_ if (graph->count()) { graph->begin_capture(stream); - forward_run_(stream, cublas_handle, (half*) x, rows, columns, loras, lora_temp, graph); + forward_run_(stream, cublas_handle, (void*) x, rows, columns, loras, lora_temp, graph); graph->end_capture(stream); // printf("**** record ****\n"); // DBGI2(rows, columns); @@ -225,7 +225,7 @@ void QMLP::forward_run_ else { - gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, graph, 0); + gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, false, NULL, 0, false, graph, 0); if (layernorm_is_rms) rms_norm_cuda(stream, temp_state, post_layernorm, x, norm_epsilon, rows, columns, true, false, residual_fp32, graph, KernelLabels::POST_NORM); else From d9f0ecc12cc353df7a29bc3a0bdb540023d836dd Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 27 Aug 2024 21:47:49 +0200 Subject: [PATCH 07/10] TP: Fix vocab split for models with odd vocab sizes --- exllamav2/tensor_p.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exllamav2/tensor_p.py b/exllamav2/tensor_p.py index b4ac0f99..f8e88e59 100644 --- a/exllamav2/tensor_p.py +++ b/exllamav2/tensor_p.py @@ -155,7 +155,7 @@ def define_split( # Vocab split - vc_split = [s * 32 for s in integer_split(cfg.vocab_size // 32, gpu_split, 16)] + vc_split = [s * 32 for s in integer_split((cfg.vocab_size + 31) // 32, gpu_split, 16)] def set_split(raw_split): b = 0 From 8d3d4c227e70dc0e3f3bc1bac9fd8899765d2182 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 27 Aug 2024 21:48:30 +0200 Subject: [PATCH 08/10] Ensure logit padding happens on default stream --- exllamav2/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/exllamav2/model.py b/exllamav2/model.py index 77ae01dd..6eb90df3 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -989,6 +989,9 @@ def forward_chunk(self, if self.tp_context: self.tp_context.wait_streams() + if x is not None and x.is_cuda: + torch.cuda.set_stream(torch.cuda.default_stream(x.device)) + # Apply logit scale # if x is not None and self.config.logit_scale != 1: From db14154fee76d4c3a6a74c0bd38fb86d08a80673 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 27 Aug 2024 22:06:19 +0200 Subject: [PATCH 09/10] Don't use default stream for logit padding mask after all --- exllamav2/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index 6eb90df3..3419e032 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -990,7 +990,8 @@ def forward_chunk(self, self.tp_context.wait_streams() if x is not None and x.is_cuda: - torch.cuda.set_stream(torch.cuda.default_stream(x.device)) + context = self.get_device_context(x.device.index) + torch.cuda.set_stream(context.stream) # Apply logit scale From f1d8909809899070cc298a0b3a0b5be9fa531bd3 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Wed, 28 Aug 2024 20:25:56 +0200 Subject: [PATCH 10/10] Catch all exceptions for nvidia-smi and rocm-smi --- exllamav2/util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/exllamav2/util.py b/exllamav2/util.py index a1bd570a..fd44462e 100644 --- a/exllamav2/util.py +++ b/exllamav2/util.py @@ -291,19 +291,19 @@ def get_all_gpu_memory(): try: nvidia_memory = get_nvidia_gpu_memory(visible_devices) gpu_memory.update(nvidia_memory) - except FileNotFoundError: + except: pass # print("nvidia-smi not found. Skipping NVIDIA GPU check.") try: amd_memory = get_amd_gpu_memory() gpu_memory.update(amd_memory) - except FileNotFoundError: + except: pass - # print("rocm-smi not found. Skipping AMD GPU check.") # TODO: remove warning on NVidia, test on AMD + # print("rocm-smi not found. Skipping AMD GPU check.") # TODO: test on AMD assert gpu_memory, \ - "Unable to read available VRAM from nvidia-smi or rocm-smi" + "Unable to read available VRAM from either nvidia-smi or rocm-smi" return gpu_memory