Skip to content

Commit

Permalink
[Object Detection] Model Overview Cache Support Extension for torchme…
Browse files Browse the repository at this point in the history
…trics (#2170)

* cache ext ckpt

* cache ext working ckpt

* python lint fixes

* auto lint fixes
  • Loading branch information
Advitya17 authored Jul 13, 2023
1 parent f3d8cbb commit 5dc4159
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -657,27 +657,47 @@ export class ModelOverview extends React.Component<
new AbortController().signal
)
.then((result) => {
// Assumption: the lengths of `result` and `selectionIndexes` are the same.
const [allCohortMetrics, cohortClasses] = result;

// Assumption: the lengths of `allCohortMetrics` and `selectionIndexes` are the same.
const updatedMetricStats: ILabeledStatistic[][] = [];

for (const [
cohortIndex,
[meanAveragePrecision, averagePrecision, averageRecall]
] of result.entries()) {
cohortMetrics
] of allCohortMetrics.entries()) {
const count = selectionIndexes[cohortIndex].length;

const key: [number[], string, string, number] = [
selectionIndexes[cohortIndex],
this.state.aggregateMethod,
this.state.className,
this.state.iouThreshold
];
if (!this.objectDetectionCache.has(key.toString())) {
this.objectDetectionCache.set(key.toString(), [
meanAveragePrecision,
averagePrecision,
averageRecall
]);
let meanAveragePrecision = -1;
let averagePrecision = -1;
let averageRecall = -1;

// checking 2D array of computed metrics to cache
if (
Array.isArray(cohortMetrics) &&
cohortMetrics.every((subArray) => Array.isArray(subArray))
) {
for (const [i, cohortMetric] of cohortMetrics.entries()) {
const [mAP, aP, aR] = cohortMetric;

const key: [number[], string, string, number] = [
selectionIndexes[cohortIndex],
this.state.aggregateMethod,
cohortClasses[i],
this.state.iouThreshold
];
if (!this.objectDetectionCache.has(key.toString())) {
this.objectDetectionCache.set(key.toString(), [mAP, aP, aR]);
}

if (this.state.className === cohortClasses[i]) {
[meanAveragePrecision, averagePrecision, averageRecall] =
cohortMetric;
}
}
} else if (Array.isArray(cohortMetrics)) {
[meanAveragePrecision, averagePrecision, averageRecall] =
cohortMetrics;
}

const updatedCohortMetricStats = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1147,22 +1147,27 @@ def compute_object_detection_metrics(
cohort_classes = list(set([classes[i - 1]
for i in pred_labels + gt_labels]))
cohort_classes.sort(
key=lambda class_name: classes.index(class_name))
key=lambda cname: classes.index(cname))
# to catch if the class is not in the cohort
try:
index = cohort_classes.index(class_name)
except ValueError:
if class_name not in cohort_classes:
all_cohort_metrics.append([-1, -1, -1])
else:
metric_OD.update(cohort_pred,
cohort_gt)
object_detection_values = metric_OD.compute()
mAP = round(object_detection_values
['map'].item(), 2)
AP = round(object_detection_values
['map_per_class'][index].item(), 2)
AR = round(object_detection_values
['mar_100_per_class'][index].item(), 2)
all_cohort_metrics.append([mAP, AP, AR])
APs = [round(value, 2) for value in
object_detection_values['map_per_class']
.detach().tolist()]
ARs = [round(value, 2) for value in
object_detection_values['mar_100_per_class']
.detach().tolist()]

assert len(APs) == len(ARs) == len(cohort_classes)

all_submetrics = [[mAP, APs[i], ARs[i]]
for i in range(len(APs))]
all_cohort_metrics.append(all_submetrics)

return all_cohort_metrics
return [all_cohort_metrics, cohort_classes]

0 comments on commit 5dc4159

Please sign in to comment.