Skip to content

Commit

Permalink
Parse TextDicts chunkwise to avoid OverflowError
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends committed Sep 3, 2024
1 parent fc01cd0 commit c646539
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 24 deletions.
26 changes: 8 additions & 18 deletions returnn/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ def tasks(self):
yield Task("run", mini_task=True)

def run(self):
d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(d, dict) # seq_tag -> bpe string
d = util.parse_text_dict(self.search_py_output)
assert not os.path.exists(self.out_word_search_results.get_path())
with util.uopen(self.out_word_search_results, "wt") as out:
out.write("{\n")
Expand Down Expand Up @@ -400,8 +399,7 @@ def tasks(self):
yield Task("run", mini_task=True)

def run(self):
d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(d, dict) # seq_tag -> bpe string
d = util.parse_text_dict(self.search_py_output)
assert not os.path.exists(self.out_search_results.get_path())

def _transform_text(s: str):
Expand Down Expand Up @@ -446,8 +444,7 @@ def tasks(self):
def run(self):
corpus = Corpus()
corpus.load(self.bliss_corpus.get_path())
d = eval(util.uopen(self.recog_words_file.get_path(), "rt").read())
assert isinstance(d, dict), "only search output file with dict format is supported"
d = util.parse_text_dict(self.recog_words_file)
with util.uopen(self.out_ctm_file.get_path(), "wt") as out:
out.write(";; <name> <track> <start> <duration> <word> <confidence> [<n-best>]\n")
for seg in corpus.segments():
Expand Down Expand Up @@ -531,10 +528,7 @@ def tasks(self):
yield Task("run", mini_task=True)

def run(self):
# nan/inf should not be needed, but avoids errors at this point and will print an error below,
# that we don't expect an N-best list here.
d = eval(util.uopen(self.recog_words_file, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(d, dict), "only search output file with dict format is supported"
d = util.parse_text_dict(self.recog_words_file)
if self.seq_order_file is not None:
seq_order = eval(util.uopen(self.seq_order_file, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(seq_order, (dict, list, tuple))
Expand Down Expand Up @@ -647,8 +641,7 @@ def tasks(self):

def run(self):
"""run"""
d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(d, dict) # seq_tag -> bpe string
d = util.parse_text_dict(self.search_py_output)
assert not os.path.exists(self.out_best_search_results.get_path())
with util.uopen(self.out_best_search_results, "wt") as out:
out.write("{\n")
Expand Down Expand Up @@ -686,8 +679,7 @@ def tasks(self):

def run(self):
"""run"""
d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(d, dict) # seq_tag -> bpe string
d = util.parse_text_dict(self.search_py_output)
assert not os.path.exists(self.out_search_results.get_path())
with util.uopen(self.out_search_results, "wt") as out:
out.write("{\n")
Expand Down Expand Up @@ -727,8 +719,7 @@ def tasks(self):

def run(self):
"""run"""
d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(d, dict) # seq_tag -> bpe string
d = util.parse_text_dict(self.search_py_output)
assert not os.path.exists(self.out_search_results.get_path())
with util.uopen(self.out_search_results, "wt") as out:
out.write("{\n")
Expand Down Expand Up @@ -786,8 +777,7 @@ def logsumexp(*args):
lsp = numpy.log(sum(numpy.exp(a - a_max) for a in args))
return a_max + lsp

d = eval(util.uopen(self.search_py_output, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(d, dict) # seq_tag -> bpe string
d = util.parse_text_dict(self.search_py_output)
assert not os.path.exists(self.out_search_results.get_path())
with util.uopen(self.out_search_results, "wt") as out:
out.write("{\n")
Expand Down
10 changes: 4 additions & 6 deletions text/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"TextDictToStmJob",
]

from typing import Optional, Union, Sequence, Dict, List, Tuple
from typing import Union, Sequence, Dict, Tuple
import re
from sisyphus import Job, Path, Task
from i6_core.util import uopen
from i6_core.util import parse_text_dict, uopen


class TextDictToTextLinesJob(Job):
Expand All @@ -30,8 +30,7 @@ def tasks(self):
def run(self):
# nan/inf should not be needed, but avoids errors at this point and will print an error below,
# that we don't expect an N-best list here.
d = eval(uopen(self.text_dict, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(d, dict) # seq_tag -> text
d = parse_text_dict(self.text_dict)

with uopen(self.out_text_lines, "wt") as out:
for seq_tag, entry in d.items():
Expand Down Expand Up @@ -83,8 +82,7 @@ def tasks(self):
def run(self):
# nan/inf should not be needed, but avoids errors at this point and will print an error below,
# that we don't expect an N-best list here.
c = eval(uopen(self.text_dict, "rt").read(), {"nan": float("nan"), "inf": float("inf")})
assert isinstance(c, dict)
c = parse_text_dict(self.text_dict)

all_tags = [
("d%d" % i, "default%d" % i, "all other segments of category %d" % i)
Expand Down
21 changes: 21 additions & 0 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,24 @@ def update_nested_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]):
else:
dict1[k] = v
return dict1


def parse_text_dict(path: Union[str, tk.Path]) -> Dict[str, str]:
"""
Loads the text dict at :param:`path` making sure not to trigger line counter overflow.
"""

with uopen(path, "rt") as text_dict_file:
txt = text_dict_file.read()

# remove leading and trailing dict brackets
txt = txt.strip().strip("{}").strip()

lines = txt.splitlines()
result = {
k: v
# parse chunkwise to avoid line counter overflow when the text dict is very large
for chunk in chunks(lines, max(1, len(lines) // 1000))
for k, v in eval(["{", *chunk, "}"].join("\n"), {"nan": float("nan"), "inf": float("inf")}).items()
}
return result

0 comments on commit c646539

Please sign in to comment.