Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added logprobs to vllm and nemo #264

Open
wants to merge 73 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
691a0c6
Allow for not passing prompt config/template
shtoshni Nov 5, 2024
c1d49d5
Allow for not passing prompt config/template
shtoshni Nov 5, 2024
0f24697
Merge branch 'main' into gen_rm
shtoshni Nov 6, 2024
805ed64
Prompt config for math generative RM
shtoshni Nov 7, 2024
2df0e4b
Merge branch 'main' into gen_rm
shtoshni Nov 11, 2024
73a77a1
added include_generation parameter to generation config
i-vainn Nov 11, 2024
52e4290
modified conf template
i-vainn Nov 12, 2024
9d7e9ae
make vllm return logprobs
i-vainn Nov 15, 2024
d03a33f
Merge branch 'main' into imoshkov/continue_generation
i-vainn Nov 22, 2024
34185ad
new prompt template
i-vainn Nov 22, 2024
8e06822
merged with main
i-vainn Nov 25, 2024
ea6691e
added logprobs param
i-vainn Nov 25, 2024
f55bfc2
added logprobs param
i-vainn Nov 25, 2024
fca3203
Merge branch 'main' into imoshkov/continue_generation
i-vainn Nov 25, 2024
0244e1b
Merge branch 'imoshkov/continue_generation' into imoshkov/debiasing_a…
i-vainn Nov 25, 2024
c25368f
added logprobs inference param
i-vainn Nov 25, 2024
12add35
attempted to add nemo logprobs
i-vainn Nov 25, 2024
1bbaf6d
removed exp files from commit
i-vainn Nov 25, 2024
d779ef1
added template
i-vainn Nov 26, 2024
326b2d5
Merging with main
shtoshni Dec 5, 2024
8f95d7a
Merging with main
shtoshni Dec 6, 2024
b27a08f
Merge branch 'gen_rm' into imoshkov/debiasing_answer
shtoshni Dec 6, 2024
c56227d
Fixing model implementation
shtoshni Dec 6, 2024
d6525a1
merged with main
i-vainn Dec 9, 2024
5d52ebf
Merge branch 'main' into imoshkov/debiasing_answer
shtoshni Dec 19, 2024
ba944fe
Merge branch 'main' into imoshkov/debiasing_answer
shtoshni Dec 19, 2024
5fb2ace
Logprobs for gen rm
shtoshni Dec 22, 2024
51147e0
Reward model score support
shtoshni Jan 13, 2025
3dd9e54
merged
i-vainn Jan 14, 2025
0bfe6f0
merged with main
i-vainn Jan 14, 2025
cc0d128
merged with main
i-vainn Feb 3, 2025
c4b5cda
fix merge
i-vainn Feb 3, 2025
8fdd18b
fix merge
i-vainn Feb 3, 2025
acae93f
fix
i-vainn Feb 3, 2025
4a0701e
enabled trt logprobs
i-vainn Feb 3, 2025
013613f
enabled trt logprobs
i-vainn Feb 3, 2025
8ab3732
modified include generation logic
i-vainn Feb 3, 2025
35b85ae
fix
i-vainn Feb 3, 2025
7b2f5af
minor fixes
i-vainn Feb 4, 2025
470b4a1
started logporbs refactoring, wip
i-vainn Feb 4, 2025
835decb
trt fix
i-vainn Feb 4, 2025
7aff1d5
trt fix
i-vainn Feb 4, 2025
dedc694
Merge branch 'imoshkov/debiasing_answer' into imoshkov/logprobs_wip
i-vainn Feb 4, 2025
470e955
completed refactoring
i-vainn Feb 4, 2025
d00ffe4
quick fix
i-vainn Feb 4, 2025
ba91774
quick fix
i-vainn Feb 4, 2025
bad1fab
quick fix
i-vainn Feb 4, 2025
e6e9a3e
merged with main
i-vainn Feb 4, 2025
e0a33ff
minor naming fixes
i-vainn Feb 5, 2025
4944d66
removed rm related things from pr
i-vainn Feb 5, 2025
0aafec4
make nemo return full tokens
i-vainn Feb 5, 2025
7d5c82b
fixed nemo generation handling
i-vainn Feb 6, 2025
0afb319
added inference tests
i-vainn Feb 6, 2025
395d561
merged with main
i-vainn Feb 6, 2025
f2aecfe
fixed tests
i-vainn Feb 7, 2025
190ffca
fixed tests
i-vainn Feb 7, 2025
4379a29
removed tmp change
i-vainn Feb 7, 2025
ed9bc8d
merged with main
i-vainn Feb 10, 2025
a944d33
fixed logprobs tests
i-vainn Feb 10, 2025
03e515b
fixed tests
i-vainn Feb 10, 2025
a31fbcf
fixed nemo num_generated_tokens issue
i-vainn Feb 11, 2025
2daad96
fixed logprobs test
i-vainn Feb 11, 2025
a0738d8
Merge branch 'main' into imoshkov/debiasing_answer
i-vainn Feb 11, 2025
5e52e19
fixed logprobs tests
i-vainn Feb 11, 2025
d74d046
Merge branch 'main' into imoshkov/debiasing_answer
i-vainn Feb 13, 2025
d1d54b7
minor fixes
i-vainn Feb 13, 2025
83fbe29
Merge branch 'main' into imoshkov/debiasing_answer
i-vainn Feb 16, 2025
ca6b979
test fixes
i-vainn Feb 16, 2025
e521b5a
test fix
i-vainn Feb 16, 2025
b3fc4c2
test fix
i-vainn Feb 17, 2025
9870151
fix
i-vainn Feb 17, 2025
2dc56f8
fix
i-vainn Feb 18, 2025
281e137
fix
i-vainn Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class InferenceConfig:
random_seed: int = 0
tokens_to_generate: int = 2048
repetition_penalty: float = 1.0
top_logprobs: int | None = None


@nested_dataclass(kw_only=True)
Expand Down Expand Up @@ -76,6 +77,7 @@ class GenerateSolutionsConfig:
chunk_id: int | None = None # if specified, will index the specified chunk only

generation_key: str = "generation"
partial_generation: bool = False # if True, model will be prompted to continue "generation" without closing assistant tag
# if specified, we will have a loop over that key in the data file and
# treat each element as a new turn of conversation
# E.g. if multi_turn_key="turns" and a line in your data file has
Expand Down Expand Up @@ -178,7 +180,7 @@ def sync_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params)
if len(data_points) == cfg.batch_size or idx == len(data) - 1:
if cfg.multi_turn_key is None:
outputs = llm.generate(
prompts=[prompt.fill(dp, include_generation=cfg.include_generation) for dp in data_points],
prompts=[prompt.fill(dp, include_generation=cfg.include_generation, partial_generation=cfg.partial_generation) for dp in data_points],
stop_phrases=combine_stop_phrases(prompt.stop_phrases, extra_stop_phrases),
**asdict(cfg.inference),
**extra_generate_params,
Expand Down Expand Up @@ -206,7 +208,7 @@ def sync_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params)
# getting a new set of generations
turn_outputs = llm.generate(
prompts=[
prompt.fill(turn_data_points[dp_index], multi_turn_key=cfg.multi_turn_key, include_generation=cfg.include_generation)
prompt.fill(turn_data_points[dp_index], multi_turn_key=cfg.multi_turn_key, include_generation=cfg.include_generation, partial_generation=cfg.partial_generation)
for dp_index in dp_indices
],
stop_phrases=combine_stop_phrases(prompt.stop_phrases, extra_stop_phrases),
Expand Down Expand Up @@ -276,7 +278,7 @@ def async_loop(cfg, data, llm, prompt, extra_stop_phrases, extra_generate_params
# Dynamic sending requests to maintain cfg.max_concurrent_requests running requests
num_to_submit = min(cfg.max_concurrent_requests - len(in_progress), len(request_queue))
batch_indices = [request_queue.popleft() for _ in range(num_to_submit)]
batch_prompts = [prompt.fill(data[idx], include_generation=cfg.include_generation) for idx in batch_indices]
batch_prompts = [prompt.fill(data[idx], include_generation=cfg.include_generation, partial_generation=cfg.partial_generation) for idx in batch_indices]

if len(batch_prompts) > 0:
generation_ids = llm.generate_async(
Expand Down Expand Up @@ -382,7 +384,7 @@ def generate(cfg: GenerateSolutionsConfig):
LOG.info("Prompt used: %s", prompt)

if cfg.multi_turn_key is None:
LOG.info("Example prompt:\nData dictionary: %s\nPrompt: %s", data[0], prompt.fill(data[0]))
LOG.info("Example prompt:\nData dictionary: %s\nPrompt: %s", data[0], prompt.fill(data[0], include_generation=cfg.include_generation, partial_generation=cfg.partial_generation))
else:
first_sample = deepcopy(data[0])
first_sample[cfg.multi_turn_key] = first_sample[cfg.multi_turn_key][:1]
Expand Down
20 changes: 15 additions & 5 deletions nemo_skills/inference/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def generate(cfg: RewardModelConfig):
LOG.info("Config used: %s", cfg)
llm = get_reward_model(model_type=cfg.reward_model_type, **cfg.server)

rm_type = cfg.server['rm_type']

# making sure output dir exists
Path(cfg.output_file).absolute().parent.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -137,13 +139,21 @@ def generate(cfg: RewardModelConfig):
if len(data) == 0: # we might not have any examples if skip_filled=True
return

LOG.info(
"Example prompt:\nData dictionary: %s\nPrompt: %s", data[0], prompt.fill(data[0], include_generation=True)
)

if cfg.dry_run:
return

if rm_type == 'disc':
include_generation = True
else:
# The template for GenRM already includes the generation, so we don't need to include it again
include_generation = False

LOG.info(
"Example prompt:\nData dictionary: %s\nPrompt: %s",
data[0],
prompt.fill(data[0], include_generation=include_generation),
)

# setting buffering=1 to force to dump the output after every line, so that we can see intermediate generations
with open(cfg.output_file, "at" if cfg.skip_filled else "wt", encoding="utf-8", buffering=1) as fout:
data_points = []
Expand All @@ -154,7 +164,7 @@ def generate(cfg: RewardModelConfig):

if len(data_points) == cfg.batch_size or idx == cfg.max_samples - 1:
outputs = llm.score(
prompts=[prompt.fill(dp, include_generation=True) for dp in data_points],
prompts=[prompt.fill(dp, include_generation=include_generation) for dp in data_points],
)

for output, original_data_point in zip(outputs, data_points):
Expand Down
6 changes: 6 additions & 0 deletions nemo_skills/inference/server/code_execution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def _generate_single(
repetition_penalty: float,
random_seed: int,
stop_phrases: list[str] | None = None,
top_logprobs: int | None = None,
):
if not isinstance(prompt, str):
raise NotImplementedError("OpenAI API is not supported yet.")
if top_logprobs is not None: # TODO: add this
raise NotImplementedError("top_logprobs is not supported yet.")

if stop_phrases is None:
stop_phrases = []
Expand Down Expand Up @@ -133,6 +136,7 @@ def generate_async(
random_seed: int | list[int] = 0,
stop_phrases: list[str] | list[list[str]] | None = None,
remove_stop_phrases: bool = True,
top_logprobs: int | list[int] | None = None,
) -> list[dict]:
"""For any generation parameter you can specify a list of values that needs to match the number of prompts.

Expand All @@ -141,6 +145,8 @@ def generate_async(
# TODO: currently nemo server would get separate 1-batch requests, which is likely really inefficient
# but the alternative is to have a fully separate implementation, which is also not nice
# If we find ourselves needing to use nemo with code execution often, we should fix this
if top_logprobs is not None: # TODO: add this
raise NotImplementedError("top_logprobs is not supported yet.")
kwargs = {
'code_begin': code_begin,
'code_end': code_end,
Expand Down
Loading