diff --git a/tests/ludwig/schema/test_model_config.py b/tests/ludwig/schema/test_model_config.py index 4165a8801e4..21e2883b989 100644 --- a/tests/ludwig/schema/test_model_config.py +++ b/tests/ludwig/schema/test_model_config.py @@ -952,62 +952,60 @@ def test_llm_quantization_backend_compatibility(): ray.shutdown() -def test_max_new_tokens_override_no_changes_to_max_new_tokens(): - """Tests that the default value for max_new_tokens is respected when explicitly set in the config.""" - config = { - MODEL_TYPE: MODEL_LLM, - BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", - INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for generation.max_sequence_length is 32 - OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], - "generation": {"max_new_tokens": 64}, - } - - config_obj = ModelConfig.from_dict(config) - assert config_obj.generation.max_new_tokens == 64 - - -def test_max_new_tokens_override_large_max_sequence_length(): - """Tests that the default value for max_new_tokens is overridden when max_sequence_length is set to a large - value than the default max_new_tokens.""" - config = { - MODEL_TYPE: MODEL_LLM, - BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", - INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for generation.max_sequence_length is 32 - OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text", "preprocessing": {"max_sequence_length": 100}}], - } - - config_obj = ModelConfig.from_dict(config) - assert config_obj.generation.max_new_tokens == 100 - - -def test_max_new_tokens_override_large_global_max_sequence_length(): - """Tests that the default value for max_new_tokens is overridden when global_max_sequence_length is set to a - larger value than the default max_new_tokens.""" - config = { - MODEL_TYPE: MODEL_LLM, - BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", - INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for generation.max_sequence_length is 32 - OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], - PREPROCESSING: {"global_max_sequence_length": 100}, - } - - config_obj = ModelConfig.from_dict(config) - assert config_obj.generation.max_new_tokens == 100 - - -def test_max_new_tokens_override_fallback_to_model_window_size(): - config = { - MODEL_TYPE: MODEL_LLM, - BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", - INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for generation.max_sequence_length is 32 - OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], - } - - config_obj = ModelConfig.from_dict(config) - # Base model context length is 2048 tokens by default - # Since we fallback to setting max_new_tokens to the model context length / 2, we expect it to be 1024 tokens - assert config_obj.generation.max_new_tokens == 1024 +class TestMaxNewTokensOverride: + def test_max_new_tokens_override_no_changes_to_max_new_tokens(self): + """Tests that the default value for max_new_tokens is respected when explicitly set in the config.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for generation.max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + "generation": {"max_new_tokens": 64}, + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 64 + + def test_max_new_tokens_override_large_max_sequence_length(self): + """Tests that the default value for max_new_tokens is overridden when max_sequence_length is set to a large + value than the default max_new_tokens.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for generation.max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text", "preprocessing": {"max_sequence_length": 100}}], + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 100 + + def test_max_new_tokens_override_large_global_max_sequence_length(self): + """Tests that the default value for max_new_tokens is overridden when global_max_sequence_length is set to + a larger value than the default max_new_tokens.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for generation.max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + PREPROCESSING: {"global_max_sequence_length": 100}, + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 100 + + def test_max_new_tokens_override_fallback_to_model_window_size(self): + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for generation.max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + } + + config_obj = ModelConfig.from_dict(config) + # Base model context length is 2048 tokens by default + # Since we fallback to setting max_new_tokens to the model context length / 2, we expect it to be 1024 tokens + assert config_obj.generation.max_new_tokens == 1024