diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx index d3e726afc0..01d324290e 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx @@ -657,27 +657,47 @@ export class ModelOverview extends React.Component< new AbortController().signal ) .then((result) => { - // Assumption: the lengths of `result` and `selectionIndexes` are the same. + const [allCohortMetrics, cohortClasses] = result; + + // Assumption: the lengths of `allCohortMetrics` and `selectionIndexes` are the same. const updatedMetricStats: ILabeledStatistic[][] = []; for (const [ cohortIndex, - [meanAveragePrecision, averagePrecision, averageRecall] - ] of result.entries()) { + cohortMetrics + ] of allCohortMetrics.entries()) { const count = selectionIndexes[cohortIndex].length; - const key: [number[], string, string, number] = [ - selectionIndexes[cohortIndex], - this.state.aggregateMethod, - this.state.className, - this.state.iouThreshold - ]; - if (!this.objectDetectionCache.has(key.toString())) { - this.objectDetectionCache.set(key.toString(), [ - meanAveragePrecision, - averagePrecision, - averageRecall - ]); + let meanAveragePrecision = -1; + let averagePrecision = -1; + let averageRecall = -1; + + // checking 2D array of computed metrics to cache + if ( + Array.isArray(cohortMetrics) && + cohortMetrics.every((subArray) => Array.isArray(subArray)) + ) { + for (const [i, cohortMetric] of cohortMetrics.entries()) { + const [mAP, aP, aR] = cohortMetric; + + const key: [number[], string, string, number] = [ + selectionIndexes[cohortIndex], + this.state.aggregateMethod, + cohortClasses[i], + this.state.iouThreshold + ]; + if (!this.objectDetectionCache.has(key.toString())) { + this.objectDetectionCache.set(key.toString(), [mAP, aP, aR]); + } + + if (this.state.className === cohortClasses[i]) { + [meanAveragePrecision, averagePrecision, averageRecall] = + cohortMetric; + } + } + } else if (Array.isArray(cohortMetrics)) { + [meanAveragePrecision, averagePrecision, averageRecall] = + cohortMetrics; } const updatedCohortMetricStats = [ diff --git a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py index 945566afcf..55a9e68cd7 100644 --- a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py +++ b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py @@ -1147,11 +1147,9 @@ def compute_object_detection_metrics( cohort_classes = list(set([classes[i - 1] for i in pred_labels + gt_labels])) cohort_classes.sort( - key=lambda class_name: classes.index(class_name)) + key=lambda cname: classes.index(cname)) # to catch if the class is not in the cohort - try: - index = cohort_classes.index(class_name) - except ValueError: + if class_name not in cohort_classes: all_cohort_metrics.append([-1, -1, -1]) else: metric_OD.update(cohort_pred, @@ -1159,10 +1157,17 @@ def compute_object_detection_metrics( object_detection_values = metric_OD.compute() mAP = round(object_detection_values ['map'].item(), 2) - AP = round(object_detection_values - ['map_per_class'][index].item(), 2) - AR = round(object_detection_values - ['mar_100_per_class'][index].item(), 2) - all_cohort_metrics.append([mAP, AP, AR]) + APs = [round(value, 2) for value in + object_detection_values['map_per_class'] + .detach().tolist()] + ARs = [round(value, 2) for value in + object_detection_values['mar_100_per_class'] + .detach().tolist()] + + assert len(APs) == len(ARs) == len(cohort_classes) + + all_submetrics = [[mAP, APs[i], ARs[i]] + for i in range(len(APs))] + all_cohort_metrics.append(all_submetrics) - return all_cohort_metrics + return [all_cohort_metrics, cohort_classes]