-
Notifications
You must be signed in to change notification settings - Fork 0
/
track4-metrics.py
333 lines (287 loc) · 12.4 KB
/
track4-metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
# Metrics for track 4 - point cloud classification
import argparse
from pathlib import Path
import numpy as np
import re
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
# Add classification labels we're tracking here, this list is self sorting
NULL_CLASS = 0 # not classified in ground_truth
LABELS_OBJ = {2: "Ground", 5: "High Vegetation", 6: "Building", 9: "Water", 17: "Bridge Deck"}
LABELS = sorted(LABELS_OBJ.keys())
LABEL_INDEXES = dict([(label, index) for index, label in enumerate(sorted(LABELS_OBJ.keys()))])
class ClassificationScore:
"""
Used to store and compute metric scores
"""
def __init__(self, true_positive=0, false_negative=0, false_positive=0):
self.true_positive = true_positive
self.false_negative = false_negative
self.false_positive = false_positive
def get_iou(self):
denominator = self.true_positive + self.false_negative + self.false_positive
if denominator == 0:
return 1
return self.true_positive / denominator
def add(self, other):
self.true_positive += other.true_positive
self.false_positive += other.false_positive
self.false_negative += other.false_negative
def generate_confusion_matrix(zipped):
"""
Creates a confusion matrix from two aligned data sets
:param zipped: an array of tuples representing 2 arrays joined by index
:return: an MxM matrix of predictions vs truth
"""
invalid_labels = []
dim = len(LABELS)
matrix = np.zeros((dim, dim + 1), np.uint64)
skipped_count = 0
total_count = 0
for truth_val, prediction_val in zipped:
truth_index = LABEL_INDEXES.get(truth_val)
prediction_index = LABEL_INDEXES.get(prediction_val)
# if a label is found, but not tracked by this software, log it
if truth_index is None:
skipped_count += 1
if truth_val == NULL_CLASS:
continue # This was junk, can't reliably score it
if truth_val not in invalid_labels:
print("Invalid truth LABEL: {}".format(truth_val))
invalid_labels.append(truth_val)
continue
if prediction_index is None:
if prediction_val not in invalid_labels:
invalid_labels.append(prediction_val)
print("Invalid prediction LABEL: {}".format(prediction_val))
prediction_index = -1
matrix[truth_index,prediction_index] += 1
if(truth_index == prediction_index):
total_count += 1
if skipped_count:
print("Skipped {} unlabeled truth points".format(skipped_count))
print(total_count)
return matrix
def print_matrix(mat):
"""
Prints the confusion matrix to terminal in a human readable manner
:param mat: Confusion matrix generated by this software
:return: nothing
"""
across = "Truth"
ALL_LABELS = LABELS + ['OTHER']
for label in ALL_LABELS:
across += "{:>9} ".format("Pred({})".format(label))
print(across)
print('-' * len(across))
for label in LABELS:
across = "{:<3}".format(label)
for label_inner in ALL_LABELS:
i = LABEL_INDEXES[label]
j = -1 if label_inner is 'OTHER' else LABEL_INDEXES[label_inner]
across += "{:>9}|".format(mat[i][j])
print(across)
print()
def get_overall_accuracy(matrix):
"""
Computes the overall accuracy of predictions given a confusion matrix
:param matrix: Confusion matrix for a prediction event
:return: overall accuracy between 0 -> 1 where equates to 100%
"""
count = matrix.sum()
correct = 0
for ix in range(len(LABELS)):
correct += matrix[ix][ix]
if count == 0:
return 1
return correct / count
def get_mean_intersection_over_union(scores):
"""
MIOU generated by averaging IOUs of each class, unweighted
:param scores: an array of ClassificationScore objects
:return: 0->1 where 1 implies perfect intersection
"""
iou = 0
for score in scores:
iou += score.get_iou()
if len(scores) < 1:
return 0
return iou / len(scores)
def score_predictions(matrix):
"""
Calculates the number of true positives, false negatives and false positives for each class in a confusion matrix
:param matrix: Confusion matrix for a prediction event
:return: An array of ClassificationScores, one for each classification label
"""
per_class = []
for current_class_index in range(len(LABELS)):
true_positives = matrix[current_class_index][current_class_index]
# false negatives are the points where the truth is the class, but the prediction is not (so sum across the
# row and subtract off true positives)
false_negatives = np.sum(matrix[current_class_index, :]) - true_positives
# false negatives are the points where the prediction is the class, but the truth is not (so sum across the
# column and subtract off true positives)
false_positives = np.sum(matrix[:, current_class_index]) - true_positives
per_class.append(ClassificationScore(true_positives, false_negatives, false_positives))
return per_class
def report_scores(confusion_matrix):
overall_accuracy = get_overall_accuracy(confusion_matrix) * 100
print("Confusion matrix with overall accuracy: {:.2f}%".format(overall_accuracy))
print_matrix(confusion_matrix)
prediction_scores = score_predictions(confusion_matrix)
print("MIOU: {}".format(get_mean_intersection_over_union(prediction_scores)))
for index, prediction_score in enumerate(prediction_scores):
print("Class {:2d} ({:^17}), IOU: {:.4f}".format(LABELS[index], LABELS_OBJ[LABELS[index]],
prediction_score.get_iou()))
print()
def score_prediction_files(ground_truth_file, prediction_file):
"""
Scores a list of prediction files
:param ground_truth_file: Ground truth classification file
:param prediction_files: Array of classification prediction files
:return: None
"""
print("Scoring {} against {}:".format(prediction_file, ground_truth_file))
# Create default confusion matrix
dim = len(LABELS)
confusion_matrix = np.zeros((dim, dim + 1), np.uint)
# Load ground truth data
with open(str(ground_truth_file), 'r') as file:
try:
gt_data = [int(line) for line in file]
except ValueError:
print("Error reading {}".format(ground_truth_file))
return confusion_matrix
# Load prediction data
with open(str(prediction_file), 'r') as file:
try:
pd_data = [int(line) for line in file]
except ValueError:
print("Error reading {}".format(prediction_file))
return confusion_matrix
# Error check number of values
if len(gt_data) != len(pd_data):
print("Mismatched file lengths!")
return confusion_matrix
# Matches line i to line i, creating an array of tuples (ground_truth[i], prediction[i])
one_to_one = zip(gt_data, pd_data)
confusion_matrix = generate_confusion_matrix(one_to_one)
print("Scores for {} (truth: {}):".format(prediction_file, ground_truth_file))
report_scores(confusion_matrix)
return confusion_matrix
def directory_type(arg_string):
"""
Allows arg parser to handle directories
:param arg_string: A path, relative or absolute to a folder
:return: A python pure path object to a directory.
"""
directory_path = Path(arg_string)
if directory_path.exists() and directory_path.is_dir():
return directory_path
raise argparse.ArgumentError("{} is not a valid directory.".format(arg_string))
def file_type(arg_string):
"""
Allows arg parser to check against files
:param arg_string: A path, relative or absolute to a file
:return: A python pure path object to a file.
"""
file_path = Path(arg_string)
if file_path.exists() and file_path.is_file():
return file_path
raise argparse.ArgumentError("{} is not a valid directory.".format(arg_string))
def get_list_of_files(directory_path):
p = re.compile(r'[A-Z]{3}_\d{3,}_.*')
# First try to find text files that include 'CLS' in their name
classification_files = [Path(file) for file in directory_path.glob('*CLS*.txt') if p.match(file.name)]
if not classification_files:
# Fall back to any text files
classification_files = [Path(file) for file in directory_path.glob('*.txt') if p.match(file.name)]
if not classification_files:
raise ValueError("Could not find classification files in {}".format(directory_path))
return sorted(classification_files)
def get_tile_name(file):
p = re.compile(r'([A-Z]{3}_\d{3,})_.*')
return p.match(file.name).group(1)
def match_file_pairs(A, B):
if not type(A) == list:
a_tile = get_tile_name(A)
b = next(b for b in B if get_tile_name(b) == a_tile)
if b:
return (A, b)
else:
raise ValueError("Could not match {}".format(A))
elif not type(B) == list:
return match_file_pairs(B, A)[::-1]
else:
return [match_file_pairs(a, B) for a in A]
labels = ["Undefined", "Ground", "High Vegetation", "Building", "Water", "Bridge Deck"]
def plot_confusion_matrix(cm, title='Confusion Matrix', cmap=plt.cm.binary):
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
xlocations = np.array(range(len(labels) - 1))
plt.xticks(xlocations, labels[1:], rotation=90)
plt.yticks(xlocations, labels[1:])
plt.ylabel('True label',fontsize=16)
plt.xlabel('Predicted label',fontsize=16)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--ground_truth_directory', type=directory_type)
parser.add_argument('-t', '--ground_truth_file', type=file_type)
parser.add_argument('-d', '--prediction_directory', type=directory_type)
parser.add_argument('-f', '--prediction_file', type=file_type)
args = parser.parse_args()
# Get list of truth files
truth_files = []
if args.ground_truth_file is not None:
truth_files.append(args.ground_truth_file)
if args.ground_truth_directory is not None:
truth_files.extend(get_list_of_files(args.ground_truth_directory))
if not truth_files:
raise ValueError('No ground truth paths specified')
# Get list of class prediction files
prediction_files = []
if args.prediction_file is not None:
prediction_files.append(args.prediction_file)
if args.prediction_directory is not None:
prediction_files.extend(get_list_of_files(args.prediction_directory))
if not prediction_files:
raise ValueError('No prediction paths specified')
# Match truth to prediction files
file_pairs = match_file_pairs(truth_files, prediction_files)
if not type(file_pairs) == list:
file_pairs = [file_pairs]
confusion_matrix = np.zeros((len(LABELS), len(LABELS) + 1), np.uint)
for file_pair in file_pairs:
confusion_matrix += score_prediction_files(*file_pair)
if len(file_pairs) >= 1:
print("----- OVERALL SCORES -----")
report_scores(confusion_matrix)
#Plot Confusion Metrics
tick_marks = np.array(range(len(labels) -1 )) + 0.5
confusion_matrix = np.array(confusion_matrix[:,:-1])
np.set_printoptions(precision=3)
x = confusion_matrix.sum(axis=1)[:, np.newaxis]
for i in range(len(x)):
if not x[i]:
x[i] = 1
cm_normalized = confusion_matrix.astype('float') / x
# print(cm_normalized)
plt.figure(figsize=(12, 8), dpi=120)
ind_array = np.arange(len(labels) - 1)
x, y = np.meshgrid(ind_array, ind_array)
for x_val, y_val in zip(x.flatten(), y.flatten()):
c = cm_normalized[y_val][x_val]
if c > 0.001:
plt.text(x_val, y_val, "%0.3f" % (c,), color='red', fontsize=10, va='center', ha='center')
# offset the tick
plt.gca().set_xticks(tick_marks, minor=True)
plt.gca().set_yticks(tick_marks, minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.25)
plot_confusion_matrix(cm_normalized, title='OVERALL SCORES')
# show confusion matrix
plt.savefig('confusion_matrix_overall_bl.png', format='png')
plt.show()