Skip to content

Commit

Permalink
Add support for eval batch size tuning for LLMs on local backend (#3957)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 authored Mar 11, 2024
1 parent e0f112c commit 4610574
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 54 deletions.
26 changes: 24 additions & 2 deletions ludwig/schema/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,9 +922,31 @@ class FineTuneTrainerConfig(ECDTrainerConfig):
description="Base learning rate used for training in the LLM trainer.",
)

eval_batch_size: int = schema_utils.PositiveInteger(
batch_size: Union[int, str, None] = schema_utils.OneOfOptionsField(
default=1,
allow_none=False,
description=(
"The number of training examples utilized in one training step of the model. If `auto`, the "
"batch size that maximized training throughput (samples / sec) will be used."
),
field_options=[
schema_utils.PositiveInteger(default=1, description="", allow_none=False),
schema_utils.StringOptions(options=["auto"], default="auto", allow_none=False),
],
)

eval_batch_size: Union[int, str, None] = schema_utils.OneOfOptionsField(
default=2,
description="Batch size used for evaluation in the LLM trainer.",
allow_none=True,
description=(
"Size of batch to pass to the model for evaluation. If it is `0` or `None`, the same value of `batch_size` "
"is used. This is useful to speedup evaluation with a much bigger batch size than training, if enough "
"memory is available. If `auto`, the biggest batch size (power of 2) that can fit in memory will be used."
),
field_options=[
schema_utils.PositiveInteger(default=2, description="", allow_none=False),
schema_utils.StringOptions(options=["auto"], default="auto", allow_none=False),
],
)


Expand Down
61 changes: 10 additions & 51 deletions ludwig/trainers/trainer_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
from typing import Callable, Dict, List, Optional, Union

import torch
from torch.utils.tensorboard import SummaryWriter

from ludwig.constants import MINIMUM_BATCH_SIZE, TEST, TRAINING, VALIDATION
Expand All @@ -19,7 +18,11 @@
from ludwig.trainers.trainer import Trainer
from ludwig.types import ModelConfigDict
from ludwig.utils import time_utils
from ludwig.utils.batch_size_tuner import BatchSizeEvaluator
from ludwig.utils.batch_size_tuner import (
BatchSizeEvaluator,
LLMFinetunePredictBatchSizeEvaluator,
LLMFinetuneTrainerBatchSizeEvaluator,
)
from ludwig.utils.defaults import default_random_seed
from ludwig.utils.metric_utils import TrainerMetric
from ludwig.utils.metrics_printed_table import print_metrics_table
Expand Down Expand Up @@ -415,7 +418,7 @@ def __init__(
skip_save_log: bool = False,
callbacks: List = None,
report_tqdm_to_ray=False,
random_seed: float = default_random_seed,
random_seed: int = default_random_seed,
distributed: Optional[DistributedStrategy] = None,
device: Optional[str] = None,
**kwargs,
Expand Down Expand Up @@ -500,54 +503,10 @@ def tune_batch_size(
)

def _create_batch_size_evaluator(self) -> BatchSizeEvaluator:
trainer = self

class _TrainerBatchSizeEvaluator(BatchSizeEvaluator):
def __init__(self):
self.input_feature_name, self.input_feature = trainer.model.input_features.items()[0]
self.output_feature_name, self.output_feature = trainer.model.output_features.items()[0]

# Get the length of the longest input sequence from the training data
self.input_msl = self.input_feature.input_shape[0]
# Get the length of the longest output sequence from the training data
self.output_msl = self.output_feature.output_shape[0]
# max_sequence_length here is the smaller value between the global max sequence length of the model
# and the model's context length
if trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length:
self.output_msl = trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length

# This is useful to create the synthetic input and target data which will be a
# random sequence of integers between 0 and vocab_size
self.vocab_size = len(trainer.model.config_obj.input_features[0].encoder.vocab)

def reset(self):
trainer.model.reset_metrics()
trainer.optimizer.zero_grad()

def step(self, batch_size: int, global_max_sequence_length: int):
trainer.distributed.set_batch_size(trainer.dist_model, batch_size)

input_msl = self.input_msl
output_msl = self.output_msl
if self.input_msl + self.output_msl > global_max_sequence_length:
# In this case, we just need to make sure that the length of the synthetic data exceeds
# max_sequence_length by at most a small amount
input_msl = global_max_sequence_length // 2 + 1
output_msl = global_max_sequence_length // 2 + 1

inputs = {
self.input_feature_name: torch.randint(0, self.vocab_size, size=(batch_size, input_msl))
.to(self.input_feature.input_dtype)
.to(trainer.device)
}
targets = {
self.output_feature_name: torch.randint(0, self.vocab_size, size=(batch_size, output_msl))
.to(self.output_feature.get_output_dtype())
.to(trainer.device)
}
trainer.train_step(inputs, targets)

return _TrainerBatchSizeEvaluator()
return LLMFinetuneTrainerBatchSizeEvaluator(self)

def _create_predict_batch_size_evaluator(self) -> BatchSizeEvaluator:
return LLMFinetunePredictBatchSizeEvaluator(self)


class RemoteLLMTrainer(NoneTrainer):
Expand Down
65 changes: 65 additions & 0 deletions ludwig/utils/batch_size_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,68 @@ def reset(self):
def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None):
"""Called each step to evaluate the given batch size."""
raise NotImplementedError("`step` must be implemented by concrete evaluator.")


class BaseLLMBatchSizeEvaluator(BatchSizeEvaluator):
"""Base class for batch size evaluators for LLM models."""

def __init__(self, trainer):
self.trainer = trainer
self.input_feature_name, self.input_feature = list(trainer.model.input_features.items())[0]
self.output_feature_name, self.output_feature = list(trainer.model.output_features.items())[0]

# Get the length of the longest input sequence from the training data
self.input_msl = self.input_feature.input_shape[0]
if trainer.model.config_obj.input_features[0].preprocessing.max_sequence_length:
self.input_msl = trainer.model.config_obj.input_features[0].preprocessing.max_sequence_length

# Get the length of the longest output sequence from the training data
self.output_msl = self.output_feature.output_shape[0]
if trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length:
self.output_msl = trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length

# This is useful to create the synthetic input and target data which will be a
# random sequence of integers between 0 and vocab_size
self.vocab_size = len(trainer.model.config_obj.input_features[0].encoder.vocab)

def reset(self):
self.trainer.model.reset_metrics()
self.trainer.optimizer.zero_grad()

def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None):
if global_max_sequence_length and self.input_msl + self.output_msl > global_max_sequence_length:
# In this case, we just need to make sure that the length of the synthetic data exceeds
# max_sequence_length by at most a small amount
self.input_msl = global_max_sequence_length // 2 + 1
self.output_msl = global_max_sequence_length // 2 + 1

inputs = {
self.input_feature_name: torch.randint(0, self.vocab_size, size=(batch_size, self.input_msl))
.to(self.input_feature.input_dtype)
.to(self.trainer.device)
}
targets = {
self.output_feature_name: torch.randint(0, self.vocab_size, size=(batch_size, self.output_msl))
.to(self.output_feature.get_output_dtype())
.to(self.trainer.device)
}

self.perform_step(inputs, targets)

def perform_step(self, inputs, targets):
raise NotImplementedError("perform_step method must be implemented in subclasses")


class LLMFinetuneTrainerBatchSizeEvaluator(BaseLLMBatchSizeEvaluator):
"""Batch size evaluator for training batch size for LLM finetuning."""

def perform_step(self, inputs, targets):
self.trainer.train_step(inputs, targets)


class LLMFinetunePredictBatchSizeEvaluator(BaseLLMBatchSizeEvaluator):
"""Batch size evaluator for prediction/evaluation batch size for LLM finetuning."""

def perform_step(self, inputs, targets):
with torch.no_grad():
self.trainer.dist_model((inputs, targets))
6 changes: 5 additions & 1 deletion tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BATCH_SIZE,
COMBINER,
EPOCHS,
EVAL_BATCH_SIZE,
GENERATION,
INPUT_FEATURES,
MERGE_ADAPTER_INTO_BASE_MODEL,
Expand Down Expand Up @@ -350,9 +351,11 @@ def _prepare_finetuning_test(
BASE_MODEL: model_name,
INPUT_FEATURES: input_features,
OUTPUT_FEATURES: output_features,
GENERATION: {"max_new_tokens": 64},
TRAINER: {
TYPE: "finetune",
BATCH_SIZE: 8,
BATCH_SIZE: "auto",
EVAL_BATCH_SIZE: "auto",
EPOCHS: 2,
},
BACKEND: backend,
Expand Down Expand Up @@ -1305,6 +1308,7 @@ def test_llm_batch_size_tuning():
type: finetune
optimizer:
type: adam
batch_size: auto
train_steps: 1
learning_rate: 0.0002
eval_batch_size: 2
Expand Down

0 comments on commit 4610574

Please sign in to comment.