From 920a05447a3c8673f866a51b7f9bd85f17add2c3 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Fri, 10 Nov 2023 14:38:53 +0530 Subject: [PATCH] Add example to fine-tune torchvision model with Keras (#457) * add: example to fine-tune torchvision model with Keras * add: colab badge * update: wandb.init * update: torchvision + keras notebook * update: timm + keras notebook * fix: timm_keras notebook * add: keras + monai example * update: colab links * update: torchvision + keras notebook * update: torchvision + keras notebook * update: torchvision + keras notebook --- colabs/README.md | 4 +- .../keras_core/monai_medmnist_keras.ipynb | 477 +++++++++ colabs/keras/keras_core/timm_keras.ipynb | 377 +++++++ .../keras/keras_core/torchvision-keras.ipynb | 922 ++++++++++++++++++ 4 files changed, 1779 insertions(+), 1 deletion(-) create mode 100644 colabs/keras/keras_core/monai_medmnist_keras.ipynb create mode 100644 colabs/keras/keras_core/timm_keras.ipynb create mode 100644 colabs/keras/keras_core/torchvision-keras.ipynb diff --git a/colabs/README.md b/colabs/README.md index 7913dc56..141c049b 100644 --- a/colabs/README.md +++ b/colabs/README.md @@ -26,7 +26,9 @@ | Kaolin-Wisp | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://wandb.me/vqad-colab) | | Super Gradients | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://wandb.me/yolo-nas-colab) | | 🎸 Generating music with AudioCraft | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/audiocraft/AudioCraft.ipynb) | - +| 🦄 Fine-tune a Torchvision Model with KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/torchvision_keras.ipynb) | +| 🦄 Fine-tune a Timm Model with KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/timm_keras.ipynb) | +| 🦄 Medical Image Classification Tutorial using MonAI and KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/monai_medmnist_keras.ipynb) | # 🏋🏽‍♂️ W&B Features diff --git a/colabs/keras/keras_core/monai_medmnist_keras.ipynb b/colabs/keras/keras_core/monai_medmnist_keras.ipynb new file mode 100644 index 00000000..747baa6f --- /dev/null +++ b/colabs/keras/keras_core/monai_medmnist_keras.ipynb @@ -0,0 +1,477 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "dMq6cFtJl2vR" + }, + "source": [ + "\"Keras\"\n", + "\"Weights\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2eqSE_8rl6yo" + }, + "source": [ + "# 🩺 Medical Image Classification Tutorial using MonAI and Keras\n", + "\n", + "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/monai_medmnist_keras.ipynb)\n", + "\n", + "This notebook demonstrates\n", + "- an end-to-end training using [MonAI](https://github.com/Project-MONAI/MONAI) and [KerasCore](https://github.com/keras-team/keras-core).\n", + "- how we can use the backend-agnostic Keras callbacks for [Weights & Biases](https://wandb.ai/site) to manage and track our experiment.\n", + "\n", + "Original Notebook: https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/mednist_tutorial.ipynb" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ny_HOlvymX6O" + }, + "source": [ + "## Installing and Importing the Dependencies" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6SvDmUxVmatW" + }, + "source": [ + "- We install the `main` branch of [KerasCore](https://github.com/keras-team/keras-core), this lets us use the latest feature merged in KerasCore.\n", + "- We install [monai](https://github.com/Project-MONAI/MONAI), a PyTorch-based, open-source framework for deep learning in healthcare imaging, part of the PyTorch Ecosystem.\n", + "- We also install [wandb-addons](https://github.com/soumik12345/wandb-addons), a library that hosts the backend-agnostic callbacks compatible with KerasCore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZO1xjFtvVdqQ" + }, + "outputs": [], + "source": [ + "# install the `main` branch of KerasCore\n", + "!pip install -qq namex\n", + "!apt install python3.10-venv\n", + "!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install\n", + "\n", + "# install monai and wandb-addons\n", + "!pip install -qq git+https://github.com/soumik12345/wandb-addons\n", + "!pip install -q \"monai-weekly[pillow, tqdm]\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c6vd3NZ-mhxs" + }, + "source": [ + "We specify the Keras backend to be using `torch` by explicitly specifying the environment variable `KERAS_BACKEND`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d5ZiQmMkW-h3" + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"KERAS_BACKEND\"] = \"torch\"\n", + "\n", + "import shutil\n", + "import tempfile\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "import torch\n", + "import numpy as np\n", + "from sklearn.metrics import classification_report\n", + "\n", + "import keras_core as keras\n", + "from keras_core.utils import TorchModuleWrapper\n", + "\n", + "from monai.apps import download_and_extract\n", + "from monai.config import print_config\n", + "from monai.data import decollate_batch, DataLoader\n", + "from monai.metrics import ROCAUCMetric\n", + "from monai.networks.nets import DenseNet121\n", + "from monai.transforms import (\n", + " Activations,\n", + " EnsureChannelFirst,\n", + " AsDiscrete,\n", + " Compose,\n", + " LoadImage,\n", + " RandFlip,\n", + " RandRotate,\n", + " RandZoom,\n", + " ScaleIntensity,\n", + ")\n", + "from monai.utils import set_determinism\n", + "\n", + "import wandb\n", + "from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1BvtLvxKmkSp" + }, + "source": [ + "We initialize a [wandb run](https://docs.wandb.ai/guides/runs) and set the configs for the experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WUt7PAsado4j" + }, + "outputs": [], + "source": [ + "wandb.init(project=\"keras-torch\")\n", + "\n", + "config = wandb.config\n", + "config.batch_size = 128\n", + "config.num_epochs = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NE938skxmoMI" + }, + "source": [ + "## Setup data directory\n", + "\n", + "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.\n", + "This allows you to save results and reuse downloads.\n", + "If not specified a temporary directory will be used." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "82KX1Sj6XXY1" + }, + "outputs": [], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", + "print(root_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nkVQ-tojmzwf" + }, + "source": [ + "## Download dataset\n", + "\n", + "The MedNIST dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions),\n", + "[the RSNA Bone Age Challenge](http://rsnachallenges.cloudapp.net/competitions/4),\n", + "and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).\n", + "\n", + "The dataset is kindly made available by [Dr. Bradley J. Erickson M.D., Ph.D.](https://www.mayo.edu/research/labs/radiology-informatics/overview) (Department of Radiology, Mayo Clinic)\n", + "under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/).\n", + "\n", + "If you use the MedNIST dataset, please acknowledge the source." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DAHybyvdXZoH" + }, + "outputs": [], + "source": [ + "resource = \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz\"\n", + "md5 = \"0bc7306e7427e00ad1c5526a6677552d\"\n", + "\n", + "compressed_file = os.path.join(root_dir, \"MedNIST.tar.gz\")\n", + "data_dir = os.path.join(root_dir, \"MedNIST\")\n", + "if not os.path.exists(data_dir):\n", + " download_and_extract(resource, compressed_file, root_dir, md5)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u6m2Uas1nMi8" + }, + "source": [ + "## Read image filenames from the dataset folders\n", + "\n", + "First of all, check the dataset files and show some statistics. \n", + "There are 6 folders in the dataset: Hand, AbdomenCT, CXR, ChestCT, BreastMRI, HeadCT, \n", + "which should be used as the labels to train our classification model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qtLr1T0gXydq" + }, + "outputs": [], + "source": [ + "class_names = sorted(x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))\n", + "num_class = len(class_names)\n", + "image_files = [\n", + " [os.path.join(data_dir, class_names[i], x) for x in os.listdir(os.path.join(data_dir, class_names[i]))]\n", + " for i in range(num_class)\n", + "]\n", + "num_each = [len(image_files[i]) for i in range(num_class)]\n", + "image_files_list = []\n", + "image_class = []\n", + "for i in range(num_class):\n", + " image_files_list.extend(image_files[i])\n", + " image_class.extend([i] * num_each[i])\n", + "num_total = len(image_class)\n", + "image_width, image_height = PIL.Image.open(image_files_list[0]).size\n", + "\n", + "print(f\"Total image count: {num_total}\")\n", + "print(f\"Image dimensions: {image_width} x {image_height}\")\n", + "print(f\"Label names: {class_names}\")\n", + "print(f\"Label counts: {num_each}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "k4WOUb4IX6KK" + }, + "outputs": [], + "source": [ + "plt.subplots(3, 3, figsize=(8, 8))\n", + "for i, k in enumerate(np.random.randint(num_total, size=9)):\n", + " im = PIL.Image.open(image_files_list[k])\n", + " arr = np.array(im)\n", + " plt.subplot(3, 3, i + 1)\n", + " plt.xlabel(class_names[image_class[k]])\n", + " plt.imshow(arr, cmap=\"gray\", vmin=0, vmax=255)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wOGja_mEnQ5m" + }, + "source": [ + "## Prepare training, validation and test data lists\n", + "\n", + "Randomly select 10% of the dataset as validation and 10% as test." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BmHhKta1X8ZM" + }, + "outputs": [], + "source": [ + "val_frac = 0.1\n", + "test_frac = 0.1\n", + "length = len(image_files_list)\n", + "indices = np.arange(length)\n", + "np.random.shuffle(indices)\n", + "\n", + "test_split = int(test_frac * length)\n", + "val_split = int(val_frac * length) + test_split\n", + "test_indices = indices[:test_split]\n", + "val_indices = indices[test_split:val_split]\n", + "train_indices = indices[val_split:]\n", + "\n", + "train_x = [image_files_list[i] for i in train_indices]\n", + "train_y = [image_class[i] for i in train_indices]\n", + "val_x = [image_files_list[i] for i in val_indices]\n", + "val_y = [image_class[i] for i in val_indices]\n", + "test_x = [image_files_list[i] for i in test_indices]\n", + "test_y = [image_class[i] for i in test_indices]\n", + "\n", + "print(f\"Training count: {len(train_x)}, Validation count: \" f\"{len(val_x)}, Test count: {len(test_x)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lkhtr3p8nT1G" + }, + "source": [ + "## Define MONAI transforms, Dataset and Dataloader to pre-process data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2b5H2WuUYLsc" + }, + "outputs": [], + "source": [ + "train_transforms = Compose(\n", + " [\n", + " LoadImage(image_only=True),\n", + " EnsureChannelFirst(),\n", + " ScaleIntensity(),\n", + " RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),\n", + " RandFlip(spatial_axis=0, prob=0.5),\n", + " RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),\n", + " ]\n", + ")\n", + "\n", + "val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])\n", + "\n", + "y_pred_trans = Compose([Activations(softmax=True)])\n", + "y_trans = Compose([AsDiscrete(to_onehot=num_class)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QdFnaFUDYOaq" + }, + "outputs": [], + "source": [ + "class MedNISTDataset(torch.utils.data.Dataset):\n", + " def __init__(self, image_files, labels, transforms):\n", + " self.image_files = image_files\n", + " self.labels = labels\n", + " self.transforms = transforms\n", + "\n", + " def __len__(self):\n", + " return len(self.image_files)\n", + "\n", + " def __getitem__(self, index):\n", + " return self.transforms(self.image_files[index]), self.labels[index]\n", + "\n", + "\n", + "train_ds = MedNISTDataset(train_x, train_y, train_transforms)\n", + "train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, num_workers=2)\n", + "\n", + "val_ds = MedNISTDataset(val_x, val_y, val_transforms)\n", + "val_loader = DataLoader(val_ds, batch_size=config.batch_size, num_workers=2)\n", + "\n", + "test_ds = MedNISTDataset(test_x, test_y, val_transforms)\n", + "test_loader = DataLoader(test_ds, batch_size=config.batch_size, num_workers=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FIhKqMrVnZV9" + }, + "source": [ + "We typically define a model in PyTorch using [`torch.nn.Module`s](https://pytorch.org/docs/stable/notes/modules.html) which act as the building blocks of stateful computation. Even though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a [`keras_core.Model`](https://keras.io/keras_core/api/models/), because trainable variables inside a Keras Model is tracked exclusively via [Keras Layers](https://keras.io/keras_core/api/layers/).\n", + "\n", + "KerasCore provides us with a feature called `TorchModuleWrapper` which enables us to do exactly this. The `TorchModuleWrapper` is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single `model.fit()`!\n", + "\n", + "The idea of the `TorchModuleWrapper` was proposed by Keras' creator [François Chollet](https://github.com/fchollet) on [this issue thread](https://github.com/keras-team/keras-core/issues/604)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7zbinqy4ZsEy" + }, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "inputs = keras.Input(shape=(1, 64, 64))\n", + "outputs = TorchModuleWrapper(\n", + " DenseNet121(\n", + " spatial_dims=2, in_channels=1, out_channels=num_class\n", + " )\n", + ")(inputs)\n", + "model = keras.Model(inputs, outputs)\n", + "\n", + "# model = MedMnistModel()\n", + "model(next(iter(train_loader))[0].to(device)).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "opjBI87nneYU" + }, + "source": [ + "**Note:** It is actually possible to use torch modules inside a Keras Model without having to explicitly have them wrapped with the `TorchModuleWrapper` as evident by [this tweet](https://twitter.com/fchollet/status/1697381832164290754) from François Chollet. However, this doesn't seem to work at the point of time this example was created, as reported in [this issue](https://github.com/keras-team/keras-core/issues/834)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "y3H-cUKraRDh" + }, + "outputs": [], + "source": [ + "# Compile the model\n", + "model.compile(\n", + " loss=\"sparse_categorical_crossentropy\",\n", + " optimizer=keras.optimizers.Adam(1e-5),\n", + " metrics=[\"accuracy\"],\n", + ")\n", + "\n", + "# Define the backend-agnostic WandB callbacks for KerasCore\n", + "callbacks = [\n", + " # Track experiment metrics\n", + " WandbMetricsLogger(log_freq=\"batch\")\n", + "]\n", + "\n", + "# Train the model by calling model.fit\n", + "model.fit(\n", + " train_loader,\n", + " validation_data=val_loader,\n", + " epochs=config.num_epochs,\n", + " callbacks=callbacks,\n", + ")\n", + "\n", + "wandb.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q0E-kiGBeCnZ" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "V100", + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/colabs/keras/keras_core/timm_keras.ipynb b/colabs/keras/keras_core/timm_keras.ipynb new file mode 100644 index 00000000..1cea4769 --- /dev/null +++ b/colabs/keras/keras_core/timm_keras.ipynb @@ -0,0 +1,377 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Keras\"\n", + "\"Weights\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 🔥 Fine-tune a [Timm](https://huggingface.co/docs/timm/index) Model with Keras and WandB 🦄\n", + "\n", + "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/timm_keras.ipynb)\n", + "\n", + "This notebook demonstrates\n", + "- how we can fine-tune a pre-trained model from timm using [KerasCore](https://github.com/keras-team/keras-core).\n", + "- how we can use the backend-agnostic Keras callbacks for [Weights & Biases](https://wandb.ai/site) to manage and track our experiment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installing and Importing the Dependencies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- We install the `main` branch of [KerasCore](https://github.com/keras-team/keras-core), this lets us use the latest feature merged in KerasCore.\n", + "- We install [timm](https://huggingface.co/docs/timm/index), a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts.\n", + "- We also install [wandb-addons](https://github.com/soumik12345/wandb-addons), a library that hosts the backend-agnostic callbacks compatible with KerasCore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# install the `main` branch of KerasCore\n", + "!pip install -qq namex\n", + "!apt install python3.10-venv\n", + "!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install\n", + "\n", + "# install timm and wandb-addons\n", + "!pip install -qq git+https://github.com/soumik12345/wandb-addons" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We specify the Keras backend to be using `torch` by explicitly specifying the environment variable `KERAS_BACKEND`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"KERAS_BACKEND\"] = \"torch\"\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import timm\n", + "from timm.data import resolve_data_config\n", + "\n", + "import torchvision\n", + "from torchvision import datasets, models, transforms\n", + "from torchvision.transforms.functional import InterpolationMode\n", + "\n", + "import wandb\n", + "from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We initialize a [wandb run](https://docs.wandb.ai/guides/runs) and set the configs for the experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wandb.init(project=\"keras-torch\")\n", + "\n", + "config = wandb.config\n", + "config.model_name = \"xception41\"\n", + "config.freeze_backbone = False\n", + "config.preprocess_config = resolve_data_config({}, model=config.model_name)\n", + "config.dropout_rate = 0.5\n", + "config.batch_size = 4\n", + "config.num_epochs = 25" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## A PyTorch-based Input Pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will be using the [ImageNette](https://github.com/fastai/imagenette) dataset for this experiment. Imagenette is a subset of 10 easily classified classes from [Imagenet](https://www.image-net.org/) (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).\n", + "\n", + "First, let's download this dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz -P imagenette\n", + "!tar zxf imagenette/imagenette2-320.tgz -C imagenette\n", + "!gzip -d imagenette/imagenette2-320.tgz" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we create our standard torch-based data loading pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define pre-processing and augmentation transforms for the train and validation sets\n", + "data_transforms = {\n", + " 'train': transforms.Compose([\n", + " transforms.RandomResizedCrop(\n", + " size=config.preprocess_config[\"input_size\"][1],\n", + " interpolation=InterpolationMode.BICUBIC,\n", + " ),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(\n", + " config.preprocess_config[\"mean\"],\n", + " config.preprocess_config[\"std\"]\n", + " )\n", + " ]),\n", + " 'val': transforms.Compose([\n", + " transforms.Resize(256),\n", + " transforms.CenterCrop(config.preprocess_config[\"input_size\"][1]),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(\n", + " config.preprocess_config[\"mean\"],\n", + " config.preprocess_config[\"std\"]\n", + " )\n", + " ]),\n", + "}\n", + "\n", + "# Define the train and validation datasets\n", + "data_dir = 'imagenette/imagenette2-320'\n", + "image_datasets = {\n", + " x: datasets.ImageFolder(\n", + " os.path.join(data_dir, x), data_transforms[x]\n", + " )\n", + " for x in ['train', 'val']\n", + "}\n", + "\n", + "# Define the torch dataloaders corresponding to the train and validation dataset\n", + "dataloaders = {\n", + " x: torch.utils.data.DataLoader(\n", + " image_datasets[x],\n", + " batch_size=config.batch_size,\n", + " shuffle=True,\n", + " num_workers=4\n", + " )\n", + " for x in ['train', 'val']\n", + "}\n", + "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n", + "class_names = image_datasets['train'].classes\n", + "\n", + "# Specify the global device\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's take a look at a few of the samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def imshow(inp, title=None):\n", + " \"\"\"Display image for Tensor.\"\"\"\n", + " inp = inp.numpy().transpose((1, 2, 0))\n", + " mean = np.array(config.preprocess_config[\"mean\"])\n", + " std = np.array(config.preprocess_config[\"std\"])\n", + " inp = std * inp + mean\n", + " inp = np.clip(inp, 0, 1)\n", + " plt.imshow(inp)\n", + " if title is not None:\n", + " plt.title(title)\n", + " plt.pause(0.001)\n", + "\n", + "\n", + "# Get a batch of training data\n", + "inputs, classes = next(iter(dataloaders['train']))\n", + "print(inputs.shape, classes.shape)\n", + "\n", + "# Make a grid from batch\n", + "out = torchvision.utils.make_grid(inputs)\n", + "\n", + "imshow(out, title=[class_names[x] for x in classes])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating and Training our Classifier" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We typically define a model in PyTorch using [`torch.nn.Module`s](https://pytorch.org/docs/stable/notes/modules.html) which act as the building blocks of stateful computation. Even though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a [`keras_core.Model`](https://keras.io/keras_core/api/models/), because trainable variables inside a Keras Model is tracked exclusively via [Keras Layers](https://keras.io/keras_core/api/layers/).\n", + "\n", + "KerasCore provides us with a feature called `TorchModuleWrapper` which enables us to do exactly this. The `TorchModuleWrapper` is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single `model.fit()`!\n", + "\n", + "The idea of the `TorchModuleWrapper` was proposed by Keras' creator [François Chollet](https://github.com/fchollet) on [this issue thread](https://github.com/keras-team/keras-core/issues/604)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import keras_core as keras\n", + "from keras_core.utils import TorchModuleWrapper\n", + "\n", + "\n", + "class TimmClassifier(keras.Model):\n", + "\n", + " def __init__(self, model_name, freeze_backbone, dropout_rate, num_classes, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " \n", + " # Define the pre-trained module from timm\n", + " self.backbone = TorchModuleWrapper(\n", + " timm.create_model(model_name, pretrained=True)\n", + " )\n", + " self.backbone.trainable = not freeze_backbone\n", + " \n", + " # Build the classification head using keras layers\n", + " self.global_average_pooling = keras.layers.GlobalAveragePooling2D()\n", + " self.dropout = keras.layers.Dropout(dropout_rate)\n", + " self.classification_head = keras.layers.Dense(num_classes)\n", + "\n", + " def call(self, inputs):\n", + " # We get the unpooled features from the timm backbone by calling `forward_features`\n", + " # on the torch module corresponding to the backbone.\n", + " x = self.backbone.module.forward_features(inputs)\n", + " x = self.global_average_pooling(x)\n", + " x = self.dropout(x)\n", + " x = self.classification_head(x)\n", + " return keras.activations.softmax(x, axis=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** It is actually possible to use torch modules inside a Keras Model without having to explicitly have them wrapped with the `TorchModuleWrapper` as evident by [this tweet](https://twitter.com/fchollet/status/1697381832164290754) from François Chollet. However, this doesn't seem to work at the point of time this example was created, as reported in [this issue](https://github.com/keras-team/keras-core/issues/834)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Now, we define the model and pass a random tensor to check the output shape\n", + "model = TimmClassifier(\n", + " model_name=config.model_name,\n", + " freeze_backbone=config.freeze_backbone,\n", + " dropout_rate=config.dropout_rate,\n", + " num_classes=len(class_names)\n", + ")\n", + "model(torch.ones(1, *config.preprocess_config[\"input_size\"]).to(device)).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, in standard Keras fashion, all we need to do is compile the model and call `model.fit()`!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create exponential decay learning rate scheduler\n", + "decay_steps = config.num_epochs * len(dataloaders[\"train\"]) // config.batch_size\n", + "lr_scheduler = keras.optimizers.schedules.ExponentialDecay(\n", + " initial_learning_rate=1e-3, decay_steps=decay_steps, decay_rate=0.1,\n", + ")\n", + "\n", + "# Compile the model\n", + "model.compile(\n", + " loss=\"sparse_categorical_crossentropy\",\n", + " optimizer=keras.optimizers.Adam(lr_scheduler),\n", + " metrics=[\"accuracy\"],\n", + ")\n", + "\n", + "# Define the backend-agnostic WandB callbacks for KerasCore\n", + "callbacks = [\n", + " # Track experiment metrics\n", + " WandbMetricsLogger(log_freq=\"batch\"),\n", + " # Track and version model checkpoints\n", + " WandbModelCheckpoint(\"model.keras\")\n", + "]\n", + "\n", + "# Train the model by calling model.fit\n", + "model.fit(\n", + " dataloaders[\"train\"],\n", + " validation_data=dataloaders[\"val\"],\n", + " epochs=config.num_epochs,\n", + " callbacks=callbacks,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to know more about the backend-agnostic Keras callbacks for Weights & Biases, check out the [docs for wandb-addons](https://geekyrakshit.dev/wandb-addons/keras/keras_core/)." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/colabs/keras/keras_core/torchvision-keras.ipynb b/colabs/keras/keras_core/torchvision-keras.ipynb new file mode 100644 index 00000000..30783184 --- /dev/null +++ b/colabs/keras/keras_core/torchvision-keras.ipynb @@ -0,0 +1,922 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "jLRorPqA1-h4" + }, + "source": [ + "\"Keras\"\n", + "\"Weights\n", + "\n", + "\n", + "# 🔥 Fine-tune a TorchVision Model with Keras and WandB 🦄\n", + "\n", + "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/torchvision_keras.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aEl-j2hq2w25" + }, + "source": [ + "## Introduction\n", + "\n", + "[TorchVision](https://pytorch.org/vision/stable/index.html) is a library part of the [PyTorch](http://pytorch.org/) project that consists of popular datasets, model architectures, and common image transformations for computer vision. This example demonstrates how we can perform transfer learning for image classification using a pre-trained backbone model from TorchVision on the [Imagenette dataset](https://github.com/fastai/imagenette) using KerasCore. We will also demonstrate the compatibility of KerasCore with an input system consisting of [Torch Datasets and Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html).\n", + "\n", + "### References:\n", + "\n", + "- [Customizing what happens in `fit()` with PyTorch](https://keras.io/keras_core/guides/custom_train_step_in_torch/)\n", + "- [PyTorch Datasets and Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)\n", + "- [Transfer learning for Computer Vision using PyTorch](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)\n", + "\n", + "## Setup\n", + "\n", + "- We install the `main` branch of [KerasCore](https://github.com/keras-team/keras-core), this lets us use the latest feature merged in KerasCore.\n", + "- We also install [wandb-addons](https://github.com/soumik12345/wandb-addons), a library that hosts the backend-agnostic callbacks compatible with KerasCore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "r4rfNRPgiy9v", + "outputId": "ce5ba027-567a-4577-a638-ca8802ee1f84" + }, + "outputs": [], + "source": [ + "# install the `main` branch of KerasCore\n", + "!pip install -qq namex\n", + "!apt install python3.10-venv\n", + "!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install\n", + "\n", + "# install wandb-addons\n", + "!pip install -qq git+https://github.com/soumik12345/wandb-addons" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7nudAUt8jHRB" + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"KERAS_BACKEND\"] = \"torch\"\n", + "\n", + "import numpy as np\n", + "from tqdm.auto import tqdm\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torchvision\n", + "from torchvision import datasets, models, transforms\n", + "\n", + "import keras_core as keras\n", + "from keras_core.utils import TorchModuleWrapper\n", + "\n", + "import wandb\n", + "from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pS1c-ySo7nty" + }, + "source": [ + "## Define the Hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2ovtXSUA7ksk", + "outputId": "725bef1a-0e68-473e-8c24-1f1ffd28506c" + }, + "outputs": [], + "source": [ + "wandb.init(project=\"keras-torch\", entity=\"ml-colabs\", job_type=\"torchvision/train\")\n", + "\n", + "config = wandb.config\n", + "config.batch_size = 32\n", + "config.image_size = 224\n", + "config.freeze_backbone = True\n", + "config.initial_learning_rate = 1e-3\n", + "config.num_epochs = 5" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b5uU_Q_H74GO" + }, + "source": [ + "## Creating the Torch Datasets and Dataloaders\n", + "\n", + "In this example, we would train an image classification model on the [Imagenette dataset](https://github.com/fastai/imagenette). Imagenette is a subset of 10 easily classified classes from [Imagenet](https://www.image-net.org/) (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gWV2hNGo8vW7", + "outputId": "cc70b598-fc66-480a-a98c-7077fa634a22" + }, + "outputs": [], + "source": [ + "# Fetch the imagenette dataset\n", + "data_dir = keras.utils.get_file(\n", + " fname=\"imagenette2-320.tgz\",\n", + " origin=\"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz\",\n", + " extract=True,\n", + ")\n", + "data_dir = data_dir.replace(\".tgz\", \"\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ka7TDUMn9IcG" + }, + "source": [ + "Next, we define pre-processing and augmentation transforms from TorchVision for the train and validation sets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rFZKJafF9H6y" + }, + "outputs": [], + "source": [ + "data_transforms = {\n", + " 'train': transforms.Compose([\n", + " transforms.RandomResizedCrop(config.image_size),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", + " ]),\n", + " 'val': transforms.Compose([\n", + " transforms.Resize(256),\n", + " transforms.CenterCrop(config.image_size),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", + " ]),\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aE3VFQHm9srl" + }, + "source": [ + "Finally, we will use TorchVision and the [`torch.utils.data`](https://pytorch.org/docs/stable/data.html) packages for creating the dataloaders for trainig and validation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0N81UNjtjMZ4", + "outputId": "bfb2af1a-977c-4408-cb1c-8da78d03d13d" + }, + "outputs": [], + "source": [ + "# Define the train and validation datasets\n", + "image_datasets = {\n", + " x: datasets.ImageFolder(\n", + " os.path.join(data_dir, x), data_transforms[x]\n", + " )\n", + " for x in ['train', 'val']\n", + "}\n", + "\n", + "# Define the torch dataloaders corresponding to the\n", + "# train and validation dataset\n", + "dataloaders = {\n", + " x: torch.utils.data.DataLoader(\n", + " image_datasets[x],\n", + " batch_size=config.batch_size,\n", + " shuffle=True,\n", + " num_workers=4\n", + " )\n", + " for x in ['train', 'val']\n", + "}\n", + "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n", + "class_names = image_datasets['train'].classes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AY6kOIL--EdP" + }, + "source": [ + "Let us visualize a few samples from the training dataloader." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 829 + }, + "id": "yffdD4LxjOQG", + "outputId": "38e6f182-11b1-4830-e5ec-a5588c7bdbf9" + }, + "outputs": [], + "source": [ + "plt.figure(figsize=(10, 10))\n", + "sample_images, sample_labels = next(iter(dataloaders['train']))\n", + "sample_images = sample_images.numpy()\n", + "sample_labels = sample_labels.numpy()\n", + "for idx in range(9):\n", + " ax = plt.subplot(3, 3, idx + 1)\n", + " image = sample_images[idx].transpose((1, 2, 0))\n", + " mean = np.array([0.485, 0.456, 0.406])\n", + " std = np.array([0.229, 0.224, 0.225])\n", + " image = std * image + mean\n", + " image = np.clip(image, 0, 1)\n", + " plt.imshow(image)\n", + " plt.title(\"Ground Truth Label: \" + class_names[int(sample_labels[idx])])\n", + " plt.axis(\"off\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0zvfzD04-ce9" + }, + "source": [ + "## The Image Classification Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZBLzXVwk-mLP" + }, + "source": [ + "We typically define a model in PyTorch using [`torch.nn.Module`s](https://pytorch.org/docs/stable/notes/modules.html) which act as the building blocks of stateful computation. Let us define the ResNet18 model from the TorchVision package as a `torch.nn.Module` pre-trained on the [Imagenet1K dataset](https://huggingface.co/datasets/imagenet-1k)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tOGUiI9K_BRk", + "outputId": "97aa387c-7b3d-41cf-959b-de60eb572912" + }, + "outputs": [], + "source": [ + "# Define the pre-trained resnet18 module from TorchVision\n", + "resnet_18 = models.resnet18(weights='IMAGENET1K_V1')\n", + "\n", + "# We set the classification head of the pre-trained ResNet18\n", + "# module to an identity module\n", + "resnet_18.fc = nn.Identity()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "afxbMwcYF-Yz" + }, + "source": [ + "ven though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a [`keras_core.Model`](https://keras.io/keras_core/api/models/), because trainable variables inside a Keras Model is tracked exclusively via [Keras Layers](https://keras.io/keras_core/api/layers/).\n", + "\n", + "KerasCore provides us with a feature called `TorchModuleWrapper` which enables us to do exactly this. The `TorchModuleWrapper` is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single `model.fit()`!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JLuCIAy5F6L1" + }, + "outputs": [], + "source": [ + "# We set the trainable ResNet18 backbone to be a Keras Layer\n", + "# using `TorchModuleWrapper`\n", + "backbone = TorchModuleWrapper(resnet_18)\n", + "\n", + "# We set this to `False` if you want to freeze the backbone\n", + "backbone.trainable = config.freeze_backbone" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7y28txoVHHk8" + }, + "source": [ + "Now, we will build a Keras functional model with the backbone layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 314 + }, + "id": "l2rxqA8vjR3W", + "outputId": "206ad321-5fb5-41e1-a391-37cef3b10edb" + }, + "outputs": [], + "source": [ + "inputs = keras.Input(shape=(3, config.image_size, config.image_size))\n", + "x = backbone(inputs)\n", + "x = keras.layers.Dropout(0.5)(x)\n", + "x = keras.layers.Dense(len(class_names))(x)\n", + "outputs = keras.activations.softmax(x, axis=1)\n", + "model = keras.Model(inputs, outputs, name=\"ResNet18_Classifier\")\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BXbvYzDnjyDQ", + "outputId": "9d7e51b2-85e7-4717-993b-0f6efac2999d" + }, + "outputs": [], + "source": [ + "# Create exponential decay learning rate scheduler\n", + "decay_steps = config.num_epochs * len(dataloaders[\"train\"]) // config.batch_size\n", + "lr_scheduler = keras.optimizers.schedules.ExponentialDecay(\n", + " initial_learning_rate=config.initial_learning_rate,\n", + " decay_steps=decay_steps,\n", + " decay_rate=0.1,\n", + ")\n", + "\n", + "# Compile the model\n", + "model.compile(\n", + " loss=\"sparse_categorical_crossentropy\",\n", + " optimizer=keras.optimizers.Adam(lr_scheduler),\n", + " metrics=[\"accuracy\"],\n", + ")\n", + "\n", + "# Define the backend-agnostic WandB callbacks for KerasCore\n", + "callbacks = [\n", + " # Track experiment metrics with WandB\n", + " WandbMetricsLogger(log_freq=\"batch\"),\n", + " # Save best model checkpoints to WandB\n", + " WandbModelCheckpoint(\n", + " filepath=\"model.weights.h5\",\n", + " monitor=\"val_loss\",\n", + " save_best_only=True,\n", + " save_weights_only=True,\n", + " )\n", + "]\n", + "\n", + "# Train the model by calling model.fit\n", + "history = model.fit(\n", + " dataloaders[\"train\"],\n", + " validation_data=dataloaders[\"val\"],\n", + " epochs=config.num_epochs,\n", + " callbacks=callbacks,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "upJf2M92JBlD", + "outputId": "6993ffbe-2b69-4ff8-9a2f-9c502ca1414d" + }, + "outputs": [], + "source": [ + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I58qvNnzJmHD" + }, + "source": [ + "## Evaluation and Inference\n", + "\n", + "Now, we let us load the best model weights checkpoint and evaluate the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cNpfuELqnDI_", + "outputId": "18bfe7bb-adea-4ac2-807b-924dce6fbcc8" + }, + "outputs": [], + "source": [ + "wandb.init(\n", + " project=\"keras-torch\", entity=\"ml-colabs\", job_type=\"torchvision/eval\"\n", + ")\n", + "artifact = wandb.use_artifact(\n", + " 'ml-colabs/keras-torch/run_hiceci7f_model:latest', type='model'\n", + ")\n", + "artifact_dir = artifact.download()\n", + "\n", + "model.load_weights(os.path.join(artifact_dir, \"model.weights.h5\"))\n", + "\n", + "_, val_accuracy = model.evaluate(dataloaders[\"val\"])\n", + "wandb.log({\"Validation-Accuracy\": val_accuracy})" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vE2vAvBNKAI9" + }, + "source": [ + "Finally, let us visualize the some predictions of the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "1b82953b5f134926ae8a11c4fedca385", + "639fec0083134ab18b6a22999203536e", + "5d85942dc6e44640aed593b7a8494493", + "757ca7f811ec4db5a4dd517c5aec2bb8", + "122edd2a0300448cac1fcc8645fd8708", + "b1df464c3a2d47dbb49b36fa5d9912bb", + "dd95b1ba48ec4bc5b494f319ad41eedf", + "84dfa1f6d6f04c50b87e8683a0abdf55", + "634e163cb1354824a764dc49f7d9f2fa", + "5b1e8c0d14d84e0f923a997727794c89", + "2e912ebb924c4269a07fb7c1204eb923" + ] + }, + "id": "ugrP307SpxMj", + "outputId": "ead193f0-bfd9-4bbc-c675-d741d64fd70f" + }, + "outputs": [], + "source": [ + "table = wandb.Table(\n", + " columns=[\n", + " \"Image\", \"Ground-Truth\", \"Prediction\"\n", + " ] + [\"Confidence-\" + cls for cls in class_names]\n", + ")\n", + "\n", + "sample_images, sample_labels = next(iter(dataloaders['train']))\n", + "\n", + "# We perform inference and detach the predicted probabilities from the Torch\n", + "# computation graph with a tensor that does not require gradient computation.\n", + "sample_pred_probas = model(sample_images.to(\"cuda\")).detach()\n", + "sample_pred_logits = keras.ops.argmax(sample_pred_probas, axis=1)\n", + "sample_pred_logits = sample_pred_logits.to(\"cpu\").numpy()\n", + "sample_pred_probas = sample_pred_probas.to(\"cpu\").numpy()\n", + "\n", + "sample_images = sample_images.numpy()\n", + "sample_labels = sample_labels.numpy()\n", + "\n", + "for idx in tqdm(range(sample_images.shape[0])):\n", + " image = sample_images[idx].transpose((1, 2, 0))\n", + " mean = np.array([0.485, 0.456, 0.406])\n", + " std = np.array([0.229, 0.224, 0.225])\n", + " image = std * image + mean\n", + " image = np.clip(image, 0, 1)\n", + " table.add_data(\n", + " wandb.Image(image),\n", + " class_names[int(sample_labels[idx])],\n", + " class_names[int(sample_pred_logits[idx])],\n", + " *sample_pred_probas[idx].tolist(),\n", + " )\n", + "\n", + "wandb.log({\"Evaluation-Table\": table})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a90XmoRR65SJ", + "outputId": "88e5b0b1-a3db-4366-ba71-f6a21a877676" + }, + "outputs": [], + "source": [ + "wandb.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QA6ytgUaSxsS" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "122edd2a0300448cac1fcc8645fd8708": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1b82953b5f134926ae8a11c4fedca385": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_639fec0083134ab18b6a22999203536e", + "IPY_MODEL_5d85942dc6e44640aed593b7a8494493", + "IPY_MODEL_757ca7f811ec4db5a4dd517c5aec2bb8" + ], + "layout": "IPY_MODEL_122edd2a0300448cac1fcc8645fd8708" + } + }, + "2e912ebb924c4269a07fb7c1204eb923": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5b1e8c0d14d84e0f923a997727794c89": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5d85942dc6e44640aed593b7a8494493": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_84dfa1f6d6f04c50b87e8683a0abdf55", + "max": 32, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_634e163cb1354824a764dc49f7d9f2fa", + "value": 32 + } + }, + "634e163cb1354824a764dc49f7d9f2fa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "639fec0083134ab18b6a22999203536e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b1df464c3a2d47dbb49b36fa5d9912bb", + "placeholder": "​", + "style": "IPY_MODEL_dd95b1ba48ec4bc5b494f319ad41eedf", + "value": "100%" + } + }, + "757ca7f811ec4db5a4dd517c5aec2bb8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5b1e8c0d14d84e0f923a997727794c89", + "placeholder": "​", + "style": "IPY_MODEL_2e912ebb924c4269a07fb7c1204eb923", + "value": " 32/32 [00:01<00:00, 17.84it/s]" + } + }, + "84dfa1f6d6f04c50b87e8683a0abdf55": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b1df464c3a2d47dbb49b36fa5d9912bb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dd95b1ba48ec4bc5b494f319ad41eedf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}