-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_ctc_utils.py
102 lines (86 loc) · 2.78 KB
/
text_ctc_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import pandas as pd
import numpy as np
from sklearn import preprocessing
import torch
def remove_duplicates(x):
if len(x) < 2:
return x
fin = ""
for j in x:
if fin == "":
fin = j
else:
if j == fin[-1]:
continue
else:
fin = fin + j
return fin
def decode_predictions(preds, encoder):
preds = torch.softmax(preds, 2)
preds = torch.argmax(preds, 2)
preds = preds.detach().cpu().numpy()
sign_preds = []
for j in range(preds.shape[0]):
temp = []
for k in preds[j, :]:
k = k - 1
if k == -1:
temp.append("§")
else:
p = encoder.inverse_transform([k])[0]
temp.append(p)
tp = "".join(temp).replace("§", "")
sign_preds.append(remove_duplicates(tp))
return sign_preds
def numerize(sents, vocab_map,full_transformer):
outs = []
for sent in sents:
if type(sent) != float :
if full_transformer:
outs.append([32]+ list(map(lambda x: vocab_map[x], sent))+ [0])
else:
outs.append(list(map(lambda x: vocab_map[x], sent)))
return outs
def invert_to_chars(sents, inv_ctc_map):
sents = sents.detach().numpy()
outs = []
for sent in sents:
for x in sent:
if x == 0:
break
outs.append(inv_ctc_map[x])
return outs
def get_ctc_vocab(char_list):
# blank
ctc_char_list = "_" + char_list
ctc_map, inv_ctc_map = {}, {}
for i, char in enumerate(ctc_char_list):
ctc_map[char] = i
inv_ctc_map[i] = char
return ctc_map, inv_ctc_map, ctc_char_list
def get_autoreg_vocab(char_list):
# blank
ctc_map, inv_ctc_map = {}, {}
for i, char in enumerate(char_list):
ctc_map[char] = i
inv_ctc_map[i] = char
return ctc_map, inv_ctc_map, char_list
def convert_text_for_ctc(DATASET_CSV_PATH,vocab_map,full_transformer=False):
all_data = pd.read_csv(DATASET_CSV_PATH)
all_data = all_data[all_data['filename'].notna()]
all_data = all_data[all_data['label_proc'].notna()]
label = all_data["label_proc"]
targets_enc = numerize(label, vocab_map,full_transformer)
# targets = [[c for c in x] for x in label]
# targets_flat = [c for clist in targets for c in clist]
# lbl_enc = preprocessing.LabelEncoder()
# lbl_enc.fit(targets_flat)
# targets_enc = [lbl_enc.transform(x) for x in targets]
# targets_enc = np.array(targets_enc)
# targets_enc = targets_enc + 1
df = pd.DataFrame()
df["names"] = all_data["filename"]
df["enc"] = targets_enc
# print("number of classes after conversion for CTC", lbl_enc.classes_)
return df
# return df , lbl_enc