-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathperformance_evaluation.py
179 lines (143 loc) · 8.77 KB
/
performance_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
from typing import List
def current_best_val_acc(val_acc: float, test_acc: float, best_val_acc: List[float],
best_test_acc_based_on_val_acc: List[float], naswt_score=None,
best_naswt_score_based_on_val_acc=None, fitness='val_acc'):
if best_val_acc != []:
if val_acc > best_val_acc[-1]:
best_val_acc.append(val_acc)
best_test_acc_based_on_val_acc.append(test_acc)
if fitness == 'naswt':
best_naswt_score_based_on_val_acc.append(naswt_score)
else:
best_val_acc.append(best_val_acc[-1])
best_test_acc_based_on_val_acc.append(best_test_acc_based_on_val_acc[-1])
if fitness == 'naswt':
best_naswt_score_based_on_val_acc.append(best_naswt_score_based_on_val_acc[-1])
else:
best_val_acc.append(val_acc)
best_test_acc_based_on_val_acc.append(test_acc)
if fitness == 'naswt':
best_naswt_score_based_on_val_acc.append(naswt_score)
if fitness == 'naswt':
return best_val_acc, best_test_acc_based_on_val_acc, best_naswt_score_based_on_val_acc
else:
return best_val_acc, best_test_acc_based_on_val_acc
def current_best_test_acc(test_acc, best_test_acc):
if best_test_acc != []:
if test_acc > best_test_acc[-1]:
best_test_acc.append(test_acc)
else:
best_test_acc.append(best_test_acc[-1])
else:
best_test_acc.append(test_acc)
return best_test_acc
def current_best_naswt_score(naswt_score: float, val_acc: float, test_acc: float, best_naswt_score: List[float],
best_val_acc_based_on_naswt_score: List[float],
best_test_acc_based_on_naswt_score: List[float]):
if best_naswt_score != []:
if naswt_score > best_naswt_score[-1]:
best_naswt_score.append(naswt_score)
best_val_acc_based_on_naswt_score.append(val_acc)
best_test_acc_based_on_naswt_score.append(test_acc)
else:
best_naswt_score.append(best_naswt_score[-1])
best_val_acc_based_on_naswt_score.append(best_val_acc_based_on_naswt_score[-1])
best_test_acc_based_on_naswt_score.append(best_test_acc_based_on_naswt_score[-1])
else:
best_naswt_score.append(naswt_score)
best_val_acc_based_on_naswt_score.append(val_acc)
best_test_acc_based_on_naswt_score.append(test_acc)
return best_naswt_score, best_val_acc_based_on_naswt_score, best_test_acc_based_on_naswt_score
def current_total_train_time(train_time: float, total_train_time: List[float]):
if total_train_time != []:
total_train_time.append(total_train_time[-1] + train_time)
else:
total_train_time.append(train_time)
return total_train_time
def current_total_naswt_calc_time(calc_time: float, total_naswt_calc_time: List[float]):
if total_naswt_calc_time != []:
total_naswt_calc_time.append(total_naswt_calc_time[-1] + calc_time)
else:
total_naswt_calc_time.append(calc_time)
return total_naswt_calc_time
def progress_update(val_acc: float, test_acc: float, train_time: float, best_val_acc: List[float],
best_test_acc_based_on_val_acc: List[float], best_test_acc: List[float], train_times: List[float],
total_train_time: List[float], fitness='val_acc', naswt_score=None, naswt_calc_time=None,
best_naswt_score_based_on_val_acc=None, best_naswt_score=None,
best_val_acc_based_on_naswt_score=None, best_test_acc_based_on_naswt_score=None,
naswt_calc_times=None, total_naswt_calc_time=None):
# validation accuracy
if fitness == 'naswt':
best_val_acc, best_test_acc_based_on_val_acc, best_naswt_score_based_on_val_acc = \
current_best_val_acc(val_acc=val_acc, test_acc=test_acc, best_val_acc=best_val_acc,
best_test_acc_based_on_val_acc=best_test_acc_based_on_val_acc, naswt_score=naswt_score,
best_naswt_score_based_on_val_acc=best_naswt_score_based_on_val_acc, fitness=fitness)
else:
best_val_acc, best_test_acc_based_on_val_acc = current_best_val_acc(val_acc=val_acc, test_acc=test_acc,
best_val_acc=best_val_acc,
best_test_acc_based_on_val_acc=best_test_acc_based_on_val_acc,
fitness=fitness)
# test accuracy
best_test_acc = current_best_test_acc(test_acc, best_test_acc)
# training time
train_times.append(train_time)
# total training time
total_train_time = current_total_train_time(train_time, total_train_time)
if fitness == 'naswt':
naswt_calc_times.append(naswt_calc_time)
best_naswt_score, best_val_acc_based_on_naswt_score, best_test_acc_based_on_naswt_score = \
current_best_naswt_score(naswt_score, val_acc, test_acc, best_naswt_score,
best_val_acc_based_on_naswt_score, best_test_acc_based_on_naswt_score)
total_naswt_calc_time = current_total_naswt_calc_time(naswt_calc_time, total_naswt_calc_time)
return best_val_acc, best_test_acc_based_on_val_acc, best_naswt_score_based_on_val_acc, best_test_acc, \
best_naswt_score, best_val_acc_based_on_naswt_score, best_test_acc_based_on_naswt_score, train_times, \
naswt_calc_times, total_train_time, total_naswt_calc_time
else:
return best_val_acc, best_test_acc_based_on_val_acc, best_test_acc, train_times, total_train_time
def save_performance(folder_name: str, exp_repeat_index: int, start_time: float, end_time: float,
best_val_acc: List[float], best_test_acc_based_on_val_acc: List[float],
best_test_acc: List[float], train_times: List[float], total_train_time: List[float],
fitness='val_acc', best_naswt_score_based_on_val_acc=None, best_naswt_score=None,
best_val_acc_based_on_naswt_score=None, best_test_acc_based_on_naswt_score=None,
naswt_calc_times=None, total_naswt_calc_time=None):
with open(os.path.join(folder_name, 'best_val_acc' + str(exp_repeat_index + 1) + '.txt'), 'w') as f:
for element in best_val_acc:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'best_test_acc_based_on_val_acc' + str(exp_repeat_index + 1) + '.txt'),
'w') as f:
for element in best_test_acc_based_on_val_acc:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'best_test_acc' + str(exp_repeat_index + 1) + '.txt'), 'w') as f:
for element in best_test_acc:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'train_times' + str(exp_repeat_index + 1) + '.txt'), 'w') as f:
for element in train_times:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'total_train_time' + str(exp_repeat_index + 1) + '.txt'), 'w') as f:
for element in total_train_time:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'execution_time' + str(exp_repeat_index + 1) + '.txt'), 'w') as f:
f.write(str(end_time - start_time) + '\n') # in seconds
if fitness == 'naswt':
with open(os.path.join(folder_name, 'best_naswt_score_based_on_val_acc' + str(exp_repeat_index + 1) + '.txt'),
'w') as f:
for element in best_naswt_score_based_on_val_acc:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'best_naswt_score' + str(exp_repeat_index + 1) + '.txt'), 'w') as f:
for element in best_naswt_score:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'best_val_acc_based_on_naswt_score' + str(exp_repeat_index + 1) + '.txt'),
'w') as f:
for element in best_val_acc_based_on_naswt_score:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'best_test_acc_based_on_naswt_score' + str(exp_repeat_index + 1) + '.txt'),
'w') as f:
for element in best_test_acc_based_on_naswt_score:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'naswt_calc_times' + str(exp_repeat_index + 1) + '.txt'), 'w') as f:
for element in naswt_calc_times:
f.write(str(element) + '\n')
with open(os.path.join(folder_name, 'total_naswt_calc_time' + str(exp_repeat_index + 1) + '.txt'), 'w') as f:
for element in total_naswt_calc_time:
f.write(str(element) + '\n')