Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 committed Oct 12, 2023
1 parent 3b5ec39 commit d40b75f
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions ludwig/schema/model_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Dict, List, Mapping, Set, TYPE_CHECKING

from marshmallow import ValidationError
from transformers import AutoConfig

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import (
Expand All @@ -29,9 +30,11 @@
from ludwig.features.feature_utils import compute_feature_hash
from ludwig.schema.features.utils import output_config_registry
from ludwig.schema.hyperopt.scheduler import BaseHyperbandSchedulerConfig
from ludwig.schema.llms.generation import LLMGenerationConfig
from ludwig.schema.trainer import ECDTrainerConfig
from ludwig.types import HyperoptConfigDict, ModelConfigDict
from ludwig.utils.data_utils import get_sanitized_feature_name
from ludwig.utils.llm_utils import get_context_len

if TYPE_CHECKING:
from ludwig.schema.model_types.base import ModelConfig
Expand Down Expand Up @@ -346,33 +349,14 @@ def _set_llm_tokenizers(config: "ModelConfig") -> None:
output_feature.decoder.fallback_label = output_feature.preprocessing.fallback_label


def _set_generation_max_new_tokens(config: "ModelConfig") -> None:
"""Sets the max_new_tokens parameter in the generation config to the max sequence length of the output
features.
This ensures that the generation config is set to the correct value for the LLM model type.
"""
from transformers import AutoConfig

from ludwig.schema.llms.generation import LLMGenerationConfig
from ludwig.utils.llm_utils import get_context_len

_DEFAULT_MAX_SEQUENCE_LENGTH = LLMGenerationConfig().max_new_tokens
if config.generation.max_new_tokens != _DEFAULT_MAX_SEQUENCE_LENGTH:
# Max new tokens is explicitly set by user, so don't override
return

if config.output_features[0].type != TEXT:
# This is trickier to set for other output features, so don't override for now.
# TODO: Add better support for category output features
return

max_possible_sequence_length = _DEFAULT_MAX_SEQUENCE_LENGTH
def _get_maximum_possible_sequence_length(config: "ModelConfig", default_max_sequence_length: int) -> int:
"""Returns the maximum possible sequence length for the LLM model based on the model config."""
max_possible_sequence_length = default_max_sequence_length
if config.output_features[0].preprocessing.max_sequence_length is not None:
# Note: We don't need to check for max between feature.preprocessing.max_sequence_length and
# defaults.text.preprocessing.max_sequence_length because the latter is only applied to input features.
max_possible_sequence_length = max(
_DEFAULT_MAX_SEQUENCE_LENGTH, config.output_features[0].preprocessing.max_sequence_length
default_max_sequence_length, config.output_features[0].preprocessing.max_sequence_length
)
elif config.preprocessing.global_max_sequence_length is not None:
# This is not perfect since it includes tokens from both input + output features, but this at least
Expand All @@ -381,7 +365,7 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None:
max_possible_sequence_length = max(
max_possible_sequence_length, config.preprocessing.global_max_sequence_length
)
elif max_possible_sequence_length == _DEFAULT_MAX_SEQUENCE_LENGTH:
elif max_possible_sequence_length == default_max_sequence_length:
# It's possible that both max_sequence_length and global_max_sequence_length are not set, in which case
# we should fall back to the window size of the pretrained model. By this point, because of schema validation
# checks, we know that the base_model exists so we can safely grab the base model's config.
Expand All @@ -391,6 +375,26 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None:
# Artifically leave a buffer of half the total model window size to trade off
# runtime while likely covering a majority of the max sequence length.
max_possible_sequence_length = max_possible_sequence_length // 2
return max_possible_sequence_length


def _set_generation_max_new_tokens(config: "ModelConfig") -> None:
"""Sets the max_new_tokens parameter in the generation config to the max sequence length of the output
features.
This ensures that the generation config is set to the correct value for the LLM model type.
"""
_DEFAULT_MAX_SEQUENCE_LENGTH = LLMGenerationConfig().max_new_tokens
if config.generation.max_new_tokens != _DEFAULT_MAX_SEQUENCE_LENGTH:
# Max new tokens is explicitly set by user, so don't override
return

if config.output_features[0].type != TEXT:
# This is trickier to set for other output features, so don't override for now.
# TODO: Add better support for category output features
return

max_possible_sequence_length = _get_maximum_possible_sequence_length(config, _DEFAULT_MAX_SEQUENCE_LENGTH)

logger.info(
f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max "
Expand Down

0 comments on commit d40b75f

Please sign in to comment.