-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Math RL data preparation #368
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Igor Gitman <[email protected]>
@@ -350,3 +350,14 @@ def compute_chunk_ids(chunk_ids: list[int] | str, num_chunks: int) -> list[int] | |||
assert chunk_id >= 0, "Run ids should have 1-based indexing" | |||
|
|||
return chunk_ids | |||
|
|||
|
|||
def prefill_judgement(data_point: dict) -> str | None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we keep this function here or this is more appropriate for training
? Maybe a utils.py for training folder makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is supposed to be a basic judge model which only cares about the surface form, and should be task agnostic. Is this preferred because it's lightweight and generic?
@@ -171,4 +171,4 @@ def test_openmathinstruct2(): | |||
|
|||
assert ( | |||
expected_md5 == output_md5 | |||
), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_sft_data.py" | |||
), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_data.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
), "MD5 hashes do not match, something is wrong with nemo_skills/finetuning/prepare_data.py" | |
), "MD5 hashes do not match, something is wrong with nemo_skills/training/prepare_data.py" |
@@ -0,0 +1,90 @@ | |||
processors_to_run: all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid clutter we can move all the config files in a config
folder.
- "majority_votes" | ||
|
||
# this will optimize processors inside to avoid serializing data to disk | ||
- _target_: nemo_skills.training.data_preparation_utils.merge_processor.MergeProcessor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a good addition for math_sft.yaml
and code_sft.yaml
as well. We can do it later but good to know of this functionality.
# take only required keys from the input if exclude_optional_keys is True | ||
output_sample = {} | ||
if not self.exclude_optional_keys: | ||
output_sample = json.loads(line) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if/elif
structure is testing for two very different things. Is there a missing else
for the first if
? Can the second elif
be an if
in itself?
if judgement is not None: | ||
prefilled_judgements.append(judgement) | ||
prefilled_indices.add(len(data_points) - 1) | ||
|
||
llm = get_model(server_type="trtllm") | ||
prompt = get_prompt('judge/math', 'qwen-instruct') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
server_type
, prompt_config
, and prompt_template
should not be hardcoded in the function. Make them parameters ideally.
"predicted_answer": extract_answer(query), | ||
} | ||
) | ||
judgement = prefill_judgement(data_points[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefiling can ideally save computation but right now these data points are still being passed to the LLM. For cases where the judgement is True/Correct, we can remove those. For the ones where it is None or False, in those cases we can rely on the LLM as a judge.
from nemo_skills.code_execution.math_grader import extract_answer | ||
from nemo_skills.evaluation.metrics.utils import is_correct_judgement | ||
from nemo_skills.inference.server.model import get_model | ||
from nemo_skills.prompt.utils import get_prompt | ||
from nemo_skills.utils import prefill_judgement | ||
|
||
|
||
def reward_func(queries: list[str], prompts: list[str], prompt_metadata: list[dict]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A one-line description that the function assigns a binary score to the prompt-response/problem-response pairs.
from nemo_skills.code_execution.math_grader import extract_answer | ||
from nemo_skills.evaluation.metrics.utils import is_correct_judgement | ||
from nemo_skills.inference.server.model import get_model | ||
from nemo_skills.prompt.utils import get_prompt | ||
from nemo_skills.utils import prefill_judgement | ||
|
||
|
||
def reward_func(queries: list[str], prompts: list[str], prompt_metadata: list[dict]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused by the parameter prompts
. It's passed as a parameter, and then the variable prompts
is created inside the function on line 29. The other variables could also use better names. For e.g., queries
could be generations
/responses
.
Breaking change:
prepare_sft_data
is renamed toprepare_data
as it now covers more cases