Skip to content

Commit

Permalink
Merge branch 'master' of github.com:jackmcrider/zennit
Browse files Browse the repository at this point in the history
- merge changes coming from github web interface
  • Loading branch information
jacobkauffmann committed Sep 8, 2023
2 parents 3ee1642 + bf31955 commit 6a55fde
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
17 changes: 5 additions & 12 deletions src/zennit/canonizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,22 +414,15 @@ def register(self, distance_module):
weight = 2 * (self.distance.centroids[:, None, :] - self.distance.centroids[None, :, :])
weight = weight[mask].reshape(n_clusters, n_clusters - 1, n_dims)
norms = torch.norm(self.distance.centroids, dim=-1)
bias = (norms[None, :]**2 - norms[:, None]**2)[mask].reshape(n_clusters, n_clusters - 1)
setattr(self.parent_module, self.child_name,
torch.nn.Sequential(NeuralizedKMeans(weight, bias),
MinPool1d(n_clusters - 1),
torch.nn.Flatten()))
bias = (norms[None, :] ** 2 - norms[:, None] ** 2)[mask].reshape(n_clusters, n_clusters - 1)
self.parent_module.add_module(
self.child_name,
torch.nn.Sequential(NeuralizedKMeans(weight, bias), MinPool1d(n_clusters - 1), torch.nn.Flatten())
)

def remove(self):
"""Revert the changes introduced by this canonizer."""
setattr(self.parent_module, self.child_name, self.distance)

def copy(self):
'''Copy this Canonizer.
Returns
-------
:py:obj:`Canonizer`
A copy of this Canonizer.
'''
return KMeansCanonizer()
3 changes: 1 addition & 2 deletions src/zennit/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def forward(self, x):
:py:obj:`torch.Tensor`
shape (N, K, K-1) tensor of k-means discriminants
'''
x = torch.einsum('nd,kjd->nkj', x, self.weight) + self.bias
return x
return torch.einsum('nd,kjd->nkj', x, self.weight) + self.bias


class MinPool2d(torch.nn.MaxPool2d):
Expand Down

0 comments on commit 6a55fde

Please sign in to comment.