-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathprepare.py
81 lines (63 loc) · 2.12 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from utils import *
def lineiter(fo):
for line in fo:
x, y = line.split("\t")
x = tokenize(x, UNIT)
y = tokenize(y, UNIT)
if len(x) < MIN_LEN or len(x) > MAX_LEN:
continue
if len(y) < MIN_LEN or len(y) > MAX_LEN:
continue
yield x, y
def dict_to_tti(tti, vocab_size = 0):
tokens = [PAD, SOS, EOS, UNK] # predefined tokens
tti = sorted(tti, key = lambda x: -tti[x])
if vocab_size:
tti = tti[:vocab_size]
return {w: i for i, w in enumerate(tokens + tti)}
def load_data():
data = []
x_cti = defaultdict(int)
x_wti = defaultdict(int)
y_wti = defaultdict(int)
fo = open(sys.argv[1])
for x, y in lineiter(fo):
for w in x:
for c in w:
x_cti[c] += 1
x_wti[w] += 1
for w in y:
y_wti[w] += 1
x_cti = dict_to_tti(x_cti)
x_wti = dict_to_tti(x_wti, SRC_VOCAB_SIZE)
y_wti = dict_to_tti(y_wti, TGT_VOCAB_SIZE)
fo.seek(0)
for x, y in lineiter(fo):
x = ["+".join(str(x_cti[c]) for c in w) + ":%d" % x_wti.get(w, UNK_IDX) for w in x]
y = [str(y_wti.get(w, UNK_IDX)) for w in y]
data.append((x, y))
fo.close()
data = sorted(data, key = lambda x: -len(x[0])) # sort by source sequence length
return data, x_cti, x_wti, y_wti
def save_data(filename, data):
fo = open(filename, "w")
for seq in data:
if not seq:
print(file = fo)
continue
print(*seq[0], end = "\t", file = fo)
print(*seq[1], file = fo)
fo.close()
def save_tkn_to_idx(filename, tti):
fo = open(filename, "w")
for tkn, _ in sorted(tti.items(), key = lambda x: x[1]):
fo.write("%s\n" % tkn)
fo.close()
if __name__ == "__main__":
if len(sys.argv) != 2:
sys.exit("Usage: %s training_data" % sys.argv[0])
data, x_cti, x_wti, y_wti = load_data()
save_data(sys.argv[1] + ".csv", data)
save_tkn_to_idx(sys.argv[1] + ".src.char_to_idx", x_cti)
save_tkn_to_idx(sys.argv[1] + ".src.word_to_idx", x_wti)
save_tkn_to_idx(sys.argv[1] + ".tgt.word_to_idx", y_wti)