-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathasl_utils.py
108 lines (84 loc) · 4.01 KB
/
asl_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from asl_data import SinglesData, WordsData
import numpy as np
from IPython.core.display import display, HTML
RAW_FEATURES = ['left-x', 'left-y', 'right-x', 'right-y']
GROUND_FEATURES = ['grnd-rx', 'grnd-ry', 'grnd-lx', 'grnd-ly']
def get_wer(guesses: list, test_set: SinglesData):
S = 0
N = len(test_set.wordlist)
for word_id in range(len(test_set.wordlist)):
if guesses[word_id] != test_set.wordlist[word_id]:
S += 1
return S, N, float(S)/float(N)
def show_errors(guesses: list, test_set: SinglesData):
""" Print WER and sentence differences in tabular form
:param guesses: list of test item answers, ordered
:param test_set: SinglesData object
:return:
nothing returned, prints error report
WER = (S+I+D)/N but we have no insertions or deletions for isolated words so WER = S/N
"""
S, N, WER = get_wer(guesses, test_set)
if len(guesses) != len(test_set.wordlist):
print("Size of guesses must equal number of test words ({})!".format(num_test_words))
print("\n**** WER = {}".format(WER))
print("Total correct: {} out of {}".format(N - S, N))
print('Video Recognized Correct')
print('=====================================================================================================')
for video_num in test_set.sentences_index:
correct_sentence = [test_set.wordlist[i] for i in test_set.sentences_index[video_num]]
recognized_sentence = [guesses[i] for i in test_set.sentences_index[video_num]]
for i in range(len(recognized_sentence)):
if recognized_sentence[i] != correct_sentence[i]:
recognized_sentence[i] = '*' + recognized_sentence[i]
print('{:5}: {:60} {}'.format(video_num, ' '.join(recognized_sentence), ' '.join(correct_sentence)))
def getKey(item):
return item[1]
def train_all_words(training: WordsData, model_selector):
""" train all words given a training set and selector
:param training: WordsData object (training set)
:param model_selector: class (subclassed from ModelSelector)
:return: dict of models keyed by word
"""
sequences = training.get_all_sequences()
Xlengths = training.get_all_Xlengths()
model_dict = {}
for word in training.words:
model = model_selector(sequences, Xlengths, word,
n_constant=3).select()
model_dict[word] = model
return model_dict
def combine_sequences(split_index_list, sequences):
'''
concatenate sequences referenced in an index list and returns tuple of the new X,lengths
useful when recombining sequences split using KFold for hmmlearn
:param split_index_list: a list of indices as created by KFold splitting
:param sequences: list of feature sequences
:return: tuple of list, list in format of X,lengths use in hmmlearn
'''
sequences_fold = [sequences[idx] for idx in split_index_list]
X = [item for sublist in sequences_fold for item in sublist]
lengths = [len(sublist) for sublist in sequences_fold]
return X, lengths
def putHTML(color, msg):
source = """<font color={}>{}</font><br/>""".format(color, msg)
return HTML(source)
def feedback(passed, failmsg='', passmsg='Correct!'):
if passed:
return putHTML('green', passmsg)
else:
return putHTML('red', failmsg)
def test_features_tryit(asl):
print('asl.df sample')
display(asl.df.head())
sample = asl.df.ix[98, 1][GROUND_FEATURES].tolist()
correct = [9, 113, -12, 119]
failmsg = 'The values returned were not correct. Expected: {} Found: {}'.format(correct, sample)
return feedback(sample == correct, failmsg)
def test_std_tryit(df_std):
print('df_std')
display(df_std)
sample = df_std.ix['man-1'][RAW_FEATURES]
correct = [15.154425, 36.328485, 18.901917, 54.902340]
failmsg = 'The raw man-1 values returned were not correct.\nExpected: {} for {}'.format(correct, RAW_FEATURES)
return feedback(np.allclose(sample, correct, .001), failmsg)