Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Math RL data preparation #368

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open

Math RL data preparation #368

wants to merge 24 commits into from

Conversation

Kipok
Copy link
Collaborator

@Kipok Kipok commented Feb 9, 2025

Breaking change: prepare_sft_data is renamed to prepare_data as it now covers more cases

Kipok added 24 commits February 7, 2025 14:41
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
@Kipok Kipok requested a review from shtoshni February 9, 2025 01:46
@@ -350,3 +350,14 @@ def compute_chunk_ids(chunk_ids: list[int] | str, num_chunks: int) -> list[int]
assert chunk_id >= 0, "Run ids should have 1-based indexing"

return chunk_ids


def prefill_judgement(data_point: dict) -> str | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep this function here or this is more appropriate for training? Maybe a utils.py for training folder makes sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is supposed to be a basic judge model which only cares about the surface form, and should be task agnostic. Is this preferred because it's lightweight and generic?

@@ -171,4 +171,4 @@ def test_openmathinstruct2():

assert (
expected_md5 == output_md5
), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sft_data.py"
), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_data.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_data.py"
), "MD5 hashes do not match, something is wrong with nemo_skills/training/prepare_data.py"

@@ -0,0 +1,90 @@
processors_to_run: all
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid clutter we can move all the config files in a config folder.

- "majority_votes"

# this will optimize processors inside to avoid serializing data to disk
- _target_: nemo_skills.training.data_preparation_utils.merge_processor.MergeProcessor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a good addition for math_sft.yaml and code_sft.yaml as well. We can do it later but good to know of this functionality.

# take only required keys from the input if exclude_optional_keys is True
output_sample = {}
if not self.exclude_optional_keys:
output_sample = json.loads(line)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if/elif structure is testing for two very different things. Is there a missing else for the first if? Can the second elif be an if in itself?

if judgement is not None:
prefilled_judgements.append(judgement)
prefilled_indices.add(len(data_points) - 1)

llm = get_model(server_type="trtllm")
prompt = get_prompt('judge/math', 'qwen-instruct')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

server_type, prompt_config, and prompt_template should not be hardcoded in the function. Make them parameters ideally.

"predicted_answer": extract_answer(query),
}
)
judgement = prefill_judgement(data_points[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefiling can ideally save computation but right now these data points are still being passed to the LLM. For cases where the judgement is True/Correct, we can remove those. For the ones where it is None or False, in those cases we can rely on the LLM as a judge.

from nemo_skills.code_execution.math_grader import extract_answer
from nemo_skills.evaluation.metrics.utils import is_correct_judgement
from nemo_skills.inference.server.model import get_model
from nemo_skills.prompt.utils import get_prompt
from nemo_skills.utils import prefill_judgement


def reward_func(queries: list[str], prompts: list[str], prompt_metadata: list[dict]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A one-line description that the function assigns a binary score to the prompt-response/problem-response pairs.

from nemo_skills.code_execution.math_grader import extract_answer
from nemo_skills.evaluation.metrics.utils import is_correct_judgement
from nemo_skills.inference.server.model import get_model
from nemo_skills.prompt.utils import get_prompt
from nemo_skills.utils import prefill_judgement


def reward_func(queries: list[str], prompts: list[str], prompt_metadata: list[dict]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by the parameter prompts. It's passed as a parameter, and then the variable prompts is created inside the function on line 29. The other variables could also use better names. For e.g., queries could be generations/responses.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants