diff --git a/exllamav2/exllamav2_ext/ext_qattn.cpp b/exllamav2/exllamav2_ext/ext_qattn.cpp index 4906a5db..05c96bdf 100644 --- a/exllamav2/exllamav2_ext/ext_qattn.cpp +++ b/exllamav2/exllamav2_ext/ext_qattn.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include "config.h" #include "ext_qattn.h" diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index 8acc3dbb..575e8f4d 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -525,6 +525,7 @@ def generate( prompt: list[tuple] | list[str] | tuple | str, max_new_tokens: int, min_new_tokens: int = 0, + tokens: torch.Tensor = None, seed: int or None = None, gen_settings: ExLlamaV2Sampler.Settings | list[ExLlamaV2Sampler.Settings] | None = None, token_healing: bool = False, @@ -601,57 +602,102 @@ def generate( """ order = {} - if isinstance(prompt, list): - prompts = prompt - else: - prompts = [prompt] - filters = [filters] - if filters is None: - filters = [None] * len(prompts) - else: - assert len(filters) == len(prompts) and \ - all((f is None or isinstance(f, list)) for f in filters), \ - "If using filters, must provide one filter list (or None-value) per prompt." - - prompts = prompt if isinstance(prompt, list) else [prompt] - batch_size = len(prompts) - for idx, p in enumerate(prompts): - - if isinstance(p, str): - input_ids = self.tokenizer.encode(p, encode_special_tokens = encode_special_tokens, add_bos = add_bos) - elif isinstance(p, tuple): - input_ids = [self.tokenizer.encode(p_, encode_special_tokens = encode_special_tokens, add_bos = add_bos) for p_ in p] + if tokens == None: + + if isinstance(prompt, list): + prompts = prompt else: - assert False, "Unexpected type in prompt" - - if gen_settings is None: - p_settings = ExLlamaV2Sampler.Settings() - elif isinstance(gen_settings, ExLlamaV2Sampler.Settings): - p_settings = gen_settings - elif isinstance(gen_settings, list): - assert len(gen_settings) == len(prompts) - p_settings = gen_settings[idx] + prompts = [prompt] + filters = [filters] + + if filters is None: + filters = [None] * len(prompts) else: - assert False, "Unexpected type in gen_settings" - - job = ExLlamaV2DynamicJob( - input_ids = input_ids, - max_new_tokens = max_new_tokens, - min_new_tokens = min_new_tokens, - seed = seed, - stop_conditions = stop_conditions, - gen_settings = p_settings, - filters = filters[idx] or [], - filter_prefer_eos = filter_prefer_eos, - token_healing = token_healing, - decode_special_tokens = decode_special_tokens, - ) - - if seed is not None: seed += 1 + assert len(filters) == len(prompts) and \ + all((f is None or isinstance(f, list)) for f in filters), \ + "If using filters, must provide one filter list (or None-value) per prompt." + + prompts = prompt if isinstance(prompt, list) else [prompt] + batch_size = len(prompts) + + for idx, p in enumerate(prompts): + + if isinstance(p, str): + input_ids = self.tokenizer.encode(p, encode_special_tokens = encode_special_tokens, add_bos = add_bos) + elif isinstance(p, tuple): + input_ids = [self.tokenizer.encode(p_, encode_special_tokens = encode_special_tokens, add_bos = add_bos) for p_ in p] + else: + assert False, "Unexpected type in prompt" + + if gen_settings is None: + p_settings = ExLlamaV2Sampler.Settings() + elif isinstance(gen_settings, ExLlamaV2Sampler.Settings): + p_settings = gen_settings + elif isinstance(gen_settings, list): + assert len(gen_settings) == len(prompts) + p_settings = gen_settings[idx] + else: + assert False, "Unexpected type in gen_settings" + + job = ExLlamaV2DynamicJob( + input_ids = input_ids, + max_new_tokens = max_new_tokens, + min_new_tokens = min_new_tokens, + seed = seed, + stop_conditions = stop_conditions, + gen_settings = p_settings, + filters = filters[idx] or [], + filter_prefer_eos = filter_prefer_eos, + token_healing = token_healing, + decode_special_tokens = decode_special_tokens, + ) + + if seed is not None: seed += 1 + + serial = self.enqueue(job) + order[serial] = idx - serial = self.enqueue(job) - order[serial] = idx + else: + + if tokens.ndim == 1: + tokens = tokens.unsqueeze(0) + tokens_ndim_was_1 = True + else: + tokens_ndim_was_1 = False + + batch_size = tokens.shape[0] + + for idx in range(tokens.shape[0]): + token_sequence = tokens[idx:idx+1] + + if gen_settings is None: + p_settings = ExLlamaV2Sampler.Settings() + elif isinstance(gen_settings, ExLlamaV2Sampler.Settings): + p_settings = gen_settings + elif isinstance(gen_settings, list): + assert len(gen_settings) == tokens.shape[0] + p_settings = gen_settings[idx] + else: + assert False, "Unexpected type in gen_settings" + + job = ExLlamaV2DynamicJob( + input_ids = token_sequence, + max_new_tokens = max_new_tokens, + min_new_tokens = min_new_tokens, + seed = seed, + stop_conditions = stop_conditions, + gen_settings = p_settings, + filters = [], + filter_prefer_eos = filter_prefer_eos, + token_healing = token_healing, + decode_special_tokens = decode_special_tokens, + ) + + if seed is not None: seed += 1 + + serial = self.enqueue(job) + order[serial] = idx # Collect outputs until all jobs finish @@ -673,12 +719,21 @@ def generate( # Return results - if not completion_only: - completions = [(p if isinstance(p, str) else p[0]) + c for p, c in zip(prompts, completions)] + if tokens == None: + if not completion_only: + completions = [(p if isinstance(p, str) else p[0]) + c for p, c in zip(prompts, completions)] + + if not isinstance(prompt, list): + completions = completions[0] + last_results = last_results[0] + + else: + if not completion_only: + completions = [(p if isinstance(p, str) else p[0]) + c for p, c in zip(self.tokenizer.decode(tokens), completions)] - if not isinstance(prompt, list): - completions = completions[0] - last_results = last_results[0] + if not tokens_ndim_was_1: + completions = completions[0] + last_results = last_results[0] if return_last_results: return completions, last_results