diff --git a/docs/source/tutorial/deep-kmeans.ipynb b/docs/source/tutorial/deep-kmeans.ipynb new file mode 100644 index 0000000..082232e --- /dev/null +++ b/docs/source/tutorial/deep-kmeans.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a0661ff4-9f41-405c-8453-f009c31e6a0e", + "metadata": {}, + "source": [ + "## Explaining Deep Cluster Assignments with Neuralized K-Means on Image Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3aef718-d2a0-4f30-9b91-b53f5b288299", + "metadata": {}, + "outputs": [], + "source": [ + "dummy = True\n", + "# for colab folks\n", + "# %pip install zennit\n", + "# dummy = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa6d0ce7-ea3d-46e5-a8d7-e9a8b31d9239", + "metadata": {}, + "outputs": [], + "source": [ + "# Basic boilerplate code\n", + "from torchvision import datasets, transforms\n", + "from torchvision.models import vgg16\n", + "import torch\n", + "import numpy as np\n", + "\n", + "transform_img = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224)])\n", + "transform_norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n", + "\n", + "transform = transforms.Compose([\n", + " transform_img,\n", + " transforms.ToTensor(),\n", + " transform_norm\n", + "])" + ] + }, + { + "cell_type": "markdown", + "id": "d73397bd-14a2-48ee-8c42-46d6b5104115", + "metadata": {}, + "source": [ + "### Data and weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5b258b8-c670-473f-858e-2f8464863e29", + "metadata": {}, + "outputs": [], + "source": [ + "## Data loading\n", + "if dummy:\n", + " images, labels = transform_norm(torch.randn(3, 3, 224, 224).clamp(min=0, max=1)), torch.tensor([0,1,2])\n", + " features = vgg16(weights=None).eval()._modules['features']\n", + "else:\n", + " from torch.utils.data import SubsetRandomSampler, DataLoader\n", + "\n", + " # Attention: the next row downloads a dataset into the current folder!\n", + " dataset = datasets.Caltech101(root='.', transform=transform, download=True)\n", + "\n", + " categories = ['cougar_body', 'Leopards', 'wild_cat']\n", + "\n", + " all_indices = []\n", + " for category in categories:\n", + " category_idx = dataset.categories.index(category)\n", + " category_indices = [i for i, label in enumerate(dataset.y) if label == category_idx]\n", + "\n", + " num_samples = min(7, len(category_indices))\n", + "\n", + " selected_indices = np.random.choice(category_indices, num_samples, replace=False)\n", + " all_indices.extend(selected_indices)\n", + "\n", + " sampler = SubsetRandomSampler(all_indices)\n", + " loader = DataLoader(dataset, batch_size=21, sampler=sampler)\n", + "\n", + " try:\n", + " images, labels = next(iter(loader))\n", + " except Exception as e:\n", + " print(f\"Exception: {e}\\nSimply run the cell again.\")\n", + "\n", + " ## Feature extractor\n", + " features = vgg16(weights='IMAGENET1K_V1').eval()._modules['features']" + ] + }, + { + "cell_type": "markdown", + "id": "e7f02b4d-1da8-44ea-a887-6413d150b355", + "metadata": {}, + "source": [ + "### The fun begins here\n", + "\n", + "We construct a feature map $\\phi$ from image space to feature space.\n", + "Here, we sum over spatial locations in feature space to get more or less translation invariance in pixel space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eef79eae-f9c7-4b77-8d7c-5edff8e84aeb", + "metadata": {}, + "outputs": [], + "source": [ + "from zennit.layer import Sum\n", + "\n", + "phi = torch.nn.Sequential(\n", + " features,\n", + " Sum((2,3))\n", + ")\n", + "\n", + "Z = phi(images).detach()" + ] + }, + { + "cell_type": "markdown", + "id": "97b43d41-322a-483c-8506-93e3fa0a852d", + "metadata": {}, + "source": [ + "Use simple `scikit-learn.KMeans` on the features:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87c058d4-a3e4-4d29-af50-a7f2235e78c3", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize on class means\n", + "# because we have very few data points here\n", + "centroids = np.stack([Z[labels == y].mean(0) for y in labels.unique()])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "309e1158-de08-4493-af07-32a592622a94", + "metadata": {}, + "outputs": [], + "source": [ + "if not dummy:\n", + " from sklearn.cluster import KMeans\n", + " standard_kmeans = KMeans(n_clusters=3, n_init='auto', init=centroids).fit(Z)\n", + " centroids = standard_kmeans.cluster_centers_" + ] + }, + { + "cell_type": "markdown", + "id": "5d65f068-b651-4f87-81d4-54508b71c841", + "metadata": {}, + "source": [ + "Now build a deep clustering model that takes images as input and predicts the k-means assignments\n", + "\n", + "We also apply a little scaling trick that makes heatmaps nicer, but usually does not change the cluster assignments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce2dbb2a-8a97-488d-9f88-25426881ee10", + "metadata": {}, + "outputs": [], + "source": [ + "from zennit.layer import PairwiseCentroidDistance\n", + "\n", + "# it's not necessary, just looks a bit nicer\n", + "s = ((centroids**2).sum(-1, keepdims=True)**.5)\n", + "s = s / s.mean()\n", + "\n", + "model = torch.nn.Sequential(\n", + " phi,\n", + " PairwiseCentroidDistance(torch.from_numpy(centroids / s).float())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f177bbce-fe8f-46b8-b7a9-b9bfb9048145", + "metadata": {}, + "source": [ + "### Enter zennit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06892de9-0add-448d-8b76-0f6ea3a0ccd7", + "metadata": {}, + "outputs": [], + "source": [ + "# import zennit\n", + "from zennit.attribution import Gradient\n", + "from zennit.composites import EpsilonGammaBox\n", + "from zennit.image import imgify\n", + "from zennit.torchvision import VGGCanonizer\n", + "from zennit.canonizers import KMeansCanonizer\n", + "from zennit.composites import LayerMapComposite, MixedComposite\n", + "from zennit.layer import NeuralizedKMeans, MinPool1d\n", + "from zennit.rules import ZPlus, Gamma, MinTakesMost1d\n", + "\n", + "def data2img(x):\n", + " return (x.squeeze().permute(1,2,0) * torch.tensor([0.229, 0.224, 0.225])) + torch.tensor([0.485, 0.456, 0.406])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aac5b8af-61cc-400b-a0fc-b036148104ad", + "metadata": {}, + "outputs": [], + "source": [ + "# compute cluster assignments and check if they are equal\n", + "# without the scaling trick above, the are definitely equal (trust me)\n", + "ypred = model(images).argmin(1)\n", + "# assert (ypred.numpy() == standard_kmeans.predict(Z)).all()" + ] + }, + { + "cell_type": "markdown", + "id": "47e38917-b4ee-499f-ba9e-55cce7cb8163", + "metadata": {}, + "source": [ + "### Everything is ready.\n", + "\n", + "You can play around with the `beta` parameter in `MinTakesMost1d` and the `gamma` parameter in `Gamma`.\n", + "\n", + "`beta` is a contrast parameter. Keep `beta < 0`.\n", + "Small negative `beta` can be seen as *one-vs-all* explanation whereas large negative `beta` is more like *one-vs-nearest-competitor*.\n", + "\n", + "The `gamma` parameter controls the contribution of negative weights. Keep `gamma >= 0`.\n", + "In practice, small (positive) `gamma` can result in entirely negative heatmaps. Think of thousand negative weights and a single positive weight. The positive weight could be enough to win the k-means assignment in feature space, but it's lost after a few layers because the graph is flooded with negative contributions.\n", + "\n", + "If you are trying to explain contribution to another cluster (say, $x$ is assigned to cluster $1$, but you want to see if there is some evidence for cluster $2$ in the image), then definitely cramp up `gamma` or even use `ZPlus` instead of `Gamma`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa0f7ca6-3e73-4254-ba31-26a6de28e690", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "canonizer = KMeansCanonizer()\n", + "\n", + "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", + "\n", + "composite = MixedComposite([\n", + " EpsilonGammaBox(low=low, high=high, canonizers=[canonizer]),\n", + " LayerMapComposite([\n", + " (NeuralizedKMeans, Gamma(gamma=.0)),\n", + " (MinPool1d, MinTakesMost1d(beta=1e-6))\n", + " ])\n", + "])\n", + "\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " for c in range(len(centroids)):\n", + " print(\"Cluster %d\"%c)\n", + " cluster_members = (ypred == c).nonzero()[:,0]\n", + " for i in cluster_members:\n", + " img = images[i].unsqueeze(0)\n", + " target = torch.eye(len(centroids))[[c]]\n", + " output, attribution = attributor(img, target)\n", + " relevance = attribution[0].sum(0)\n", + "\n", + " heatmap = np.array(imgify(relevance, symmetric=True, cmap='seismic').convert('RGB'))\n", + " display(imgify(np.stack([data2img(img).numpy(), heatmap]), grid=(1,2)))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst index d89d94b..111cb1e 100644 --- a/docs/source/tutorial/index.rst +++ b/docs/source/tutorial/index.rst @@ -6,6 +6,7 @@ :maxdepth: 1 image-classification-vgg-resnet + deep-kmeans .. image-segmentation-with-unet text-classification-with-tbd diff --git a/src/zennit/canonizers.py b/src/zennit/canonizers.py index fc4a4f2..eefb095 100644 --- a/src/zennit/canonizers.py +++ b/src/zennit/canonizers.py @@ -22,6 +22,7 @@ from .core import collect_leaves from .types import Linear, BatchNorm, ConvolutionTranspose +from .layer import PairwiseCentroidDistance, NeuralizedKMeans, MinPool1d class Canonizer(metaclass=ABCMeta): @@ -329,3 +330,99 @@ def register(self): def remove(self): '''Remove this Canonizer. Nothing to do for a CompositeCanonizer.''' + + +class KMeansCanonizer(Canonizer): + '''Canonizer for k-means. + + This canonizer replaces a :py:obj:`Distance` layer with power 2 with a :py:obj:`NeuralizedKMeans` layer followed by + a :py:obj:`LogMeanExpPool` + + Parameters + ---------- + beta : float + stiffness of the :py:obj:`LogMeanExpPool` layer. Should be smaller than 0 in order to approximate the min + function. Default is -1. + + Examples + -------- + >>> from sklearn.cluster import KMeans + >>> centroids = KMeans(n_clusters=10).fit(X).cluster_centers_ + >>> model = torch.nn.Sequential(Distance(torch.from_numpy(centroids).float(), power=2)) + >>> cluster_assignment = model(x).argmin() + >>> canonizer = KMeansCanonizer() + >>> with Gradient(model, canonizer=[canonizer]) as attributor: + >>> output, attribution = attributor(x, torch.eye(len(centroids))[[cluster_assignment]]) + ''' + def __init__(self): + self.distance = None + self.distance_unchanged = None + self.parent_module = None + self.child_name = None + + def apply(self, root_module): + '''Apply this canonizer recursively on all applicable modules. + + Iterates over all modules of the root module and applies this canonizer to all :py:obj:`Distance` layers with + power 2. + + Parameters + ---------- + root_module : :py:obj:`torch.nn.Module` + Root module containing a :py:obj:`Distance` layer with power 2 as a submodule. + ''' + instances = [] + + for full_name, module in root_module.named_modules(): + if isinstance(module, PairwiseCentroidDistance) and module.power == 2: + instance = self.copy() + if '.' in full_name: + parent_name, child_name = full_name.rsplit('.', 1) + parent_module = getattr(root_module, parent_name) + else: + parent_module = root_module + child_name = full_name + + instance.parent_module = parent_module + instance.child_name = child_name + + instance.register(module) + instances.append(instance) + + return instances + + def register(self, distance_module): + '''Register the :py:obj:`Distance` layer and replace it with a :py:obj:`NeuralizedKMeans` layer followed by a + :py:obj:`LogMeanExpPool` layer. + + compute :math:`w_{ck} = 2(\\mathbf{\\mu}_c - \\mathbf{\\mu}_k)` and :math:`b_{ck} = \\|\\mathbf{\\mu}_k\\|^2 - + \\|\\mathbf{\\mu}_c\\|^2`. Weights are stored in a tensor :math:`W \\in \\mathbb{R}^{K \\times (K - 1) + \\times D}` and biases in a vector :math:`b \\in \\mathbb{R}^{K \\times (K - 1)}`. + + A :py:obj:`NeuralizedKMeans` layer is created with these weights and biases. The :py:obj:`LogMeanExpPool` layer + is created with the beta value supplied to the constructor. + + Parameters + ---------- + distance_module : list of :py:obj:`Distance` + Distance layers to replace. + ''' + self.distance = distance_module + + n_clusters, n_dims = self.distance.centroids.shape + mask = ~torch.eye(n_clusters, dtype=bool) + weight = 2 * (self.distance.centroids[:, None, :] - self.distance.centroids[None, :, :]) + weight = weight[mask].reshape(n_clusters, n_clusters - 1, n_dims) + norms = torch.norm(self.distance.centroids, dim=-1) + bias = (norms[None, :] ** 2 - norms[:, None] ** 2)[mask].reshape(n_clusters, n_clusters - 1) + self.parent_module.add_module( + self.child_name, + torch.nn.Sequential(NeuralizedKMeans(weight, bias), MinPool1d(n_clusters - 1), torch.nn.Flatten()) + ) + + def remove(self): + """Revert the changes introduced by this canonizer.""" + setattr(self.parent_module, self.child_name, self.distance) + + def copy(self): + return KMeansCanonizer() diff --git a/src/zennit/layer.py b/src/zennit/layer.py index bd93d90..301c402 100644 --- a/src/zennit/layer.py +++ b/src/zennit/layer.py @@ -34,3 +34,164 @@ def __init__(self, dim=-1): def forward(self, input): '''Computes the sum along a dimension.''' return torch.sum(input, dim=self.dim) + + +class PairwiseCentroidDistance(torch.nn.Module): + '''Compute pairwise distances between inputs and centroids. + + Initialized with a set of centroids, this layer computes the pairwise distance between the input and the centroids. + + Parameters + ---------- + centroids : :py:obj:`torch.Tensor` + shape (K, D) tensor of centroids + power : float + power to raise the distance to + + Examples + -------- + >>> centroids = torch.randn(10, 2) + >>> distance = PairwiseCentroidDistance(centroids) + >>> x = torch.randn(100, 2) + >>> distance(x) + + ''' + def __init__(self, centroids, power=2): + super().__init__() + self.centroids = torch.nn.Parameter(centroids) + self.power = power + + def forward(self, input): + '''Computes the pairwise distance between `input` and `self.centroids` and raises to the power `self.power`. + + Parameters + ---------- + input : :py:obj:`torch.Tensor` + shape (N, D) tensor of points + + Returns + ------- + :py:obj:`torch.Tensor` + shape (N, K) tensor of distances + ''' + return torch.cdist(input, self.centroids)**self.power + + +class NeuralizedKMeans(torch.nn.Module): + '''Compute the k-means discriminants for a set of points. + + Technically, this is a tensor-matrix product with a bias. + + Parameters + ---------- + weight : :py:obj:`torch.Tensor` + shape (K, K-1, D) tensor of weights + bias : :py:obj:`torch.Tensor` + shape (K, K-1) tensor of biases + + Examples + -------- + >>> weight = torch.randn(10, 9, 2) + >>> bias = torch.randn(10, 9) + >>> neuralized_kmeans = NeuralizedKMeans(weight, bias) + + ''' + def __init__(self, weight, bias): + super().__init__() + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) + + def forward(self, x): + '''Computes the tensor-matrix product of `x` and `self.weight` and adds `self.bias`. + + Parameters + ---------- + x : :py:obj:`torch.Tensor` + shape (N, D) tensor of points + + Returns + ------- + :py:obj:`torch.Tensor` + shape (N, K, K-1) tensor of k-means discriminants + ''' + return torch.einsum('nd,kjd->nkj', x, self.weight) + self.bias + + +class MinPool2d(torch.nn.MaxPool2d): + '''Computes a min pool. + + Parameters + ---------- + kernel_size : int or tuple + size of the pooling window + stride : int or tuple + stride of the pooling operation + padding : int or tuple + zero-padding added to both sides of the input + dilation : int or tuple + spacing between kernel elements + return_indices : bool + if True, will return the max indices along with the outputs + ceil_mode : bool + if True, will use ceil instead of floor to compute the output shape + + Examples + -------- + >>> pool = MinPool2d(2) + >>> x = torch.randn(1, 1, 4, 4) + >>> pool(x) + ''' + def forward(self, input): + '''Computes the min pool of `input`. + + Parameters + ---------- + input : :py:obj:`torch.Tensor` + the input tensor + + Returns + ------- + :py:obj:`torch.Tensor` + the min pool of `input` + ''' + return -super().forward(-input) + + +class MinPool1d(torch.nn.MaxPool1d): + '''Computes a min pool. + + Parameters + ---------- + kernel_size : int or tuple + size of the pooling window + stride : int or tuple + stride of the pooling operation + padding : int or tuple + zero-padding added to both sides of the input + dilation : int or tuple + spacing between kernel elements + return_indices : bool + if True, will return the max indices along with the outputs + ceil_mode : bool + if True, will use ceil instead of floor to compute the output shape + + Examples + -------- + >>> pool = MinPool1d(2) + >>> x = torch.randn(1, 1, 4) + >>> pool(x) + ''' + def forward(self, input): + '''Computes the min pool of `input`. + + Parameters + ---------- + input : :py:obj:`torch.Tensor` + the input tensor + + Returns + ------- + :py:obj:`torch.Tensor` + the min pool of `input` + ''' + return -super().forward(-input) diff --git a/src/zennit/rules.py b/src/zennit/rules.py index 4c7554e..0d29f6a 100644 --- a/src/zennit/rules.py +++ b/src/zennit/rules.py @@ -436,3 +436,198 @@ def forward(self, module, input, output): def backward(self, module, grad_input, grad_output): '''Modify ReLU gradient to the smooth softplus gradient :cite:p:`dombrowski2019explanations`.''' return (torch.sigmoid(self.beta_smooth * self.stored_tensors['input'][0]) * grad_output[0],) + + +class TakesMostBase(Hook): + '''Base class for TakesMost rules. + This class provides a common interface for rule variants that utilize a softmax-like weighting of the input + contributions based on their magnitude. + + Parameters + ---------- + beta: float + Beta parameter for controlling the sensitivity of the softmax weighting. + + Methods + ------- + max_fn(input, kernel_size, stride, padding, dilation): + Computes the maximum value in a local window for each entry in the input tensor. + sum_fn(input, kernel, stride, padding, dilation): + Computes the sum of elements in a local window for each entry in the input tensor. + forward(module, input, output): + Stores the input for later use in the backward pass. + backward(module, grad_input, grad_output): + Modifies the gradient based on the softmax weighting of the input contributions. + ''' + def __init__(self, beta=1.0): + super().__init__() + self.beta = beta + self.stored_tensors = {} + + def copy(self): + '''Return a copy of this hook with the same beta parameter.''' + return self.__class__(beta=self.beta) + + def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' + raise NotImplementedError("Implement in subclass") + + def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' + raise NotImplementedError("Implement in subclass") + + def forward(self, module, input, output): + '''Stores the input for later use in the backward pass.''' + self.stored_tensors['input'] = input + + def backward(self, module, grad_input, grad_output): + '''Modifies the gradient based on the softmax-like weighting of input contributions.''' + stored_input = self.stored_tensors['input'][0] + + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + + # For numerical stability, we subtract the maximum value from the input + max_val = self.max_fn(self.beta * stored_input, kernel_size, stride, padding, dilation) + exp_input = torch.exp(self.beta * stored_input - max_val) + summed_elements = self.sum_fn(exp_input, kernel_size, stride=stride, padding=padding, dilation=dilation) + softmax_output = exp_input / summed_elements + + return (softmax_output * grad_output[0],) + + +class MinTakesMost1d(TakesMostBase): + '''1D variant of TakesMost rule that weights the smallest contributions the most. + This rule is a 1D variant of TakesMostBase, but weights the smallest input contributions the most. + + Methods + ------- + __init__(beta=1.0): + Initializes the MinTakesMost1d class with a negative beta value. + ''' + def __init__(self, beta=1.0): + super().__init__(-beta) + + def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' + return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + + def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' + in_channels = input.shape[1] + kernel = torch.ones((in_channels, 1, kernel_size), device=input.device) + return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, + groups=in_channels) + + +class MaxTakesMost1d(TakesMostBase): + '''1D variant of TakesMost rule that weights the largest contributions the most. + This rule is a 1D variant of TakesMostBase, but weights the largest input contributions the most. + + Methods + ------- + __init__(beta=1.0): + Initializes the MaxTakesMost1d class. + ''' + def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' + return torch.nn.functional.max_pool1d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + + def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' + in_channels = input.shape[1] + kernel = torch.ones((in_channels, 1, kernel_size), device=input.device) + return torch.nn.functional.conv1d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, + groups=in_channels) + + +class MinTakesMost2d(TakesMostBase): + '''2D variant of TakesMost rule that weights the smallest contributions the most. + This rule is a 2D variant of TakesMostBase, but weights the smallest input contributions the most. + + Methods + ------- + __init__(beta=1.0): + Initializes the MinTakesMost2d class with a negative beta value. + ''' + def __init__(self, beta=1.0): + super().__init__(-beta) + + def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' + return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + + def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' + in_channels = input.shape[1] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + kernel = torch.ones((in_channels, 1, *kernel_size), device=input.device) + return torch.nn.functional.conv2d(input, weight=kernel, stride=stride, padding=padding, dilation=dilation, + groups=in_channels) + + +class MaxTakesMost2d(TakesMostBase): + '''2D variant of TakesMost rule that weights the largest contributions the most. + This rule is a 2D variant of TakesMostBase, but weights the largest input contributions the most. + + Methods + ------- + __init__(beta=1.0): + Initializes the MaxTakesMost2d class. + ''' + def max_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the maximum value in a local window for each entry in the input tensor.''' + # return torch.nn.functional.max_pool2d(input, kernel_size, stride=stride, padding=padding, dilation=dilation) + return input.max().view(1,1,1,1) + + def sum_fn(self, input, kernel_size, stride, padding, dilation): + '''Computes the sum of elements in a local window for each entry in the input tensor.''' + in_channels = input.shape[1] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride) + kernel = torch.ones((in_channels, 1, *kernel_size), device=input.device) + summed_tensor = torch.nn.functional.conv2d(input, weight=kernel, stride=stride, padding=padding, + dilation=dilation, groups=in_channels) + expanded_sum = torch.nn.functional.conv_transpose2d(summed_tensor, weight=kernel, stride=stride, + padding=padding, dilation=dilation, groups=in_channels) + pad_height = input.shape[2] - expanded_sum.shape[2] + pad_width = input.shape[3] - expanded_sum.shape[3] + if pad_height > 0 or pad_width > 0: + expanded_sum = torch.nn.functional.pad(expanded_sum, (0, pad_width, 0, pad_height)) + + return expanded_sum + + def backward(self, module, grad_input, grad_output): + '''Modifies the gradient based on the softmax-like weighting of input contributions.''' + stored_input = self.stored_tensors['input'][0] + if isinstance(module.stride, int): + stride = (module.stride, module.stride) + else: + stride = module.stride + + kernel_size = module.kernel_size + padding = module.padding + dilation = module.dilation + + max_val = self.max_fn(self.beta * stored_input, kernel_size, stride, padding, dilation) + exp_input = torch.exp(self.beta * stored_input - max_val) + summed_elements = self.sum_fn(exp_input, kernel_size, stride=stride, padding=padding, dilation=dilation) + softmax_output = exp_input / summed_elements + softmax_output[summed_elements == 0] = 0 + + in_channels = stored_input.shape[1] + kernel = torch.ones((in_channels, 1, kernel_size, kernel_size), device=stored_input.device) + expanded_grad_output = torch.nn.functional.conv_transpose2d(grad_output[0], weight=kernel, stride=stride, + padding=padding, dilation=dilation, + groups=in_channels) + pad_height = stored_input.shape[2] - expanded_grad_output.shape[2] + pad_width = stored_input.shape[3] - expanded_grad_output.shape[3] + if pad_height > 0 or pad_width > 0: + expanded_grad_output = torch.nn.functional.pad(expanded_grad_output, (0, pad_width, 0, pad_height)) + + return (softmax_output * expanded_grad_output,)