-
Notifications
You must be signed in to change notification settings - Fork 870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add functions for input-masked loss calculation and batching #825
base: main
Are you sure you want to change the base?
Conversation
llms/mlx_lm/tuner/trainer.py
Outdated
def iterate_input_masked_batches( | ||
input_text, output_text, tokenizer, max_seq_length=2048 | ||
): | ||
batch_size = len(input_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why set the batch size to the length of the dataset?
llms/mlx_lm/tuner/trainer.py
Outdated
input_lengths = mx.array(input_lengths) | ||
lengths = mx.array(adjusted_lengths) | ||
|
||
return batch[:, :-1], input_lengths, lengths |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only returns one example? Is that intentional? I assumed this is a drop-in replacement for iterate_batches but it's not clear that is the case based on how it's written..
… an updated attempt to better sync with iterate_batches logic
Adds support for completion-only finetuning via functions for iterating over batching that also calculates input masks along with padding and a loss function using the masks
-- Updated 5 months later to keep up with mlx(_lm) changes, etc.