Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added option for tokenized input to dynamic generator #613

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion exllamav2/exllamav2_ext/ext_qattn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/torch.h>
#include <barrier>

#include "config.h"
#include "ext_qattn.h"
Expand Down
159 changes: 107 additions & 52 deletions exllamav2/generator/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def generate(
prompt: list[tuple] | list[str] | tuple | str,
max_new_tokens: int,
min_new_tokens: int = 0,
tokens: torch.Tensor = None,
seed: int or None = None,
gen_settings: ExLlamaV2Sampler.Settings | list[ExLlamaV2Sampler.Settings] | None = None,
token_healing: bool = False,
Expand Down Expand Up @@ -601,57 +602,102 @@ def generate(
"""

order = {}
if isinstance(prompt, list):
prompts = prompt
else:
prompts = [prompt]
filters = [filters]

if filters is None:
filters = [None] * len(prompts)
else:
assert len(filters) == len(prompts) and \
all((f is None or isinstance(f, list)) for f in filters), \
"If using filters, must provide one filter list (or None-value) per prompt."

prompts = prompt if isinstance(prompt, list) else [prompt]
batch_size = len(prompts)
for idx, p in enumerate(prompts):

if isinstance(p, str):
input_ids = self.tokenizer.encode(p, encode_special_tokens = encode_special_tokens, add_bos = add_bos)
elif isinstance(p, tuple):
input_ids = [self.tokenizer.encode(p_, encode_special_tokens = encode_special_tokens, add_bos = add_bos) for p_ in p]
if tokens == None:

if isinstance(prompt, list):
prompts = prompt
else:
assert False, "Unexpected type in prompt"

if gen_settings is None:
p_settings = ExLlamaV2Sampler.Settings()
elif isinstance(gen_settings, ExLlamaV2Sampler.Settings):
p_settings = gen_settings
elif isinstance(gen_settings, list):
assert len(gen_settings) == len(prompts)
p_settings = gen_settings[idx]
prompts = [prompt]
filters = [filters]

if filters is None:
filters = [None] * len(prompts)
else:
assert False, "Unexpected type in gen_settings"

job = ExLlamaV2DynamicJob(
input_ids = input_ids,
max_new_tokens = max_new_tokens,
min_new_tokens = min_new_tokens,
seed = seed,
stop_conditions = stop_conditions,
gen_settings = p_settings,
filters = filters[idx] or [],
filter_prefer_eos = filter_prefer_eos,
token_healing = token_healing,
decode_special_tokens = decode_special_tokens,
)

if seed is not None: seed += 1
assert len(filters) == len(prompts) and \
all((f is None or isinstance(f, list)) for f in filters), \
"If using filters, must provide one filter list (or None-value) per prompt."

prompts = prompt if isinstance(prompt, list) else [prompt]
batch_size = len(prompts)

for idx, p in enumerate(prompts):

if isinstance(p, str):
input_ids = self.tokenizer.encode(p, encode_special_tokens = encode_special_tokens, add_bos = add_bos)
elif isinstance(p, tuple):
input_ids = [self.tokenizer.encode(p_, encode_special_tokens = encode_special_tokens, add_bos = add_bos) for p_ in p]
else:
assert False, "Unexpected type in prompt"

if gen_settings is None:
p_settings = ExLlamaV2Sampler.Settings()
elif isinstance(gen_settings, ExLlamaV2Sampler.Settings):
p_settings = gen_settings
elif isinstance(gen_settings, list):
assert len(gen_settings) == len(prompts)
p_settings = gen_settings[idx]
else:
assert False, "Unexpected type in gen_settings"

job = ExLlamaV2DynamicJob(
input_ids = input_ids,
max_new_tokens = max_new_tokens,
min_new_tokens = min_new_tokens,
seed = seed,
stop_conditions = stop_conditions,
gen_settings = p_settings,
filters = filters[idx] or [],
filter_prefer_eos = filter_prefer_eos,
token_healing = token_healing,
decode_special_tokens = decode_special_tokens,
)

if seed is not None: seed += 1

serial = self.enqueue(job)
order[serial] = idx

serial = self.enqueue(job)
order[serial] = idx
else:

if tokens.ndim == 1:
tokens = tokens.unsqueeze(0)
tokens_ndim_was_1 = True
else:
tokens_ndim_was_1 = False

batch_size = tokens.shape[0]

for idx in range(tokens.shape[0]):
token_sequence = tokens[idx:idx+1]

if gen_settings is None:
p_settings = ExLlamaV2Sampler.Settings()
elif isinstance(gen_settings, ExLlamaV2Sampler.Settings):
p_settings = gen_settings
elif isinstance(gen_settings, list):
assert len(gen_settings) == tokens.shape[0]
p_settings = gen_settings[idx]
else:
assert False, "Unexpected type in gen_settings"

job = ExLlamaV2DynamicJob(
input_ids = token_sequence,
max_new_tokens = max_new_tokens,
min_new_tokens = min_new_tokens,
seed = seed,
stop_conditions = stop_conditions,
gen_settings = p_settings,
filters = [],
filter_prefer_eos = filter_prefer_eos,
token_healing = token_healing,
decode_special_tokens = decode_special_tokens,
)

if seed is not None: seed += 1

serial = self.enqueue(job)
order[serial] = idx

# Collect outputs until all jobs finish

Expand All @@ -673,12 +719,21 @@ def generate(

# Return results

if not completion_only:
completions = [(p if isinstance(p, str) else p[0]) + c for p, c in zip(prompts, completions)]
if tokens == None:
if not completion_only:
completions = [(p if isinstance(p, str) else p[0]) + c for p, c in zip(prompts, completions)]

if not isinstance(prompt, list):
completions = completions[0]
last_results = last_results[0]

else:
if not completion_only:
completions = [(p if isinstance(p, str) else p[0]) + c for p, c in zip(self.tokenizer.decode(tokens), completions)]

if not isinstance(prompt, list):
completions = completions[0]
last_results = last_results[0]
if not tokens_ndim_was_1:
completions = completions[0]
last_results = last_results[0]

if return_last_results:
return completions, last_results
Expand Down