Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions for input-masked loss calculation and batching #825

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
20 changes: 20 additions & 0 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def build_parser():
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.",
)

parser.add_argument(
"--mask-inputs",
dest="mask_inputs",
action="store_true",
help="Whether to mask the inputs when training. Default is False.",
default=False,
)

parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -169,6 +178,13 @@ def train_model(
valid_set,
training_callback: TrainingCallback = None,
):
from .tuner.trainer import (
default_loss,
input_masked_loss,
iterate_batches,
iterate_delineated_batches,
)

model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
Expand Down Expand Up @@ -225,6 +241,10 @@ def train_model(
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
iterate_batches=(
iterate_delineated_batches if args.mask_inputs else iterate_batches
),
loss=input_masked_loss if args.mask_inputs else default_loss,
)


Expand Down
3 changes: 3 additions & 0 deletions llms/mlx_lm/tuner/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(
self._prompt_key = prompt_key
self._completion_key = completion_key

def get_prompt_and_completion(self, idx: int):
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]

def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
Expand Down
134 changes: 133 additions & 1 deletion llms/mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
from typing import List, Tuple

import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer

from .datasets import CompletionsDataset


def grad_checkpoint(layer):
Expand Down Expand Up @@ -63,6 +66,24 @@ class TrainingArgs:
)


def input_masked_loss(model, inputs, input_lengths, lengths):
shifted_inputs = inputs[:, :-1]
shifted_labels = inputs[:, 1:]
logits = model(shifted_inputs)
logits = logits.astype(mx.float32)

mask_width = shifted_inputs.shape[1]
token_indices = mx.arange(mask_width)[None, :]
mask = mx.logical_and(
token_indices >= input_lengths[:, None], token_indices < lengths[:, None]
)

ce = nn.losses.cross_entropy(logits, shifted_labels) * mask
ntoks = mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks


def default_loss(model, inputs, targets, lengths):
logits = model(inputs)
logits = logits.astype(mx.float32)
Expand All @@ -76,6 +97,117 @@ def default_loss(model, inputs, targets, lengths):
return ce, ntoks


def contains(small_list: List, big_list: List) -> Tuple[int, int]:
"""
Returns the beginning and end index of the first occurrence of small_list in big_list.
"""
for i in range(len(big_list) - len(small_list) + 1):
for j in range(len(small_list)):
if big_list[i + j] != small_list[j]:
break
else:
return i, i + len(small_list)
raise RuntimeError("Not found")


def no_bos(sequence: List, bos: int) -> List:
return sequence if sequence[0] != bos else sequence[1:]


def input_and_output_lengths(
input_text: str, output_text: str, tokenizer: PreTrainedTokenizer
) -> Tuple[int, int]:
"""
Returns the length of the portion of the encoding of the concatenation of input_text and output_text
that corresponds to the input tokens and the length of the portion that corresponds to the output tokens.
"""
message = [
{"role": "user", "content": input_text},
{"role": "assistant", "content": output_text},
]
output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id)
full_sequence = tokenizer.apply_chat_template(message, tokenize=True)
output_begin, output_end = contains(output_tokens, full_sequence)
return output_begin, len(full_sequence) - len(output_tokens) + 1


def iterate_delineated_batches(
dataset: CompletionsDataset,
tokenizer: PreTrainedTokenizer,
batch_size: int,
max_seq_length: int,
train: bool = False,
):
"""
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
and returns the lengths of input tokens as well as that of the full sequences.
"""
idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)

# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches:
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
prompt_lengths = []
completion_lengths = []
batch = []
for j in batch_idx[i]:
prompt, completion = dataset.get_prompt_and_completion(j)
prompt_length, completion_length = input_and_output_lengths(
prompt, prompt, tokenizer
)

prompt_lengths.append(prompt_length)
completion_lengths.append(completion_length)

full_sequence = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
if full_sequence[-1] != tokenizer.eos_token_id:
full_sequence.append(tokenizer.eos_token_id)
batch.append(full_sequence)

lengths = [len(x) for x in batch]

if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
"Consider pre-splitting your data to save memory."
)

# Pad to the nearest multiple of 8 or the maximum length
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)

batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)

for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = (
truncated_length # Update lengths to match truncated lengths
)

yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths)

if not train:
break


def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
Expand Down