Skip to content

Commit

Permalink
Adds dataset parameter to class:DiscountNgramsJob (#525)
Browse files Browse the repository at this point in the history
* add dataset var for opt

* address reviewer comments

* Nicer scientific number handling

Co-authored-by: michelwi <[email protected]>

* Consistent naming

Co-authored-by: michelwi <[email protected]>

* black

---------

Co-authored-by: Christoph Lüscher <[email protected]>
Co-authored-by: michelwi <[email protected]>
  • Loading branch information
3 people authored Jul 11, 2024
1 parent d26cb61 commit 3b48a7f
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions lm/srilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,16 @@ class DiscountNgramsJob(Job):
Create a file with the discounted ngrams with SRILM
"""

__sis_hash_exclude__ = {"data_for_optimization": None}

def __init__(
self,
ngram_order: int,
counts: tk.Path,
count_exe: tk.Path,
*,
vocab: Optional[tk.Path] = None,
data_for_optimization: Optional[tk.Path] = None,
extra_discount_args: Optional[List[str]] = None,
use_modified_srilm: bool = False,
cpu_rqmt: int = 1,
Expand All @@ -116,8 +119,9 @@ def __init__(
"""
:param ngram_order: order of the ngram counts, typically 3 or 4.
:param counts: file with the ngram counts, see :class:`CountNgramsJob.out_counts`.
:param vocab: vocabulary file for the discounting.
:param count_exe: path to the binary.
:param vocab: vocabulary file for the discounting.
:param data_for_optimization: the discounting will be optimized on this dataset.
:param extra_discount_args: additional arguments for the discounting step.
:param use_modified_srilm: Use the i6 modified SRILM version by Sundermeyer.
The SRILM binary ngram-count was modified.
Expand All @@ -129,6 +133,7 @@ def __init__(
self.ngram_order = ngram_order
self.counts = counts
self.vocab = vocab
self.data_for_optimization = data_for_optimization
self.discount_args = extra_discount_args or []
self.use_modified_srilm = use_modified_srilm

Expand All @@ -153,10 +158,10 @@ def create_files(self):
f" -order {self.ngram_order} \\\n",
]
if self.vocab is not None:
cmd += [
f" -vocab {self.vocab.get_cached_path()} \\\n",
]
cmd += [" -kn discounts\\\n"] if not self.use_modified_srilm else [f" -multi-kn-file discounts\\\n"]
cmd.append(f" -vocab {self.vocab.get_cached_path()} \\\n")
cmd += [" -kn discounts\\\n"] if not self.use_modified_srilm else [f" -multi-kn-file discounts \\\n"]
if self.data_for_optimization is not None:
cmd.append(f" -optimize-discounts {self.data_for_optimization.get_cached_path()} \\\n")
cmd += [
f" -read {self.counts.get_cached_path()} \\\n",
f" {' '.join(self.discount_args)} -memuse\n",
Expand Down Expand Up @@ -410,7 +415,7 @@ def get_ppl(self):
if ln == "sentences,":
self.out_num_sentences.set(int(line[idx - 1]))
if ln == "words,":
self.out_num_words.set(int(line[idx - 1]))
self.out_num_words.set(int(float(line[idx - 1])))
if ln == "OOVs":
self.out_num_oovs.set(int(line[idx - 1]))
if ln == "ppl=":
Expand Down

0 comments on commit 3b48a7f

Please sign in to comment.