diff --git a/docs/source/tutorial/deep-kmeans.ipynb b/docs/source/tutorial/deep-kmeans.ipynb new file mode 100644 index 0000000..b9ae15f --- /dev/null +++ b/docs/source/tutorial/deep-kmeans.ipynb @@ -0,0 +1,312 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a0661ff4-9f41-405c-8453-f009c31e6a0e", + "metadata": {}, + "source": [ + "## Explaining Deep Cluster Assignments with Neuralized K-Means on Image Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3aef718-d2a0-4f30-9b91-b53f5b288299", + "metadata": {}, + "outputs": [], + "source": [ + "# for colab folks\n", + "# %pip install zennit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa6d0ce7-ea3d-46e5-a8d7-e9a8b31d9239", + "metadata": {}, + "outputs": [], + "source": [ + "# Basic boilerplate code\n", + "from torchvision import datasets, transforms\n", + "from torchvision.models import vgg16\n", + "import torch\n", + "import numpy as np\n", + "\n", + "transform_img = transforms.Compose([transforms.Resize(224), transforms.CenterCrop(224)])\n", + "transform_norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n", + "\n", + "transform = transforms.Compose([\n", + " transform_img,\n", + " transforms.ToTensor(),\n", + " transform_norm\n", + "])" + ] + }, + { + "cell_type": "markdown", + "id": "d73397bd-14a2-48ee-8c42-46d6b5104115", + "metadata": {}, + "source": [ + "### Real data and weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5b258b8-c670-473f-858e-2f8464863e29", + "metadata": {}, + "outputs": [], + "source": [ + "# uncomment this cell for an example on real data and real weights\n", + "### Data loading\n", + "# from torch.utils.data import SubsetRandomSampler, DataLoader\n", + "\n", + "# # Attention: the next row downloads a dataset into the current folder!\n", + "# dataset = datasets.Caltech101(root='.', transform=transform, download=True)\n", + "\n", + "# categories = ['cougar_body', 'Leopards', 'wild_cat']\n", + "\n", + "# all_indices = []\n", + "# for category in categories:\n", + "# category_idx = dataset.categories.index(category)\n", + "# category_indices = [i for i, label in enumerate(dataset.y) if label == category_idx]\n", + "\n", + "# num_samples = min(7, len(category_indices))\n", + "\n", + "# selected_indices = np.random.choice(category_indices, num_samples, replace=False)\n", + "# all_indices.extend(selected_indices)\n", + "\n", + "# sampler = SubsetRandomSampler(all_indices)\n", + "# loader = DataLoader(dataset, batch_size=21, sampler=sampler)\n", + "\n", + "# # If this line throws a shape error, just run this cell again (some images in Caltech101 are grayscale)\n", + "# images, labels = next(iter(loader))\n", + "\n", + "### Feature extractor\n", + "# features = vgg16(weights='IMAGENET1K_V1').eval()._modules['features']" + ] + }, + { + "cell_type": "markdown", + "id": "be3a0e0d-afa0-4af1-b8c2-a3f6525dcb03", + "metadata": {}, + "source": [ + "### Random data and weights for online preview" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "736252a7-4445-408c-a946-163ea44903da", + "metadata": {}, + "outputs": [], + "source": [ + "# for zennit contribution guidelines\n", + "# some random data and weights\n", + "images, labels = transform_norm(torch.randn(3, 3, 224, 224).clamp(min=0, max=1)), torch.tensor([0,1,2])\n", + "features = vgg16(weights=None).eval()._modules['features']" + ] + }, + { + "cell_type": "markdown", + "id": "e7f02b4d-1da8-44ea-a887-6413d150b355", + "metadata": {}, + "source": [ + "### The fun begins here\n", + "\n", + "We construct a feature map $\\phi$ from image space to feature space.\n", + "Here, we sum over spatial locations in feature space to get more or less translation invariance in pixel space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eef79eae-f9c7-4b77-8d7c-5edff8e84aeb", + "metadata": {}, + "outputs": [], + "source": [ + "from zennit.layer import Sum\n", + "\n", + "phi = torch.nn.Sequential(\n", + " features,\n", + " Sum((2,3))\n", + ")\n", + "\n", + "Z = phi(images).detach()" + ] + }, + { + "cell_type": "markdown", + "id": "97b43d41-322a-483c-8506-93e3fa0a852d", + "metadata": {}, + "source": [ + "Use simple `scikit-learn.KMeans` on the features:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87c058d4-a3e4-4d29-af50-a7f2235e78c3", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize on class means\n", + "# because we have very few data points here\n", + "centroids = np.stack([Z[labels == y].mean(0) for y in labels.unique()])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "309e1158-de08-4493-af07-32a592622a94", + "metadata": {}, + "outputs": [], + "source": [ + "### uncomment for real fun\n", + "# from sklearn.cluster import KMeans\n", + "# standard_kmeans = KMeans(n_clusters=3, n_init='auto', init=centroids).fit(Z)\n", + "# centroids = standard_kmeans.cluster_centers_" + ] + }, + { + "cell_type": "markdown", + "id": "5d65f068-b651-4f87-81d4-54508b71c841", + "metadata": {}, + "source": [ + "Now build a deep clustering model that takes images as input and predicts the k-means assignments\n", + "\n", + "We also apply a little scaling trick that makes heatmaps nicer, but usually does not change the cluster assignments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce2dbb2a-8a97-488d-9f88-25426881ee10", + "metadata": {}, + "outputs": [], + "source": [ + "from zennit.layer import Distance\n", + "\n", + "# it's not necessary, just looks a bit nicer\n", + "s = ((centroids**2).sum(-1, keepdims=True)**.5)\n", + "s = s / s.mean()\n", + "\n", + "model = torch.nn.Sequential(\n", + " phi,\n", + " Distance(torch.from_numpy(centroids / s).float())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f177bbce-fe8f-46b8-b7a9-b9bfb9048145", + "metadata": {}, + "source": [ + "### Enter zennit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06892de9-0add-448d-8b76-0f6ea3a0ccd7", + "metadata": {}, + "outputs": [], + "source": [ + "# import zennit\n", + "from zennit.attribution import Gradient\n", + "from zennit.composites import EpsilonGammaBox\n", + "from zennit.image import imgify\n", + "from zennit.torchvision import VGGCanonizer\n", + "from zennit.canonizers import KMeansCanonizer\n", + "from zennit.composites import LayerMapComposite, MixedComposite\n", + "from zennit.layer import NeuralizedKMeans\n", + "from zennit.rules import ZPlus, Gamma\n", + "\n", + "def data2img(x):\n", + " return (x.squeeze().permute(1,2,0) * torch.tensor([0.229, 0.224, 0.225])) + torch.tensor([0.485, 0.456, 0.406])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aac5b8af-61cc-400b-a0fc-b036148104ad", + "metadata": {}, + "outputs": [], + "source": [ + "# compute cluster assignments and check if they are equal\n", + "# without the scaling trick above, the are definitely equal (trust me)\n", + "ypred = model(images).argmin(1)\n", + "# assert (ypred.numpy() == standard_kmeans.predict(Z)).all()" + ] + }, + { + "cell_type": "markdown", + "id": "47e38917-b4ee-499f-ba9e-55cce7cb8163", + "metadata": {}, + "source": [ + "### Everything is ready.\n", + "\n", + "You can play around with the `beta` parameter in `KMeansCanonizer` and the `gamma` parameter in `Gamma`.\n", + "\n", + "`beta` is a contrast parameter. Keep `beta < 0`.\n", + "Small negative `beta` can be seen as *one-vs-all* explanation whereas large negative `beta` is more like *one-vs-nearest-competitor*.\n", + "\n", + "The `gamma` parameter controls the contribution of negative weights. Keep `gamma >= 0`.\n", + "In practice, small (positive) `gamma` can result in entirely negative heatmaps. Think of thousand negative weights and a single positive weight. The positive weight could be enough to win the k-means assignment in feature space, but it's lost after a few layers because the graph is flooded with negative contributions.\n", + "\n", + "If you are trying to explain contribution to another cluster (say, $x$ is assigned to cluster $1$, but you want to see if there is some evidence for cluster $2$ in the image), then definitely cramp up `gamma` or even use `ZPlus` instead of `Gamma`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa0f7ca6-3e73-4254-ba31-26a6de28e690", + "metadata": {}, + "outputs": [], + "source": [ + "canonizer = KMeansCanonizer(beta=-1e-12)\n", + "\n", + "low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))\n", + "\n", + "composite = MixedComposite([\n", + " EpsilonGammaBox(low=low, high=high, canonizers=[canonizer]),\n", + " LayerMapComposite([(NeuralizedKMeans, Gamma(gamma=1.))])\n", + "])\n", + "\n", + "with Gradient(model=model, composite=composite) as attributor:\n", + " for c in range(len(centroids)):\n", + " print(\"Cluster %d\"%c)\n", + " cluster_members = (ypred == c).nonzero()[:,0]\n", + " for i in cluster_members:\n", + " img = images[i].unsqueeze(0)\n", + " target = torch.eye(len(centroids))[[c]]\n", + " output, attribution = attributor(img, target)\n", + " relevance = attribution[0].sum(0)\n", + "\n", + " heatmap = np.array(imgify(relevance, symmetric=True, cmap='seismic').convert('RGB'))\n", + " display(imgify(np.stack([data2img(img).numpy(), heatmap]), grid=(1,2)))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst index d89d94b..111cb1e 100644 --- a/docs/source/tutorial/index.rst +++ b/docs/source/tutorial/index.rst @@ -6,6 +6,7 @@ :maxdepth: 1 image-classification-vgg-resnet + deep-kmeans .. image-segmentation-with-unet text-classification-with-tbd