Skip to content

Commit

Permalink
Merge pull request #260 from mlopezantequera/fix_custom_label_compari…
Browse files Browse the repository at this point in the history
…son_fn

Custom comparison fn plays well with embeddings_come_from_same_source
  • Loading branch information
Kevin Musgrave authored Jan 12, 2021
2 parents 8fac226 + 026145a commit 62d6ad9
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 17 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setuptools.setup(
name="pytorch-metric-learning",
version="0.9.96.dev3",
version="0.9.96",
author="Kevin Musgrave",
author_email="[email protected]",
description="The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.96.dev3"
__version__ = "0.9.96"
13 changes: 8 additions & 5 deletions src/pytorch_metric_learning/utils/accuracy_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ def get_label_match_counts(query_labels, reference_labels, label_comparison_fn):
# Labels are compared with a custom function.
# They might be non-categorical or multidimensional labels.
match_counts = np.array([0 for _ in unique_query_labels])
for ix_a, label_a in enumerate(unique_query_labels):
for label_b in reference_labels:
if label_comparison_fn(label_a[None, :], label_b[None, :]):
for ix_a in range(len(unique_query_labels)):
label_a = unique_query_labels[ix_a : ix_a + 1]
for ix_b in range(len(reference_labels)):
label_b = reference_labels[ix_b : ix_b + 1]
if label_comparison_fn(label_a, label_b):
match_counts[ix_a] += 1

# faiss can only do a max of k=1024, and we have to do k+1
Expand All @@ -151,8 +153,9 @@ def get_lone_query_labels(
label_comparison_fn,
):
unique_labels, match_counts = label_counts
if label_comparison_fn is EQUALITY and embeddings_come_from_same_source:
lone_condition = match_counts <= 1
label_matches_itself = label_comparison_fn(unique_labels, unique_labels)
if embeddings_come_from_same_source:
lone_condition = match_counts - label_matches_itself <= 0
else:
lone_condition = match_counts == 0
lone_query_labels = unique_labels[lone_condition]
Expand Down
109 changes: 99 additions & 10 deletions tests/utils/test_calculate_accuracies.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,81 @@ def correct_mean_average_precision(
else:
return np.mean([(acc0 + acc1) / 2, acc2, acc3, acc4])

def test_get_lone_query_labels_custom(self):
def fn1(x, y):
return abs(x - y) < 2

def fn2(x, y):
return abs(x - y) > 99

query_labels = np.array([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100])

for comparison_fn in [fn1, fn2]:
correct_unique_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100])

if comparison_fn is fn1:
correct_counts = np.array([3, 4, 3, 3, 3, 3, 3, 3, 3, 2, 1])
correct_lone_query_labels = np.array([100])
correct_not_lone_query_mask = np.array(
[
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
]
)
elif comparison_fn is fn2:
correct_counts = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2])
correct_lone_query_labels = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
correct_not_lone_query_mask = np.array(
[
True,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
]
)

label_counts, num_k = accuracy_calculator.get_label_match_counts(
query_labels,
query_labels,
comparison_fn,
)
unique_labels, counts = label_counts

self.assertTrue(np.all(unique_labels == correct_unique_labels))
self.assertTrue(np.all(counts == correct_counts))

(
lone_query_labels,
not_lone_query_mask,
) = accuracy_calculator.get_lone_query_labels(
query_labels, label_counts, True, comparison_fn
)

self.assertTrue(np.all(lone_query_labels == correct_lone_query_labels))
self.assertTrue(np.all(not_lone_query_mask == correct_not_lone_query_mask))

def test_get_lone_query_labels_multi_dim(self):
def equality2D(x, y):
return (x[..., 0] == y[..., 0]) & (x[..., 1] == y[..., 1])

def custom_label_comparison_fn(x, y):
return (x[..., 0] == y[..., 0]) & (x[..., 1] != y[..., 1])

Expand All @@ -152,38 +226,53 @@ def custom_label_comparison_fn(x, y):
(0, 3),
(0, 3),
(0, 3),
(0, 2),
(1, 2),
(4, 5),
]
)

for comparison_fn in [accuracy_calculator.EQUALITY, custom_label_comparison_fn]:
for comparison_fn in [equality2D, custom_label_comparison_fn]:
label_counts, num_k = accuracy_calculator.get_label_match_counts(
query_labels,
query_labels,
comparison_fn,
)

if comparison_fn is accuracy_calculator.EQUALITY:
unique_labels, counts = label_counts
correct_unique_labels = np.array([[0, 3], [1, 2], [1, 3], [4, 5]])
if comparison_fn is equality2D:
correct_counts = np.array([3, 1, 1, 1])
else:
correct_counts = np.array([0, 1, 1, 0])

self.assertTrue(np.all(correct_counts == counts))
self.assertTrue(np.all(correct_unique_labels == unique_labels))

if comparison_fn is equality2D:
correct = [
(
True,
np.array([[0, 2], [1, 2], [1, 3], [4, 5]]),
np.array([False, True, True, True, False, False, False]),
np.array([[1, 2], [1, 3], [4, 5]]),
np.array([False, True, True, True, False, False]),
),
(
False,
np.array([[]]),
np.array([True, True, True, True, True, True, True]),
np.array([True, True, True, True, True, True]),
),
]
else:
correct_lone = np.array([[4, 5]])
correct_mask = np.array([True, True, True, True, True, True, False])
correct = [
(True, correct_lone, correct_mask),
(False, correct_lone, correct_mask),
(
True,
np.array([[0, 3], [4, 5]]),
np.array([True, False, False, False, True, False]),
),
(
False,
np.array([[0, 3], [4, 5]]),
np.array([True, False, False, False, True, False]),
),
]

for same_source, correct_lone, correct_mask in correct:
Expand Down

0 comments on commit 62d6ad9

Please sign in to comment.