diff --git a/lm/vocabulary.py b/lm/vocabulary.py index 59c9d749..080c6bf1 100644 --- a/lm/vocabulary.py +++ b/lm/vocabulary.py @@ -159,7 +159,7 @@ class VocabularyFromTextJob(Job): Extract vocabulary from given text files based on frequency. """ - def __init__(self, file_paths: List[tk.Path], num_words: int = 1_000_000): + def __init__(self, file_paths: List[tk.Path], num_words: Union[int, tk.Variable] = 1_000_000): """ :param file_paths: paths to the text files :param num_words: expected size of the vocabulary @@ -185,12 +185,14 @@ def run(self): words = line.strip().split() counter.update(words) - cutoff = min(self.num_words, len(counter)) + num_words = self.num_words.get() if isinstance(self.num_words, tk.Variable) else self.num_words + + cutoff = min(num_words, len(counter)) with open(self.out_vocabulary, "w") as vocabulary, open( self.out_vocabulary_with_counts, "w" ) as vocabulary_with_counts: - for (word, count) in counter.most_common(cutoff): + for word, count in counter.most_common(cutoff): vocabulary.write(f"{word}\n") vocabulary_with_counts.write(f"{word} {count}\n")