diff --git a/setup.py b/setup.py index c7b4603d..46bf05f4 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name="pytorch-metric-learning", - version="0.9.96.dev3", + version="0.9.96", author="Kevin Musgrave", author_email="tkm45@cornell.edu", description="The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.", diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index e7e9395c..02784fbc 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "0.9.96.dev3" +__version__ = "0.9.96" diff --git a/src/pytorch_metric_learning/utils/accuracy_calculator.py b/src/pytorch_metric_learning/utils/accuracy_calculator.py index 80bb455d..920cb278 100644 --- a/src/pytorch_metric_learning/utils/accuracy_calculator.py +++ b/src/pytorch_metric_learning/utils/accuracy_calculator.py @@ -133,9 +133,11 @@ def get_label_match_counts(query_labels, reference_labels, label_comparison_fn): # Labels are compared with a custom function. # They might be non-categorical or multidimensional labels. match_counts = np.array([0 for _ in unique_query_labels]) - for ix_a, label_a in enumerate(unique_query_labels): - for label_b in reference_labels: - if label_comparison_fn(label_a[None, :], label_b[None, :]): + for ix_a in range(len(unique_query_labels)): + label_a = unique_query_labels[ix_a : ix_a + 1] + for ix_b in range(len(reference_labels)): + label_b = reference_labels[ix_b : ix_b + 1] + if label_comparison_fn(label_a, label_b): match_counts[ix_a] += 1 # faiss can only do a max of k=1024, and we have to do k+1 @@ -151,8 +153,9 @@ def get_lone_query_labels( label_comparison_fn, ): unique_labels, match_counts = label_counts - if label_comparison_fn is EQUALITY and embeddings_come_from_same_source: - lone_condition = match_counts <= 1 + label_matches_itself = label_comparison_fn(unique_labels, unique_labels) + if embeddings_come_from_same_source: + lone_condition = match_counts - label_matches_itself <= 0 else: lone_condition = match_counts == 0 lone_query_labels = unique_labels[lone_condition] diff --git a/tests/utils/test_calculate_accuracies.py b/tests/utils/test_calculate_accuracies.py index 1bd222a0..d4a4e200 100644 --- a/tests/utils/test_calculate_accuracies.py +++ b/tests/utils/test_calculate_accuracies.py @@ -142,7 +142,81 @@ def correct_mean_average_precision( else: return np.mean([(acc0 + acc1) / 2, acc2, acc3, acc4]) + def test_get_lone_query_labels_custom(self): + def fn1(x, y): + return abs(x - y) < 2 + + def fn2(x, y): + return abs(x - y) > 99 + + query_labels = np.array([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) + + for comparison_fn in [fn1, fn2]: + correct_unique_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100]) + + if comparison_fn is fn1: + correct_counts = np.array([3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1]) + correct_lone_query_labels = np.array([100]) + correct_not_lone_query_mask = np.array( + [ + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + ] + ) + elif comparison_fn is fn2: + correct_counts = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]) + correct_lone_query_labels = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + correct_not_lone_query_mask = np.array( + [ + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + ] + ) + + label_counts, num_k = accuracy_calculator.get_label_match_counts( + query_labels, + query_labels, + comparison_fn, + ) + unique_labels, counts = label_counts + + self.assertTrue(np.all(unique_labels == correct_unique_labels)) + self.assertTrue(np.all(counts == correct_counts)) + + ( + lone_query_labels, + not_lone_query_mask, + ) = accuracy_calculator.get_lone_query_labels( + query_labels, label_counts, True, comparison_fn + ) + + self.assertTrue(np.all(lone_query_labels == correct_lone_query_labels)) + self.assertTrue(np.all(not_lone_query_mask == correct_not_lone_query_mask)) + def test_get_lone_query_labels_multi_dim(self): + def equality2D(x, y): + return (x[..., 0] == y[..., 0]) & (x[..., 1] == y[..., 1]) + def custom_label_comparison_fn(x, y): return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1]) @@ -152,38 +226,53 @@ def custom_label_comparison_fn(x, y): (0, 3), (0, 3), (0, 3), - (0, 2), (1, 2), (4, 5), ] ) - for comparison_fn in [accuracy_calculator.EQUALITY, custom_label_comparison_fn]: + for comparison_fn in [equality2D, custom_label_comparison_fn]: label_counts, num_k = accuracy_calculator.get_label_match_counts( query_labels, query_labels, comparison_fn, ) - if comparison_fn is accuracy_calculator.EQUALITY: + unique_labels, counts = label_counts + correct_unique_labels = np.array([[0, 3], [1, 2], [1, 3], [4, 5]]) + if comparison_fn is equality2D: + correct_counts = np.array([3, 1, 1, 1]) + else: + correct_counts = np.array([0, 1, 1, 0]) + + self.assertTrue(np.all(correct_counts == counts)) + self.assertTrue(np.all(correct_unique_labels == unique_labels)) + + if comparison_fn is equality2D: correct = [ ( True, - np.array([[0, 2], [1, 2], [1, 3], [4, 5]]), - np.array([False, True, True, True, False, False, False]), + np.array([[1, 2], [1, 3], [4, 5]]), + np.array([False, True, True, True, False, False]), ), ( False, np.array([[]]), - np.array([True, True, True, True, True, True, True]), + np.array([True, True, True, True, True, True]), ), ] else: - correct_lone = np.array([[4, 5]]) - correct_mask = np.array([True, True, True, True, True, True, False]) correct = [ - (True, correct_lone, correct_mask), - (False, correct_lone, correct_mask), + ( + True, + np.array([[0, 3], [4, 5]]), + np.array([True, False, False, False, True, False]), + ), + ( + False, + np.array([[0, 3], [4, 5]]), + np.array([True, False, False, False, True, False]), + ), ] for same_source, correct_lone, correct_mask in correct: