From 920a05447a3c8673f866a51b7f9bd85f17add2c3 Mon Sep 17 00:00:00 2001
From: Soumik Rakshit <19soumik.rakshit96@gmail.com>
Date: Fri, 10 Nov 2023 14:38:53 +0530
Subject: [PATCH] Add example to fine-tune torchvision model with Keras (#457)
* add: example to fine-tune torchvision model with Keras
* add: colab badge
* update: wandb.init
* update: torchvision + keras notebook
* update: timm + keras notebook
* fix: timm_keras notebook
* add: keras + monai example
* update: colab links
* update: torchvision + keras notebook
* update: torchvision + keras notebook
* update: torchvision + keras notebook
---
colabs/README.md | 4 +-
.../keras_core/monai_medmnist_keras.ipynb | 477 +++++++++
colabs/keras/keras_core/timm_keras.ipynb | 377 +++++++
.../keras/keras_core/torchvision-keras.ipynb | 922 ++++++++++++++++++
4 files changed, 1779 insertions(+), 1 deletion(-)
create mode 100644 colabs/keras/keras_core/monai_medmnist_keras.ipynb
create mode 100644 colabs/keras/keras_core/timm_keras.ipynb
create mode 100644 colabs/keras/keras_core/torchvision-keras.ipynb
diff --git a/colabs/README.md b/colabs/README.md
index 7913dc56..141c049b 100644
--- a/colabs/README.md
+++ b/colabs/README.md
@@ -26,7 +26,9 @@
| Kaolin-Wisp | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://wandb.me/vqad-colab) |
| Super Gradients | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://wandb.me/yolo-nas-colab) |
| 🎸 Generating music with AudioCraft | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/audiocraft/AudioCraft.ipynb) |
-
+| 🦄 Fine-tune a Torchvision Model with KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/torchvision_keras.ipynb) |
+| 🦄 Fine-tune a Timm Model with KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/timm_keras.ipynb) |
+| 🦄 Medical Image Classification Tutorial using MonAI and KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/monai_medmnist_keras.ipynb) |
# 🏋🏽♂️ W&B Features
diff --git a/colabs/keras/keras_core/monai_medmnist_keras.ipynb b/colabs/keras/keras_core/monai_medmnist_keras.ipynb
new file mode 100644
index 00000000..747baa6f
--- /dev/null
+++ b/colabs/keras/keras_core/monai_medmnist_keras.ipynb
@@ -0,0 +1,477 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dMq6cFtJl2vR"
+ },
+ "source": [
+ "
\n",
+ "
\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2eqSE_8rl6yo"
+ },
+ "source": [
+ "# 🩺 Medical Image Classification Tutorial using MonAI and Keras\n",
+ "\n",
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/monai_medmnist_keras.ipynb)\n",
+ "\n",
+ "This notebook demonstrates\n",
+ "- an end-to-end training using [MonAI](https://github.com/Project-MONAI/MONAI) and [KerasCore](https://github.com/keras-team/keras-core).\n",
+ "- how we can use the backend-agnostic Keras callbacks for [Weights & Biases](https://wandb.ai/site) to manage and track our experiment.\n",
+ "\n",
+ "Original Notebook: https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/mednist_tutorial.ipynb"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ny_HOlvymX6O"
+ },
+ "source": [
+ "## Installing and Importing the Dependencies"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6SvDmUxVmatW"
+ },
+ "source": [
+ "- We install the `main` branch of [KerasCore](https://github.com/keras-team/keras-core), this lets us use the latest feature merged in KerasCore.\n",
+ "- We install [monai](https://github.com/Project-MONAI/MONAI), a PyTorch-based, open-source framework for deep learning in healthcare imaging, part of the PyTorch Ecosystem.\n",
+ "- We also install [wandb-addons](https://github.com/soumik12345/wandb-addons), a library that hosts the backend-agnostic callbacks compatible with KerasCore"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZO1xjFtvVdqQ"
+ },
+ "outputs": [],
+ "source": [
+ "# install the `main` branch of KerasCore\n",
+ "!pip install -qq namex\n",
+ "!apt install python3.10-venv\n",
+ "!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install\n",
+ "\n",
+ "# install monai and wandb-addons\n",
+ "!pip install -qq git+https://github.com/soumik12345/wandb-addons\n",
+ "!pip install -q \"monai-weekly[pillow, tqdm]\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "c6vd3NZ-mhxs"
+ },
+ "source": [
+ "We specify the Keras backend to be using `torch` by explicitly specifying the environment variable `KERAS_BACKEND`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "d5ZiQmMkW-h3"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"torch\"\n",
+ "\n",
+ "import shutil\n",
+ "import tempfile\n",
+ "import matplotlib.pyplot as plt\n",
+ "import PIL\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "from sklearn.metrics import classification_report\n",
+ "\n",
+ "import keras_core as keras\n",
+ "from keras_core.utils import TorchModuleWrapper\n",
+ "\n",
+ "from monai.apps import download_and_extract\n",
+ "from monai.config import print_config\n",
+ "from monai.data import decollate_batch, DataLoader\n",
+ "from monai.metrics import ROCAUCMetric\n",
+ "from monai.networks.nets import DenseNet121\n",
+ "from monai.transforms import (\n",
+ " Activations,\n",
+ " EnsureChannelFirst,\n",
+ " AsDiscrete,\n",
+ " Compose,\n",
+ " LoadImage,\n",
+ " RandFlip,\n",
+ " RandRotate,\n",
+ " RandZoom,\n",
+ " ScaleIntensity,\n",
+ ")\n",
+ "from monai.utils import set_determinism\n",
+ "\n",
+ "import wandb\n",
+ "from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1BvtLvxKmkSp"
+ },
+ "source": [
+ "We initialize a [wandb run](https://docs.wandb.ai/guides/runs) and set the configs for the experiment."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WUt7PAsado4j"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.init(project=\"keras-torch\")\n",
+ "\n",
+ "config = wandb.config\n",
+ "config.batch_size = 128\n",
+ "config.num_epochs = 1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NE938skxmoMI"
+ },
+ "source": [
+ "## Setup data directory\n",
+ "\n",
+ "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.\n",
+ "This allows you to save results and reuse downloads.\n",
+ "If not specified a temporary directory will be used."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "82KX1Sj6XXY1"
+ },
+ "outputs": [],
+ "source": [
+ "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
+ "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
+ "print(root_dir)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nkVQ-tojmzwf"
+ },
+ "source": [
+ "## Download dataset\n",
+ "\n",
+ "The MedNIST dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions),\n",
+ "[the RSNA Bone Age Challenge](http://rsnachallenges.cloudapp.net/competitions/4),\n",
+ "and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).\n",
+ "\n",
+ "The dataset is kindly made available by [Dr. Bradley J. Erickson M.D., Ph.D.](https://www.mayo.edu/research/labs/radiology-informatics/overview) (Department of Radiology, Mayo Clinic)\n",
+ "under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/).\n",
+ "\n",
+ "If you use the MedNIST dataset, please acknowledge the source."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DAHybyvdXZoH"
+ },
+ "outputs": [],
+ "source": [
+ "resource = \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz\"\n",
+ "md5 = \"0bc7306e7427e00ad1c5526a6677552d\"\n",
+ "\n",
+ "compressed_file = os.path.join(root_dir, \"MedNIST.tar.gz\")\n",
+ "data_dir = os.path.join(root_dir, \"MedNIST\")\n",
+ "if not os.path.exists(data_dir):\n",
+ " download_and_extract(resource, compressed_file, root_dir, md5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "u6m2Uas1nMi8"
+ },
+ "source": [
+ "## Read image filenames from the dataset folders\n",
+ "\n",
+ "First of all, check the dataset files and show some statistics. \n",
+ "There are 6 folders in the dataset: Hand, AbdomenCT, CXR, ChestCT, BreastMRI, HeadCT, \n",
+ "which should be used as the labels to train our classification model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qtLr1T0gXydq"
+ },
+ "outputs": [],
+ "source": [
+ "class_names = sorted(x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))\n",
+ "num_class = len(class_names)\n",
+ "image_files = [\n",
+ " [os.path.join(data_dir, class_names[i], x) for x in os.listdir(os.path.join(data_dir, class_names[i]))]\n",
+ " for i in range(num_class)\n",
+ "]\n",
+ "num_each = [len(image_files[i]) for i in range(num_class)]\n",
+ "image_files_list = []\n",
+ "image_class = []\n",
+ "for i in range(num_class):\n",
+ " image_files_list.extend(image_files[i])\n",
+ " image_class.extend([i] * num_each[i])\n",
+ "num_total = len(image_class)\n",
+ "image_width, image_height = PIL.Image.open(image_files_list[0]).size\n",
+ "\n",
+ "print(f\"Total image count: {num_total}\")\n",
+ "print(f\"Image dimensions: {image_width} x {image_height}\")\n",
+ "print(f\"Label names: {class_names}\")\n",
+ "print(f\"Label counts: {num_each}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "k4WOUb4IX6KK"
+ },
+ "outputs": [],
+ "source": [
+ "plt.subplots(3, 3, figsize=(8, 8))\n",
+ "for i, k in enumerate(np.random.randint(num_total, size=9)):\n",
+ " im = PIL.Image.open(image_files_list[k])\n",
+ " arr = np.array(im)\n",
+ " plt.subplot(3, 3, i + 1)\n",
+ " plt.xlabel(class_names[image_class[k]])\n",
+ " plt.imshow(arr, cmap=\"gray\", vmin=0, vmax=255)\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wOGja_mEnQ5m"
+ },
+ "source": [
+ "## Prepare training, validation and test data lists\n",
+ "\n",
+ "Randomly select 10% of the dataset as validation and 10% as test."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BmHhKta1X8ZM"
+ },
+ "outputs": [],
+ "source": [
+ "val_frac = 0.1\n",
+ "test_frac = 0.1\n",
+ "length = len(image_files_list)\n",
+ "indices = np.arange(length)\n",
+ "np.random.shuffle(indices)\n",
+ "\n",
+ "test_split = int(test_frac * length)\n",
+ "val_split = int(val_frac * length) + test_split\n",
+ "test_indices = indices[:test_split]\n",
+ "val_indices = indices[test_split:val_split]\n",
+ "train_indices = indices[val_split:]\n",
+ "\n",
+ "train_x = [image_files_list[i] for i in train_indices]\n",
+ "train_y = [image_class[i] for i in train_indices]\n",
+ "val_x = [image_files_list[i] for i in val_indices]\n",
+ "val_y = [image_class[i] for i in val_indices]\n",
+ "test_x = [image_files_list[i] for i in test_indices]\n",
+ "test_y = [image_class[i] for i in test_indices]\n",
+ "\n",
+ "print(f\"Training count: {len(train_x)}, Validation count: \" f\"{len(val_x)}, Test count: {len(test_x)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Lkhtr3p8nT1G"
+ },
+ "source": [
+ "## Define MONAI transforms, Dataset and Dataloader to pre-process data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2b5H2WuUYLsc"
+ },
+ "outputs": [],
+ "source": [
+ "train_transforms = Compose(\n",
+ " [\n",
+ " LoadImage(image_only=True),\n",
+ " EnsureChannelFirst(),\n",
+ " ScaleIntensity(),\n",
+ " RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),\n",
+ " RandFlip(spatial_axis=0, prob=0.5),\n",
+ " RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])\n",
+ "\n",
+ "y_pred_trans = Compose([Activations(softmax=True)])\n",
+ "y_trans = Compose([AsDiscrete(to_onehot=num_class)])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QdFnaFUDYOaq"
+ },
+ "outputs": [],
+ "source": [
+ "class MedNISTDataset(torch.utils.data.Dataset):\n",
+ " def __init__(self, image_files, labels, transforms):\n",
+ " self.image_files = image_files\n",
+ " self.labels = labels\n",
+ " self.transforms = transforms\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.image_files)\n",
+ "\n",
+ " def __getitem__(self, index):\n",
+ " return self.transforms(self.image_files[index]), self.labels[index]\n",
+ "\n",
+ "\n",
+ "train_ds = MedNISTDataset(train_x, train_y, train_transforms)\n",
+ "train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, num_workers=2)\n",
+ "\n",
+ "val_ds = MedNISTDataset(val_x, val_y, val_transforms)\n",
+ "val_loader = DataLoader(val_ds, batch_size=config.batch_size, num_workers=2)\n",
+ "\n",
+ "test_ds = MedNISTDataset(test_x, test_y, val_transforms)\n",
+ "test_loader = DataLoader(test_ds, batch_size=config.batch_size, num_workers=2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FIhKqMrVnZV9"
+ },
+ "source": [
+ "We typically define a model in PyTorch using [`torch.nn.Module`s](https://pytorch.org/docs/stable/notes/modules.html) which act as the building blocks of stateful computation. Even though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a [`keras_core.Model`](https://keras.io/keras_core/api/models/), because trainable variables inside a Keras Model is tracked exclusively via [Keras Layers](https://keras.io/keras_core/api/layers/).\n",
+ "\n",
+ "KerasCore provides us with a feature called `TorchModuleWrapper` which enables us to do exactly this. The `TorchModuleWrapper` is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single `model.fit()`!\n",
+ "\n",
+ "The idea of the `TorchModuleWrapper` was proposed by Keras' creator [François Chollet](https://github.com/fchollet) on [this issue thread](https://github.com/keras-team/keras-core/issues/604)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7zbinqy4ZsEy"
+ },
+ "outputs": [],
+ "source": [
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ "inputs = keras.Input(shape=(1, 64, 64))\n",
+ "outputs = TorchModuleWrapper(\n",
+ " DenseNet121(\n",
+ " spatial_dims=2, in_channels=1, out_channels=num_class\n",
+ " )\n",
+ ")(inputs)\n",
+ "model = keras.Model(inputs, outputs)\n",
+ "\n",
+ "# model = MedMnistModel()\n",
+ "model(next(iter(train_loader))[0].to(device)).shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "opjBI87nneYU"
+ },
+ "source": [
+ "**Note:** It is actually possible to use torch modules inside a Keras Model without having to explicitly have them wrapped with the `TorchModuleWrapper` as evident by [this tweet](https://twitter.com/fchollet/status/1697381832164290754) from François Chollet. However, this doesn't seem to work at the point of time this example was created, as reported in [this issue](https://github.com/keras-team/keras-core/issues/834)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "y3H-cUKraRDh"
+ },
+ "outputs": [],
+ "source": [
+ "# Compile the model\n",
+ "model.compile(\n",
+ " loss=\"sparse_categorical_crossentropy\",\n",
+ " optimizer=keras.optimizers.Adam(1e-5),\n",
+ " metrics=[\"accuracy\"],\n",
+ ")\n",
+ "\n",
+ "# Define the backend-agnostic WandB callbacks for KerasCore\n",
+ "callbacks = [\n",
+ " # Track experiment metrics\n",
+ " WandbMetricsLogger(log_freq=\"batch\")\n",
+ "]\n",
+ "\n",
+ "# Train the model by calling model.fit\n",
+ "model.fit(\n",
+ " train_loader,\n",
+ " validation_data=val_loader,\n",
+ " epochs=config.num_epochs,\n",
+ " callbacks=callbacks,\n",
+ ")\n",
+ "\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Q0E-kiGBeCnZ"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "V100",
+ "private_outputs": true,
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/colabs/keras/keras_core/timm_keras.ipynb b/colabs/keras/keras_core/timm_keras.ipynb
new file mode 100644
index 00000000..1cea4769
--- /dev/null
+++ b/colabs/keras/keras_core/timm_keras.ipynb
@@ -0,0 +1,377 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "
\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 🔥 Fine-tune a [Timm](https://huggingface.co/docs/timm/index) Model with Keras and WandB 🦄\n",
+ "\n",
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/timm_keras.ipynb)\n",
+ "\n",
+ "This notebook demonstrates\n",
+ "- how we can fine-tune a pre-trained model from timm using [KerasCore](https://github.com/keras-team/keras-core).\n",
+ "- how we can use the backend-agnostic Keras callbacks for [Weights & Biases](https://wandb.ai/site) to manage and track our experiment."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Installing and Importing the Dependencies"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- We install the `main` branch of [KerasCore](https://github.com/keras-team/keras-core), this lets us use the latest feature merged in KerasCore.\n",
+ "- We install [timm](https://huggingface.co/docs/timm/index), a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts.\n",
+ "- We also install [wandb-addons](https://github.com/soumik12345/wandb-addons), a library that hosts the backend-agnostic callbacks compatible with KerasCore"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# install the `main` branch of KerasCore\n",
+ "!pip install -qq namex\n",
+ "!apt install python3.10-venv\n",
+ "!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install\n",
+ "\n",
+ "# install timm and wandb-addons\n",
+ "!pip install -qq git+https://github.com/soumik12345/wandb-addons"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We specify the Keras backend to be using `torch` by explicitly specifying the environment variable `KERAS_BACKEND`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"torch\"\n",
+ "\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "import timm\n",
+ "from timm.data import resolve_data_config\n",
+ "\n",
+ "import torchvision\n",
+ "from torchvision import datasets, models, transforms\n",
+ "from torchvision.transforms.functional import InterpolationMode\n",
+ "\n",
+ "import wandb\n",
+ "from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We initialize a [wandb run](https://docs.wandb.ai/guides/runs) and set the configs for the experiment."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.init(project=\"keras-torch\")\n",
+ "\n",
+ "config = wandb.config\n",
+ "config.model_name = \"xception41\"\n",
+ "config.freeze_backbone = False\n",
+ "config.preprocess_config = resolve_data_config({}, model=config.model_name)\n",
+ "config.dropout_rate = 0.5\n",
+ "config.batch_size = 4\n",
+ "config.num_epochs = 25"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## A PyTorch-based Input Pipeline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will be using the [ImageNette](https://github.com/fastai/imagenette) dataset for this experiment. Imagenette is a subset of 10 easily classified classes from [Imagenet](https://www.image-net.org/) (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).\n",
+ "\n",
+ "First, let's download this dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz -P imagenette\n",
+ "!tar zxf imagenette/imagenette2-320.tgz -C imagenette\n",
+ "!gzip -d imagenette/imagenette2-320.tgz"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we create our standard torch-based data loading pipeline."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define pre-processing and augmentation transforms for the train and validation sets\n",
+ "data_transforms = {\n",
+ " 'train': transforms.Compose([\n",
+ " transforms.RandomResizedCrop(\n",
+ " size=config.preprocess_config[\"input_size\"][1],\n",
+ " interpolation=InterpolationMode.BICUBIC,\n",
+ " ),\n",
+ " transforms.RandomHorizontalFlip(),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize(\n",
+ " config.preprocess_config[\"mean\"],\n",
+ " config.preprocess_config[\"std\"]\n",
+ " )\n",
+ " ]),\n",
+ " 'val': transforms.Compose([\n",
+ " transforms.Resize(256),\n",
+ " transforms.CenterCrop(config.preprocess_config[\"input_size\"][1]),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize(\n",
+ " config.preprocess_config[\"mean\"],\n",
+ " config.preprocess_config[\"std\"]\n",
+ " )\n",
+ " ]),\n",
+ "}\n",
+ "\n",
+ "# Define the train and validation datasets\n",
+ "data_dir = 'imagenette/imagenette2-320'\n",
+ "image_datasets = {\n",
+ " x: datasets.ImageFolder(\n",
+ " os.path.join(data_dir, x), data_transforms[x]\n",
+ " )\n",
+ " for x in ['train', 'val']\n",
+ "}\n",
+ "\n",
+ "# Define the torch dataloaders corresponding to the train and validation dataset\n",
+ "dataloaders = {\n",
+ " x: torch.utils.data.DataLoader(\n",
+ " image_datasets[x],\n",
+ " batch_size=config.batch_size,\n",
+ " shuffle=True,\n",
+ " num_workers=4\n",
+ " )\n",
+ " for x in ['train', 'val']\n",
+ "}\n",
+ "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n",
+ "class_names = image_datasets['train'].classes\n",
+ "\n",
+ "# Specify the global device\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's take a look at a few of the samples."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def imshow(inp, title=None):\n",
+ " \"\"\"Display image for Tensor.\"\"\"\n",
+ " inp = inp.numpy().transpose((1, 2, 0))\n",
+ " mean = np.array(config.preprocess_config[\"mean\"])\n",
+ " std = np.array(config.preprocess_config[\"std\"])\n",
+ " inp = std * inp + mean\n",
+ " inp = np.clip(inp, 0, 1)\n",
+ " plt.imshow(inp)\n",
+ " if title is not None:\n",
+ " plt.title(title)\n",
+ " plt.pause(0.001)\n",
+ "\n",
+ "\n",
+ "# Get a batch of training data\n",
+ "inputs, classes = next(iter(dataloaders['train']))\n",
+ "print(inputs.shape, classes.shape)\n",
+ "\n",
+ "# Make a grid from batch\n",
+ "out = torchvision.utils.make_grid(inputs)\n",
+ "\n",
+ "imshow(out, title=[class_names[x] for x in classes])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating and Training our Classifier"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We typically define a model in PyTorch using [`torch.nn.Module`s](https://pytorch.org/docs/stable/notes/modules.html) which act as the building blocks of stateful computation. Even though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a [`keras_core.Model`](https://keras.io/keras_core/api/models/), because trainable variables inside a Keras Model is tracked exclusively via [Keras Layers](https://keras.io/keras_core/api/layers/).\n",
+ "\n",
+ "KerasCore provides us with a feature called `TorchModuleWrapper` which enables us to do exactly this. The `TorchModuleWrapper` is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single `model.fit()`!\n",
+ "\n",
+ "The idea of the `TorchModuleWrapper` was proposed by Keras' creator [François Chollet](https://github.com/fchollet) on [this issue thread](https://github.com/keras-team/keras-core/issues/604)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import keras_core as keras\n",
+ "from keras_core.utils import TorchModuleWrapper\n",
+ "\n",
+ "\n",
+ "class TimmClassifier(keras.Model):\n",
+ "\n",
+ " def __init__(self, model_name, freeze_backbone, dropout_rate, num_classes, *args, **kwargs):\n",
+ " super().__init__(*args, **kwargs)\n",
+ " \n",
+ " # Define the pre-trained module from timm\n",
+ " self.backbone = TorchModuleWrapper(\n",
+ " timm.create_model(model_name, pretrained=True)\n",
+ " )\n",
+ " self.backbone.trainable = not freeze_backbone\n",
+ " \n",
+ " # Build the classification head using keras layers\n",
+ " self.global_average_pooling = keras.layers.GlobalAveragePooling2D()\n",
+ " self.dropout = keras.layers.Dropout(dropout_rate)\n",
+ " self.classification_head = keras.layers.Dense(num_classes)\n",
+ "\n",
+ " def call(self, inputs):\n",
+ " # We get the unpooled features from the timm backbone by calling `forward_features`\n",
+ " # on the torch module corresponding to the backbone.\n",
+ " x = self.backbone.module.forward_features(inputs)\n",
+ " x = self.global_average_pooling(x)\n",
+ " x = self.dropout(x)\n",
+ " x = self.classification_head(x)\n",
+ " return keras.activations.softmax(x, axis=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Note:** It is actually possible to use torch modules inside a Keras Model without having to explicitly have them wrapped with the `TorchModuleWrapper` as evident by [this tweet](https://twitter.com/fchollet/status/1697381832164290754) from François Chollet. However, this doesn't seem to work at the point of time this example was created, as reported in [this issue](https://github.com/keras-team/keras-core/issues/834)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Now, we define the model and pass a random tensor to check the output shape\n",
+ "model = TimmClassifier(\n",
+ " model_name=config.model_name,\n",
+ " freeze_backbone=config.freeze_backbone,\n",
+ " dropout_rate=config.dropout_rate,\n",
+ " num_classes=len(class_names)\n",
+ ")\n",
+ "model(torch.ones(1, *config.preprocess_config[\"input_size\"]).to(device)).shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, in standard Keras fashion, all we need to do is compile the model and call `model.fit()`!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create exponential decay learning rate scheduler\n",
+ "decay_steps = config.num_epochs * len(dataloaders[\"train\"]) // config.batch_size\n",
+ "lr_scheduler = keras.optimizers.schedules.ExponentialDecay(\n",
+ " initial_learning_rate=1e-3, decay_steps=decay_steps, decay_rate=0.1,\n",
+ ")\n",
+ "\n",
+ "# Compile the model\n",
+ "model.compile(\n",
+ " loss=\"sparse_categorical_crossentropy\",\n",
+ " optimizer=keras.optimizers.Adam(lr_scheduler),\n",
+ " metrics=[\"accuracy\"],\n",
+ ")\n",
+ "\n",
+ "# Define the backend-agnostic WandB callbacks for KerasCore\n",
+ "callbacks = [\n",
+ " # Track experiment metrics\n",
+ " WandbMetricsLogger(log_freq=\"batch\"),\n",
+ " # Track and version model checkpoints\n",
+ " WandbModelCheckpoint(\"model.keras\")\n",
+ "]\n",
+ "\n",
+ "# Train the model by calling model.fit\n",
+ "model.fit(\n",
+ " dataloaders[\"train\"],\n",
+ " validation_data=dataloaders[\"val\"],\n",
+ " epochs=config.num_epochs,\n",
+ " callbacks=callbacks,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In order to know more about the backend-agnostic Keras callbacks for Weights & Biases, check out the [docs for wandb-addons](https://geekyrakshit.dev/wandb-addons/keras/keras_core/)."
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/colabs/keras/keras_core/torchvision-keras.ipynb b/colabs/keras/keras_core/torchvision-keras.ipynb
new file mode 100644
index 00000000..30783184
--- /dev/null
+++ b/colabs/keras/keras_core/torchvision-keras.ipynb
@@ -0,0 +1,922 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jLRorPqA1-h4"
+ },
+ "source": [
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "# 🔥 Fine-tune a TorchVision Model with Keras and WandB 🦄\n",
+ "\n",
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/torchvision_keras.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aEl-j2hq2w25"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "[TorchVision](https://pytorch.org/vision/stable/index.html) is a library part of the [PyTorch](http://pytorch.org/) project that consists of popular datasets, model architectures, and common image transformations for computer vision. This example demonstrates how we can perform transfer learning for image classification using a pre-trained backbone model from TorchVision on the [Imagenette dataset](https://github.com/fastai/imagenette) using KerasCore. We will also demonstrate the compatibility of KerasCore with an input system consisting of [Torch Datasets and Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html).\n",
+ "\n",
+ "### References:\n",
+ "\n",
+ "- [Customizing what happens in `fit()` with PyTorch](https://keras.io/keras_core/guides/custom_train_step_in_torch/)\n",
+ "- [PyTorch Datasets and Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)\n",
+ "- [Transfer learning for Computer Vision using PyTorch](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)\n",
+ "\n",
+ "## Setup\n",
+ "\n",
+ "- We install the `main` branch of [KerasCore](https://github.com/keras-team/keras-core), this lets us use the latest feature merged in KerasCore.\n",
+ "- We also install [wandb-addons](https://github.com/soumik12345/wandb-addons), a library that hosts the backend-agnostic callbacks compatible with KerasCore"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "r4rfNRPgiy9v",
+ "outputId": "ce5ba027-567a-4577-a638-ca8802ee1f84"
+ },
+ "outputs": [],
+ "source": [
+ "# install the `main` branch of KerasCore\n",
+ "!pip install -qq namex\n",
+ "!apt install python3.10-venv\n",
+ "!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install\n",
+ "\n",
+ "# install wandb-addons\n",
+ "!pip install -qq git+https://github.com/soumik12345/wandb-addons"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7nudAUt8jHRB"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"torch\"\n",
+ "\n",
+ "import numpy as np\n",
+ "from tqdm.auto import tqdm\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "import torchvision\n",
+ "from torchvision import datasets, models, transforms\n",
+ "\n",
+ "import keras_core as keras\n",
+ "from keras_core.utils import TorchModuleWrapper\n",
+ "\n",
+ "import wandb\n",
+ "from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pS1c-ySo7nty"
+ },
+ "source": [
+ "## Define the Hyperparameters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2ovtXSUA7ksk",
+ "outputId": "725bef1a-0e68-473e-8c24-1f1ffd28506c"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.init(project=\"keras-torch\", entity=\"ml-colabs\", job_type=\"torchvision/train\")\n",
+ "\n",
+ "config = wandb.config\n",
+ "config.batch_size = 32\n",
+ "config.image_size = 224\n",
+ "config.freeze_backbone = True\n",
+ "config.initial_learning_rate = 1e-3\n",
+ "config.num_epochs = 5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b5uU_Q_H74GO"
+ },
+ "source": [
+ "## Creating the Torch Datasets and Dataloaders\n",
+ "\n",
+ "In this example, we would train an image classification model on the [Imagenette dataset](https://github.com/fastai/imagenette). Imagenette is a subset of 10 easily classified classes from [Imagenet](https://www.image-net.org/) (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "gWV2hNGo8vW7",
+ "outputId": "cc70b598-fc66-480a-a98c-7077fa634a22"
+ },
+ "outputs": [],
+ "source": [
+ "# Fetch the imagenette dataset\n",
+ "data_dir = keras.utils.get_file(\n",
+ " fname=\"imagenette2-320.tgz\",\n",
+ " origin=\"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz\",\n",
+ " extract=True,\n",
+ ")\n",
+ "data_dir = data_dir.replace(\".tgz\", \"\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ka7TDUMn9IcG"
+ },
+ "source": [
+ "Next, we define pre-processing and augmentation transforms from TorchVision for the train and validation sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "rFZKJafF9H6y"
+ },
+ "outputs": [],
+ "source": [
+ "data_transforms = {\n",
+ " 'train': transforms.Compose([\n",
+ " transforms.RandomResizedCrop(config.image_size),\n",
+ " transforms.RandomHorizontalFlip(),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
+ " ]),\n",
+ " 'val': transforms.Compose([\n",
+ " transforms.Resize(256),\n",
+ " transforms.CenterCrop(config.image_size),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
+ " ]),\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aE3VFQHm9srl"
+ },
+ "source": [
+ "Finally, we will use TorchVision and the [`torch.utils.data`](https://pytorch.org/docs/stable/data.html) packages for creating the dataloaders for trainig and validation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0N81UNjtjMZ4",
+ "outputId": "bfb2af1a-977c-4408-cb1c-8da78d03d13d"
+ },
+ "outputs": [],
+ "source": [
+ "# Define the train and validation datasets\n",
+ "image_datasets = {\n",
+ " x: datasets.ImageFolder(\n",
+ " os.path.join(data_dir, x), data_transforms[x]\n",
+ " )\n",
+ " for x in ['train', 'val']\n",
+ "}\n",
+ "\n",
+ "# Define the torch dataloaders corresponding to the\n",
+ "# train and validation dataset\n",
+ "dataloaders = {\n",
+ " x: torch.utils.data.DataLoader(\n",
+ " image_datasets[x],\n",
+ " batch_size=config.batch_size,\n",
+ " shuffle=True,\n",
+ " num_workers=4\n",
+ " )\n",
+ " for x in ['train', 'val']\n",
+ "}\n",
+ "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n",
+ "class_names = image_datasets['train'].classes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AY6kOIL--EdP"
+ },
+ "source": [
+ "Let us visualize a few samples from the training dataloader."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 829
+ },
+ "id": "yffdD4LxjOQG",
+ "outputId": "38e6f182-11b1-4830-e5ec-a5588c7bdbf9"
+ },
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(10, 10))\n",
+ "sample_images, sample_labels = next(iter(dataloaders['train']))\n",
+ "sample_images = sample_images.numpy()\n",
+ "sample_labels = sample_labels.numpy()\n",
+ "for idx in range(9):\n",
+ " ax = plt.subplot(3, 3, idx + 1)\n",
+ " image = sample_images[idx].transpose((1, 2, 0))\n",
+ " mean = np.array([0.485, 0.456, 0.406])\n",
+ " std = np.array([0.229, 0.224, 0.225])\n",
+ " image = std * image + mean\n",
+ " image = np.clip(image, 0, 1)\n",
+ " plt.imshow(image)\n",
+ " plt.title(\"Ground Truth Label: \" + class_names[int(sample_labels[idx])])\n",
+ " plt.axis(\"off\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0zvfzD04-ce9"
+ },
+ "source": [
+ "## The Image Classification Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZBLzXVwk-mLP"
+ },
+ "source": [
+ "We typically define a model in PyTorch using [`torch.nn.Module`s](https://pytorch.org/docs/stable/notes/modules.html) which act as the building blocks of stateful computation. Let us define the ResNet18 model from the TorchVision package as a `torch.nn.Module` pre-trained on the [Imagenet1K dataset](https://huggingface.co/datasets/imagenet-1k)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "tOGUiI9K_BRk",
+ "outputId": "97aa387c-7b3d-41cf-959b-de60eb572912"
+ },
+ "outputs": [],
+ "source": [
+ "# Define the pre-trained resnet18 module from TorchVision\n",
+ "resnet_18 = models.resnet18(weights='IMAGENET1K_V1')\n",
+ "\n",
+ "# We set the classification head of the pre-trained ResNet18\n",
+ "# module to an identity module\n",
+ "resnet_18.fc = nn.Identity()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "afxbMwcYF-Yz"
+ },
+ "source": [
+ "ven though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a [`keras_core.Model`](https://keras.io/keras_core/api/models/), because trainable variables inside a Keras Model is tracked exclusively via [Keras Layers](https://keras.io/keras_core/api/layers/).\n",
+ "\n",
+ "KerasCore provides us with a feature called `TorchModuleWrapper` which enables us to do exactly this. The `TorchModuleWrapper` is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single `model.fit()`!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JLuCIAy5F6L1"
+ },
+ "outputs": [],
+ "source": [
+ "# We set the trainable ResNet18 backbone to be a Keras Layer\n",
+ "# using `TorchModuleWrapper`\n",
+ "backbone = TorchModuleWrapper(resnet_18)\n",
+ "\n",
+ "# We set this to `False` if you want to freeze the backbone\n",
+ "backbone.trainable = config.freeze_backbone"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7y28txoVHHk8"
+ },
+ "source": [
+ "Now, we will build a Keras functional model with the backbone layer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 314
+ },
+ "id": "l2rxqA8vjR3W",
+ "outputId": "206ad321-5fb5-41e1-a391-37cef3b10edb"
+ },
+ "outputs": [],
+ "source": [
+ "inputs = keras.Input(shape=(3, config.image_size, config.image_size))\n",
+ "x = backbone(inputs)\n",
+ "x = keras.layers.Dropout(0.5)(x)\n",
+ "x = keras.layers.Dense(len(class_names))(x)\n",
+ "outputs = keras.activations.softmax(x, axis=1)\n",
+ "model = keras.Model(inputs, outputs, name=\"ResNet18_Classifier\")\n",
+ "\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "BXbvYzDnjyDQ",
+ "outputId": "9d7e51b2-85e7-4717-993b-0f6efac2999d"
+ },
+ "outputs": [],
+ "source": [
+ "# Create exponential decay learning rate scheduler\n",
+ "decay_steps = config.num_epochs * len(dataloaders[\"train\"]) // config.batch_size\n",
+ "lr_scheduler = keras.optimizers.schedules.ExponentialDecay(\n",
+ " initial_learning_rate=config.initial_learning_rate,\n",
+ " decay_steps=decay_steps,\n",
+ " decay_rate=0.1,\n",
+ ")\n",
+ "\n",
+ "# Compile the model\n",
+ "model.compile(\n",
+ " loss=\"sparse_categorical_crossentropy\",\n",
+ " optimizer=keras.optimizers.Adam(lr_scheduler),\n",
+ " metrics=[\"accuracy\"],\n",
+ ")\n",
+ "\n",
+ "# Define the backend-agnostic WandB callbacks for KerasCore\n",
+ "callbacks = [\n",
+ " # Track experiment metrics with WandB\n",
+ " WandbMetricsLogger(log_freq=\"batch\"),\n",
+ " # Save best model checkpoints to WandB\n",
+ " WandbModelCheckpoint(\n",
+ " filepath=\"model.weights.h5\",\n",
+ " monitor=\"val_loss\",\n",
+ " save_best_only=True,\n",
+ " save_weights_only=True,\n",
+ " )\n",
+ "]\n",
+ "\n",
+ "# Train the model by calling model.fit\n",
+ "history = model.fit(\n",
+ " dataloaders[\"train\"],\n",
+ " validation_data=dataloaders[\"val\"],\n",
+ " epochs=config.num_epochs,\n",
+ " callbacks=callbacks,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "upJf2M92JBlD",
+ "outputId": "6993ffbe-2b69-4ff8-9a2f-9c502ca1414d"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "I58qvNnzJmHD"
+ },
+ "source": [
+ "## Evaluation and Inference\n",
+ "\n",
+ "Now, we let us load the best model weights checkpoint and evaluate the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "cNpfuELqnDI_",
+ "outputId": "18bfe7bb-adea-4ac2-807b-924dce6fbcc8"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.init(\n",
+ " project=\"keras-torch\", entity=\"ml-colabs\", job_type=\"torchvision/eval\"\n",
+ ")\n",
+ "artifact = wandb.use_artifact(\n",
+ " 'ml-colabs/keras-torch/run_hiceci7f_model:latest', type='model'\n",
+ ")\n",
+ "artifact_dir = artifact.download()\n",
+ "\n",
+ "model.load_weights(os.path.join(artifact_dir, \"model.weights.h5\"))\n",
+ "\n",
+ "_, val_accuracy = model.evaluate(dataloaders[\"val\"])\n",
+ "wandb.log({\"Validation-Accuracy\": val_accuracy})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vE2vAvBNKAI9"
+ },
+ "source": [
+ "Finally, let us visualize the some predictions of the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 49,
+ "referenced_widgets": [
+ "1b82953b5f134926ae8a11c4fedca385",
+ "639fec0083134ab18b6a22999203536e",
+ "5d85942dc6e44640aed593b7a8494493",
+ "757ca7f811ec4db5a4dd517c5aec2bb8",
+ "122edd2a0300448cac1fcc8645fd8708",
+ "b1df464c3a2d47dbb49b36fa5d9912bb",
+ "dd95b1ba48ec4bc5b494f319ad41eedf",
+ "84dfa1f6d6f04c50b87e8683a0abdf55",
+ "634e163cb1354824a764dc49f7d9f2fa",
+ "5b1e8c0d14d84e0f923a997727794c89",
+ "2e912ebb924c4269a07fb7c1204eb923"
+ ]
+ },
+ "id": "ugrP307SpxMj",
+ "outputId": "ead193f0-bfd9-4bbc-c675-d741d64fd70f"
+ },
+ "outputs": [],
+ "source": [
+ "table = wandb.Table(\n",
+ " columns=[\n",
+ " \"Image\", \"Ground-Truth\", \"Prediction\"\n",
+ " ] + [\"Confidence-\" + cls for cls in class_names]\n",
+ ")\n",
+ "\n",
+ "sample_images, sample_labels = next(iter(dataloaders['train']))\n",
+ "\n",
+ "# We perform inference and detach the predicted probabilities from the Torch\n",
+ "# computation graph with a tensor that does not require gradient computation.\n",
+ "sample_pred_probas = model(sample_images.to(\"cuda\")).detach()\n",
+ "sample_pred_logits = keras.ops.argmax(sample_pred_probas, axis=1)\n",
+ "sample_pred_logits = sample_pred_logits.to(\"cpu\").numpy()\n",
+ "sample_pred_probas = sample_pred_probas.to(\"cpu\").numpy()\n",
+ "\n",
+ "sample_images = sample_images.numpy()\n",
+ "sample_labels = sample_labels.numpy()\n",
+ "\n",
+ "for idx in tqdm(range(sample_images.shape[0])):\n",
+ " image = sample_images[idx].transpose((1, 2, 0))\n",
+ " mean = np.array([0.485, 0.456, 0.406])\n",
+ " std = np.array([0.229, 0.224, 0.225])\n",
+ " image = std * image + mean\n",
+ " image = np.clip(image, 0, 1)\n",
+ " table.add_data(\n",
+ " wandb.Image(image),\n",
+ " class_names[int(sample_labels[idx])],\n",
+ " class_names[int(sample_pred_logits[idx])],\n",
+ " *sample_pred_probas[idx].tolist(),\n",
+ " )\n",
+ "\n",
+ "wandb.log({\"Evaluation-Table\": table})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "a90XmoRR65SJ",
+ "outputId": "88e5b0b1-a3db-4366-ba71-f6a21a877676"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QA6ytgUaSxsS"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "122edd2a0300448cac1fcc8645fd8708": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1b82953b5f134926ae8a11c4fedca385": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_639fec0083134ab18b6a22999203536e",
+ "IPY_MODEL_5d85942dc6e44640aed593b7a8494493",
+ "IPY_MODEL_757ca7f811ec4db5a4dd517c5aec2bb8"
+ ],
+ "layout": "IPY_MODEL_122edd2a0300448cac1fcc8645fd8708"
+ }
+ },
+ "2e912ebb924c4269a07fb7c1204eb923": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "5b1e8c0d14d84e0f923a997727794c89": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5d85942dc6e44640aed593b7a8494493": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_84dfa1f6d6f04c50b87e8683a0abdf55",
+ "max": 32,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_634e163cb1354824a764dc49f7d9f2fa",
+ "value": 32
+ }
+ },
+ "634e163cb1354824a764dc49f7d9f2fa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "639fec0083134ab18b6a22999203536e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b1df464c3a2d47dbb49b36fa5d9912bb",
+ "placeholder": "",
+ "style": "IPY_MODEL_dd95b1ba48ec4bc5b494f319ad41eedf",
+ "value": "100%"
+ }
+ },
+ "757ca7f811ec4db5a4dd517c5aec2bb8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5b1e8c0d14d84e0f923a997727794c89",
+ "placeholder": "",
+ "style": "IPY_MODEL_2e912ebb924c4269a07fb7c1204eb923",
+ "value": " 32/32 [00:01<00:00, 17.84it/s]"
+ }
+ },
+ "84dfa1f6d6f04c50b87e8683a0abdf55": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b1df464c3a2d47dbb49b36fa5d9912bb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dd95b1ba48ec4bc5b494f319ad41eedf": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}