Skip to content

Commit

Permalink
Group tests under a single test class
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 committed Oct 11, 2023
1 parent a8a2cc8 commit 4d198d4
Showing 1 changed file with 57 additions and 59 deletions.
116 changes: 57 additions & 59 deletions tests/ludwig/schema/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,62 +952,60 @@ def test_llm_quantization_backend_compatibility():
ray.shutdown()


def test_max_new_tokens_override_no_changes_to_max_new_tokens():
"""Tests that the default value for max_new_tokens is respected when explicitly set in the config."""
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
"generation": {"max_new_tokens": 64},
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.generation.max_new_tokens == 64


def test_max_new_tokens_override_large_max_sequence_length():
"""Tests that the default value for max_new_tokens is overridden when max_sequence_length is set to a large
value than the default max_new_tokens."""
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text", "preprocessing": {"max_sequence_length": 100}}],
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.generation.max_new_tokens == 100


def test_max_new_tokens_override_large_global_max_sequence_length():
"""Tests that the default value for max_new_tokens is overridden when global_max_sequence_length is set to a
larger value than the default max_new_tokens."""
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
PREPROCESSING: {"global_max_sequence_length": 100},
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.generation.max_new_tokens == 100


def test_max_new_tokens_override_fallback_to_model_window_size():
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
}

config_obj = ModelConfig.from_dict(config)
# Base model context length is 2048 tokens by default
# Since we fallback to setting max_new_tokens to the model context length / 2, we expect it to be 1024 tokens
assert config_obj.generation.max_new_tokens == 1024
class TestMaxNewTokensOverride:
def test_max_new_tokens_override_no_changes_to_max_new_tokens(self):
"""Tests that the default value for max_new_tokens is respected when explicitly set in the config."""
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
"generation": {"max_new_tokens": 64},
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.generation.max_new_tokens == 64

def test_max_new_tokens_override_large_max_sequence_length(self):
"""Tests that the default value for max_new_tokens is overridden when max_sequence_length is set to a large
value than the default max_new_tokens."""
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text", "preprocessing": {"max_sequence_length": 100}}],
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.generation.max_new_tokens == 100

def test_max_new_tokens_override_large_global_max_sequence_length(self):
"""Tests that the default value for max_new_tokens is overridden when global_max_sequence_length is set to
a larger value than the default max_new_tokens."""
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
PREPROCESSING: {"global_max_sequence_length": 100},
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.generation.max_new_tokens == 100

def test_max_new_tokens_override_fallback_to_model_window_size(self):
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
}

config_obj = ModelConfig.from_dict(config)
# Base model context length is 2048 tokens by default
# Since we fallback to setting max_new_tokens to the model context length / 2, we expect it to be 1024 tokens
assert config_obj.generation.max_new_tokens == 1024

0 comments on commit 4d198d4

Please sign in to comment.