diff --git a/exllamav2/exllamav2_ext/config.h b/exllamav2/exllamav2_ext/config.h deleted file mode 100644 index 36de128c..00000000 --- a/exllamav2/exllamav2_ext/config.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _config_h -#define _config_h - -#define MAX_Q_GEMM_ROWS 32 -#define MAX_Q_GEMM_ROWS_KERNEL 4 -#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS_KERNEL - -#define QMODE_2BIT 1 -#define QMODE_3BIT 1 -#define QMODE_4BIT 1 -#define QMODE_5BIT 1 -#define QMODE_6BIT 1 -#define QMODE_8BIT 0 - -#define USE_AVX2 -//#define PROFILING - -#define Q_CACHE_BLOCKSIZE_Q 512 -#define Q_CACHE_SUPER_BLOCKSIZE_Q (128 * 1024) - -#endif diff --git a/exllamav2/exllamav2_ext/ext_cache.cpp b/exllamav2/exllamav2_ext/ext_cache.cpp index 11413171..049df454 100644 --- a/exllamav2/exllamav2_ext/ext_cache.cpp +++ b/exllamav2/exllamav2_ext/ext_cache.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -15,7 +15,8 @@ #include "cpp/util.h" -void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_size, int offset, int width) +void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor, + int64_t batch_size, int64_t offset, int64_t width) { TORCH_CHECK_DTYPE(in_tensor, kHalf); TORCH_CHECK_DTYPE(out_tensor, kUInt8); @@ -46,7 +47,8 @@ void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_si ); } -void fp8_to_fp16(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_size, int offset, int width) +void fp8_to_fp16(torch::Tensor in_tensor, torch::Tensor out_tensor, + int64_t batch_size, int64_t offset, int64_t width) { TORCH_CHECK_DTYPE(in_tensor, kUInt8); TORCH_CHECK_DTYPE(out_tensor, kHalf); @@ -85,15 +87,15 @@ void fp16_to_q_kv torch::Tensor v_in, torch::Tensor v_out, torch::Tensor v_scales, - int batch_size, - int offset, - int width, - int page_size, + int64_t batch_size, + int64_t offset, + int64_t width, + int64_t page_size, torch::Tensor cache_seqlens, torch::Tensor block_table, torch::Tensor cal_k, torch::Tensor cal_v, - int wbits + int64_t wbits ) { TORCH_CHECK_DTYPE(k_in, kHalf); @@ -193,15 +195,15 @@ void q_to_fp16_kv torch::Tensor v_in, torch::Tensor v_out, torch::Tensor v_scales, - int batch_size, - int offset, - int width, - int page_size, + int64_t batch_size, + int64_t offset, + int64_t width, + int64_t page_size, torch::Tensor cache_seqlens, torch::Tensor block_table, torch::Tensor cal_k, torch::Tensor cal_v, - int wbits + int64_t wbits ) { TORCH_CHECK_DTYPE(k_in, kUInt8); @@ -310,7 +312,7 @@ int count_match ( torch::Tensor a, torch::Tensor b, - int max_a + int64_t max_a ) { uint64_t* pa = (uint64_t*) a.data_ptr(); diff --git a/exllamav2/exllamav2_ext/ext_cache.h b/exllamav2/exllamav2_ext/ext_cache.h deleted file mode 100644 index 96e2899f..00000000 --- a/exllamav2/exllamav2_ext/ext_cache.h +++ /dev/null @@ -1,54 +0,0 @@ - -void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_size, int offset, int width); -void fp8_to_fp16(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_size, int offset, int width); - -void fp16_to_q_kv -( - torch::Tensor k_in, - torch::Tensor k_out, - torch::Tensor k_scales, - torch::Tensor v_in, - torch::Tensor v_out, - torch::Tensor v_scales, - int batch_size, - int offset, - int width, - int page_size, - torch::Tensor cache_seqlens, - torch::Tensor block_table, - torch::Tensor cal_k, - torch::Tensor cal_v, - int wbits -); - -void q_to_fp16_kv -( - torch::Tensor k_in, - torch::Tensor k_out, - torch::Tensor k_scales, - torch::Tensor v_in, - torch::Tensor v_out, - torch::Tensor v_scales, - int batch_size, - int offset, - int width, - int page_size, - torch::Tensor cache_seqlens, - torch::Tensor block_table, - torch::Tensor cal_k, - torch::Tensor cal_v, - int wbits -); - -int count_match -( - torch::Tensor a, - torch::Tensor b, - int max_a -); - -//void array_fp16_to_fp8_ref(torch::Tensor in_tensor, torch::Tensor out_tensor, int size); -//void array_fp8_to_fp16_ref(torch::Tensor in_tensor, torch::Tensor out_tensor, int size); - - - diff --git a/exllamav2/exllamav2_ext/ext_element.cpp b/exllamav2/exllamav2_ext/ext_element.cpp index 29cd6952..a60a21b5 100644 --- a/exllamav2/exllamav2_ext/ext_element.cpp +++ b/exllamav2/exllamav2_ext/ext_element.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -17,7 +17,7 @@ void softcap_ ( torch::Tensor x, - float scale + double scale ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); diff --git a/exllamav2/exllamav2_ext/ext_element.h b/exllamav2/exllamav2_ext/ext_element.h deleted file mode 100644 index 97e6c706..00000000 --- a/exllamav2/exllamav2_ext/ext_element.h +++ /dev/null @@ -1,6 +0,0 @@ - -void softcap_ -( - torch::Tensor x, - float scale -); diff --git a/exllamav2/exllamav2_ext/ext_gemm.cpp b/exllamav2/exllamav2_ext/ext_gemm.cpp index 9e1464eb..7dba8e24 100644 --- a/exllamav2/exllamav2_ext/ext_gemm.cpp +++ b/exllamav2/exllamav2_ext/ext_gemm.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -20,8 +20,8 @@ void gemm_half_half_half torch::Tensor a, torch::Tensor b, torch::Tensor c, - const float alpha, - const float beta, + const double alpha, + const double beta, bool force_cublas ) { diff --git a/exllamav2/exllamav2_ext/ext_gemm.h b/exllamav2/exllamav2_ext/ext_gemm.h deleted file mode 100644 index ee2edd77..00000000 --- a/exllamav2/exllamav2_ext/ext_gemm.h +++ /dev/null @@ -1,10 +0,0 @@ - -void gemm_half_half_half -( - torch::Tensor a, - torch::Tensor b, - torch::Tensor c, - const float alpha, - const float beta, - bool force_cublas -); diff --git a/exllamav2/exllamav2_ext/ext_hadamard.cpp b/exllamav2/exllamav2_ext/ext_hadamard.cpp index c855fe3c..51f722c4 100644 --- a/exllamav2/exllamav2_ext/ext_hadamard.cpp +++ b/exllamav2/exllamav2_ext/ext_hadamard.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include diff --git a/exllamav2/exllamav2_ext/ext_hadamard.h b/exllamav2/exllamav2_ext/ext_hadamard.h deleted file mode 100644 index d2c51827..00000000 --- a/exllamav2/exllamav2_ext/ext_hadamard.h +++ /dev/null @@ -1,10 +0,0 @@ - -void had_paley -( - torch::Tensor h -); - -void had_paley2 -( - torch::Tensor h -); \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_norm.cpp b/exllamav2/exllamav2_ext/ext_norm.cpp index 6c08b860..f8774071 100644 --- a/exllamav2/exllamav2_ext/ext_norm.cpp +++ b/exllamav2/exllamav2_ext/ext_norm.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -25,7 +25,7 @@ void rms_norm torch::Tensor x, torch::Tensor w, torch::Tensor y, - float epsilon + double epsilon ) { bool input_fp32 = x.dtype() == torch::kFloat; @@ -61,7 +61,7 @@ void rms_norm_tp std::vector x, std::vector w, std::vector y, - float epsilon, + double epsilon, uintptr_t tp_context ) { @@ -96,7 +96,7 @@ void rms_norm_ ( torch::Tensor x, torch::Tensor w, - float epsilon + double epsilon ) { rms_norm(x, w, x, epsilon); @@ -111,7 +111,7 @@ void layer_norm torch::Tensor w, torch::Tensor b, torch::Tensor y, - float epsilon + double epsilon ) { TORCH_CHECK_DTYPE(x, kHalf); @@ -147,7 +147,7 @@ void layer_norm_ torch::Tensor x, torch::Tensor w, torch::Tensor b, - float epsilon + double epsilon ) { layer_norm(x, w, b, x, epsilon); @@ -162,7 +162,7 @@ void head_norm torch::Tensor w, torch::Tensor b, torch::Tensor y, - float epsilon + double epsilon ) { TORCH_CHECK_DTYPE(x, kHalf); @@ -202,7 +202,7 @@ void head_norm_ torch::Tensor x, torch::Tensor w, torch::Tensor b, - float epsilon + double epsilon ) { head_norm(x, w, b, x, epsilon); diff --git a/exllamav2/exllamav2_ext/ext_norm.h b/exllamav2/exllamav2_ext/ext_norm.h deleted file mode 100644 index 15e43f4b..00000000 --- a/exllamav2/exllamav2_ext/ext_norm.h +++ /dev/null @@ -1,61 +0,0 @@ - -void rms_norm -( - torch::Tensor x, - torch::Tensor w, - torch::Tensor y, - float epsilon -); - -void rms_norm_tp -( - std::vector x, - std::vector w, - std::vector y, - float epsilon, - uintptr_t tp_context -); - -void rms_norm_ -( - torch::Tensor x, - torch::Tensor w, - float epsilon -); - -void layer_norm -( - torch::Tensor x, - torch::Tensor w, - torch::Tensor b, - torch::Tensor y, - float epsilon -); - -void layer_norm_ -( - torch::Tensor x, - torch::Tensor w, - torch::Tensor b, - float epsilon -); - -void head_norm -( - torch::Tensor x, - torch::Tensor w, - torch::Tensor b, - torch::Tensor y, - float epsilon -); - -void head_norm_ -( - torch::Tensor x, - torch::Tensor w, - torch::Tensor b, - float epsilon -); - - - diff --git a/exllamav2/exllamav2_ext/ext_ops.h b/exllamav2/exllamav2_ext/ext_ops.h new file mode 100644 index 00000000..c3ab02ae --- /dev/null +++ b/exllamav2/exllamav2_ext/ext_ops.h @@ -0,0 +1,749 @@ +#pragma once + +#include +#include + +#include + +#include "cpp/quantize_func.h" +#include "cpp/safetensors.h" +#include "cpp/generator.h" +#include "cpp/threadpool.h" +#include "cuda/tp.cuh" + +// cache ops +void fp16_to_fp8( + torch::Tensor in_tensor, + torch::Tensor out_tensor, + int64_t batch_size, + int64_t offset, + int64_t width); +void fp8_to_fp16( + torch::Tensor in_tensor, + torch::Tensor out_tensor, + int64_t batch_size, + int64_t offset, + int64_t width); + +void fp16_to_q_kv +( + torch::Tensor k_in, + torch::Tensor k_out, + torch::Tensor k_scales, + torch::Tensor v_in, + torch::Tensor v_out, + torch::Tensor v_scales, + int64_t batch_size, + int64_t offset, + int64_t width, + int64_t page_size, + torch::Tensor cache_seqlens, + torch::Tensor block_table, + torch::Tensor cal_k, + torch::Tensor cal_v, + int64_t wbits +); + +void q_to_fp16_kv +( + torch::Tensor k_in, + torch::Tensor k_out, + torch::Tensor k_scales, + torch::Tensor v_in, + torch::Tensor v_out, + torch::Tensor v_scales, + int64_t batch_size, + int64_t offset, + int64_t width, + int64_t page_size, + torch::Tensor cache_seqlens, + torch::Tensor block_table, + torch::Tensor cal_k, + torch::Tensor cal_v, + int64_t wbits +); + +int64_t count_match +( + torch::Tensor a, + torch::Tensor b, + int64_t max_a +); + +// element ops +void softcap_ +( + torch::Tensor x, + double scale +); + +// gemm ops +void gemm_half_half_half +( + torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + const double alpha, + const double beta, + bool force_cublas +); + +// hadamard ops +void had_paley +( + torch::Tensor h +); + +void had_paley2 +( + torch::Tensor h +); + +// layernorm ops +void rms_norm +( + torch::Tensor x, + torch::Tensor w, + torch::Tensor y, + double epsilon +); + +void rms_norm_tp +( + std::vector x, + std::vector w, + std::vector y, + double epsilon, + uintptr_t tp_context +); + +void rms_norm_ +( + torch::Tensor x, + torch::Tensor w, + double epsilon +); + +void layer_norm +( + torch::Tensor x, + torch::Tensor w, + torch::Tensor b, + torch::Tensor y, + double epsilon +); + +void layer_norm_ +( + torch::Tensor x, + torch::Tensor w, + torch::Tensor b, + double epsilon +); + +void head_norm +( + torch::Tensor x, + torch::Tensor w, + torch::Tensor b, + torch::Tensor y, + double epsilon +); + +void head_norm_ +( + torch::Tensor x, + torch::Tensor w, + torch::Tensor b, + double epsilon +); + +// qattn ops +uintptr_t make_q_attn +( + torch::Tensor layernorm, + torch::Tensor layernorm_bias, + bool layernorm_is_rms, + double norm_epsilon, + uintptr_t q_q_proj, + uintptr_t q_k_proj, + uintptr_t q_v_proj, + uintptr_t q_o_proj, + torch::Tensor temp_state, +// torch::Tensor temp_q, +// torch::Tensor temp_k, +// torch::Tensor temp_v, + torch::Tensor temp_dq, + int64_t max_rows, + int64_t hidden_size, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_dim, + int64_t max_seq_len, + bool has_residual, + int64_t rope_style, + torch::Tensor q_norm, + torch::Tensor k_norm, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias, + bool residual_fp32, + bool use_graphs +); + +void free_q_attn +( + uintptr_t handle +); + +void q_attn_forward_1 +( + uintptr_t q_attn, + torch::Tensor x, + int64_t batch_size, + int64_t q_len, + int64_t past_len, + torch::Tensor past_lens, + torch::Tensor q_temp, + torch::Tensor k_temp, + torch::Tensor v_temp, + torch::Tensor sin, + torch::Tensor cos, + const std::vector& loras, + torch::Tensor loras_temp +); + +void q_attn_forward_2 +( + uintptr_t q_attn, + torch::Tensor x, + torch::Tensor attn_output, + int64_t batch_size, + int64_t q_len, + const std::vector& loras, + torch::Tensor loras_temp +); + +int64_t q_attn_set_loras +( + uintptr_t q_attn, + std::unordered_map& q_proj_lora_a, + std::unordered_map& q_proj_lora_b, + std::unordered_map& k_proj_lora_a, + std::unordered_map& k_proj_lora_b, + std::unordered_map& v_proj_lora_a, + std::unordered_map& v_proj_lora_b, + std::unordered_map& o_proj_lora_a, + std::unordered_map& o_proj_lora_b +); + +// TODO: Find a way to call this function directly without going through pybind + +typedef std::vector (*MHAFwdKVCacheFunc) +( + at::Tensor &, + const at::Tensor &, + const at::Tensor &, + c10::optional &, + c10::optional &, + c10::optional &, + c10::optional &, + c10::optional &, + c10::optional &, + c10::optional &, + c10::optional &, + c10::optional &, + c10::optional &, + const double, + bool, + int64_t, + int64_t, + const double, + bool, + int64_t, +); + +//void set_flash_attn_func(MHAFwdKVCacheFunc f); +void set_flash_attn_func(); + +void tp_attn_forward_paged_ +( + uintptr_t tp_context, + torch::Tensor hidden_states, + const std::vector &temp_bc0, + const std::vector &temp_bc1, + const std::vector &temp_bc2, + const std::vector &temp_q, + const std::vector &temp_k, + const std::vector &temp_v, + const std::vector &temp_o, + const std::vector &k_cache, + const std::vector &v_cache, + const std::vector &pre_layernorm, + double norm_epsilon, + const std::vector &q_proj, + const std::vector &k_proj, + const std::vector &v_proj, + const std::vector &o_proj, + int64_t head_dim, + int64_t rope_style, + int64_t batch_size, + int64_t q_len, + const std::vector &sin, + const std::vector &cos, + const std::vector &past_lens, + const std::vector &block_index, + double scaling +); + +void tp_attn_forward_ +( + uintptr_t tp_context, + torch::Tensor hidden_states, + const std::vector &temp_bc0, + const std::vector &temp_bc1, + const std::vector &temp_bc2, + const std::vector &temp_q, + const std::vector &temp_k, + const std::vector &temp_v, + const std::vector &temp_o, + const std::vector &k_cache, + const std::vector &v_cache, + const std::vector &pre_layernorm, + double norm_epsilon, + const std::vector &q_proj, + const std::vector &k_proj, + const std::vector &v_proj, + const std::vector &o_proj, + int64_t head_dim, + int64_t rope_style, + int64_t batch_size, + int64_t q_len, + const std::vector &sin, + const std::vector &cos, + const std::vector &past_len_tp, + double scaling +); + +// qmatrix ops +uintptr_t make_q_matrix +( + torch::Tensor q_weight, + torch::Tensor q_perm, + torch::Tensor q_invperm, + torch::Tensor q_scale, + torch::Tensor q_scale_max, + torch::Tensor q_groups, + torch::Tensor q_group_map, + torch::Tensor gptq_qzeros, + torch::Tensor gptq_scales, + torch::Tensor gptq_g_idx, + torch::Tensor bias, + torch::Tensor temp_dq, + int64_t max_dq_rows +); + +uintptr_t make_q_matrix_split +( + torch::Tensor q_weight, + torch::Tensor q_perm, + torch::Tensor q_invperm, + torch::Tensor q_scale, + torch::Tensor q_scale_max, + torch::Tensor q_groups, + torch::Tensor q_group_map, + torch::Tensor gptq_qzeros, + torch::Tensor gptq_scales, + torch::Tensor gptq_g_idx, + torch::Tensor bias, + torch::Tensor temp_dq, + int64_t max_dq_rows +); + +void free_q_matrix +( + uintptr_t tp_context +); + +void reconstruct +( + uintptr_t q_handle, + torch::Tensor output +); + +void gemm_half_q_half +( + torch::Tensor a, + uintptr_t b, + torch::Tensor c, + bool force_cuda +); + +void gemm_half_q_half_tp +( + const std::vector &a, + const std::vector &b, + const std::vector &c, + bool force_cuda, + uintptr_t tp_context, + int64_t t_device = -1 +); + +void matrix_q4_to_fp16 +( + torch::Tensor in, + torch::Tensor scales, + torch::Tensor out +); + +void matrix_fp16_to_q4 +( + torch::Tensor in, + torch::Tensor out, + torch::Tensor scales +); + +// qmlp ops +uintptr_t make_q_mlp +( + torch::Tensor layernorm, + torch::Tensor layernorm_bias, + bool layernorm_is_rms, + double norm_epsilon, + uintptr_t q_gate, + uintptr_t q_up, + uintptr_t q_down, + torch::Tensor temp_state, + torch::Tensor temp_a, + torch::Tensor temp_b, + torch::Tensor temp_dq, + int64_t max_rows, + bool act_gelu, + bool has_residual, + torch::Tensor post_layernorm, + torch::Tensor post_layernorm_bias, + bool residual_fp32, + bool use_graphs +); + +void free_q_mlp +( + uintptr_t handle +); + +void q_mlp_forward_ +( + uintptr_t q_mlp, + torch::Tensor x, + const std::vector& loras, + torch::Tensor loras_temp +); + +int64_t q_mlp_set_loras +( + uintptr_t q_mlp, + std::unordered_map& gate_proj_lora_a, + std::unordered_map& gate_proj_lora_b, + std::unordered_map& up_proj_lora_a, + std::unordered_map& up_proj_lora_b, + std::unordered_map& down_proj_lora_a, + std::unordered_map& down_proj_lora_b +); + +uintptr_t make_q_moe_mlp +( + torch::Tensor layernorm, + torch::Tensor layernorm_bias, + bool layernorm_is_rms, + double norm_epsilon, + torch::Tensor gate, + int64_t num_experts, + int64_t num_experts_per_token, + const std::vector& w1, + const std::vector& w2, + const std::vector& w3, + torch::Tensor temp_state, + torch::Tensor temp_gathered_state, + torch::Tensor temp_a, + torch::Tensor temp_b, + torch::Tensor temp_logits, + torch::Tensor temp_dq, + int64_t max_rows, + bool act_gelu +); + +void free_q_moe_mlp +( + uintptr_t handle +); + +void q_moe_mlp_forward_ +( + uintptr_t q_moe_mlp, + torch::Tensor x +// const std::vector& loras, +// torch::Tensor loras_temp +); + +//int64_t q_moe_mlp_set_loras +//( +// uintptr_t q_moe_mlp, +// std::vector>& w1_lora_a, +// std::vector>& w1_lora_b, +// std::vector>& w2_lora_a, +// std::vector>& w2_lora_b, +// std::vector>& w3_lora_a, +// std::vector>& w3_lora_b +//); + +void tp_mlp_forward_ +( + uintptr_t tp_context, + torch::Tensor hidden_states, + const std::vector &temp_bc0, + const std::vector &temp_bc1, + const std::vector &temp_bc2, + const std::vector &temp_gate, + const std::vector &temp_up, + const std::vector &temp_down, + const std::vector &pre_layernorm, + double norm_epsilon, + const std::vector &gate, + const std::vector &up, + const std::vector &down, + bool act_gelu +); + +// quant ops +void pack_columns +( + torch::Tensor input, + torch::Tensor output, + int64_t bits +); + +void pack_rows_4 +( + torch::Tensor input, + torch::Tensor output +); + +void quantize_err +( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + double qzero, + double maxq, + double err_norm, + double min_p, + double max_p, + int64_t p_grid +); + +void quantize +( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + torch::Tensor out_q, + double qzero, + double maxq +); + +std::tuple>, std::vector, double, uint64_t, double> sim_anneal +( + const std::vector>>& slots, + uint64_t max_cost, + double initial_temp, + double cooling_factor, + double min_temp, + int64_t iterations, + double norm +); + +// rope ops +void rope_ +( + torch::Tensor x, + torch::Tensor sin, + torch::Tensor cos, + int64_t past_len, + int64_t num_heads, + int64_t head_dim, + torch::Tensor offsets, + bool neox_style +); + +// sampling ops +void apply_rep_penalty +( + torch::Tensor sequence, + double penalty_max, + int64_t sustain, + int64_t decay, + double alpha_frequency, + double alpha_presence, + torch::Tensor logits +); + +std::vector sample_basic +( + torch::Tensor logits, // shape [bsz, vocab_size] + double temperature, + int64_t top_k, + float top_p, + double top_a, + double min_p, + double tfs, + double typical, + double random, + torch::Tensor output_tokens, // shape [bsz, 1] + torch::Tensor output_probs, // shape [bsz, 1] + torch::Tensor output_kprobs, // None or [bsz, 1, num_probs] + torch::Tensor output_ktokens, // None or [bsz, 1, num_probs] + torch::Tensor logit_filter, // shape [bsz, vocab_size] + bool mirostat, + std::vector& mirostat_mu, + double mirostat_tau, + double mirostat_eta, + double post_temperature, + double min_temp, + double max_temp, + double temp_exponent, + double smoothing_factor, + double skew +); + +void logit_filter_exclusive +( + torch::Tensor filter, // shape [bsz, vocab_size] + const py::list& exclusive_lists +); + +void fast_fill_cpu_ones_bool(torch::Tensor tensor); + +void fast_fadd_cpu(torch::Tensor a, torch::Tensor b); + +void fast_copy_cpu(torch::Tensor a, torch::Tensor b); + +void dump_profile_results(); + +// tensor parallel ops +#ifndef _ext_tp_h +#define _ext_tp_h + +#define BROADCAST_KV 0 +#define BROADCAST_ID 1 +#define BROADCAST_VC 2 +#define BROADCAST_RS 3 +#define BROADCAST_Q 4 + +class ExtTPContext +{ +public: + std::vector> kv_split; + std::vector> id_split; + std::vector> vc_split; + std::vector> rs_split; + std::vector> q_split; + std::vector pinned_temp; + size_t pinned_size; + std::vector streams; + + std::vector all_devices; + + ThreadPool* thread_pool; + ExtTPData* tp_data; + + void* mapped_globals; + + std::vector sync_events; +// std::vector comms; +// std::vector comms_index; + + ExtTPContext + ( + std::vector> _kv_split, + std::vector> _id_split, + std::vector> _vc_split, + std::vector> _rs_split, + std::vector> _q_split, + std::vector _pinned_temp, + std::vector _streams + ); + ~ExtTPContext(); +}; + +uintptr_t make_tp_context +( + const std::vector> kv_split, + const std::vector> id_split, + const std::vector> vc_split, + const std::vector> rs_split, + const std::vector> q_split, + std::vector pinned_temp, + std::vector streams +); + +void free_tp_context(uintptr_t ctx); + +void tp_broadcast +( + uintptr_t tp_context, + int64_t buffer, + torch::Tensor source, + int64_t broadcast_type, + const std::vector &targets, + int64_t dim, + int64_t t_device = -1 +); + +void tp_gather +( + uintptr_t tp_context, + int64_t buffer, + const std::vector &inputs, + int64_t broadcast_type, + const std::vector &targets, + int64_t broadcast_type_target, + int64_t dim, + int64_t t_device = -1 +); + +void tp_gather_barrier +( + uintptr_t tp_context, + int64_t buffer, + const std::vector &inputs, + int64_t broadcast_type, + const std::vector &targets, + int64_t broadcast_type_target, + int64_t dim, + int64_t t_device = -1, + Barrier* barrier = nullptr +); + +void tp_cross_device_barrier +( + uintptr_t tp_context, + int64_t broadcast_type, + int64_t t_device = -1, + int64_t stage = -1, + int64_t next_stage = -1 +); + +//void tp_all_reduce +//( +// uintptr_t tp_context, +// const std::vector &tensors +//); + +void tp_all_reduce +( + uintptr_t tp_context, + int64_t buffer, + const std::vector &tensors, + const std::vector &residuals +); + +#endif \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_qattn.cpp b/exllamav2/exllamav2_ext/ext_qattn.cpp index 4906a5db..eec3a690 100644 --- a/exllamav2/exllamav2_ext/ext_qattn.cpp +++ b/exllamav2/exllamav2_ext/ext_qattn.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -37,14 +37,14 @@ uintptr_t make_q_attn // torch::Tensor temp_k, // torch::Tensor temp_v, torch::Tensor temp_dq, - int max_rows, - int hidden_size, - int num_heads, - int num_kv_heads, - int head_dim, - int max_seq_len, - bool has_residual, - int rope_style, + int64_t max_rows, + int64_t hidden_size, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_dim, + int64_t max_seq_len, + bool64_t has_residual, + int64_t rope_style, torch::Tensor q_norm, torch::Tensor k_norm, torch::Tensor post_layernorm, @@ -113,9 +113,9 @@ void q_attn_forward_1 ( uintptr_t q_attn, torch::Tensor x, - int batch_size, - int q_len, - int past_len, + int64_t batch_size, + int64_t q_len, + int64_t past_len, torch::Tensor past_lens, torch::Tensor q_temp, torch::Tensor k_temp, @@ -160,8 +160,8 @@ void q_attn_forward_2 uintptr_t q_attn, torch::Tensor x, torch::Tensor attn_output, - int batch_size, - int q_len, + int64_t batch_size, + int64_t q_len, const std::vector& loras, torch::Tensor loras_temp ) @@ -269,20 +269,20 @@ void tp_attn_forward_paged_ const std::vector &k_cache, const std::vector &v_cache, const std::vector &pre_layernorm, - float norm_epsilon, + double norm_epsilon, const std::vector &q_proj, const std::vector &k_proj, const std::vector &v_proj, const std::vector &o_proj, - int head_dim, - int rope_style, - int batch_size, - int q_len, + int64_t head_dim, + int64_t rope_style, + int64_t batch_size, + int64_t q_len, const std::vector &sin, const std::vector &cos, const std::vector &past_lens, const std::vector &block_index, - float scaling + double scaling ) { auto fwd_kvcache_func = py::module_::import("flash_attn_2_cuda").attr("fwd_kvcache"); @@ -506,19 +506,19 @@ void tp_attn_forward_ const std::vector &k_cache, const std::vector &v_cache, const std::vector &pre_layernorm, - float norm_epsilon, + double norm_epsilon, const std::vector &q_proj, const std::vector &k_proj, const std::vector &v_proj, const std::vector &o_proj, - int head_dim, - int rope_style, - int batch_size, - int q_len, + int64_t head_dim, + int64_t rope_style, + int64_t batch_size, + int64_t q_len, const std::vector &sin, const std::vector &cos, const std::vector &past_len_tp, - float scaling + double scaling ) { auto fwd_kvcache_func = py::module_::import("flash_attn_2_cuda").attr("fwd_kvcache"); diff --git a/exllamav2/exllamav2_ext/ext_qattn.h b/exllamav2/exllamav2_ext/ext_qattn.h deleted file mode 100644 index bd905609..00000000 --- a/exllamav2/exllamav2_ext/ext_qattn.h +++ /dev/null @@ -1,165 +0,0 @@ - -uintptr_t make_q_attn -( - torch::Tensor layernorm, - torch::Tensor layernorm_bias, - bool layernorm_is_rms, - float norm_epsilon, - uintptr_t q_q_proj, - uintptr_t q_k_proj, - uintptr_t q_v_proj, - uintptr_t q_o_proj, - torch::Tensor temp_state, -// torch::Tensor temp_q, -// torch::Tensor temp_k, -// torch::Tensor temp_v, - torch::Tensor temp_dq, - int max_rows, - int hidden_size, - int num_heads, - int num_kv_heads, - int head_dim, - int max_seq_len, - bool has_residual, - int rope_style, - torch::Tensor q_norm, - torch::Tensor k_norm, - torch::Tensor post_layernorm, - torch::Tensor post_layernorm_bias, - bool residual_fp32, - bool use_graphs -); - -void free_q_attn -( - uintptr_t handle -); - -void q_attn_forward_1 -( - uintptr_t q_attn, - torch::Tensor x, - int batch_size, - int q_len, - int past_len, - torch::Tensor past_lens, - torch::Tensor q_temp, - torch::Tensor k_temp, - torch::Tensor v_temp, - torch::Tensor sin, - torch::Tensor cos, - const std::vector& loras, - torch::Tensor loras_temp -); - -void q_attn_forward_2 -( - uintptr_t q_attn, - torch::Tensor x, - torch::Tensor attn_output, - int batch_size, - int q_len, - const std::vector& loras, - torch::Tensor loras_temp -); - -int q_attn_set_loras -( - uintptr_t q_attn, - std::unordered_map& q_proj_lora_a, - std::unordered_map& q_proj_lora_b, - std::unordered_map& k_proj_lora_a, - std::unordered_map& k_proj_lora_b, - std::unordered_map& v_proj_lora_a, - std::unordered_map& v_proj_lora_b, - std::unordered_map& o_proj_lora_a, - std::unordered_map& o_proj_lora_b -); - -// TODO: Find a way to call this function directly without going through pybind - -typedef std::vector (*MHAFwdKVCacheFunc) -( - at::Tensor &, - const at::Tensor &, - const at::Tensor &, - c10::optional &, - c10::optional &, - c10::optional &, - c10::optional &, - c10::optional &, - c10::optional &, - c10::optional &, - c10::optional &, - c10::optional &, - c10::optional &, - const float, - bool, - int, - int, - const float, - bool, - int -); - -//void set_flash_attn_func(MHAFwdKVCacheFunc f); -void set_flash_attn_func(); - -void tp_attn_forward_paged_ -( - uintptr_t tp_context, - torch::Tensor hidden_states, - const std::vector &temp_bc0, - const std::vector &temp_bc1, - const std::vector &temp_bc2, - const std::vector &temp_q, - const std::vector &temp_k, - const std::vector &temp_v, - const std::vector &temp_o, - const std::vector &k_cache, - const std::vector &v_cache, - const std::vector &pre_layernorm, - float norm_epsilon, - const std::vector &q_proj, - const std::vector &k_proj, - const std::vector &v_proj, - const std::vector &o_proj, - int head_dim, - int rope_style, - int batch_size, - int q_len, - const std::vector &sin, - const std::vector &cos, - const std::vector &past_lens, - const std::vector &block_index, - float scaling -); - -void tp_attn_forward_ -( - uintptr_t tp_context, - torch::Tensor hidden_states, - const std::vector &temp_bc0, - const std::vector &temp_bc1, - const std::vector &temp_bc2, - const std::vector &temp_q, - const std::vector &temp_k, - const std::vector &temp_v, - const std::vector &temp_o, - const std::vector &k_cache, - const std::vector &v_cache, - const std::vector &pre_layernorm, - float norm_epsilon, - const std::vector &q_proj, - const std::vector &k_proj, - const std::vector &v_proj, - const std::vector &o_proj, - int head_dim, - int rope_style, - int batch_size, - int q_len, - const std::vector &sin, - const std::vector &cos, - const std::vector &past_len_tp, - float scaling -); \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_qmatrix.cpp b/exllamav2/exllamav2_ext/ext_qmatrix.cpp index 5baa3fb2..fc8bc54e 100644 --- a/exllamav2/exllamav2_ext/ext_qmatrix.cpp +++ b/exllamav2/exllamav2_ext/ext_qmatrix.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -32,7 +32,7 @@ uintptr_t make_q_matrix torch::Tensor gptq_g_idx, torch::Tensor bias, torch::Tensor temp_dq, - int max_dq_rows + int64_t max_dq_rows ) { TORCH_CHECK_DTYPE(q_weight, kInt); @@ -120,7 +120,7 @@ uintptr_t make_q_matrix_split torch::Tensor gptq_g_idx, torch::Tensor bias, torch::Tensor temp_dq, - int max_dq_rows + int64_t max_dq_rows ) { TORCH_CHECK( @@ -245,7 +245,7 @@ void gemm_half_q_half_tp const std::vector &c, bool force_cuda, uintptr_t tp_context, - int t_device + int64_t t_device ) { ExtTPContext* ctx = reinterpret_cast (tp_context); diff --git a/exllamav2/exllamav2_ext/ext_qmatrix.h b/exllamav2/exllamav2_ext/ext_qmatrix.h deleted file mode 100644 index 2e89e649..00000000 --- a/exllamav2/exllamav2_ext/ext_qmatrix.h +++ /dev/null @@ -1,79 +0,0 @@ - -uintptr_t make_q_matrix -( - torch::Tensor q_weight, - torch::Tensor q_perm, - torch::Tensor q_invperm, - torch::Tensor q_scale, - torch::Tensor q_scale_max, - torch::Tensor q_groups, - torch::Tensor q_group_map, - torch::Tensor gptq_qzeros, - torch::Tensor gptq_scales, - torch::Tensor gptq_g_idx, - torch::Tensor bias, - torch::Tensor temp_dq, - int max_dq_rows -); - -uintptr_t make_q_matrix_split -( - torch::Tensor q_weight, - torch::Tensor q_perm, - torch::Tensor q_invperm, - torch::Tensor q_scale, - torch::Tensor q_scale_max, - torch::Tensor q_groups, - torch::Tensor q_group_map, - torch::Tensor gptq_qzeros, - torch::Tensor gptq_scales, - torch::Tensor gptq_g_idx, - torch::Tensor bias, - torch::Tensor temp_dq, - int max_dq_rows -); - -void free_q_matrix -( - uintptr_t tp_context -); - -void reconstruct -( - uintptr_t q_handle, - torch::Tensor output -); - -void gemm_half_q_half -( - torch::Tensor a, - uintptr_t b, - torch::Tensor c, - bool force_cuda -); - -void gemm_half_q_half_tp -( - const std::vector &a, - const std::vector &b, - const std::vector &c, - bool force_cuda, - uintptr_t tp_context, - int t_device = -1 -); - -void matrix_q4_to_fp16 -( - torch::Tensor in, - torch::Tensor scales, - torch::Tensor out -); - -void matrix_fp16_to_q4 -( - torch::Tensor in, - torch::Tensor out, - torch::Tensor scales -); - - diff --git a/exllamav2/exllamav2_ext/ext_qmlp.cpp b/exllamav2/exllamav2_ext/ext_qmlp.cpp index babc1a6f..46a997d9 100644 --- a/exllamav2/exllamav2_ext/ext_qmlp.cpp +++ b/exllamav2/exllamav2_ext/ext_qmlp.cpp @@ -24,7 +24,7 @@ uintptr_t make_q_mlp torch::Tensor layernorm, torch::Tensor layernorm_bias, bool layernorm_is_rms, - float norm_epsilon, + double norm_epsilon, uintptr_t q_gate, uintptr_t q_up, uintptr_t q_down, @@ -32,7 +32,7 @@ uintptr_t make_q_mlp torch::Tensor temp_a, torch::Tensor temp_b, torch::Tensor temp_dq, - int max_rows, + int64_t max_rows, bool act_gelu, bool has_residual, torch::Tensor post_layernorm, @@ -173,10 +173,10 @@ uintptr_t make_q_moe_mlp torch::Tensor layernorm, torch::Tensor layernorm_bias, bool layernorm_is_rms, - float norm_epsilon, + double norm_epsilon, torch::Tensor gate, - int num_experts, - int num_experts_per_token, + int64_t num_experts, + int64_t num_experts_per_token, const std::vector& w1, const std::vector& w2, const std::vector& w3, @@ -186,7 +186,7 @@ uintptr_t make_q_moe_mlp torch::Tensor temp_b, torch::Tensor temp_logits, torch::Tensor temp_dq, - int max_rows, + int64_t max_rows, bool act_gelu ) { @@ -334,7 +334,7 @@ void tp_mlp_forward_ const std::vector &temp_up_, const std::vector &temp_down_, const std::vector &pre_layernorm, - float norm_epsilon, + double norm_epsilon, const std::vector &gate, const std::vector &up, const std::vector &down, diff --git a/exllamav2/exllamav2_ext/ext_qmlp.h b/exllamav2/exllamav2_ext/ext_qmlp.h deleted file mode 100644 index 4bc5207b..00000000 --- a/exllamav2/exllamav2_ext/ext_qmlp.h +++ /dev/null @@ -1,110 +0,0 @@ - -uintptr_t make_q_mlp -( - torch::Tensor layernorm, - torch::Tensor layernorm_bias, - bool layernorm_is_rms, - float norm_epsilon, - uintptr_t q_gate, - uintptr_t q_up, - uintptr_t q_down, - torch::Tensor temp_state, - torch::Tensor temp_a, - torch::Tensor temp_b, - torch::Tensor temp_dq, - int max_rows, - bool act_gelu, - bool has_residual, - torch::Tensor post_layernorm, - torch::Tensor post_layernorm_bias, - bool residual_fp32, - bool use_graphs -); - -void free_q_mlp -( - uintptr_t handle -); - -void q_mlp_forward_ -( - uintptr_t q_mlp, - torch::Tensor x, - const std::vector& loras, - torch::Tensor loras_temp -); - -int q_mlp_set_loras -( - uintptr_t q_mlp, - std::unordered_map& gate_proj_lora_a, - std::unordered_map& gate_proj_lora_b, - std::unordered_map& up_proj_lora_a, - std::unordered_map& up_proj_lora_b, - std::unordered_map& down_proj_lora_a, - std::unordered_map& down_proj_lora_b -); - -uintptr_t make_q_moe_mlp -( - torch::Tensor layernorm, - torch::Tensor layernorm_bias, - bool layernorm_is_rms, - float norm_epsilon, - torch::Tensor gate, - int num_experts, - int num_experts_per_token, - const std::vector& w1, - const std::vector& w2, - const std::vector& w3, - torch::Tensor temp_state, - torch::Tensor temp_gathered_state, - torch::Tensor temp_a, - torch::Tensor temp_b, - torch::Tensor temp_logits, - torch::Tensor temp_dq, - int max_rows, - bool act_gelu -); - -void free_q_moe_mlp -( - uintptr_t handle -); - -void q_moe_mlp_forward_ -( - uintptr_t q_moe_mlp, - torch::Tensor x -// const std::vector& loras, -// torch::Tensor loras_temp -); - -//int q_moe_mlp_set_loras -//( -// uintptr_t q_moe_mlp, -// std::vector>& w1_lora_a, -// std::vector>& w1_lora_b, -// std::vector>& w2_lora_a, -// std::vector>& w2_lora_b, -// std::vector>& w3_lora_a, -// std::vector>& w3_lora_b -//); - -void tp_mlp_forward_ -( - uintptr_t tp_context, - torch::Tensor hidden_states, - const std::vector &temp_bc0, - const std::vector &temp_bc1, - const std::vector &temp_bc2, - const std::vector &temp_gate, - const std::vector &temp_up, - const std::vector &temp_down, - const std::vector &pre_layernorm, - float norm_epsilon, - const std::vector &gate, - const std::vector &up, - const std::vector &down, - bool act_gelu -); \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index e76e56e5..9d4807d6 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -50,7 +50,7 @@ void pack_columns ( torch::Tensor input, torch::Tensor output, - int bits + int64_t bits ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); @@ -84,12 +84,12 @@ void quantize_err torch::Tensor input, torch::Tensor output, torch::Tensor scale, - float qzero, - float maxq, - float err_norm, - float min_p, - float max_p, - int p_grid + double qzero, + double maxq, + double err_norm, + double min_p, + double max_p, + int64_t p_grid ) { TORCH_CHECK_DTYPE(input, kFloat); @@ -126,8 +126,8 @@ void quantize torch::Tensor output, torch::Tensor scale, torch::Tensor out_q, - float qzero, - float maxq + double qzero, + double maxq ) { TORCH_CHECK_DTYPE(input, kFloat); @@ -152,15 +152,15 @@ void quantize ); } -std::tuple>, std::vector, float, uint64_t, float> sim_anneal +std::tuple>, std::vector, double, uint64_t, double> sim_anneal ( - const std::vector>>& slots, + const std::vector>>& slots, uint64_t max_cost, - float initial_temp, - float cooling_factor, - float min_temp, - int iterations, - float norm + double initial_temp, + double cooling_factor, + double min_temp, + int64_t iterations, + double norm ) { int num_slots = slots.size(); diff --git a/exllamav2/exllamav2_ext/ext_quant.h b/exllamav2/exllamav2_ext/ext_quant.h deleted file mode 100644 index da5821ac..00000000 --- a/exllamav2/exllamav2_ext/ext_quant.h +++ /dev/null @@ -1,51 +0,0 @@ - -#include "cpp/quantize_func.h" -#include -#include - -void pack_columns -( - torch::Tensor input, - torch::Tensor output, - int bits -); - -void pack_rows_4 -( - torch::Tensor input, - torch::Tensor output -); - -void quantize_err -( - torch::Tensor input, - torch::Tensor output, - torch::Tensor scale, - float qzero, - float maxq, - float err_norm, - float min_p, - float max_p, - int p_grid -); - -void quantize -( - torch::Tensor input, - torch::Tensor output, - torch::Tensor scale, - torch::Tensor out_q, - float qzero, - float maxq -); - -std::tuple>, std::vector, float, uint64_t, float> sim_anneal -( - const std::vector>>& slots, - uint64_t max_cost, - float initial_temp, - float cooling_factor, - float min_temp, - int iterations, - float norm -); diff --git a/exllamav2/exllamav2_ext/ext_rope.cpp b/exllamav2/exllamav2_ext/ext_rope.cpp index 9f89cdbf..7c66dde2 100644 --- a/exllamav2/exllamav2_ext/ext_rope.cpp +++ b/exllamav2/exllamav2_ext/ext_rope.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -22,8 +22,8 @@ void rope_ torch::Tensor x, torch::Tensor sin, torch::Tensor cos, - int past_len, - int num_heads, + int64_t past_len, + int64_t num_heads, int head_dim, torch::Tensor offsets, bool neox_style diff --git a/exllamav2/exllamav2_ext/ext_rope.h b/exllamav2/exllamav2_ext/ext_rope.h deleted file mode 100644 index 640ad456..00000000 --- a/exllamav2/exllamav2_ext/ext_rope.h +++ /dev/null @@ -1,12 +0,0 @@ - -void rope_ -( - torch::Tensor x, - torch::Tensor sin, - torch::Tensor cos, - int past_len, - int num_heads, - int head_dim, - torch::Tensor offsets, - bool neox_style -); diff --git a/exllamav2/exllamav2_ext/ext_safetensors.h b/exllamav2/exllamav2_ext/ext_safetensors.h deleted file mode 100644 index 9a4e4971..00000000 --- a/exllamav2/exllamav2_ext/ext_safetensors.h +++ /dev/null @@ -1,2 +0,0 @@ - -#include "cpp/safetensors.h" \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_sampling.h b/exllamav2/exllamav2_ext/ext_sampling.h deleted file mode 100644 index 10664a39..00000000 --- a/exllamav2/exllamav2_ext/ext_sampling.h +++ /dev/null @@ -1,55 +0,0 @@ - -#include "cpp/generator.h" - -void apply_rep_penalty -( - torch::Tensor sequence, - float penalty_max, - int sustain, - int decay, - float alpha_frequency, - float alpha_presence, - torch::Tensor logits -); - -std::vector sample_basic -( - torch::Tensor logits, // shape [bsz, vocab_size] - float temperature, - int top_k, - float top_p, - float top_a, - float min_p, - float tfs, - float typical, - float random, - torch::Tensor output_tokens, // shape [bsz, 1] - torch::Tensor output_probs, // shape [bsz, 1] - torch::Tensor output_kprobs, // None or [bsz, 1, num_probs] - torch::Tensor output_ktokens, // None or [bsz, 1, num_probs] - torch::Tensor logit_filter, // shape [bsz, vocab_size] - bool mirostat, - std::vector& mirostat_mu, - float mirostat_tau, - float mirostat_eta, - float post_temperature, - float min_temp, - float max_temp, - float temp_exponent, - float smoothing_factor, - float skew -); - -void logit_filter_exclusive -( - torch::Tensor filter, // shape [bsz, vocab_size] - const py::list& exclusive_lists -); - -void fast_fill_cpu_ones_bool(torch::Tensor tensor); - -void fast_fadd_cpu(torch::Tensor a, torch::Tensor b); - -void fast_copy_cpu(torch::Tensor a, torch::Tensor b); - -void dump_profile_results(); \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/ext_tp.cpp b/exllamav2/exllamav2_ext/ext_tp.cpp index 374f8a9a..2e149600 100644 --- a/exllamav2/exllamav2_ext/ext_tp.cpp +++ b/exllamav2/exllamav2_ext/ext_tp.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include diff --git a/exllamav2/exllamav2_ext/ext_tp.h b/exllamav2/exllamav2_ext/ext_tp.h deleted file mode 100644 index ef8eb495..00000000 --- a/exllamav2/exllamav2_ext/ext_tp.h +++ /dev/null @@ -1,124 +0,0 @@ -#ifndef _ext_tp_h -#define _ext_tp_h - -#define BROADCAST_KV 0 -#define BROADCAST_ID 1 -#define BROADCAST_VC 2 -#define BROADCAST_RS 3 -#define BROADCAST_Q 4 - -//#define TP_MULTITHREADED - -//#include -#include "cpp/threadpool.h" -#include "cuda/tp.cuh" - -class ExtTPContext -{ -public: - std::vector> kv_split; - std::vector> id_split; - std::vector> vc_split; - std::vector> rs_split; - std::vector> q_split; - std::vector pinned_temp; - size_t pinned_size; - std::vector streams; - - std::vector all_devices; - - ThreadPool* thread_pool; - ExtTPData* tp_data; - - void* mapped_globals; - - std::vector sync_events; -// std::vector comms; -// std::vector comms_index; - - ExtTPContext - ( - std::vector> _kv_split, - std::vector> _id_split, - std::vector> _vc_split, - std::vector> _rs_split, - std::vector> _q_split, - std::vector _pinned_temp, - std::vector _streams - ); - ~ExtTPContext(); -}; - -uintptr_t make_tp_context -( - const std::vector> kv_split, - const std::vector> id_split, - const std::vector> vc_split, - const std::vector> rs_split, - const std::vector> q_split, - std::vector pinned_temp, - std::vector streams -); - -void free_tp_context(uintptr_t ctx); - -void tp_broadcast -( - uintptr_t tp_context, - int buffer, - torch::Tensor source, - int broadcast_type, - const std::vector &targets, - int dim, - int t_device = -1 -); - -void tp_gather -( - uintptr_t tp_context, - int buffer, - const std::vector &inputs, - int broadcast_type, - const std::vector &targets, - int broadcast_type_target, - int dim, - int t_device = -1 -); - -void tp_gather_barrier -( - uintptr_t tp_context, - int buffer, - const std::vector &inputs, - int broadcast_type, - const std::vector &targets, - int broadcast_type_target, - int dim, - int t_device = -1, - Barrier* barrier = nullptr -); - -void tp_cross_device_barrier -( - uintptr_t tp_context, - int broadcast_type, - int t_device = -1, - int stage = -1, - int next_stage = -1 -); - -//void tp_all_reduce -//( -// uintptr_t tp_context, -// const std::vector &tensors -//); - -void tp_all_reduce -( - uintptr_t tp_context, - int buffer, - const std::vector &tensors, - const std::vector &residuals -); - -#endif \ No newline at end of file