From 5879d8e34cacaef5f5b764a2a6e010cea7d2af1f Mon Sep 17 00:00:00 2001 From: hoominchu Date: Thu, 20 Jun 2024 13:56:41 -0700 Subject: [PATCH 01/12] Using sklearn metrics and combining visualization. Adding support for validation for ViT CLIP model with timm. Other. --- notebooks/validation.py | 402 ++++++++++++++++++++++++---------------- 1 file changed, 240 insertions(+), 162 deletions(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index a9b3dd1..688dd7f 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -10,7 +10,11 @@ from torch import nn, optim from copy import deepcopy import sys -from visualize import draw_confusion_matrices +from sklearn.metrics import precision_recall_curve, auc, PrecisionRecallDisplay +import matplotlib.pyplot as plt +from sklearn.metrics import roc_curve, RocCurveDisplay +from sklearn.metrics import average_precision_score, accuracy_score, hamming_loss, f1_score +import timm sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -104,8 +108,32 @@ def forward(self, x): return x +class CLIP_Classifier(nn.Module): + def __init__(self, model_name='', n_target_classes=7): + super().__init__() + self.model = timm.create_model(model_name=model_name, pretrained=True, in_chans=3) + # Replace the final head layers in model with our own Linear layer + num_features = self.model.num_features + self.model.head = nn.Linear(num_features, 256) + self.fully_connect = nn.Sequential(nn.Linear(256, 128), + nn.ReLU(), + nn.Linear(128, n_target_classes)) + + # self.model.load_state_dict(torch.load('{}/'.format(local_directory) + model_name, map_location=torch.device(device))) + + # self.transformer = deepcopy(self.model) + + def forward(self, image): + x = self.model(image) + # Using dropout functions to randomly shutdown some of the nodes in hidden layers to prevent overfitting. + # x = self.dropout(x) + # # Concatenate the metadata into the results. + x = self.fully_connect(x) + return x + + def get_labels_ref_for_run(inference_set_dir): - csv_file_path = os.path.join(inference_set_dir, '_classes.csv') + csv_file_path = os.path.join(inference_set_dir, 'test.csv') label_data = pd.read_csv(csv_file_path) # get the header row @@ -121,16 +149,13 @@ def get_labels_ref_for_run(inference_set_dir): global c12n_category_offset c12n_category_offset = validated_by_index + 1 - return labels_ref_for_run - + if params['label_type'] == 'obstacle': + if len(labels_ref_for_run) == 20: + labels_ref_for_run = labels_ref_for_run[:-3] + else: + raise ValueError('Unexpected number of labels for obstacle') -image_dimension = 256 - -# This is what DinoV2 sees -target_size = (image_dimension, image_dimension) - - -confidence_levels = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1] + return labels_ref_for_run # enum to track the classification categories @@ -139,24 +164,54 @@ def get_labels_ref_for_run(inference_set_dir): 'SEVERITY': 'severity', } +MODEL_PREFIXES = { + 'CLIP': 'clip', + 'DINO': 'dino', +} + # ------------------------------ # all the parameters to be customized for the run # ------------------------------ -label_type = 'surfaceproblem' -c12n_category = C12N_CATEGORIES['TAGS'] -inference_set_dir_name = 'test' + +params = { + 'label_type': 'curbramp', + 'pretrained_model_prefix': MODEL_PREFIXES['DINO'], + 'dataset_type': 'validated', # 'unvalidated' or 'validated' + + # these don't really change for now + 'c12n_category': C12N_CATEGORIES['TAGS'], + 'inference_set_dir_name': 'test', +} + +# ------------------------------ + +# suppress the tags that have less than the threshold count in the plot +suppress_thresholds = { + 'crosswalk': 10, + 'obstacle': 10, + 'surfaceproblem': 10, + 'curbramp': 10 +} + +image_dimension = 256 + +if params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP']: + image_dimension = 224 + +# This is what DinoV2 sees +target_size = (image_dimension, image_dimension) # temporarily skipping the cities with messy data -skip_cities = ['cdmx', 'spgg', 'newberg', 'columbus'] +# skip_cities = ['cdmx', 'spgg', 'newberg', 'columbus'] +skip_cities = [] -dataset_dirname = 'crops-' + label_type + '-' + c12n_category # example: crops-surfaceproblem-tags-archive +dataset_dirname = 'crops-' + params['label_type'] + '-' + params['c12n_category'] # example: crops-surfaceproblem-tags-archive # dataset_dirname = 'crops-' + label_type + '-' + c12n_category + '-validated' # example: crops-surfaceproblem-tags-archive dataset_dir_path = '../datasets/' + dataset_dirname # example: ../datasets/crops-surfaceproblem-tags-archive -inference_dataset_dir = Path(dataset_dir_path + "/" + inference_set_dir_name) +inference_dataset_dir = Path(dataset_dir_path + "/" + params['inference_set_dir_name']) -model_name = 'cls-b-' + label_type + '-' + c12n_category + '-best.pth' -# model_name = 'cls-b-obstacle-tags-masked-best.pth' +model_name = 'models/' + params['dataset_type'] + '-' + params['pretrained_model_prefix'] + '-cls-b-' + params['label_type'] + '-' + params['c12n_category'] + '-best.pth' # ------------------------------ local_directory = os.getcwd() @@ -168,18 +223,14 @@ def get_labels_ref_for_run(inference_set_dir): else: print("GPU not available") + +if params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP']: + img_resize_multiple = 32 +else: + img_resize_multiple = 14 + data_transforms = { - "train": transforms.Compose( - [ - ResizeAndPad(target_size, 14), - # transforms.RandomRotation(360), - # transforms.RandomHorizontalFlip(), - # transforms.RandomVerticalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ), - "inference": transforms.Compose([ResizeAndPad(target_size, 14), + "inference": transforms.Compose([ResizeAndPad(target_size, img_resize_multiple), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) @@ -191,7 +242,7 @@ def get_labels_ref_for_run(inference_set_dir): # it's okay if the csv contains more filenames than the images in the directory # we will only load the images that are present in the directory and query the csv for labels def images_loader(dir_path, batch_size, imgsz, transform): - file_path = os.path.join(dir_path, '_classes.csv') + file_path = os.path.join(dir_path, 'test.csv') label_data = pd.read_csv(file_path) filenames = [] @@ -246,9 +297,6 @@ def data_loader(dir_path, batch_size, imgsz, transform): # ----------------------------------------------------------------- - -serverity_labels = ['s-1', 's-2', 's-3', 's-4', 's-5'] - # todo this needs a better name! labels_ref_for_run = get_labels_ref_for_run(inference_dataset_dir) @@ -265,173 +313,203 @@ def data_loader(dir_path, batch_size, imgsz, transform): images_and_labels = data_loader(inference_dataset_dir, 1, image_dimension, 'inference') + # ----------------------------------------------------------------- -def inference_on_validation_data(inference_model, confidence_level=0.5): +def inference_on_validation_data(inference_model): - # we track these for each confidence level - n_incorrect_predictions_to_filenames = {} - category_to_true_positive_counts = {} - category_to_false_positive_counts = {} - category_to_false_negative_counts = {} + y_true = [] + y_pred = [] - category_to_prediction_stats = {} - category_to_prediction_details = {} + for idx in range(len(images_and_labels)): - for img_label_filename in images_and_labels: + img_label_filename = images_and_labels[idx] img_tensor, labels, filename = img_label_filename + print('Processing image. Index: {}, Filename: {}'.format(idx, filename)) + input_tensor = img_tensor.to(device) labels_tensor = labels.to(device) + y_true.append(labels.tolist()[0]) + # run model on input image data with torch.no_grad(): - embeddings = inference_model.transformer(input_tensor) - x = inference_model.transformer.norm(embeddings) - output_tensor = inference_model.classifier(x) + if params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP']: + output_tensor = inference_model(input_tensor) + else: + embeddings = inference_model.transformer(input_tensor) + x = inference_model.transformer.norm(embeddings) + output_tensor = inference_model.classifier(x) # Convert outputs to probabilities using sigmoid - if c12n_category == C12N_CATEGORIES['SEVERITY']: + if params['c12n_category'] == C12N_CATEGORIES['SEVERITY']: probabilities = torch.softmax(output_tensor, dim=1) - elif c12n_category == C12N_CATEGORIES['TAGS']: + elif params['c12n_category'] == C12N_CATEGORIES['TAGS']: probabilities = torch.sigmoid(output_tensor) + y_pred.append(probabilities.tolist()[0]) - # Convert probabilities to predicted classes - predicted_classes = probabilities > confidence_level - # Calculate accuracy - n_labels = labels.size(1) - n_incorrect_predictions = (predicted_classes != labels_tensor.byte()).sum().item() - correct_predictions = ((predicted_classes == labels_tensor.byte()).sum().item()) / n_labels + y_true_np = np.array(y_true) + y_pred_np = np.array(y_pred) - # updating number of incorrect predictions to file names - if n_incorrect_predictions in n_incorrect_predictions_to_filenames: - n_incorrect_predictions_to_filenames[n_incorrect_predictions].append(filename) - else: - n_incorrect_predictions_to_filenames[n_incorrect_predictions] = [filename] - - - predicted_classes_list = predicted_classes.tolist()[0] - ground_truth_labels_list = labels.tolist()[0] - - # getting the list of predicted and ground truth labels for the current crop - predicted_classes_for_crop = [] - for x in range(len(predicted_classes_list)): - if predicted_classes_list[x]: - predicted_classes_for_crop.append(labels_ref_for_run[x]) - - gt_labels_for_crop = [] - for x in range(len(ground_truth_labels_list)): - if ground_truth_labels_list[x] == 1.0: - gt_labels_for_crop.append(labels_ref_for_run[x]) - - - # updating true positives - for elem in predicted_classes_for_crop: - if elem in gt_labels_for_crop: - if elem in category_to_true_positive_counts: - category_to_true_positive_counts[elem].append(filename) - else: - category_to_true_positive_counts[elem] = [filename] - - if elem in category_to_prediction_stats: - category_to_prediction_stats[elem]['true-positive'] += 1 - category_to_prediction_details[elem]['true-positive'].append({'filename': filename, 'predicted': predicted_classes_for_crop, 'ground-truth': gt_labels_for_crop}) - else: - category_to_prediction_stats[elem] = {'true-positive': 1, 'false-positive': 0, 'false-negative': 0, 'true-negative': 0} - category_to_prediction_details[elem] = {'true-positive': [{'filename': filename, 'predicted': predicted_classes_for_crop, 'ground-truth': gt_labels_for_crop}], 'false-positive': [], 'false-negative': [], 'true-negative': []} - - # updating false positives - for elem in predicted_classes_for_crop: - if elem not in gt_labels_for_crop: - if elem in category_to_false_positive_counts: - category_to_false_positive_counts[elem].append(filename) - else: - category_to_false_positive_counts[elem] = [filename] - - if elem in category_to_prediction_stats: - category_to_prediction_stats[elem]['false-positive'] += 1 - category_to_prediction_details[elem]['false-positive'].append({'filename': filename, 'predicted': predicted_classes_for_crop, 'ground-truth': gt_labels_for_crop}) - else: - category_to_prediction_stats[elem] = {'true-positive': 0, 'false-positive': 1, 'false-negative': 0, 'true-negative': 0} - category_to_prediction_details[elem] = {'true-positive': [], 'false-positive': [{'filename': filename, 'predicted': predicted_classes_for_crop, 'ground-truth': gt_labels_for_crop}], 'false-negative': [], 'true-negative': []} - - # updating false negatives - for elem in gt_labels_for_crop: - if elem not in predicted_classes_for_crop: - if elem in category_to_false_negative_counts: - category_to_false_negative_counts[elem].append(filename) - else: - category_to_false_negative_counts[elem] = [filename] - - if elem in category_to_prediction_stats: - category_to_prediction_stats[elem]['false-negative'] += 1 - category_to_prediction_details[elem]['false-negative'].append({'filename': filename, 'predicted': predicted_classes_for_crop, 'ground-truth': gt_labels_for_crop}) - else: - category_to_prediction_stats[elem] = {'true-positive': 0, 'false-positive': 0, 'false-negative': 1, 'true-negative': 0} - category_to_prediction_details[elem] = {'true-positive': [], 'false-positive': [], 'false-negative': [{'filename': filename, 'predicted': predicted_classes_for_crop, 'ground-truth': gt_labels_for_crop}], 'true-negative': []} - - print("{} | Correct percent = {} | Predicted = {} vs. " - "Ground Truth = {}:".format(filename, correct_predictions, predicted_classes_for_crop, gt_labels_for_crop)) - - # update the global variables - all_n_incorrect_predictions_to_filenames[conf_level] = n_incorrect_predictions_to_filenames - all_category_to_true_positive_counts[conf_level] = category_to_true_positive_counts - all_category_to_false_positive_counts[conf_level] = category_to_false_positive_counts - all_category_to_false_negative_counts[conf_level] = category_to_false_negative_counts - # all_category_to_true_negative_counts[confidence_threshold] | add true negative here - all_category_to_prediction_stats[conf_level] = category_to_prediction_stats - all_category_to_prediction_details[conf_level] = category_to_prediction_details + # Create a list of tuples (tag, precision, recall, n_instances) + # sum the columns of y_true_np to get the number of instances in the ground truth labels + tag_to_n_instances = [(labels_ref_for_run[i], int(np.sum(y_true_np[:, i])), i) for i in range(len(labels_ref_for_run))] + + # Sort the list based on n_instances + tag_to_n_instances.sort(key=lambda x: x[1], reverse=True) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12)) + + tags_not_plotted = [] + + all_average_precisions = [] + + for i in range(len(tag_to_n_instances)): + + tag_name, n_instances, tag_idx = tag_to_n_instances[i] + + sum_gt = np.sum(y_true_np[:, tag_idx]) + if sum_gt != n_instances: + # throw an error and stop + raise ValueError('What is happening! For tag {} sum of instances: {} and n_instances: {} are not equal'.format(tag_name, sum_gt, n_instances)) + + # compute precision, recall, thresholds + precision, recall, thresholds = precision_recall_curve(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], pos_label=1) + pr_auc = auc(recall, precision) + + average_precision_val = average_precision_score(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], average='weighted') + + all_f1_pr = 2 * precision * recall / (precision + recall) + ix_pr = np.argmax(all_f1_pr) + best_thresh_pr = thresholds[ix_pr] + precision_pr_at_best_conf = precision[ix_pr] + recall_pr_at_best_conf = recall[ix_pr] + + y_pred_class_pr = np.where(y_pred_np[:, tag_idx] > best_thresh_pr, 1, 0) + f1_score_pr = f1_score(y_true_np[:, tag_idx], y_pred_class_pr) + accuracy_score_pr = accuracy_score(y_true_np[:, tag_idx], y_pred_class_pr) + + + all_category_to_prediction_stats[tag_name] = {'n_instances': n_instances, 'precision': precision.tolist(), 'recall': recall.tolist(), + 'thresholds': thresholds.tolist(), 'pr_auc': pr_auc, 'average_precision_val': average_precision_val} + + if len(np.unique(y_true_np[:, tag_idx])) < 2 or len(np.unique(y_pred_np[:, tag_idx])) < 2: + print('For tag {} all instances of y_true: {}'.format(tag_name, np.unique(y_true_np[:, tag_idx]))) + + # Compute ROC curve + fpr, tpr, thresholds = roc_curve(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], pos_label=1) + roc_auc = auc(fpr, tpr) + + J_roc = tpr - fpr + ix_roc = np.argmax(J_roc) + best_thresh_roc = thresholds[ix_roc] + precision_roc_at_best_conf = precision[ix_roc] + recall_roc_at_best_conf = recall[ix_roc] + + y_pred_class_roc = np.where(y_pred_np[:, tag_idx] > best_thresh_roc, 1, 0) + f1_score_roc = f1_score(y_true_np[:, tag_idx], y_pred_class_roc) + accuracy_score_roc = accuracy_score(y_true_np[:, tag_idx], y_pred_class_roc) + + all_category_to_prediction_stats[tag_name].update({'fpr': fpr.tolist(), 'tpr': tpr.tolist(), 'roc_auc': roc_auc}) + + + # don't plot if there are no instances of the tag in the ground truth labels + st = suppress_thresholds[params['label_type']] + if params['label_type'] not in suppress_thresholds: + st = 0 + + if n_instances < st: + tags_not_plotted.append((tag_name, n_instances)) + continue + + # note: this should be done after the suppression part + all_average_precisions.append(average_precision_val) + + # Create a PrecisionRecallDisplay and plot it on the same axis + pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + '\n(n={}, AUC={}, AP={})\n(conf={}, acc={}, f1={})\n(prec={}, rec={})'.format(n_instances, round(pr_auc, 2), round(average_precision_val, 2), round(best_thresh_pr, 2), round(accuracy_score_pr, 2), round(f1_score_pr, 2), round(precision_pr_at_best_conf, 2), round(recall_pr_at_best_conf, 2))) + # pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + '\n(n={}, AUC={}, AP={})'.format(n_instances, round(pr_auc, 2), round(average_precision_val, 2))) + + # Create a RocCurveDisplay and plot it on the same axis + roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot(ax=ax2, name=tag_name + '\n(n={}, AUC={})\n(conf={}, acc={}, f1={}\n(prec={}, rec={}))'.format(n_instances, round(roc_auc, 2), round(best_thresh_roc, 2), round(accuracy_score_roc, 2), round(f1_score_roc, 2), round(precision_roc_at_best_conf, 2), round(recall_roc_at_best_conf, 2))) + + + mean_average_precision = sum(all_average_precisions) / len(all_average_precisions) + + # Add a legend to the plot + legend1 = ax1.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left') + # ax1.add_artist(legend1) + + # # Draw a second legend to show the classes that were not plotted + # patches = [mpatches.Patch(color='none', label=tag_name + ' (n={})'.format(n_instances)) for tag_name, n_instances in + # tags_not_plotted] + # legend2 = plt.legend(handles=patches, title='Classes not plotted', bbox_to_anchor=(1.05, 0), loc='lower left') + # ax1.add_artist(legend2) + + + # Add a legend to the ROC plot + ax2.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left') + + # Set titles for the plots + ax1.set_title('Precision-Recall Curve') + ax2.set_title('ROC Curve') + + # Set the plot title + plot_title_str = 'PR and ROC Curves for label type: ' + params['label_type'] + plot_title_str += '\nTest set size: ' + str(len(images_and_labels)) + ' images' + plot_title_str += '\nModel: ' + ('ViT CLIP Base' if params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP'] else 'DINOv2 Base') + plot_title_str += ' | Train dataset: ' + params['dataset_type'] + + plot_title_str += '\nmAP: ' + str(round(mean_average_precision, 2)) + + # Set title for the figure and save + plt.suptitle(plot_title_str, fontsize=16) + + model_display_name = 'ViT CLIP Base' if params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP'] else 'DINOv2 Base' + + plt.tight_layout() + + pt_model_prefix = params['pretrained_model_prefix'] + inf_set_dir_name = params['inference_set_dir_name'] + dataset_type = params['dataset_type'] + + if suppress_thresholds[params['label_type']] > 0: + plt.savefig(f'{dataset_dir_path}/{dataset_type}-{pt_model_prefix}-pr-roc-curve-{inf_set_dir_name}.png') + else: + plt.savefig(f'{dataset_dir_path}/{dataset_type}-{pt_model_prefix}-pr-roc-curve-{inf_set_dir_name}-all.png') + + plt.show() nc = len(labels_ref_for_run) # number of classes. -classifier = DinoVisionTransformerClassifier("base", nc) +if params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP']: + classifier = CLIP_Classifier("vit_base_patch16_clip_224", nc) +else: + classifier = DinoVisionTransformerClassifier("base", nc) -classifier.load_state_dict( - torch.load('{}/'.format(local_directory) + model_name, map_location=torch.device(device))) +classifier.load_state_dict(torch.load('{}/'.format(local_directory) + model_name, map_location=torch.device(device))) classifier = classifier.to(device) classifier.eval() # runs the inference for all confidence levels -for conf_level in confidence_levels: - inference_on_validation_data(inference_model=classifier, confidence_level=conf_level) -# output_file_name = dataset_dir_path + '/inference-stats' + '-' + inference_set_dir_name + '-masked.json' -output_file_name = dataset_dir_path + '/inference-stats' + '-' + inference_set_dir_name + '.json' +inference_on_validation_data(inference_model=classifier) + +output_file_name = dataset_dir_path + '/' + params['dataset_type'] + '-' + params['pretrained_model_prefix'] + '-inference-stats' + '-' + params['inference_set_dir_name'] + '.json' # save the results to a file with open(output_file_name, 'w') as f: all_stats = { - 'n_incorrect_predictions_to_filenames': all_n_incorrect_predictions_to_filenames, + # 'n_incorrect_predictions_to_filenames': all_n_incorrect_predictions_to_filenames, 'category_to_prediction_stats': all_category_to_prediction_stats, - 'category_to_prediction_details': all_category_to_prediction_details + # 'category_to_prediction_details': all_category_to_prediction_details } f.write(json.dumps(all_stats, indent=4)) - draw_confusion_matrices(all_stats, c12n_category, label_type, dataset_dir_path, inference_set_dir_name) - - for conf_level in all_n_incorrect_predictions_to_filenames: - for key in all_n_incorrect_predictions_to_filenames[conf_level]: - print("Number of incorrect predictions: {} | Count: {} | Confidence level: {}".format(key, len(all_n_incorrect_predictions_to_filenames[conf_level][key]), conf_level)) - print("-------------------") - # print("True Positives:") - # - # for category in category_to_true_positive_counts: - # print(category + ": " + str(len(category_to_true_positive_counts[category]))) - # - # print("-------------------") - # print("False Positives:") - # - # for category in category_to_false_positive_counts: - # print(category + ": " + str(len(category_to_false_positive_counts[category]))) - # - # print("-------------------") - # print("False Negatives:") - # for category in category_to_false_negative_counts: - # print(category + ": " + str(len(category_to_false_negative_counts[category]))) From 878abf3ff15d8745f62014c62c771bc3cb20610f Mon Sep 17 00:00:00 2001 From: hoominchu Date: Fri, 21 Jun 2024 11:37:44 -0700 Subject: [PATCH 02/12] Saving svg version of PR curve only (not including the ROC curve). --- notebooks/validation.py | 38 +++++++++++++++----------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index 688dd7f..ecb203b 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -174,7 +174,7 @@ def get_labels_ref_for_run(inference_set_dir): # ------------------------------ params = { - 'label_type': 'curbramp', + 'label_type': 'crosswalk', 'pretrained_model_prefix': MODEL_PREFIXES['DINO'], 'dataset_type': 'validated', # 'unvalidated' or 'validated' @@ -201,8 +201,7 @@ def get_labels_ref_for_run(inference_set_dir): # This is what DinoV2 sees target_size = (image_dimension, image_dimension) -# temporarily skipping the cities with messy data -# skip_cities = ['cdmx', 'spgg', 'newberg', 'columbus'] +# for temporarily skipping some cities skip_cities = [] dataset_dirname = 'crops-' + params['label_type'] + '-' + params['c12n_category'] # example: crops-surfaceproblem-tags-archive @@ -301,7 +300,6 @@ def data_loader(dir_path, batch_size, imgsz, transform): labels_ref_for_run = get_labels_ref_for_run(inference_dataset_dir) -all_n_incorrect_predictions_to_filenames = {} all_category_to_true_positive_counts = {} all_category_to_false_positive_counts = {} all_category_to_false_negative_counts = {} @@ -351,6 +349,7 @@ def inference_on_validation_data(inference_model): y_pred.append(probabilities.tolist()[0]) + y_true_np = np.array(y_true) y_pred_np = np.array(y_pred) @@ -361,7 +360,7 @@ def inference_on_validation_data(inference_model): # Sort the list based on n_instances tag_to_n_instances.sort(key=lambda x: x[1], reverse=True) - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12)) + fig, ax1 = plt.subplots(1, 1, figsize=(10, 10)) tags_not_plotted = [] @@ -429,32 +428,25 @@ def inference_on_validation_data(inference_model): all_average_precisions.append(average_precision_val) # Create a PrecisionRecallDisplay and plot it on the same axis - pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + '\n(n={}, AUC={}, AP={})\n(conf={}, acc={}, f1={})\n(prec={}, rec={})'.format(n_instances, round(pr_auc, 2), round(average_precision_val, 2), round(best_thresh_pr, 2), round(accuracy_score_pr, 2), round(f1_score_pr, 2), round(precision_pr_at_best_conf, 2), round(recall_pr_at_best_conf, 2))) - # pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + '\n(n={}, AUC={}, AP={})'.format(n_instances, round(pr_auc, 2), round(average_precision_val, 2))) + pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + ' (n={})'.format(n_instances)) + # pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + '\n(n={}, AUC={}, AP={})\n(conf={}, acc={}, f1={})\n(prec={}, rec={})'.format(n_instances, round(pr_auc, 2), round(average_precision_val, 2), round(best_thresh_pr, 2), round(accuracy_score_pr, 2), round(f1_score_pr, 2), round(precision_pr_at_best_conf, 2), round(recall_pr_at_best_conf, 2))) + # Create a RocCurveDisplay and plot it on the same axis - roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot(ax=ax2, name=tag_name + '\n(n={}, AUC={})\n(conf={}, acc={}, f1={}\n(prec={}, rec={}))'.format(n_instances, round(roc_auc, 2), round(best_thresh_roc, 2), round(accuracy_score_roc, 2), round(f1_score_roc, 2), round(precision_roc_at_best_conf, 2), round(recall_roc_at_best_conf, 2))) + # roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot(ax=ax2, name=tag_name + '\n(n={}, AUC={})\n(conf={}, acc={}, f1={}\n(prec={}, rec={}))'.format(n_instances, round(roc_auc, 2), round(best_thresh_roc, 2), round(accuracy_score_roc, 2), round(f1_score_roc, 2), round(precision_roc_at_best_conf, 2), round(recall_roc_at_best_conf, 2))) mean_average_precision = sum(all_average_precisions) / len(all_average_precisions) # Add a legend to the plot - legend1 = ax1.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left') - # ax1.add_artist(legend1) - - # # Draw a second legend to show the classes that were not plotted - # patches = [mpatches.Patch(color='none', label=tag_name + ' (n={})'.format(n_instances)) for tag_name, n_instances in - # tags_not_plotted] - # legend2 = plt.legend(handles=patches, title='Classes not plotted', bbox_to_anchor=(1.05, 0), loc='lower left') - # ax1.add_artist(legend2) - + legend1 = ax1.legend(title='Classes', fontsize='16') # Add a legend to the ROC plot - ax2.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left') - - # Set titles for the plots + # ax2.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left') + # + # # Set titles for the plots ax1.set_title('Precision-Recall Curve') - ax2.set_title('ROC Curve') + # ax2.set_title('ROC Curve') # Set the plot title plot_title_str = 'PR and ROC Curves for label type: ' + params['label_type'] @@ -476,9 +468,9 @@ def inference_on_validation_data(inference_model): dataset_type = params['dataset_type'] if suppress_thresholds[params['label_type']] > 0: - plt.savefig(f'{dataset_dir_path}/{dataset_type}-{pt_model_prefix}-pr-roc-curve-{inf_set_dir_name}.png') + plt.savefig(f'{dataset_dir_path}/{dataset_type}-{pt_model_prefix}-pr-curve-{inf_set_dir_name}.svg') else: - plt.savefig(f'{dataset_dir_path}/{dataset_type}-{pt_model_prefix}-pr-roc-curve-{inf_set_dir_name}-all.png') + plt.savefig(f'{dataset_dir_path}/{dataset_type}-{pt_model_prefix}-pr-curve-{inf_set_dir_name}-all.svg') plt.show() From 01779f62d4513ea248ee77d49aa7cfe2ecabad38 Mon Sep 17 00:00:00 2001 From: hoominchu Date: Wed, 26 Jun 2024 09:53:40 -0700 Subject: [PATCH 03/12] Saving top TP, FP, FN, TNs and drawing only PR curve and not the ROC curve. --- notebooks/validation.py | 134 +++++++++++++++++++++++++++++++++------- 1 file changed, 110 insertions(+), 24 deletions(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index ecb203b..50b5bb4 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -1,4 +1,5 @@ import json +import shutil import torch import os @@ -174,7 +175,7 @@ def get_labels_ref_for_run(inference_set_dir): # ------------------------------ params = { - 'label_type': 'crosswalk', + 'label_type': 'curbramp', 'pretrained_model_prefix': MODEL_PREFIXES['DINO'], 'dataset_type': 'validated', # 'unvalidated' or 'validated' @@ -203,6 +204,7 @@ def get_labels_ref_for_run(inference_set_dir): # for temporarily skipping some cities skip_cities = [] +# skip_cities = ['oradell', 'walla_walla', 'cdmx', 'spgg', 'chicago', 'amsterdam', 'columbus', 'newberg'] dataset_dirname = 'crops-' + params['label_type'] + '-' + params['c12n_category'] # example: crops-surfaceproblem-tags-archive # dataset_dirname = 'crops-' + label_type + '-' + c12n_category + '-validated' # example: crops-surfaceproblem-tags-archive @@ -210,6 +212,9 @@ def get_labels_ref_for_run(inference_set_dir): inference_dataset_dir = Path(dataset_dir_path + "/" + params['inference_set_dir_name']) +# top tp, fp, fn, tn images are saved here +inference_results_dir = Path("../inference-results") + model_name = 'models/' + params['dataset_type'] + '-' + params['pretrained_model_prefix'] + '-cls-b-' + params['label_type'] + '-' + params['c12n_category'] + '-best.pth' # ------------------------------ @@ -300,19 +305,71 @@ def data_loader(dir_path, batch_size, imgsz, transform): labels_ref_for_run = get_labels_ref_for_run(inference_dataset_dir) -all_category_to_true_positive_counts = {} -all_category_to_false_positive_counts = {} -all_category_to_false_negative_counts = {} -all_category_to_true_negative_counts = {} - -all_category_to_prediction_stats = {} -all_category_to_prediction_details = {} - +all_tag_to_prediction_stats = {} +all_tag_to_prediction_details = {} images_and_labels = data_loader(inference_dataset_dir, 1, image_dimension, 'inference') # ----------------------------------------------------------------- + +def copy_top_instances_to_results_dir(tag, tp_filenames_and_conf, fp_filenames_and_conf, fn_filenames_and_conf, tn_filenames_and_conf): + def copy_files(filenames_and_conf, inference_dataset_dir, tag_dir_path, category): + conf_truncate_length = 5 + for i, (fn, conf) in enumerate(filenames_and_conf): + src_file_path = os.path.join(inference_dataset_dir, fn) + truncated_conf = str(conf)[:conf_truncate_length] + dst_file_name = f'{fn.replace(".png", "")}-{truncated_conf}.png' + dst_file_path = os.path.join(tag_dir_path, category, dst_file_name) + shutil.copy2(src_file_path, dst_file_path) + + # create a directory for the label type if it doesn't exist + os.makedirs(os.path.join(inference_results_dir, params['label_type']), exist_ok=True) + + # create directory for model and dataset type if it doesn't exist + model_and_dataset_dir = os.path.join(inference_results_dir, params['label_type'], params['pretrained_model_prefix'] + '-' + params['dataset_type']) + os.makedirs(model_and_dataset_dir, exist_ok=True) + + # create a directory for the tag if it doesn't exist + tag_dir_path = os.path.join(model_and_dataset_dir, tag) + os.makedirs(os.path.join(model_and_dataset_dir, tag), exist_ok=True) + + # create directories for tp, fp, fn, tn if they don't exist + os.makedirs(os.path.join(tag_dir_path, 'tp'), exist_ok=True) + os.makedirs(os.path.join(tag_dir_path, 'fp'), exist_ok=True) + os.makedirs(os.path.join(tag_dir_path, 'fn'), exist_ok=True) + os.makedirs(os.path.join(tag_dir_path, 'tn'), exist_ok=True) + + # clear the directories + for dir_name in ['tp', 'fp', 'fn', 'tn']: + for filename in os.listdir(os.path.join(tag_dir_path, dir_name)): + os.remove(os.path.join(tag_dir_path, dir_name, filename)) + + # copy the top 20 instances to the inference-results directory + copy_files(tp_filenames_and_conf, inference_dataset_dir, tag_dir_path, 'tp') + copy_files(fp_filenames_and_conf, inference_dataset_dir, tag_dir_path, 'fp') + copy_files(fn_filenames_and_conf, inference_dataset_dir, tag_dir_path, 'fn') + copy_files(tn_filenames_and_conf, inference_dataset_dir, tag_dir_path, 'tn') + + +def check_for_mutual_exclusivity_and_total(tp_set, fp_set, fn_set, tn_set, images_and_labels): + if len(tp_set.intersection(fp_set)) > 0: + raise ValueError('TP and FP sets are not mutually exclusive') + if len(tp_set.intersection(fn_set)) > 0: + raise ValueError('TP and FN sets are not mutually exclusive') + if len(tp_set.intersection(tn_set)) > 0: + raise ValueError('TP and TN sets are not mutually exclusive') + if len(fp_set.intersection(fn_set)) > 0: + raise ValueError('FP and FN sets are not mutually exclusive') + if len(fp_set.intersection(tn_set)) > 0: + raise ValueError('FP and TN sets are not mutually exclusive') + if len(fn_set.intersection(tn_set)) > 0: + raise ValueError('FN and TN sets are not mutually exclusive') + + if len(tp_set.union(fp_set).union(fn_set).union(tn_set)) != len(images_and_labels): + raise ValueError('The total of sets doesn\'t match the number of instances in the dataset') + + def inference_on_validation_data(inference_model): y_true = [] @@ -376,14 +433,14 @@ def inference_on_validation_data(inference_model): raise ValueError('What is happening! For tag {} sum of instances: {} and n_instances: {} are not equal'.format(tag_name, sum_gt, n_instances)) # compute precision, recall, thresholds - precision, recall, thresholds = precision_recall_curve(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], pos_label=1) + precision, recall, thresholds_pr = precision_recall_curve(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], pos_label=1) pr_auc = auc(recall, precision) average_precision_val = average_precision_score(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], average='weighted') all_f1_pr = 2 * precision * recall / (precision + recall) ix_pr = np.argmax(all_f1_pr) - best_thresh_pr = thresholds[ix_pr] + best_thresh_pr = thresholds_pr[ix_pr] precision_pr_at_best_conf = precision[ix_pr] recall_pr_at_best_conf = recall[ix_pr] @@ -392,27 +449,27 @@ def inference_on_validation_data(inference_model): accuracy_score_pr = accuracy_score(y_true_np[:, tag_idx], y_pred_class_pr) - all_category_to_prediction_stats[tag_name] = {'n_instances': n_instances, 'precision': precision.tolist(), 'recall': recall.tolist(), - 'thresholds': thresholds.tolist(), 'pr_auc': pr_auc, 'average_precision_val': average_precision_val} + all_tag_to_prediction_stats[tag_name] = {'n_instances': n_instances, 'precision': precision.tolist(), 'recall': recall.tolist(), + 'thresholds': thresholds_pr.tolist(), 'pr_auc': pr_auc, 'average_precision_val': average_precision_val} if len(np.unique(y_true_np[:, tag_idx])) < 2 or len(np.unique(y_pred_np[:, tag_idx])) < 2: print('For tag {} all instances of y_true: {}'.format(tag_name, np.unique(y_true_np[:, tag_idx]))) # Compute ROC curve - fpr, tpr, thresholds = roc_curve(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], pos_label=1) + fpr, tpr, thresholds_roc = roc_curve(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], pos_label=1) roc_auc = auc(fpr, tpr) - J_roc = tpr - fpr - ix_roc = np.argmax(J_roc) - best_thresh_roc = thresholds[ix_roc] - precision_roc_at_best_conf = precision[ix_roc] - recall_roc_at_best_conf = recall[ix_roc] + # J_roc = tpr - fpr + # ix_roc = np.argmax(J_roc) + # best_thresh_roc = thresholds[ix_roc] + # precision_roc_at_best_conf = precision[ix_roc] + # recall_roc_at_best_conf = recall[ix_roc] - y_pred_class_roc = np.where(y_pred_np[:, tag_idx] > best_thresh_roc, 1, 0) - f1_score_roc = f1_score(y_true_np[:, tag_idx], y_pred_class_roc) - accuracy_score_roc = accuracy_score(y_true_np[:, tag_idx], y_pred_class_roc) + # y_pred_class_roc = np.where(y_pred_np[:, tag_idx] > best_thresh_roc, 1, 0) + # f1_score_roc = f1_score(y_true_np[:, tag_idx], y_pred_class_roc) + # accuracy_score_roc = accuracy_score(y_true_np[:, tag_idx], y_pred_class_roc) - all_category_to_prediction_stats[tag_name].update({'fpr': fpr.tolist(), 'tpr': tpr.tolist(), 'roc_auc': roc_auc}) + all_tag_to_prediction_stats[tag_name].update({'fpr': fpr.tolist(), 'tpr': tpr.tolist(), 'roc_auc': roc_auc}) # don't plot if there are no instances of the tag in the ground truth labels @@ -435,6 +492,35 @@ def inference_on_validation_data(inference_model): # Create a RocCurveDisplay and plot it on the same axis # roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot(ax=ax2, name=tag_name + '\n(n={}, AUC={})\n(conf={}, acc={}, f1={}\n(prec={}, rec={}))'.format(n_instances, round(roc_auc, 2), round(best_thresh_roc, 2), round(accuracy_score_roc, 2), round(f1_score_roc, 2), round(precision_roc_at_best_conf, 2), round(recall_roc_at_best_conf, 2))) + tp_indices_for_tag = np.where((y_true_np[:, tag_idx] == 1) & (y_pred_class_pr == 1))[0] + fp_indices_for_tag = np.where((y_true_np[:, tag_idx] == 0) & (y_pred_class_pr == 1))[0] + fn_indices_for_tag = np.where((y_true_np[:, tag_idx] == 1) & (y_pred_class_pr == 0))[0] + tn_indices_for_tag = np.where((y_true_np[:, tag_idx] == 0) & (y_pred_class_pr == 0))[0] + + tp_filenames_and_conf = [(images_and_labels[i][2], y_pred_np[:, tag_idx][i]) for i in tp_indices_for_tag] + fp_filenames_and_conf = [(images_and_labels[i][2], y_pred_np[:, tag_idx][i]) for i in fp_indices_for_tag] + fn_filenames_and_conf = [(images_and_labels[i][2], y_pred_np[:, tag_idx][i]) for i in fn_indices_for_tag] + tn_filenames_and_conf = [(images_and_labels[i][2], y_pred_np[:, tag_idx][i]) for i in tn_indices_for_tag] + + # check if all these sets are mutually exclusive + check_for_mutual_exclusivity_and_total(set(tp_indices_for_tag), set(fp_indices_for_tag), set(fn_indices_for_tag), set(tn_indices_for_tag), images_and_labels) + + # sort the lists by confidence level in descending order + tp_filenames_and_conf.sort(key=lambda x: x[1], reverse=True) + fp_filenames_and_conf.sort(key=lambda x: x[1], reverse=True) + fn_filenames_and_conf.sort(key=lambda x: x[1], reverse=True) + tn_filenames_and_conf.sort(key=lambda x: x[1], reverse=True) + + # get the top 20 instances for each set + tp_filenames_and_conf = tp_filenames_and_conf[:20] + fp_filenames_and_conf = fp_filenames_and_conf[:20] + fn_filenames_and_conf = fn_filenames_and_conf[:20] + tn_filenames_and_conf = tn_filenames_and_conf[:20] + + # copy the crops to the results directory + copy_top_instances_to_results_dir(tag_name, tp_filenames_and_conf, fp_filenames_and_conf, fn_filenames_and_conf, tn_filenames_and_conf) + + all_tag_to_prediction_details[tag_name] = {'tp': tp_filenames_and_conf, 'fp': fp_filenames_and_conf, 'fn': fn_filenames_and_conf} mean_average_precision = sum(all_average_precisions) / len(all_average_precisions) @@ -498,7 +584,7 @@ def inference_on_validation_data(inference_model): all_stats = { # 'n_incorrect_predictions_to_filenames': all_n_incorrect_predictions_to_filenames, - 'category_to_prediction_stats': all_category_to_prediction_stats, + 'category_to_prediction_stats': all_tag_to_prediction_stats, # 'category_to_prediction_details': all_category_to_prediction_details } From 9cc6de8e2215d86a8c34b8af7046a3b8174585d7 Mon Sep 17 00:00:00 2001 From: hoominchu Date: Thu, 27 Jun 2024 14:30:59 -0700 Subject: [PATCH 04/12] Adding micro, macro, weighted averages for F1 score. Removing ROC curve related code. Other. --- notebooks/validation.py | 65 ++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index 50b5bb4..667a61d 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -150,7 +150,10 @@ def get_labels_ref_for_run(inference_set_dir): global c12n_category_offset c12n_category_offset = validated_by_index + 1 - if params['label_type'] == 'obstacle': + # for the CLIP model we don't have the newly added tags e.g. mailbox, seating etc. + # but for the DINO model, trained on the validated data, we do have them in the training data. + # we need to adjust the labels_ref_for_run for the CLIP model + if params['label_type'] == 'obstacle' and params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP']: if len(labels_ref_for_run) == 20: labels_ref_for_run = labels_ref_for_run[:-3] else: @@ -175,7 +178,7 @@ def get_labels_ref_for_run(inference_set_dir): # ------------------------------ params = { - 'label_type': 'curbramp', + 'label_type': 'obstacle', 'pretrained_model_prefix': MODEL_PREFIXES['DINO'], 'dataset_type': 'validated', # 'unvalidated' or 'validated' @@ -254,6 +257,8 @@ def images_loader(dir_path, batch_size, imgsz, transform): labels = [] fs = os.listdir(dir_path) + # ignore the csv file and .DS_Store (if present). this is just the list of images. + fs = [x for x in fs if x.endswith('.png') or x.endswith('.jpg')] count = 0 for filename in fs: @@ -345,7 +350,7 @@ def copy_files(filenames_and_conf, inference_dataset_dir, tag_dir_path, category for filename in os.listdir(os.path.join(tag_dir_path, dir_name)): os.remove(os.path.join(tag_dir_path, dir_name, filename)) - # copy the top 20 instances to the inference-results directory + # copy the top instances to the inference-results directory copy_files(tp_filenames_and_conf, inference_dataset_dir, tag_dir_path, 'tp') copy_files(fp_filenames_and_conf, inference_dataset_dir, tag_dir_path, 'fp') copy_files(fn_filenames_and_conf, inference_dataset_dir, tag_dir_path, 'fn') @@ -406,10 +411,18 @@ def inference_on_validation_data(inference_model): y_pred.append(probabilities.tolist()[0]) - + # IMPORTANT variables y_true_np = np.array(y_true) y_pred_np = np.array(y_pred) + # Convert predicted probabilities to binary predictions + y_pred_binary = (y_pred_np >= 0.5).astype(int) + + # Compute micro-averaged F1 score + f1_micro = f1_score(y_true_np, y_pred_binary, average='micro') + f1_macro = f1_score(y_true_np, y_pred_binary, average='macro') + f1_weighted = f1_score(y_true_np, y_pred_binary, average='weighted') + # Create a list of tuples (tag, precision, recall, n_instances) # sum the columns of y_true_np to get the number of instances in the ground truth labels tag_to_n_instances = [(labels_ref_for_run[i], int(np.sum(y_true_np[:, i])), i) for i in range(len(labels_ref_for_run))] @@ -417,7 +430,7 @@ def inference_on_validation_data(inference_model): # Sort the list based on n_instances tag_to_n_instances.sort(key=lambda x: x[1], reverse=True) - fig, ax1 = plt.subplots(1, 1, figsize=(10, 10)) + fig, ax1 = plt.subplots(1, 1, figsize=(16, 10)) tags_not_plotted = [] @@ -443,11 +456,9 @@ def inference_on_validation_data(inference_model): best_thresh_pr = thresholds_pr[ix_pr] precision_pr_at_best_conf = precision[ix_pr] recall_pr_at_best_conf = recall[ix_pr] + f1_pr_at_best_conf = all_f1_pr[ix_pr] if not np.isnan(all_f1_pr[ix_pr]) else 0 y_pred_class_pr = np.where(y_pred_np[:, tag_idx] > best_thresh_pr, 1, 0) - f1_score_pr = f1_score(y_true_np[:, tag_idx], y_pred_class_pr) - accuracy_score_pr = accuracy_score(y_true_np[:, tag_idx], y_pred_class_pr) - all_tag_to_prediction_stats[tag_name] = {'n_instances': n_instances, 'precision': precision.tolist(), 'recall': recall.tolist(), 'thresholds': thresholds_pr.tolist(), 'pr_auc': pr_auc, 'average_precision_val': average_precision_val} @@ -455,22 +466,6 @@ def inference_on_validation_data(inference_model): if len(np.unique(y_true_np[:, tag_idx])) < 2 or len(np.unique(y_pred_np[:, tag_idx])) < 2: print('For tag {} all instances of y_true: {}'.format(tag_name, np.unique(y_true_np[:, tag_idx]))) - # Compute ROC curve - fpr, tpr, thresholds_roc = roc_curve(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], pos_label=1) - roc_auc = auc(fpr, tpr) - - # J_roc = tpr - fpr - # ix_roc = np.argmax(J_roc) - # best_thresh_roc = thresholds[ix_roc] - # precision_roc_at_best_conf = precision[ix_roc] - # recall_roc_at_best_conf = recall[ix_roc] - - # y_pred_class_roc = np.where(y_pred_np[:, tag_idx] > best_thresh_roc, 1, 0) - # f1_score_roc = f1_score(y_true_np[:, tag_idx], y_pred_class_roc) - # accuracy_score_roc = accuracy_score(y_true_np[:, tag_idx], y_pred_class_roc) - - all_tag_to_prediction_stats[tag_name].update({'fpr': fpr.tolist(), 'tpr': tpr.tolist(), 'roc_auc': roc_auc}) - # don't plot if there are no instances of the tag in the ground truth labels st = suppress_thresholds[params['label_type']] @@ -485,8 +480,8 @@ def inference_on_validation_data(inference_model): all_average_precisions.append(average_precision_val) # Create a PrecisionRecallDisplay and plot it on the same axis - pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + ' (n={})'.format(n_instances)) - # pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + '\n(n={}, AUC={}, AP={})\n(conf={}, acc={}, f1={})\n(prec={}, rec={})'.format(n_instances, round(pr_auc, 2), round(average_precision_val, 2), round(best_thresh_pr, 2), round(accuracy_score_pr, 2), round(f1_score_pr, 2), round(precision_pr_at_best_conf, 2), round(recall_pr_at_best_conf, 2))) + # pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + ' (n={})'.format(n_instances)) + pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + '\n(n={}, AUC={}, AP={})\n(conf={}, f1={})\n(prec={}, rec={})'.format(n_instances, round(pr_auc, 2), round(average_precision_val, 2), round(best_thresh_pr, 2), round(f1_pr_at_best_conf, 2), round(precision_pr_at_best_conf, 2), round(recall_pr_at_best_conf, 2))) # Create a RocCurveDisplay and plot it on the same axis @@ -511,11 +506,12 @@ def inference_on_validation_data(inference_model): fn_filenames_and_conf.sort(key=lambda x: x[1], reverse=True) tn_filenames_and_conf.sort(key=lambda x: x[1], reverse=True) - # get the top 20 instances for each set - tp_filenames_and_conf = tp_filenames_and_conf[:20] - fp_filenames_and_conf = fp_filenames_and_conf[:20] - fn_filenames_and_conf = fn_filenames_and_conf[:20] - tn_filenames_and_conf = tn_filenames_and_conf[:20] + # get the top N instances for each set + N_top_instances = 50 + tp_filenames_and_conf = tp_filenames_and_conf[:N_top_instances] + fp_filenames_and_conf = fp_filenames_and_conf[:N_top_instances] + fn_filenames_and_conf = fn_filenames_and_conf[:N_top_instances] + tn_filenames_and_conf = tn_filenames_and_conf[:N_top_instances] # copy the crops to the results directory copy_top_instances_to_results_dir(tag_name, tp_filenames_and_conf, fp_filenames_and_conf, fn_filenames_and_conf, tn_filenames_and_conf) @@ -525,7 +521,7 @@ def inference_on_validation_data(inference_model): mean_average_precision = sum(all_average_precisions) / len(all_average_precisions) # Add a legend to the plot - legend1 = ax1.legend(title='Classes', fontsize='16') + legend1 = ax1.legend(title='Classes', fontsize='14', bbox_to_anchor=(1.05, 1), loc='upper left') # Add a legend to the ROC plot # ax2.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left') @@ -540,7 +536,10 @@ def inference_on_validation_data(inference_model): plot_title_str += '\nModel: ' + ('ViT CLIP Base' if params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP'] else 'DINOv2 Base') plot_title_str += ' | Train dataset: ' + params['dataset_type'] - plot_title_str += '\nmAP: ' + str(round(mean_average_precision, 2)) + plot_title_str += ('\nmAP: ' + str(round(mean_average_precision, 2)) + + ' | ' + 'Micro F1: ' + str(round(f1_micro, 2)) + + ' | ' + 'Macro F1: ' + str(round(f1_macro, 2)) + + ' | ' + 'Weighted F1: ' + str(round(f1_weighted, 2))) # Set title for the figure and save plt.suptitle(plot_title_str, fontsize=16) From 14be4a2a428407f5d04088fd9719673067ffab49 Mon Sep 17 00:00:00 2001 From: hoominchu Date: Fri, 28 Jun 2024 15:21:36 -0700 Subject: [PATCH 05/12] Updating validation script. --- notebooks/validation.py | 74 +++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 21 deletions(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index 667a61d..ded18b7 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -8,13 +8,12 @@ from PIL import Image from torchvision import datasets, transforms import numpy as np -from torch import nn, optim +from torch import nn from copy import deepcopy import sys from sklearn.metrics import precision_recall_curve, auc, PrecisionRecallDisplay import matplotlib.pyplot as plt -from sklearn.metrics import roc_curve, RocCurveDisplay -from sklearn.metrics import average_precision_score, accuracy_score, hamming_loss, f1_score +from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score import timm sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -178,13 +177,14 @@ def get_labels_ref_for_run(inference_set_dir): # ------------------------------ params = { - 'label_type': 'obstacle', + 'label_type': 'surfaceproblem', 'pretrained_model_prefix': MODEL_PREFIXES['DINO'], 'dataset_type': 'validated', # 'unvalidated' or 'validated' # these don't really change for now 'c12n_category': C12N_CATEGORIES['TAGS'], 'inference_set_dir_name': 'test', + 'min_threshold': 0.3 } # ------------------------------ @@ -375,6 +375,34 @@ def check_for_mutual_exclusivity_and_total(tp_set, fp_set, fn_set, tn_set, image raise ValueError('The total of sets doesn\'t match the number of instances in the dataset') +# computes micro, macro, and weighted f1 scores at the fixed confidence threshold in params +def compute_overall_f1(yt, yp): + # exclude the columns where the true positives are less than 10 + # this is to suppress the tags that are not very common + + y_true_np_selected = yt[:, np.sum(yt, axis=0) >= suppress_thresholds[params['label_type']]] + y_pred_np_selected = yp[:, np.sum(yt, axis=0) >= suppress_thresholds[params['label_type']]] + + # Convert predicted probabilities to binary predictions + y_pred_binary_selected = (y_pred_np_selected >= params['min_threshold']).astype(int) + + # Compute micro-averaged F1 score + f1_micro_selected = f1_score(y_true_np_selected, y_pred_binary_selected, average='micro') + f1_macro_selected = f1_score(y_true_np_selected, y_pred_binary_selected, average='macro') + f1_weighted_selected = f1_score(y_true_np_selected, y_pred_binary_selected, average='weighted') + + return f1_micro_selected, f1_macro_selected, f1_weighted_selected + + +# computes precision, recall, and f1 at the given threshold +def get_precision_recall_f1_at_conf(y_true, y_pred, conf): + y_pred_binary = np.where(y_pred >= conf, 1, 0) + precision = precision_score(y_true, y_pred_binary) + recall = recall_score(y_true, y_pred_binary) + f1 = f1_score(y_true, y_pred_binary) + return precision, recall, f1 + + def inference_on_validation_data(inference_model): y_true = [] @@ -415,13 +443,8 @@ def inference_on_validation_data(inference_model): y_true_np = np.array(y_true) y_pred_np = np.array(y_pred) - # Convert predicted probabilities to binary predictions - y_pred_binary = (y_pred_np >= 0.5).astype(int) - - # Compute micro-averaged F1 score - f1_micro = f1_score(y_true_np, y_pred_binary, average='micro') - f1_macro = f1_score(y_true_np, y_pred_binary, average='macro') - f1_weighted = f1_score(y_true_np, y_pred_binary, average='weighted') + # Compute the overall F1 score. This function internally takes into account the suppress thresholds and leaves out the tags that have less than the threshold count. + f1_micro_selected, f1_macro_selected, f1_weighted_selected = compute_overall_f1(y_true_np, y_pred_np) # Create a list of tuples (tag, precision, recall, n_instances) # sum the columns of y_true_np to get the number of instances in the ground truth labels @@ -430,11 +453,12 @@ def inference_on_validation_data(inference_model): # Sort the list based on n_instances tag_to_n_instances.sort(key=lambda x: x[1], reverse=True) - fig, ax1 = plt.subplots(1, 1, figsize=(16, 10)) + fig, ax1 = plt.subplots(1, 1, figsize=(16, 12)) tags_not_plotted = [] all_average_precisions = [] + all_f1_scores_test = [] for i in range(len(tag_to_n_instances)): @@ -454,11 +478,13 @@ def inference_on_validation_data(inference_model): all_f1_pr = 2 * precision * recall / (precision + recall) ix_pr = np.argmax(all_f1_pr) best_thresh_pr = thresholds_pr[ix_pr] - precision_pr_at_best_conf = precision[ix_pr] - recall_pr_at_best_conf = recall[ix_pr] - f1_pr_at_best_conf = all_f1_pr[ix_pr] if not np.isnan(all_f1_pr[ix_pr]) else 0 - y_pred_class_pr = np.where(y_pred_np[:, tag_idx] > best_thresh_pr, 1, 0) + if best_thresh_pr < params['min_threshold']: + best_thresh_pr = params['min_threshold'] + + precision_at_conf, recall_at_conf, f1_at_conf = get_precision_recall_f1_at_conf(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], best_thresh_pr) + + y_pred_class_pr = np.where(y_pred_np[:, tag_idx] >= best_thresh_pr, 1, 0) all_tag_to_prediction_stats[tag_name] = {'n_instances': n_instances, 'precision': precision.tolist(), 'recall': recall.tolist(), 'thresholds': thresholds_pr.tolist(), 'pr_auc': pr_auc, 'average_precision_val': average_precision_val} @@ -479,9 +505,13 @@ def inference_on_validation_data(inference_model): # note: this should be done after the suppression part all_average_precisions.append(average_precision_val) + # this is just for testing. + y_pred_binary_05 = np.where(y_pred_np[:, tag_idx] > 0.5, 1, 0) + all_f1_scores_test.append(f1_score(y_true_np[:, tag_idx], y_pred_binary_05)) + # Create a PrecisionRecallDisplay and plot it on the same axis # pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + ' (n={})'.format(n_instances)) - pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + '\n(n={}, AUC={}, AP={})\n(conf={}, f1={})\n(prec={}, rec={})'.format(n_instances, round(pr_auc, 2), round(average_precision_val, 2), round(best_thresh_pr, 2), round(f1_pr_at_best_conf, 2), round(precision_pr_at_best_conf, 2), round(recall_pr_at_best_conf, 2))) + pr_display = PrecisionRecallDisplay(precision=precision, recall=recall).plot(ax=ax1, name=tag_name + '\n(n={}, AUC={}, AP={})\n(conf={}, f1={})\n(prec={}, rec={})'.format(n_instances, round(pr_auc, 2), round(average_precision_val, 2), round(best_thresh_pr, 2), round(f1_at_conf, 2), round(precision_at_conf, 2), round(recall_at_conf, 2))) # Create a RocCurveDisplay and plot it on the same axis @@ -519,9 +549,10 @@ def inference_on_validation_data(inference_model): all_tag_to_prediction_details[tag_name] = {'tp': tp_filenames_and_conf, 'fp': fp_filenames_and_conf, 'fn': fn_filenames_and_conf} mean_average_precision = sum(all_average_precisions) / len(all_average_precisions) + manual_average_f1 = sum(all_f1_scores_test) / len(all_f1_scores_test) # Add a legend to the plot - legend1 = ax1.legend(title='Classes', fontsize='14', bbox_to_anchor=(1.05, 1), loc='upper left') + legend1 = ax1.legend(title='Classes', fontsize='12', bbox_to_anchor=(1.05, 1), loc='upper left') # Add a legend to the ROC plot # ax2.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left') @@ -537,9 +568,10 @@ def inference_on_validation_data(inference_model): plot_title_str += ' | Train dataset: ' + params['dataset_type'] plot_title_str += ('\nmAP: ' + str(round(mean_average_precision, 2)) + - ' | ' + 'Micro F1: ' + str(round(f1_micro, 2)) + - ' | ' + 'Macro F1: ' + str(round(f1_macro, 2)) + - ' | ' + 'Weighted F1: ' + str(round(f1_weighted, 2))) + ' | ' + 'Micro F1: ' + str(round(f1_micro_selected, 2)) + + ' | ' + 'Macro F1: ' + str(round(f1_macro_selected, 2)) + + ' | ' + 'Weighted F1: ' + str(round(f1_weighted_selected, 2)) + + ' | ' + 'Manual avg.: ' + str(round(manual_average_f1, 2))) # Set title for the figure and save plt.suptitle(plot_title_str, fontsize=16) From c34f8d4799353892773466b048eb704622475115 Mon Sep 17 00:00:00 2001 From: hoominchu Date: Fri, 28 Jun 2024 15:23:39 -0700 Subject: [PATCH 06/12] Adding one model file --- notebooks/models/validated-dino-cls-b-crosswalk-tags-best.pth | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 notebooks/models/validated-dino-cls-b-crosswalk-tags-best.pth diff --git a/notebooks/models/validated-dino-cls-b-crosswalk-tags-best.pth b/notebooks/models/validated-dino-cls-b-crosswalk-tags-best.pth new file mode 100644 index 0000000..b44c8a2 --- /dev/null +++ b/notebooks/models/validated-dino-cls-b-crosswalk-tags-best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2399313abc4b462fa9a708695f6febd05d0bd199e1f7f715af55ee312e618f55 +size 347204602 From d763ffbeae5effca005177926e3a811f435a53e6 Mon Sep 17 00:00:00 2001 From: hoominchu Date: Fri, 28 Jun 2024 15:46:26 -0700 Subject: [PATCH 07/12] Adding DINOv2 base model --- dinov2_vitb14_reg4_pretrain.pth | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 dinov2_vitb14_reg4_pretrain.pth diff --git a/dinov2_vitb14_reg4_pretrain.pth b/dinov2_vitb14_reg4_pretrain.pth new file mode 100644 index 0000000..a2eb1b9 --- /dev/null +++ b/dinov2_vitb14_reg4_pretrain.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73182a088cf94833c94b1666d1c99e02fe87e2007bff57b564fb6206e25dba71 +size 346393545 From 169421685d6ebab88b089a8a404da00f72506cc7 Mon Sep 17 00:00:00 2001 From: hoominchu Date: Fri, 28 Jun 2024 20:31:25 -0700 Subject: [PATCH 08/12] Fix for zero division error. --- notebooks/validation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index ded18b7..0c107b3 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -177,7 +177,7 @@ def get_labels_ref_for_run(inference_set_dir): # ------------------------------ params = { - 'label_type': 'surfaceproblem', + 'label_type': 'obstacle', 'pretrained_model_prefix': MODEL_PREFIXES['DINO'], 'dataset_type': 'validated', # 'unvalidated' or 'validated' @@ -475,7 +475,8 @@ def inference_on_validation_data(inference_model): average_precision_val = average_precision_score(y_true_np[:, tag_idx], y_pred_np[:, tag_idx], average='weighted') - all_f1_pr = 2 * precision * recall / (precision + recall) + denominator = (precision + recall) + all_f1_pr = np.where(denominator != 0, 2 * precision * recall / denominator, 0) # to avoid zero division error ix_pr = np.argmax(all_f1_pr) best_thresh_pr = thresholds_pr[ix_pr] @@ -506,7 +507,7 @@ def inference_on_validation_data(inference_model): all_average_precisions.append(average_precision_val) # this is just for testing. - y_pred_binary_05 = np.where(y_pred_np[:, tag_idx] > 0.5, 1, 0) + y_pred_binary_05 = np.where(y_pred_np[:, tag_idx] >= params['min_threshold'], 1, 0) all_f1_scores_test.append(f1_score(y_true_np[:, tag_idx], y_pred_binary_05)) # Create a PrecisionRecallDisplay and plot it on the same axis From 160db7207a5685a399046a457985073d4120456d Mon Sep 17 00:00:00 2001 From: hoominchu Date: Fri, 28 Jun 2024 20:59:36 -0700 Subject: [PATCH 09/12] Raise error if disagreed or unsure label found in test set. --- notebooks/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index 0c107b3..729ede9 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -286,9 +286,9 @@ def images_loader(dir_path, batch_size, imgsz, transform): if len(labels_for_image) == 0: continue + # we don't expect to see any 'disagreed' or 'unsure' labels in the test set if labels_for_image['label_type_validation'].values[0] != 'agree': - print('Disagreed or unsure label: ' + filename) - continue + raise ValueError('Disagreed or unsure label: ' + filename) images.append(torch.tensor(np.array([img], dtype=np.float32), requires_grad=True)) labels.append( From 61d4a3338fac7332b4a7a822c22bb75ba7344773 Mon Sep 17 00:00:00 2001 From: hoominchu Date: Fri, 28 Jun 2024 21:39:43 -0700 Subject: [PATCH 10/12] Raise error if image label not found in CSV but image file present in test dir. --- notebooks/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index 729ede9..b8611a4 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -284,7 +284,7 @@ def images_loader(dir_path, batch_size, imgsz, transform): labels_for_image = label_data.query('filename == @filename') if len(labels_for_image) == 0: - continue + raise ValueError('No label found for image: ' + filename) # we don't expect to see any 'disagreed' or 'unsure' labels in the test set if labels_for_image['label_type_validation'].values[0] != 'agree': From 527f825bf87af4954b2f96c58aa2b1de2324e295 Mon Sep 17 00:00:00 2001 From: hoominchu Date: Fri, 28 Jun 2024 23:32:44 -0700 Subject: [PATCH 11/12] Minor. --- notebooks/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index b8611a4..5575cf5 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -538,7 +538,7 @@ def inference_on_validation_data(inference_model): tn_filenames_and_conf.sort(key=lambda x: x[1], reverse=True) # get the top N instances for each set - N_top_instances = 50 + N_top_instances = 30 tp_filenames_and_conf = tp_filenames_and_conf[:N_top_instances] fp_filenames_and_conf = fp_filenames_and_conf[:N_top_instances] fn_filenames_and_conf = fn_filenames_and_conf[:N_top_instances] From 181fec339c5659962bd3e05e11688a703d090b3a Mon Sep 17 00:00:00 2001 From: hoominchu Date: Sat, 29 Jun 2024 13:38:15 -0700 Subject: [PATCH 12/12] Special handling to accommodate the added obstacle tags. Better handling of including the confidence scores in the fp/fn files. Other. --- notebooks/validation.py | 44 ++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/notebooks/validation.py b/notebooks/validation.py index 5575cf5..16c2e62 100644 --- a/notebooks/validation.py +++ b/notebooks/validation.py @@ -1,4 +1,5 @@ import json +import math import shutil import torch @@ -149,10 +150,10 @@ def get_labels_ref_for_run(inference_set_dir): global c12n_category_offset c12n_category_offset = validated_by_index + 1 - # for the CLIP model we don't have the newly added tags e.g. mailbox, seating etc. + # for the unvalidated data model we don't have the newly added tags e.g. mailbox, seating etc. # but for the DINO model, trained on the validated data, we do have them in the training data. # we need to adjust the labels_ref_for_run for the CLIP model - if params['label_type'] == 'obstacle' and params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP']: + if params['label_type'] == 'obstacle' and params['dataset_type'] == 'unvalidated': if len(labels_ref_for_run) == 20: labels_ref_for_run = labels_ref_for_run[:-3] else: @@ -179,7 +180,7 @@ def get_labels_ref_for_run(inference_set_dir): params = { 'label_type': 'obstacle', 'pretrained_model_prefix': MODEL_PREFIXES['DINO'], - 'dataset_type': 'validated', # 'unvalidated' or 'validated' + 'dataset_type': 'unvalidated', # 'unvalidated' or 'validated' # these don't really change for now 'c12n_category': C12N_CATEGORIES['TAGS'], @@ -319,11 +320,18 @@ def data_loader(dir_path, batch_size, imgsz, transform): # ----------------------------------------------------------------- def copy_top_instances_to_results_dir(tag, tp_filenames_and_conf, fp_filenames_and_conf, fn_filenames_and_conf, tn_filenames_and_conf): + def truncate_float(f, n): + return math.trunc(f * 10 ** n) / 10 ** n + def copy_files(filenames_and_conf, inference_dataset_dir, tag_dir_path, category): - conf_truncate_length = 5 + conf_truncate_length = 4 for i, (fn, conf) in enumerate(filenames_and_conf): + + if (conf > 1.0) or (conf < 0.0): + raise ValueError('Confidence value is not in the range [0, 1]') + src_file_path = os.path.join(inference_dataset_dir, fn) - truncated_conf = str(conf)[:conf_truncate_length] + truncated_conf = truncate_float(conf, conf_truncate_length) dst_file_name = f'{fn.replace(".png", "")}-{truncated_conf}.png' dst_file_path = os.path.join(tag_dir_path, category, dst_file_name) shutil.copy2(src_file_path, dst_file_path) @@ -377,7 +385,7 @@ def check_for_mutual_exclusivity_and_total(tp_set, fp_set, fn_set, tn_set, image # computes micro, macro, and weighted f1 scores at the fixed confidence threshold in params def compute_overall_f1(yt, yp): - # exclude the columns where the true positives are less than 10 + # exclude the columns where the ground truth are less than the minimum frequency threshold # this is to suppress the tags that are not very common y_true_np_selected = yt[:, np.sum(yt, axis=0) >= suppress_thresholds[params['label_type']]] @@ -417,7 +425,6 @@ def inference_on_validation_data(inference_model): print('Processing image. Index: {}, Filename: {}'.format(idx, filename)) input_tensor = img_tensor.to(device) - labels_tensor = labels.to(device) y_true.append(labels.tolist()[0]) @@ -443,6 +450,18 @@ def inference_on_validation_data(inference_model): y_true_np = np.array(y_true) y_pred_np = np.array(y_pred) + # ------------------------------ # + + # IMPORTANT NOTE: For obstacle, we added 3 new tags in the validated data. These tags are not present in the unvalidated data. + # The 3 extra tags (mailbox, seating, and uneven-slanted) are in the test set but not in the model and the training set. + # So we need to remove these tags from the validated ground truth test set. + # Again, this is only when the model is trained on unvalidated data and tag type is obstacle. + # A similar adjustment is done in the 'get_labels_ref_for_run' function. + if params['label_type'] == 'obstacle' and params['dataset_type'] == 'unvalidated': + y_true_np = y_true_np[:, :-3] + + # ------------------------------ # + # Compute the overall F1 score. This function internally takes into account the suppress thresholds and leaves out the tags that have less than the threshold count. f1_micro_selected, f1_macro_selected, f1_weighted_selected = compute_overall_f1(y_true_np, y_pred_np) @@ -493,7 +512,6 @@ def inference_on_validation_data(inference_model): if len(np.unique(y_true_np[:, tag_idx])) < 2 or len(np.unique(y_pred_np[:, tag_idx])) < 2: print('For tag {} all instances of y_true: {}'.format(tag_name, np.unique(y_true_np[:, tag_idx]))) - # don't plot if there are no instances of the tag in the ground truth labels st = suppress_thresholds[params['label_type']] if params['label_type'] not in suppress_thresholds: @@ -562,23 +580,25 @@ def inference_on_validation_data(inference_model): ax1.set_title('Precision-Recall Curve') # ax2.set_title('ROC Curve') + model_display_name = 'ViT CLIP Base' if params['pretrained_model_prefix'] == MODEL_PREFIXES[ + 'CLIP'] else 'DINOv2 Base' + # Set the plot title plot_title_str = 'PR and ROC Curves for label type: ' + params['label_type'] plot_title_str += '\nTest set size: ' + str(len(images_and_labels)) + ' images' - plot_title_str += '\nModel: ' + ('ViT CLIP Base' if params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP'] else 'DINOv2 Base') + plot_title_str += '\nModel: ' + model_display_name plot_title_str += ' | Train dataset: ' + params['dataset_type'] plot_title_str += ('\nmAP: ' + str(round(mean_average_precision, 2)) + ' | ' + 'Micro F1: ' + str(round(f1_micro_selected, 2)) + ' | ' + 'Macro F1: ' + str(round(f1_macro_selected, 2)) + ' | ' + 'Weighted F1: ' + str(round(f1_weighted_selected, 2)) + - ' | ' + 'Manual avg.: ' + str(round(manual_average_f1, 2))) + ' | ' + 'Manual avg.: ' + str(round(manual_average_f1, 2)) + + ' | ' + 'Threshold: ' + str(params['min_threshold'])) # Set title for the figure and save plt.suptitle(plot_title_str, fontsize=16) - model_display_name = 'ViT CLIP Base' if params['pretrained_model_prefix'] == MODEL_PREFIXES['CLIP'] else 'DINOv2 Base' - plt.tight_layout() pt_model_prefix = params['pretrained_model_prefix']