Skip to content

Commit

Permalink
Fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
justinxzhao committed Oct 12, 2023
1 parent bd3f897 commit dbe795f
Show file tree
Hide file tree
Showing 9 changed files with 491 additions and 12 deletions.
5 changes: 4 additions & 1 deletion ludwig/config_validation/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
def check_global_max_sequence_length_fits_prompt_template(metadata, global_preprocessing_parameters):
"""Checks that the prompt template fits into the global max sequence length."""

if global_preprocessing_parameters["global_max_sequence_length"] is not None:
if (
"global_max_sequence_length" in global_preprocessing_parameters
and global_preprocessing_parameters["global_max_sequence_length"] is not None
):
for feature_name, feature_metadata in metadata.items():
if (
"prompt_template_num_tokens" in feature_metadata
Expand Down
2 changes: 1 addition & 1 deletion ludwig/features/bag_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_feature_meta(
"str2idx": vocabulary.str2idx,
"str2freq": vocabulary.str2freq,
"vocab_size": len(vocabulary.str2idx),
"max_set_size": vocabulary.line_length_max,
"max_set_size": vocabulary.max_sequence_length,
}

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions ludwig/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def get_feature_meta(
processor=backend.df_engine,
)
logger.info(
f"Max length of feature '{column.name}': {vocabulary.line_length_max} (without start and stop symbols)"
f"Max length of feature '{column.name}': {vocabulary.max_sequence_length} (without start and stop symbols)"
)

# Use sequence_length if provided, otherwise use max length found in dataset.
Expand All @@ -229,7 +229,7 @@ def get_feature_meta(
)
max_sequence_length = preprocessing_parameters["sequence_length"]
else:
max_sequence_length = vocabulary.line_length_max + 2 # For start and stop symbols.
max_sequence_length = vocabulary.max_sequence_length + 2 # For start and stop symbols.
logger.info(f"Setting max length using dataset: {max_sequence_length} (including start and stop symbols)")

# If max_sequence_length is None, then use the max length found in the dataset.
Expand Down
2 changes: 1 addition & 1 deletion ludwig/features/set_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def get_feature_meta(
"str2idx": vocabulary.str2idx,
"str2freq": vocabulary.str2freq,
"vocab_size": len(vocabulary.str2idx),
"max_set_size": vocabulary.line_length_max,
"max_set_size": vocabulary.max_sequence_length,
}

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion ludwig/features/text_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def get_feature_meta(
"""
prompt_template = config.get("prompt", {}).get("template", "")
vocabulary: Vocabulary = create_vocabulary(
prompt_template,
column,
tokenizer_type=preprocessing_parameters["tokenizer"],
num_most_frequent=preprocessing_parameters["most_common"],
Expand All @@ -149,6 +148,7 @@ def get_feature_meta(
ngram_size=preprocessing_parameters["ngram_size"],
compute_idf=preprocessing_parameters["compute_idf"],
processor=backend.df_engine,
prompt_template=prompt_template,
)
# Note: The vocabulary's max_sequence_length includes the prompt template, which is merged into the column prior
# to computing feature metadata.
Expand Down
Loading

0 comments on commit dbe795f

Please sign in to comment.