diff --git a/notebooks/external_augmentation.ipynb b/notebooks/external_augmentation.ipynb index 4ff452e8..3e61493b 100644 --- a/notebooks/external_augmentation.ipynb +++ b/notebooks/external_augmentation.ipynb @@ -3,39 +3,320 @@ { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": false, + "name": "#%%\n" + } + }, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline" - ], + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Integration External Frameworks into a `rising` Augmentation Pipeline" + ] + }, + { + "cell_type": "markdown", "metadata": { - "collapsed": false, "pycharm": { - "name": "#%%\n", - "is_executing": false + "name": "#%% md\n" } - } + }, + "source": [ + "### Using transformation from external libraries inside `rising`\n", + "> Note: Some external augmentation libraries are only supported at the beginning of\n", + "the transformation pipeline. In general, please consider creating an issue in `rising` \n", + "and there will be a high chance we (or if you prefer you :) ) will add the transformation in the future :) " + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ - "## Using transformation from external libraries inside `rising`\n", - "> Note: Some external augmentation libraries are only supported at the beginning of\n", - "the transformation pipeline. Generally speaking, if you need to resort to an\n", - "external library for augmentations, consider creating an issue in `rising` \n", - "and there is a high chance we will add the transformation in the future :) " - ], + "## 3D (Volumetric) Augmentation\n", + "The first part of this notebook will focus on frameworks which support volumetric transformations (like rising also does). This mean data with a shape of [C, D, H, W] (C=Channels, DHW sptial dimensions). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --quiet --upgrade SimpleITK\n", + "!git clone https://github.com/PhoenixDL/rising.git\n", + "!pip install --quiet --upgrade ./rising" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# download some volumetric data (here MRI Data)\n", + "from io import BytesIO\n", + "from zipfile import ZipFile\n", + "from urllib.request import urlopen\n", + "\n", + "resp = urlopen(\"http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip\")\n", + "zipfile = ZipFile(BytesIO(resp.read()))\n", + "\n", + "img_file = zipfile.extract(\"ExBox3/T1_brain.nii.gz\")\n", + "mask_file = zipfile.extract(\"ExBox3/T1_brain_seg.nii.gz\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import SimpleITK as sitk\n", + "import numpy as np\n", + "\n", + "img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))\n", + "img = img.astype(np.float32)\n", + "# sitk.WriteImage(sitk.GetImageFromArray(img), img_file)\n", + "mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))\n", + "mask = mask.astype(np.float32)\n", + "# sitk.WriteImage(sitk.GetImageFromArray(mask), mask_file)\n", + "\n", + "assert mask.shape == img.shape\n", + "print(f\"Image shape {img.shape}\")\n", + "print(f\"Mask shape {mask.shape}\")" + ] + }, + { + "cell_type": "markdown", "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "### Integration of `batchgenerators` transformations\n", + "Note: when batchgenerator transformations are integrated, gradients can not be propagated through its\n", + "transformations.\n", + "\n", + "`batchgenerators` transformations are based on numpy to be framework agnostic. They are also based\n", + "on dictionaries which are modified through the transformations.\n", + "\n", + "There are two steps which need to be integrated into your pipelin in order to the \n", + "`batchgenerators` transforms\n", + "\n", + "1. Exchange the `default_collate` function inside the dataloder with `numpy_collate`\n", + "2. When switching from `batchgenerators` transformations to `rising` transformations, insdert `ToTensor` transformation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": false, + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# setup transforms\n", + "from rising.transforms import *\n", + "from batchgenerators.transforms import ZeroMeanUnitVarianceTransform\n", + "\n", + "transforms = []\n", + "# convert tuple into dict\n", + "transforms.append(SeqToMap(\"data\", \"label\"))\n", + "\n", + "# batchgenerators transforms\n", + "transforms.append(ZeroMeanUnitVarianceTransform())\n", + "# ... additional batchgenerator transforms\n", + "\n", + "# convert to tensor\n", + "transforms.append(ToTensor())\n", + "\n", + "# rising transforms\n", + "transforms.append(Rot90((0, 1)))\n", + "transforms.append(Mirror(dims=(0, 1)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": false, + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "from rising.loading import DataLoader, default_transform_call, numpy_collate\n", + "from rising.transforms import Compose\n", + "\n", + "composed = Compose(transforms, transform_call=default_transform_call)\n", + "dataloader = DataLoader(dataset, batch_size=8, batch_transforms=composed,\n", + " num_workers=0, collate_fn=numpy_collate)\n", + "_iter = iter(dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(_iter)\n", + "show_batch(batch[\"data\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Integration of `tochio`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --quiet --upgrade torchio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = None\n", + "\n", + "import torchio\n", + "\n", + "subject_a = torchio.Subject(\n", + " t1=torchio.Image('./ExBox3/T1_brain.nii.gz', torchio.INTENSITY),\n", + ")\n", + "rescale = torchio.transforms.RescaleIntensity((0, 1))\n", + "transform = torchio.transforms.Compose([rescale])\n", + "\n", + "# ImagesDataset is a subclass of torch.data.utils.Dataset\n", + "dataset = torchio.ImagesDataset([subject_a], transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# setup any additional rising transformations\n", + "from rising.transforms import *\n", + "\n", + "class SelectKeys(AbstractTransform):\n", + " def __init__(self, keys=[\"t1\"]):\n", + " super().__init__(grad=False)\n", + " self.keys = keys\n", + " \n", + " def forward(self, **batch):\n", + " for _key in self.keys:\n", + " batch[_key] = batch[_key][\"data\"]\n", + " return batch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rising_transforms = [\n", + " SelectKeys(keys=[\"t1\"]),\n", + " Rot90(keys=(\"t1\",), dims=(0, 1)),\n", + " Mirror(keys=(\"t1\",), dims=(0, 1)),\n", + "]\n", + "batch_transforms = Compose(rising_transforms)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instead of using the native pytorch dataloader we exchange it for the dataloder from rising \n", + "from rising.loading import DataLoader\n", + "\n", + "dataloader = DataLoader(dataset, batch_size=1, batch_transforms=batch_transforms, num_workers=4)\n", + "_iter = iter(dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(_iter)\n", + "print(batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2D Augmentation" + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": false, + "name": "#%%\n" + } + }, "outputs": [], "source": [ "# lets prepare a basic dataset (e.g. one from `torchvision`)\n", @@ -65,18 +346,21 @@ "\n", "dataset = torchvision.datasets.MNIST(\n", " os.getcwd(), train=True, download=True, transform=to_array)" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n", - "is_executing": false - } - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": false, + "name": "#%%\n" + } + }, "outputs": [], "source": [ "# plot shape\n", @@ -87,18 +371,21 @@ "plt.imshow(dataset[0][0][0], cmap='gray')\n", "plt.colorbar()\n", "plt.show()" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n", - "is_executing": false - } - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "is_executing": false, + "name": "#%%\n" + } + }, "outputs": [], "source": [ "# helper function to visualize batches of images\n", @@ -109,101 +396,117 @@ " plt.imshow(grid[0], cmap='gray')\n", " # plt.colorbar()\n", " plt.show()" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n", - "is_executing": false - } - } + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ - "### Integration of `batchgenerators` transformations into the augmentation pipeline.\n", - "Note: when batchgenerator transformations are integrated, gradients can not be propagated through\n", - "the transformation pipeline.\n", - "\n", - "`batchgenerators` transformations are based on numpy to be framework agnostic. They are also based\n", - "on dictionaries which are modified through the transformations.\n", + "### Integration of `albumentation`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --quiet --upgrade albumentations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from albumentations import RandomRotate90, Flip, Compose\n", "\n", - "There are two steps which need to be integrated into your pipelin in order to the \n", - "`batchgenerators` transforms\n", + "def aug(p=0.5):\n", + " return Compose([\n", + " RandomRotate90(),\n", + " Flip(),\n", + " ], p=p)\n", "\n", - "1. Exchange the `default_collate` function inside the dataloder with `numpy_collate`\n", - "2. When switching from `batchgenerators` transformations to `rising` transformations, insdert `ToTensor` transformation" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } + "augmentation = aug(p=0.9)" + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "# setup transforms\n", - "from rising.transforms import *\n", - "from batchgenerators.transforms import ZeroMeanUnitVarianceTransform\n", - "\n", - "transforms = []\n", - "# convert tuple into dict\n", - "transforms.append(SeqToMap(\"data\", \"label\"))\n", - "# batchgenerators transforms\n", - "transforms.append(ZeroMeanUnitVarianceTransform())\n", - "# convert to tensor\n", - "transforms.append(ToTensor())\n", - "# rising transforms\n", - "transforms.append(Rot90((0, 1)))\n", - "transforms.append(Mirror(dims=(0, 1)))" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n", - "is_executing": false - } - } + "rising_transforms = [\n", + " SelectKeys(keys=[\"t1\"]),\n", + " Rot90(keys=(\"t1\",), dims=(0, 1)),\n", + " Mirror(keys=(\"t1\",), dims=(0, 1)),\n", + "]\n", + "batch_transforms = Compose(rising_transforms)" + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "from rising.loading import DataLoader, default_transform_call, numpy_collate\n", - "from rising.transforms import Compose\n", - "\n", - "composed = Compose(transforms, transform_call=default_transform_call)\n", - "dataloader = DataLoader(dataset, batch_size=8, batch_transforms=composed,\n", - " num_workers=0, collate_fn=numpy_collate)\n", + "# Instead of using the native pytorch dataloader we exchange it for the dataloder from rising \n", + "from rising.loading import DataLoader\n", "\n", - "_iter = iter(dataloader)\n", - "batch = next(_iter)\n", - "show_batch(batch[\"data\"])" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n", - "is_executing": false - } - } + "dataloader = DataLoader(dataset, batch_size=1, batch_transforms=batch_transforms, num_workers=4)\n", + "_iter = iter(dataloader)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Integration of `imgaug`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --quiet --upgrade imgaug" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# needs a rename transform" + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ - "### More libraries will be added in the future :) \n" - ], + "### Integration of `torchvision`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "### You want a library which is not listed here? Just open an issue [here](https://github.com/PhoenixDL/rising/issues).\n" + ] } ], "metadata": { @@ -215,25 +518,25 @@ "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.7.7" }, "pycharm": { "stem_cell": { "cell_type": "raw", - "source": [], "metadata": { "collapsed": false - } + }, + "source": [] } } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 }