From f3b6ba20acefc15ff9af397da4ebe9edcba1f2a2 Mon Sep 17 00:00:00 2001 From: Jacob Kauffmann Date: Tue, 15 Aug 2023 16:07:00 +0200 Subject: [PATCH 1/9] Neuralized K-Means - documentation in numpydoc format - pylint + flake8 stuff - KMeansCanonizer - NeuralizedKMeans layer - LogMeanExpPool layer - Distance layer - Distance type --- src/zennit/canonizers.py | 101 +++++++++++++++++++++++++++++- src/zennit/layer.py | 129 +++++++++++++++++++++++++++++++++++++++ src/zennit/types.py | 9 +++ 3 files changed, 238 insertions(+), 1 deletion(-) diff --git a/src/zennit/canonizers.py b/src/zennit/canonizers.py index fc4a4f2..fce76cc 100644 --- a/src/zennit/canonizers.py +++ b/src/zennit/canonizers.py @@ -18,10 +18,12 @@ '''Functions to produce a canonical form of models fit for LRP''' from abc import ABCMeta, abstractmethod +import copy import torch from .core import collect_leaves -from .types import Linear, BatchNorm, ConvolutionTranspose +from .types import Linear, BatchNorm, ConvolutionTranspose, Distance +from .layer import NeuralizedKMeans, LogMeanExpPool class Canonizer(metaclass=ABCMeta): @@ -329,3 +331,100 @@ def register(self): def remove(self): '''Remove this Canonizer. Nothing to do for a CompositeCanonizer.''' + + +class KMeansCanonizer(Canonizer): + '''Canonizer for k-means. + + This canonizer replaces a :py:obj:`Distance` layer with power 2 with a :py:obj:`NeuralizedKMeans` layer followed by + a :py:obj:`LogMeanExpPool` + + Parameters + ---------- + beta : float + stiffness of the :py:obj:`LogMeanExpPool` layer. Should be smaller than 0 in order to approximate the min + function. Default is -1. + + Examples + -------- + >>> from sklearn.cluster import KMeans + >>> centroids = KMeans(n_clusters=10).fit(X).cluster_centers_ + >>> model = torch.nn.Sequential(Distance(torch.from_numpy(centroids).float(), power=2)) + >>> cluster_assignment = model(x).argmin() + >>> canonizer = KMeansCanonizer(beta=-1.) + >>> with Gradient(model, canonizer=[canonizer]) as attributor: + >>> output, attribution = attributor(x, torch.eye(len(centroids))[[cluster_assignment]]) + ''' + def __init__(self, beta=-1.): + self.distance = None + self.distance_unchanged = None + self.beta = beta + self.parent_module = None + self.child_name = None + + def apply(self, root_module): + '''Apply this canonizer recursively on all applicable modules. + + Iterates over all modules of the root module and applies this canonizer to all :py:obj:`Distance` layers with + power 2. + + Parameters + ---------- + root_module : :py:obj:`torch.nn.Module` + Root module containing a :py:obj:`Distance` layer with power 2 as a submodule. + ''' + instances = [] + + for full_name, module in root_module.named_modules(): + if isinstance(module, Distance) and module.power == 2: + instance = self.copy() + if '.' in full_name: + parent_name, child_name = full_name.rsplit('.', 1) + parent_module = getattr(root_module, parent_name) + else: + parent_module = root_module + child_name = full_name + + instance.parent_module = parent_module + instance.child_name = child_name + + instance.register(module) + instances.append(instance) + + return instances + + def register(self, distance_module): + '''Register the :py:obj:`Distance` layer and replace it with a :py:obj:`NeuralizedKMeans` layer followed by a + :py:obj:`LogMeanExpPool` layer. + + compute :math:`w_{ck} = 2(\\mathbf{\\mu}_c - \\mathbf{\\mu}_k)` and :math:`b_{ck} = \\|\\mathbf{\\mu}_k\\|^2 - + \\|\\mathbf{\\mu}_c\\|^2`. Weights are stored in a tensor :math:`W \\in \\mathbb{R}^{K \\times (K - 1) + \\times D}` and biases in a vector :math:`b \\in \\mathbb{R}^{K \\times (K - 1)}`. + + A :py:obj:`NeuralizedKMeans` layer is created with these weights and biases. The :py:obj:`LogMeanExpPool` layer + is created with the beta value supplied to the constructor. + + Parameters + ---------- + distance_module : list of :py:obj:`Distance` + Distance layers to replace. + ''' + self.distance = distance_module + self.distance_unchanged = copy.deepcopy(self.distance) + + n_clusters, n_dims = self.distance.centroids.shape + mask = ~torch.eye(n_clusters, dtype=bool) + weight = 2 * (self.distance.centroids[:, None, :] - self.distance.centroids[None, :, :]) + weight = weight[mask].reshape(n_clusters, n_clusters - 1, n_dims) + norms = torch.norm(self.distance.centroids, dim=-1) + bias = (norms[None, :]**2 - norms[:, None]**2)[mask].reshape(n_clusters, n_clusters - 1) + setattr(self.parent_module, self.child_name, + torch.nn.Sequential(NeuralizedKMeans(weight, bias), + LogMeanExpPool(self.beta))) + + def remove(self): + """Revert the changes introduced by this canonizer.""" + setattr(self.parent_module, self.child_name, self.distance_unchanged) + + def copy(self): + return KMeansCanonizer(self.beta) diff --git a/src/zennit/layer.py b/src/zennit/layer.py index bd93d90..fd29d82 100644 --- a/src/zennit/layer.py +++ b/src/zennit/layer.py @@ -34,3 +34,132 @@ def __init__(self, dim=-1): def forward(self, input): '''Computes the sum along a dimension.''' return torch.sum(input, dim=self.dim) + + +class Distance(torch.nn.Module): + '''Compute pairwise distances between two sets of points. + + Initialized with a set of centroids, this layer computes the pairwise distance between the input and the centroids. + + Parameters + ---------- + centroids : :py:obj:`torch.Tensor` + shape (K, D) tensor of centroids + power : float + power to raise the distance to + + Examples + -------- + >>> centroids = torch.randn(10, 2) + >>> distance = Distance(centroids) + >>> x = torch.randn(100, 2) + >>> distance(x) + + ''' + def __init__(self, centroids, power=2): + super().__init__() + self.centroids = torch.nn.Parameter(centroids) + self.power = power + + def forward(self, input): + '''Computes the pairwise distance between `input` and `self.centroids` and raises to the power `self.power`. + + Parameters + ---------- + input : :py:obj:`torch.Tensor` + shape (N, D) tensor of points + + Returns + ------- + :py:obj:`torch.Tensor` + shape (N, K) tensor of distances + ''' + distance = torch.cdist(input, self.centroids)**self.power + return distance + + +class NeuralizedKMeans(torch.nn.Module): + '''Compute the k-means discriminants for a set of points. + + Technically, this is a tensor-matrix product with a bias. + + Parameters + ---------- + weight : :py:obj:`torch.Tensor` + shape (K, K-1, D) tensor of weights + bias : :py:obj:`torch.Tensor` + shape (K, K-1) tensor of biases + + Examples + -------- + >>> weight = torch.randn(10, 9, 2) + >>> bias = torch.randn(10, 9) + >>> neuralized_kmeans = NeuralizedKMeans(weight, bias) + + ''' + def __init__(self, weight, bias): + super().__init__() + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) + + def forward(self, x): + '''Computes the tensor-matrix product of `x` and `self.weight` and adds `self.bias`. + + Parameters + ---------- + x : :py:obj:`torch.Tensor` + shape (N, D) tensor of points + + Returns + ------- + :py:obj:`torch.Tensor` + shape (N, K, K-1) tensor of k-means discriminants + ''' + x = torch.einsum('nd,kjd->nkj', x, self.weight) + self.bias + return x + + +class LogMeanExpPool(torch.nn.Module): + '''Computes a log-mean-exp pool along an axis. + + LogMeanExpPool computes :math:`\\frac{1}{\\beta} \\log \\frac{1}{N} \\sum_{i=1}^N \\exp(\\beta x_i)` + + Parameters + ---------- + beta : float + stiffness of the pool. Positive values make the pool more like a max pool, negative values make the pool + more like a min pool. Default value is -1. + dim : int + dimension over which to pool + + Examples + -------- + >>> x = torch.randn(10, 2) + >>> pool = LogMeanExpPool() + >>> pool(x) + + ''' + def __init__(self, beta=1., dim=-1): + super().__init__() + self.dim = dim + self.beta = beta + + def forward(self, input): + '''Computes the LogMeanExpPool of `input`. + + If the input has shape (N1, N2, ..., Nk) and `self.dim` is `j`, then the output has shape + (N1, N2, ..., Nj-1, Nj+1, ..., Nk). + + Parameters + ---------- + input : :py:obj:`torch.Tensor` + the input tensor + + Returns + ------- + :py:obj:`torch.Tensor` + the LogMeanExpPool of `input` + ''' + n_dims = input.shape[self.dim] + return (torch.logsumexp(self.beta * input, dim=self.dim) + - torch.log(torch.tensor(n_dims, dtype=input.dtype))) / self.beta diff --git a/src/zennit/types.py b/src/zennit/types.py index 76cf78e..9641572 100644 --- a/src/zennit/types.py +++ b/src/zennit/types.py @@ -18,6 +18,8 @@ '''Type definitions for convenience.''' import torch +from .layer import Distance as DistanceLayer + class SubclassMeta(type): '''Meta class to bundle multiple subclasses.''' @@ -124,3 +126,10 @@ class Activation(metaclass=SubclassMeta): torch.nn.modules.activation.Tanhshrink, torch.nn.modules.activation.Threshold, ) + + +class Distance(metaclass=SubclassMeta): + '''Abstract base class that describes distance modules.''' + __subclass__ = ( + DistanceLayer, + ) From 077d583754506708cf24bdee0c95d9d79755e0e0 Mon Sep 17 00:00:00 2001 From: Jacob Kauffmann Date: Thu, 17 Aug 2023 18:26:46 +0200 Subject: [PATCH 2/9] Tutorial: Neuralized K-Means - Explaining Deep Cluster Assignments with Neuralized K-Means on Image Data - I tried to adhere to guidelines - That means: random data, random weights - Code for real data and real weights in comments - Runs on colab, did not test blender - also adds the reference to docs/source/tutorial/index.rst --- docs/source/tutorial/deep-kmeans.ipynb | 312 +++++++++++++++++++++++++ docs/source/tutorial/index.rst | 1 + 2 files changed, 313 insertions(+) create mode 100644 docs/source/tutorial/deep-kmeans.ipynb diff --git a/docs/source/tutorial/deep-kmeans.ipynb b/docs/source/tutorial/deep-kmeans.ipynb new file mode 100644 index 0000000..b9ae15f --- /dev/null +++ b/docs/source/tutorial/deep-kmeans.ipynb @@ -0,0 +1,312 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a0661ff4-9f41-405c-8453-f009c31e6a0e", + "metadata": {}, + "source": [ + "## Explaining Deep Cluster Assignments with Neuralized K-Means on Image Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3aef718-d2a0-4f30-9b91-b53f5b288299", + "metadata": {}, + "outputs": [], + "source": [ + "# for colab folks\n", + "# %pip install zennit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa6d0ce7-ea3d-46e5-a8d7-e9a8b31d9239", + "metadata": {}, + "outputs": [], + "source": [ + "# Basic boilerplate code\n", + "from torchvision import datasets, transforms\n", + "from torchvision.models import vgg16\n", + "import torch\n", + "import numpy as np\n", + "\n", + "transform_img = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224)])\n", + "transform_norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n", + "\n", + "transform = transforms.Compose([\n", + " transform_img,\n", + " transforms.ToTensor(),\n", + " transform_norm\n", + "])" + ] + }, + { + "cell_type": "markdown", + "id": "d73397bd-14a2-48ee-8c42-46d6b5104115", + "metadata": {}, + "source": [ + "### Real data and weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5b258b8-c670-473f-858e-2f8464863e29", + "metadata": {}, + "outputs": [], + "source": [ + "# uncomment this cell for an example on real data and real weights\n", + "### Data loading\n", + "# from torch.utils.data import SubsetRandomSampler, DataLoader\n", + "\n", + "# # Attention: the next row downloads a dataset into the current folder!\n", + "# dataset = datasets.Caltech101(root='.', transform=transform, download=True)\n", + "\n", + "# categories = ['cougar_body', 'Leopards', 'wild_cat']\n", + "\n", + "# all_indices = []\n", + "# for category in categories:\n", + "# category_idx = dataset.categories.index(category)\n", + "# category_indices = [i for i, label in enumerate(dataset.y) if label == category_idx]\n", + "\n", + "# num_samples = min(7, len(category_indices))\n", + "\n", + "# selected_indices = np.random.choice(category_indices, num_samples, replace=False)\n", + "# all_indices.extend(selected_indices)\n", + "\n", + "# sampler = SubsetRandomSampler(all_indices)\n", + "# loader = DataLoader(dataset, batch_size=21, sampler=sampler)\n", + "\n", + "# # If this line throws a shape error, just run this cell again (some images in Caltech101 are grayscale)\n", + "# images, labels = next(iter(loader))\n", + "\n", + "### Feature extractor\n", + "# features = vgg16(weights='IMAGENET1K_V1').eval()._modules['features']" + ] + }, + { + "cell_type": "markdown", + "id": "be3a0e0d-afa0-4af1-b8c2-a3f6525dcb03", + "metadata": {}, + "source": [ + "### Random data and weights for online preview" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "736252a7-4445-408c-a946-163ea44903da", + "metadata": {}, + "outputs": [], + "source": [ + "# for zennit contribution guidelines\n", + "# some random data and weights\n", + "images, labels = transform_norm(torch.randn(3, 3, 224, 224).clamp(min=0, max=1)), torch.tensor([0,1,2])\n", + "features = vgg16(weights=None).eval()._modules['features']" + ] + }, + { + "cell_type": "markdown", + "id": "e7f02b4d-1da8-44ea-a887-6413d150b355", + "metadata": {}, + "source": [ + "### The fun begins here\n", + "\n", + "We construct a feature map $\\phi$ from image space to feature space.\n", + "Here, we sum over spatial locations in feature space to get more or less translation invariance in pixel space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eef79eae-f9c7-4b77-8d7c-5edff8e84aeb", + "metadata": {}, + "outputs": [], + "source": [ + "from zennit.layer import Sum\n", + "\n", + "phi = torch.nn.Sequential(\n", + " features,\n", + " Sum((2,3))\n", + ")\n", + "\n", + "Z = phi(images).detach()" + ] + }, + { + "cell_type": "markdown", + "id": "97b43d41-322a-483c-8506-93e3fa0a852d", + "metadata": {}, + "source": [ + "Use simple `scikit-learn.KMeans` on the features:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87c058d4-a3e4-4d29-af50-a7f2235e78c3", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize on class means\n", + "# because we have very few data points here\n", + "centroids = np.stack([Z[labels == y].mean(0) for y in labels.unique()])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "309e1158-de08-4493-af07-32a592622a94", + "metadata": {}, + "outputs": [], + "source": [ + "### uncomment for real fun\n", + "# from sklearn.cluster import KMeans\n", + "# standard_kmeans = KMeans(n_clusters=3, n_init='auto', init=centroids).fit(Z)\n", + "# centroids = standard_kmeans.cluster_centers_" + ] + }, + { + "cell_type": "markdown", + "id": "5d65f068-b651-4f87-81d4-54508b71c841", + "metadata": {}, + "source": [ + "Now build a deep clustering model that takes images as input and predicts the k-means assignments\n", + "\n", + "We also apply a little scaling trick that makes heatmaps nicer, but usually does not change the cluster assignments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce2dbb2a-8a97-488d-9f88-25426881ee10", + "metadata": {}, + "outputs": [], + "source": [ + "from zennit.layer import Distance\n", + "\n", + "# it's not necessary, just looks a bit nicer\n", + "s = ((centroids**2).sum(-1, keepdims=True)**.5)\n", + "s = s / s.mean()\n", + "\n", + "model = torch.nn.Sequential(\n", + " phi,\n", + " Distance(torch.from_numpy(centroids / s).float())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f177bbce-fe8f-46b8-b7a9-b9bfb9048145", + "metadata": {}, + "source": [ + "### Enter zennit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06892de9-0add-448d-8b76-0f6ea3a0ccd7", + "metadata": {}, + "outputs": [], + "source": [ + "# import zennit\n", + "from zennit.attribution import Gradient\n", + "from zennit.composites import EpsilonGammaBox\n", + "from zennit.image import imgify\n", + "from zennit.torchvision import VGGCanonizer\n", + "from zennit.canonizers import KMeansCanonizer\n", + "from zennit.composites import LayerMapComposite, MixedComposite\n", + "from zennit.layer import NeuralizedKMeans\n", + "from zennit.rules import ZPlus, Gamma\n", + "\n", + "def data2img(x):\n", + " return (x.squeeze().permute(1,2,0) * torch.tensor([0.229, 0.224, 0.225])) + torch.tensor([0.485, 0.456, 0.406])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aac5b8af-61cc-400b-a0fc-b036148104ad", + "metadata": {}, + "outputs": [], + "source": [ + "# compute cluster assignments and check if they are equal\n", + "# without the scaling trick above, the are definitely equal (trust me)\n", + "ypred = model(images).argmin(1)\n", + "# assert (ypred.numpy() == standard_kmeans.predict(Z)).all()" + ] + }, + { + "cell_type": "markdown", + "id": "47e38917-b4ee-499f-ba9e-55cce7cb8163", + "metadata": {}, + "source": [ + "### Everything is ready.\n", + "\n", + "You can play around with the `beta` parameter in `KMeansCanonizer` and the `gamma` parameter in `Gamma`.\n", + "\n", + "`beta` is a contrast parameter. Keep `beta < 0`.\n", + "Small negative `beta` can be seen as *one-vs-all* explanation whereas large negative `beta` is more like *one-vs-nearest-competitor*.\n", + "\n", + "The `gamma` parameter controls the contribution of negative weights. Keep `gamma >= 0`.\n", + "In practice, small (positive) `gamma` can result in entirely negative heatmaps. Think of thousand negative weights and a single positive weight. The positive weight could be enough to win the k-means assignment in feature space, but it's lost after a few layers because the graph is flooded with negative contributions.\n", + "\n", + "If you are trying to explain contribution to another cluster (say, $x$ is assigned to cluster $1$, but you want to see if there is some evidence for cluster $2$ in the image), then definitely cramp up `gamma` or even use `ZPlus` instead of `Gamma`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa0f7ca6-3e73-4254-ba31-26a6de28e690", + "metadata": {}, + "outputs": [], + "source": [ + "canonizer = KMeansCanonizer(beta=-1e-12)\n", + "\n", + "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", + "\n", + "composite = MixedComposite([\n", + " EpsilonGammaBox(low=low, high=high, canonizers=[canonizer]),\n", + " LayerMapComposite([(NeuralizedKMeans, Gamma(gamma=1.))])\n", + "])\n", + "\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " for c in range(len(centroids)):\n", + " print(\"Cluster %d\"%c)\n", + " cluster_members = (ypred == c).nonzero()[:,0]\n", + " for i in cluster_members:\n", + " img = images[i].unsqueeze(0)\n", + " target = torch.eye(len(centroids))[[c]]\n", + " output, attribution = attributor(img, target)\n", + " relevance = attribution[0].sum(0)\n", + "\n", + " heatmap = np.array(imgify(relevance, symmetric=True, cmap='seismic').convert('RGB'))\n", + " display(imgify(np.stack([data2img(img).numpy(), heatmap]), grid=(1,2)))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst index d89d94b..111cb1e 100644 --- a/docs/source/tutorial/index.rst +++ b/docs/source/tutorial/index.rst @@ -6,6 +6,7 @@ :maxdepth: 1 image-classification-vgg-resnet + deep-kmeans .. image-segmentation-with-unet text-classification-with-tbd From 465d8bffba46be0b370730c81b7a50191dec1282 Mon Sep 17 00:00:00 2001 From: jackmcrider <756997+jackmcrider@users.noreply.github.com> Date: Thu, 24 Aug 2023 09:45:48 +0200 Subject: [PATCH 3/9] Update src/zennit/layer.py Co-authored-by: Christopher <15217558+chr5tphr@users.noreply.github.com> --- src/zennit/layer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zennit/layer.py b/src/zennit/layer.py index fd29d82..1d536ae 100644 --- a/src/zennit/layer.py +++ b/src/zennit/layer.py @@ -115,8 +115,7 @@ def forward(self, x): :py:obj:`torch.Tensor` shape (N, K, K-1) tensor of k-means discriminants ''' - x = torch.einsum('nd,kjd->nkj', x, self.weight) + self.bias - return x + return torch.einsum('nd,kjd->nkj', x, self.weight) + self.bias class LogMeanExpPool(torch.nn.Module): From f38a66394d2f75598704a453b32e916f2a4b7f2d Mon Sep 17 00:00:00 2001 From: jackmcrider <756997+jackmcrider@users.noreply.github.com> Date: Thu, 24 Aug 2023 09:47:38 +0200 Subject: [PATCH 4/9] Update src/zennit/layer.py change `torch.log(torch.tensor(n_dims, dtype=...))` to `math.log(n_dims)` Co-authored-by: Christopher <15217558+chr5tphr@users.noreply.github.com> --- src/zennit/layer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zennit/layer.py b/src/zennit/layer.py index 1d536ae..f7dbf6a 100644 --- a/src/zennit/layer.py +++ b/src/zennit/layer.py @@ -160,5 +160,4 @@ def forward(self, input): the LogMeanExpPool of `input` ''' n_dims = input.shape[self.dim] - return (torch.logsumexp(self.beta * input, dim=self.dim) - - torch.log(torch.tensor(n_dims, dtype=input.dtype))) / self.beta + return (torch.logsumexp(self.beta * input, dim=self.dim) - math.log(n_dims)) / self.beta From 661d280f3fcaef85208e8342a469a2f7c786286a Mon Sep 17 00:00:00 2001 From: jackmcrider <756997+jackmcrider@users.noreply.github.com> Date: Thu, 24 Aug 2023 10:29:04 +0200 Subject: [PATCH 5/9] Update src/zennit/canonizers.py change `setattr(parent_module, ...)` to `parent_module.add_module(...)` Co-authored-by: Christopher <15217558+chr5tphr@users.noreply.github.com> --- src/zennit/canonizers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/zennit/canonizers.py b/src/zennit/canonizers.py index fce76cc..dd7e9b8 100644 --- a/src/zennit/canonizers.py +++ b/src/zennit/canonizers.py @@ -418,9 +418,10 @@ def register(self, distance_module): weight = weight[mask].reshape(n_clusters, n_clusters - 1, n_dims) norms = torch.norm(self.distance.centroids, dim=-1) bias = (norms[None, :]**2 - norms[:, None]**2)[mask].reshape(n_clusters, n_clusters - 1) - setattr(self.parent_module, self.child_name, - torch.nn.Sequential(NeuralizedKMeans(weight, bias), - LogMeanExpPool(self.beta))) + self.parent_module.add_module( + self.child_name, + torch.nn.Sequential(NeuralizedKMeans(weight, bias), LogMeanExpPool(self.beta)) + ) def remove(self): """Revert the changes introduced by this canonizer.""" From bf3195577cfb78a8d430bf7f29da8b39b96d002b Mon Sep 17 00:00:00 2001 From: jackmcrider <756997+jackmcrider@users.noreply.github.com> Date: Thu, 24 Aug 2023 10:31:02 +0200 Subject: [PATCH 6/9] Update src/zennit/canonizers.py add spaces around binary operators Co-authored-by: Christopher <15217558+chr5tphr@users.noreply.github.com> --- src/zennit/canonizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zennit/canonizers.py b/src/zennit/canonizers.py index dd7e9b8..e0b2c39 100644 --- a/src/zennit/canonizers.py +++ b/src/zennit/canonizers.py @@ -417,7 +417,7 @@ def register(self, distance_module): weight = 2 * (self.distance.centroids[:, None, :] - self.distance.centroids[None, :, :]) weight = weight[mask].reshape(n_clusters, n_clusters - 1, n_dims) norms = torch.norm(self.distance.centroids, dim=-1) - bias = (norms[None, :]**2 - norms[:, None]**2)[mask].reshape(n_clusters, n_clusters - 1) + bias = (norms[None, :] ** 2 - norms[:, None] ** 2)[mask].reshape(n_clusters, n_clusters - 1) self.parent_module.add_module( self.child_name, torch.nn.Sequential(NeuralizedKMeans(weight, bias), LogMeanExpPool(self.beta)) From 3ee16427ec9a5ea285725ed81935e7d3bf349b8c Mon Sep 17 00:00:00 2001 From: Jacob Kauffmann Date: Fri, 8 Sep 2023 10:59:31 +0200 Subject: [PATCH 7/9] Neuralized k-means: MinPool and Backward Hooks via Conv - rename Distance to PairwiseCentroidDistance - remove LogMeanExpPool (might become relevant again, but not for now) - add MinPool1d and MinPool2d in layer.py - add MinTakesMost1d, MaxTakesMost1d, MinTakesMost2d, MaxTakesMost2d rules - largely untested. especially kernel_size as int or kernel_size as tuple - in principle, MaxTakesMost2d should also work for MaxPoll2d layers in standard conv nets - but needs some testing - add abstract TakesMostBase class - remove type definition for Distance in types.py - adapt KMeans canonizer: - replace LogMeanExpPool with MinPool1d followed by torch.nn.Flatten - remove beta parameter; beta is now sit in MinTakesMost1d - remove deepcopy and simply return the module itself - update docs/src/tutorials/deep_kmeans.ipynb - doc strings --- docs/source/tutorial/deep-kmeans.ipynb | 101 ++++++++--------- src/zennit/canonizers.py | 27 +++-- src/zennit/layer.py | 95 +++++++++++----- src/zennit/rules.py | 149 +++++++++++++++++++++++++ src/zennit/types.py | 9 -- 5 files changed, 278 insertions(+), 103 deletions(-) diff --git a/docs/source/tutorial/deep-kmeans.ipynb b/docs/source/tutorial/deep-kmeans.ipynb index b9ae15f..082232e 100644 --- a/docs/source/tutorial/deep-kmeans.ipynb +++ b/docs/source/tutorial/deep-kmeans.ipynb @@ -15,8 +15,10 @@ "metadata": {}, "outputs": [], "source": [ + "dummy = True\n", "# for colab folks\n", - "# %pip install zennit" + "# %pip install zennit\n", + "# dummy = False" ] }, { @@ -47,7 +49,7 @@ "id": "d73397bd-14a2-48ee-8c42-46d6b5104115", "metadata": {}, "source": [ - "### Real data and weights" + "### Data and weights" ] }, { @@ -57,54 +59,38 @@ "metadata": {}, "outputs": [], "source": [ - "# uncomment this cell for an example on real data and real weights\n", - "### Data loading\n", - "# from torch.utils.data import SubsetRandomSampler, DataLoader\n", + "## Data loading\n", + "if dummy:\n", + " images, labels = transform_norm(torch.randn(3, 3, 224, 224).clamp(min=0, max=1)), torch.tensor([0,1,2])\n", + " features = vgg16(weights=None).eval()._modules['features']\n", + "else:\n", + " from torch.utils.data import SubsetRandomSampler, DataLoader\n", "\n", - "# # Attention: the next row downloads a dataset into the current folder!\n", - "# dataset = datasets.Caltech101(root='.', transform=transform, download=True)\n", + " # Attention: the next row downloads a dataset into the current folder!\n", + " dataset = datasets.Caltech101(root='.', transform=transform, download=True)\n", "\n", - "# categories = ['cougar_body', 'Leopards', 'wild_cat']\n", + " categories = ['cougar_body', 'Leopards', 'wild_cat']\n", "\n", - "# all_indices = []\n", - "# for category in categories:\n", - "# category_idx = dataset.categories.index(category)\n", - "# category_indices = [i for i, label in enumerate(dataset.y) if label == category_idx]\n", + " all_indices = []\n", + " for category in categories:\n", + " category_idx = dataset.categories.index(category)\n", + " category_indices = [i for i, label in enumerate(dataset.y) if label == category_idx]\n", "\n", - "# num_samples = min(7, len(category_indices))\n", + " num_samples = min(7, len(category_indices))\n", "\n", - "# selected_indices = np.random.choice(category_indices, num_samples, replace=False)\n", - "# all_indices.extend(selected_indices)\n", + " selected_indices = np.random.choice(category_indices, num_samples, replace=False)\n", + " all_indices.extend(selected_indices)\n", "\n", - "# sampler = SubsetRandomSampler(all_indices)\n", - "# loader = DataLoader(dataset, batch_size=21, sampler=sampler)\n", + " sampler = SubsetRandomSampler(all_indices)\n", + " loader = DataLoader(dataset, batch_size=21, sampler=sampler)\n", "\n", - "# # If this line throws a shape error, just run this cell again (some images in Caltech101 are grayscale)\n", - "# images, labels = next(iter(loader))\n", + " try:\n", + " images, labels = next(iter(loader))\n", + " except Exception as e:\n", + " print(f\"Exception: {e}\\nSimply run the cell again.\")\n", "\n", - "### Feature extractor\n", - "# features = vgg16(weights='IMAGENET1K_V1').eval()._modules['features']" - ] - }, - { - "cell_type": "markdown", - "id": "be3a0e0d-afa0-4af1-b8c2-a3f6525dcb03", - "metadata": {}, - "source": [ - "### Random data and weights for online preview" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "736252a7-4445-408c-a946-163ea44903da", - "metadata": {}, - "outputs": [], - "source": [ - "# for zennit contribution guidelines\n", - "# some random data and weights\n", - "images, labels = transform_norm(torch.randn(3, 3, 224, 224).clamp(min=0, max=1)), torch.tensor([0,1,2])\n", - "features = vgg16(weights=None).eval()._modules['features']" + " ## Feature extractor\n", + " features = vgg16(weights='IMAGENET1K_V1').eval()._modules['features']" ] }, { @@ -162,10 +148,10 @@ "metadata": {}, "outputs": [], "source": [ - "### uncomment for real fun\n", - "# from sklearn.cluster import KMeans\n", - "# standard_kmeans = KMeans(n_clusters=3, n_init='auto', init=centroids).fit(Z)\n", - "# centroids = standard_kmeans.cluster_centers_" + "if not dummy:\n", + " from sklearn.cluster import KMeans\n", + " standard_kmeans = KMeans(n_clusters=3, n_init='auto', init=centroids).fit(Z)\n", + " centroids = standard_kmeans.cluster_centers_" ] }, { @@ -185,7 +171,7 @@ "metadata": {}, "outputs": [], "source": [ - "from zennit.layer import Distance\n", + "from zennit.layer import PairwiseCentroidDistance\n", "\n", "# it's not necessary, just looks a bit nicer\n", "s = ((centroids**2).sum(-1, keepdims=True)**.5)\n", @@ -193,7 +179,7 @@ "\n", "model = torch.nn.Sequential(\n", " phi,\n", - " Distance(torch.from_numpy(centroids / s).float())\n", + " PairwiseCentroidDistance(torch.from_numpy(centroids / s).float())\n", ")" ] }, @@ -219,8 +205,8 @@ "from zennit.torchvision import VGGCanonizer\n", "from zennit.canonizers import KMeansCanonizer\n", "from zennit.composites import LayerMapComposite, MixedComposite\n", - "from zennit.layer import NeuralizedKMeans\n", - "from zennit.rules import ZPlus, Gamma\n", + "from zennit.layer import NeuralizedKMeans, MinPool1d\n", + "from zennit.rules import ZPlus, Gamma, MinTakesMost1d\n", "\n", "def data2img(x):\n", " return (x.squeeze().permute(1,2,0) * torch.tensor([0.229, 0.224, 0.225])) + torch.tensor([0.485, 0.456, 0.406])" @@ -246,7 +232,7 @@ "source": [ "### Everything is ready.\n", "\n", - "You can play around with the `beta` parameter in `KMeansCanonizer` and the `gamma` parameter in `Gamma`.\n", + "You can play around with the `beta` parameter in `MinTakesMost1d` and the `gamma` parameter in `Gamma`.\n", "\n", "`beta` is a contrast parameter. Keep `beta < 0`.\n", "Small negative `beta` can be seen as *one-vs-all* explanation whereas large negative `beta` is more like *one-vs-nearest-competitor*.\n", @@ -261,16 +247,21 @@ "cell_type": "code", "execution_count": null, "id": "aa0f7ca6-3e73-4254-ba31-26a6de28e690", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ - "canonizer = KMeansCanonizer(beta=-1e-12)\n", + "canonizer = KMeansCanonizer()\n", "\n", "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", "\n", "composite = MixedComposite([\n", " EpsilonGammaBox(low=low, high=high, canonizers=[canonizer]),\n", - " LayerMapComposite([(NeuralizedKMeans, Gamma(gamma=1.))])\n", + " LayerMapComposite([\n", + " (NeuralizedKMeans, Gamma(gamma=.0)),\n", + " (MinPool1d, MinTakesMost1d(beta=1e-6))\n", + " ])\n", "])\n", "\n", "with Gradient(model=model, composite=composite) as attributor:\n", @@ -304,7 +295,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/src/zennit/canonizers.py b/src/zennit/canonizers.py index fce76cc..822839c 100644 --- a/src/zennit/canonizers.py +++ b/src/zennit/canonizers.py @@ -18,12 +18,11 @@ '''Functions to produce a canonical form of models fit for LRP''' from abc import ABCMeta, abstractmethod -import copy import torch from .core import collect_leaves -from .types import Linear, BatchNorm, ConvolutionTranspose, Distance -from .layer import NeuralizedKMeans, LogMeanExpPool +from .types import Linear, BatchNorm, ConvolutionTranspose +from .layer import PairwiseCentroidDistance, NeuralizedKMeans, MinPool1d class Canonizer(metaclass=ABCMeta): @@ -351,14 +350,13 @@ class KMeansCanonizer(Canonizer): >>> centroids = KMeans(n_clusters=10).fit(X).cluster_centers_ >>> model = torch.nn.Sequential(Distance(torch.from_numpy(centroids).float(), power=2)) >>> cluster_assignment = model(x).argmin() - >>> canonizer = KMeansCanonizer(beta=-1.) + >>> canonizer = KMeansCanonizer() >>> with Gradient(model, canonizer=[canonizer]) as attributor: >>> output, attribution = attributor(x, torch.eye(len(centroids))[[cluster_assignment]]) ''' - def __init__(self, beta=-1.): + def __init__(self): self.distance = None self.distance_unchanged = None - self.beta = beta self.parent_module = None self.child_name = None @@ -376,7 +374,7 @@ def apply(self, root_module): instances = [] for full_name, module in root_module.named_modules(): - if isinstance(module, Distance) and module.power == 2: + if isinstance(module, PairwiseCentroidDistance) and module.power == 2: instance = self.copy() if '.' in full_name: parent_name, child_name = full_name.rsplit('.', 1) @@ -410,7 +408,6 @@ def register(self, distance_module): Distance layers to replace. ''' self.distance = distance_module - self.distance_unchanged = copy.deepcopy(self.distance) n_clusters, n_dims = self.distance.centroids.shape mask = ~torch.eye(n_clusters, dtype=bool) @@ -420,11 +417,19 @@ def register(self, distance_module): bias = (norms[None, :]**2 - norms[:, None]**2)[mask].reshape(n_clusters, n_clusters - 1) setattr(self.parent_module, self.child_name, torch.nn.Sequential(NeuralizedKMeans(weight, bias), - LogMeanExpPool(self.beta))) + MinPool1d(n_clusters - 1), + torch.nn.Flatten())) def remove(self): """Revert the changes introduced by this canonizer.""" - setattr(self.parent_module, self.child_name, self.distance_unchanged) + setattr(self.parent_module, self.child_name, self.distance) def copy(self): - return KMeansCanonizer(self.beta) + '''Copy this Canonizer. + + Returns + ------- + :py:obj:`Canonizer` + A copy of this Canonizer. + ''' + return KMeansCanonizer() diff --git a/src/zennit/layer.py b/src/zennit/layer.py index fd29d82..bbf2128 100644 --- a/src/zennit/layer.py +++ b/src/zennit/layer.py @@ -36,8 +36,8 @@ def forward(self, input): return torch.sum(input, dim=self.dim) -class Distance(torch.nn.Module): - '''Compute pairwise distances between two sets of points. +class PairwiseCentroidDistance(torch.nn.Module): + '''Compute pairwise distances between inputs and centroids. Initialized with a set of centroids, this layer computes the pairwise distance between the input and the centroids. @@ -51,7 +51,7 @@ class Distance(torch.nn.Module): Examples -------- >>> centroids = torch.randn(10, 2) - >>> distance = Distance(centroids) + >>> distance = PairwiseCentroidDistance(centroids) >>> x = torch.randn(100, 2) >>> distance(x) @@ -74,8 +74,7 @@ def forward(self, input): :py:obj:`torch.Tensor` shape (N, K) tensor of distances ''' - distance = torch.cdist(input, self.centroids)**self.power - return distance + return torch.cdist(input, self.centroids)**self.power class NeuralizedKMeans(torch.nn.Module): @@ -119,36 +118,78 @@ def forward(self, x): return x -class LogMeanExpPool(torch.nn.Module): - '''Computes a log-mean-exp pool along an axis. - - LogMeanExpPool computes :math:`\\frac{1}{\\beta} \\log \\frac{1}{N} \\sum_{i=1}^N \\exp(\\beta x_i)` +class MinPool2d(torch.nn.MaxPool2d): + '''Computes a min pool. Parameters ---------- - beta : float - stiffness of the pool. Positive values make the pool more like a max pool, negative values make the pool - more like a min pool. Default value is -1. - dim : int - dimension over which to pool + kernel_size : int or tuple + size of the pooling window + stride : int or tuple + stride of the pooling operation + padding : int or tuple + zero-padding added to both sides of the input + dilation : int or tuple + spacing between kernel elements + return_indices : bool + if True, will return the max indices along with the outputs + ceil_mode : bool + if True, will use ceil instead of floor to compute the output shape Examples -------- - >>> x = torch.randn(10, 2) - >>> pool = LogMeanExpPool() + >>> pool = MinPool2d(2) + >>> x = torch.randn(1, 1, 4, 4) >>> pool(x) - ''' - def __init__(self, beta=1., dim=-1): - super().__init__() - self.dim = dim - self.beta = beta + def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) def forward(self, input): - '''Computes the LogMeanExpPool of `input`. + '''Computes the min pool of `input`. + + Parameters + ---------- + input : :py:obj:`torch.Tensor` + the input tensor + + Returns + ------- + :py:obj:`torch.Tensor` + the min pool of `input` + ''' + return -super().forward(-input) + + +class MinPool1d(torch.nn.MaxPool1d): + '''Computes a min pool. + + Parameters + ---------- + kernel_size : int or tuple + size of the pooling window + stride : int or tuple + stride of the pooling operation + padding : int or tuple + zero-padding added to both sides of the input + dilation : int or tuple + spacing between kernel elements + return_indices : bool + if True, will return the max indices along with the outputs + ceil_mode : bool + if True, will use ceil instead of floor to compute the output shape - If the input has shape (N1, N2, ..., Nk) and `self.dim` is `j`, then the output has shape - (N1, N2, ..., Nj-1, Nj+1, ..., Nk). + Examples + -------- + >>> pool = MinPool1d(2) + >>> x = torch.randn(1, 1, 4) + >>> pool(x) + ''' + def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) + + def forward(self, input): + '''Computes the min pool of `input`. Parameters ---------- @@ -158,8 +199,6 @@ def forward(self, input): Returns ------- :py:obj:`torch.Tensor` - the LogMeanExpPool of `input` + the min pool of `input` ''' - n_dims = input.shape[self.dim] - return (torch.logsumexp(self.beta * input, dim=self.dim) - - torch.log(torch.tensor(n_dims, dtype=input.dtype))) / self.beta + return -super().forward(-input) diff --git a/src/zennit/rules.py b/src/zennit/rules.py index 4c7554e..c918916 100644 --- a/src/zennit/rules.py +++ b/src/zennit/rules.py @@ -436,3 +436,152 @@ def forward(self, module, input, output): def backward(self, module, grad_input, grad_output): '''Modify ReLU gradient to the smooth softplus gradient :cite:p:`dombrowski2019explanations`.''' return (torch.sigmoid(self.beta_smooth * self.stored_tensors['input'][0]) * grad_output[0],) + + +class TakesMostBase(Hook): + '''Base class for TakesMost rules. + This class provides a common interface for rule variants that utilize a softmax-like weighting of the input + contributions based on their magnitude. + + Parameters + ---------- + beta: float + Beta parameter for controlling the sensitivity of the softmax weighting. + + Methods + ------- + max_fn(input, kernel_size, stride, padding, dilation): + Computes the maximum value in a local window for each entry in the input tensor. + sum_fn(input, kernel, stride, padding, dilation): + Computes the sum of elements in a local window for each entry in the input tensor. + forward(module, input, output): + Stores the input for later use in the backward pass. + backward(module, grad_input, grad_output): + Modifies the gradient based on the softmax weighting of the input contributions. + ''' + def __init__(self, beta=1.0): + super().__init__() + self.beta = beta + self.stored_tensors = {} + + def copy(self): + '''Return a copy of this hook with the same beta parameter.''' + return self.__class__(beta=self.beta) + + def max_fn(self, input, kernel_size, stride, padding, dilation): + raise NotImplementedError("Implement in subclass") + + def sum_fn(self, input, kernel, stride, padding, dilation): + raise NotImplementedError("Implement in subclass") + + def forward(self, module, input, output): + self.stored_tensors['input'] = input + + def backward(self, module, grad_input, grad_output): + '''Modifies the gradient based on the softmax-like weighting of input contributions.''' + stored_input = self.stored_tensors['input'][0] + + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + + # For numerical stability, we subtract the maximum value from the input + max_val = self.max_fn(self.beta * stored_input, kernel_size, stride, padding, dilation) + exp_input = torch.exp(self.beta * stored_input - max_val) + summed_elements = self.sum_fn(exp_input, kernel_size, stride=stride, padding=padding, dilation=dilation) + softmax_output = exp_input / summed_elements + + return (softmax_output * grad_output[0],) + + +class MinTakesMost1d(TakesMostBase): + '''1D variant of TakesMost rule that weights the smallest contributions the most. + This rule is a 1D variant of TakesMostBase, but weights the smallest input contributions the most. + + Methods + ------- + __init__(beta=1.0): + Initializes the MinTakesMost1d class with a negative beta value. + ''' + def __init__(self, beta=1.0): + super().__init__(-beta) + + def max_fn(self, input, kernel_size, stride, padding, dilation): + return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + + def sum_fn(self, input, kernel_size, stride, padding, dilation): + in_channels = input.shape[1] + kernel = torch.ones((in_channels, 1, kernel_size), device=input.device) + return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, + groups=in_channels) + + +class MaxTakesMost1d(TakesMostBase): + '''1D variant of TakesMost rule that weights the largest contributions the most. + This rule is a 1D variant of TakesMostBase, but weights the largest input contributions the most. + + Methods + ------- + __init__(beta=1.0): + Initializes the MaxTakesMost1d class. + ''' + def __init__(self, beta=1.0): + super().__init__(beta) + + def max_fn(self, input, kernel_size, stride, padding, dilation): + return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + + def sum_fn(self, input, kernel_size, stride, padding, dilation): + in_channels = input.shape[1] + kernel = torch.ones((in_channels, 1, kernel_size), device=input.device) + return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, + groups=in_channels) + + +class MinTakesMost2d(TakesMostBase): + '''2D variant of TakesMost rule that weights the smallest contributions the most. + This rule is a 2D variant of TakesMostBase, but weights the smallest input contributions the most. + + Methods + ------- + __init__(beta=1.0): + Initializes the MinTakesMost2d class with a negative beta value. + ''' + def __init__(self, beta=1.0): + super().__init__(-beta) + + def max_fn(self, input, kernel_size, stride, padding, dilation): + return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + + def sum_fn(self, input, kernel_size, stride, padding, dilation): + in_channels = input.shape[1] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + kernel = torch.ones((in_channels, 1, *kernel_size), device=input.device) + return torch.nn.functional.conv2d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, + groups=in_channels) + + +class MaxTakesMost2d(TakesMostBase): + '''2D variant of TakesMost rule that weights the largest contributions the most. + This rule is a 2D variant of TakesMostBase, but weights the largest input contributions the most. + + Methods + ------- + __init__(beta=1.0): + Initializes the MaxTakesMost2d class. + ''' + def __init__(self, beta=1.0): + super().__init__(beta) + + def max_fn(self, input, kernel_size, stride, padding, dilation): + return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + + def sum_fn(self, input, kernel_size, stride, padding, dilation): + in_channels = input.shape[1] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + kernel = torch.ones((in_channels, 1, *kernel_size), device=input.device) + return torch.nn.functional.conv2d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, + groups=in_channels) diff --git a/src/zennit/types.py b/src/zennit/types.py index 9641572..76cf78e 100644 --- a/src/zennit/types.py +++ b/src/zennit/types.py @@ -18,8 +18,6 @@ '''Type definitions for convenience.''' import torch -from .layer import Distance as DistanceLayer - class SubclassMeta(type): '''Meta class to bundle multiple subclasses.''' @@ -126,10 +124,3 @@ class Activation(metaclass=SubclassMeta): torch.nn.modules.activation.Tanhshrink, torch.nn.modules.activation.Threshold, ) - - -class Distance(metaclass=SubclassMeta): - '''Abstract base class that describes distance modules.''' - __subclass__ = ( - DistanceLayer, - ) From 7fb44ccbbf7eba740c30e5e27dea3c4f1f82b5d2 Mon Sep 17 00:00:00 2001 From: Jacob Kauffmann Date: Fri, 8 Sep 2023 15:50:35 +0200 Subject: [PATCH 8/9] tox compliance - various non-functional changes --- src/zennit/layer.py | 6 ------ src/zennit/rules.py | 19 ++++++++++++------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/zennit/layer.py b/src/zennit/layer.py index 6b52539..301c402 100644 --- a/src/zennit/layer.py +++ b/src/zennit/layer.py @@ -141,9 +141,6 @@ class MinPool2d(torch.nn.MaxPool2d): >>> x = torch.randn(1, 1, 4, 4) >>> pool(x) ''' - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): - super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) - def forward(self, input): '''Computes the min pool of `input`. @@ -184,9 +181,6 @@ class MinPool1d(torch.nn.MaxPool1d): >>> x = torch.randn(1, 1, 4) >>> pool(x) ''' - def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): - super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) - def forward(self, input): '''Computes the min pool of `input`. diff --git a/src/zennit/rules.py b/src/zennit/rules.py index c918916..065ca50 100644 --- a/src/zennit/rules.py +++ b/src/zennit/rules.py @@ -469,12 +469,15 @@ def copy(self): return self.__class__(beta=self.beta) def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' raise NotImplementedError("Implement in subclass") - def sum_fn(self, input, kernel, stride, padding, dilation): + def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' raise NotImplementedError("Implement in subclass") def forward(self, module, input, output): + '''Stores the input for later use in the backward pass.''' self.stored_tensors['input'] = input def backward(self, module, grad_input, grad_output): @@ -508,9 +511,11 @@ def __init__(self, beta=1.0): super().__init__(-beta) def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' in_channels = input.shape[1] kernel = torch.ones((in_channels, 1, kernel_size), device=input.device) return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, @@ -526,13 +531,12 @@ class MaxTakesMost1d(TakesMostBase): __init__(beta=1.0): Initializes the MaxTakesMost1d class. ''' - def __init__(self, beta=1.0): - super().__init__(beta) - def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' in_channels = input.shape[1] kernel = torch.ones((in_channels, 1, kernel_size), device=input.device) return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, @@ -552,9 +556,11 @@ def __init__(self, beta=1.0): super().__init__(-beta) def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' in_channels = input.shape[1] if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) @@ -572,13 +578,12 @@ class MaxTakesMost2d(TakesMostBase): __init__(beta=1.0): Initializes the MaxTakesMost2d class. ''' - def __init__(self, beta=1.0): - super().__init__(beta) - def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' in_channels = input.shape[1] if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) From b042aae4556c7f2c451c60608a03d131fdd8b9a0 Mon Sep 17 00:00:00 2001 From: Jacob Kauffmann Date: Sat, 11 Nov 2023 00:55:47 +0100 Subject: [PATCH 9/9] max takes most fix --- src/zennit/rules.py | 47 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/src/zennit/rules.py b/src/zennit/rules.py index 065ca50..0d29f6a 100644 --- a/src/zennit/rules.py +++ b/src/zennit/rules.py @@ -580,13 +580,54 @@ class MaxTakesMost2d(TakesMostBase): ''' def max_fn(self, input, kernel_size, stride, padding, dilation): '''Computes the maximum value in a local window for each entry in the input tensor.''' - return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + # return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + return input.max().view(1,1,1,1) def sum_fn(self, input, kernel_size, stride, padding, dilation): '''Computes the sum of elements in a local window for each entry in the input tensor.''' in_channels = input.shape[1] if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride) kernel = torch.ones((in_channels, 1, *kernel_size), device=input.device) - return torch.nn.functional.conv2d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, - groups=in_channels) + summed_tensor = torch.nn.functional.conv2d(input, weight=kernel, stride=stride, padding=padding, + dilation=dilation, groups=in_channels) + expanded_sum = torch.nn.functional.conv_transpose2d(summed_tensor, weight=kernel, stride=stride, + padding=padding, dilation=dilation, groups=in_channels) + pad_height = input.shape[2] - expanded_sum.shape[2] + pad_width = input.shape[3] - expanded_sum.shape[3] + if pad_height > 0 or pad_width > 0: + expanded_sum = torch.nn.functional.pad(expanded_sum, (0, pad_width, 0, pad_height)) + + return expanded_sum + + def backward(self, module, grad_input, grad_output): + '''Modifies the gradient based on the softmax-like weighting of input contributions.''' + stored_input = self.stored_tensors['input'][0] + if isinstance(module.stride, int): + stride = (module.stride, module.stride) + else: + stride = module.stride + + kernel_size = module.kernel_size + padding = module.padding + dilation = module.dilation + + max_val = self.max_fn(self.beta * stored_input, kernel_size, stride, padding, dilation) + exp_input = torch.exp(self.beta * stored_input - max_val) + summed_elements = self.sum_fn(exp_input, kernel_size, stride=stride, padding=padding, dilation=dilation) + softmax_output = exp_input / summed_elements + softmax_output[summed_elements == 0] = 0 + + in_channels = stored_input.shape[1] + kernel = torch.ones((in_channels, 1, kernel_size, kernel_size), device=stored_input.device) + expanded_grad_output = torch.nn.functional.conv_transpose2d(grad_output[0], weight=kernel, stride=stride, + padding=padding, dilation=dilation, + groups=in_channels) + pad_height = stored_input.shape[2] - expanded_grad_output.shape[2] + pad_width = stored_input.shape[3] - expanded_grad_output.shape[3] + if pad_height > 0 or pad_width > 0: + expanded_grad_output = torch.nn.functional.pad(expanded_grad_output, (0, pad_width, 0, pad_height)) + + return (softmax_output * expanded_grad_output,)