Skip to content

Commit

Permalink
Merge pull request #621 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v2.1.1
  • Loading branch information
Kevin Musgrave authored May 3, 2023
2 parents d9cd2ae + ad2e8b5 commit c57ebdd
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.1.0"
__version__ = "2.1.1"
4 changes: 3 additions & 1 deletion src/pytorch_metric_learning/distances/base_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def set_default_stats(
):
if self.collect_stats:
with torch.no_grad():
self.initial_avg_query_norm: torch.mean(self.get_norm(query_emb)).item()
self.initial_avg_query_norm = torch.mean(
self.get_norm(query_emb)
).item()
self.initial_avg_ref_norm = torch.mean(self.get_norm(ref_emb)).item()
self.final_avg_query_norm = torch.mean(
self.get_norm(query_emb_normalized)
Expand Down
20 changes: 20 additions & 0 deletions tests/distances/test_collected_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest

import torch

from pytorch_metric_learning.distances import LpDistance

from .. import WITH_COLLECT_STATS


class TestCollectedStats(unittest.TestCase):
@unittest.skipUnless(WITH_COLLECT_STATS, "WITH_COLLECT_STATS is false")
def test_collected_stats(self):
x = torch.randn(32, 128)
d = LpDistance()
d(x)

self.assertNotEqual(d.initial_avg_query_norm, 0)
self.assertNotEqual(d.initial_avg_ref_norm, 0)
self.assertNotEqual(d.final_avg_query_norm, 0)
self.assertNotEqual(d.final_avg_ref_norm, 0)

0 comments on commit c57ebdd

Please sign in to comment.