diff --git a/docs/callback.cutmixup.html b/docs/callback.cutmixup.html index 91dba53..a854339 100644 --- a/docs/callback.cutmixup.html +++ b/docs/callback.cutmixup.html @@ -102,7 +102,7 @@
class
MixUp
-class
CutMix
[source]+
CutMix
(alpha
:float
=1.0
,uniform
:bool
=True
,interp_label
:bool | None
=None
) ::MixHandlerX
class
CutMix
[source]
CutMix
(alpha
:float
=1.0
,uniform
:bool
=True
,interp_label
:bool | None
=None
) ::MixHandlerX
Implementation of https://arxiv.org/abs/1905.04899. Supports
MultiLoss
\n", " \n", + " \n", + "" + ], + "text/plain": [ + "0 \n", - "2.052253 \n", - "1.683779 \n", - "00:17 \n", + "2.041265 \n", + "1.650648 \n", + "00:14 \n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "#|hide\n", + "#|slow\n", + "#|cuda\n", + "imagenette = untar_data(URLs.IMAGENETTE_160)\n", + "\n", + "with less_random():\n", + " dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),\n", + " splitter=GrandparentSplitter(valid_name='val'),\n", + " get_items=get_image_files, get_y=parent_label,\n", + " item_tfms=Resize(128),\n", + " batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])\n", + " dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)\n", + "\n", + " learn = Learner(dls, resnet34(num_classes=dls.c), cbs=CutMixUp(cutmix_uniform=True)).to_channelslast()\n", + " learn.fit_one_cycle(1, 3e-3)\n", + " free_gpu_memory(learn, dls)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " \n", + "
" @@ -1203,7 +1281,7 @@ " get_items=get_image_files, get_y=parent_label,\n", " item_tfms=Resize(128),\n", " batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])\n", - " dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)\n", + " dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)\n", "\n", " learn = Learner(dls, resnet34(num_classes=dls.c), cbs=CutMixUp(cutmix_uniform=False)).to_channelslast()\n", " learn.fit_one_cycle(1, 3e-3)\n", @@ -1230,8 +1308,8 @@ " \n", "\n", + " \n", + " \n", + " \n", + "epoch \n", + "train_loss \n", + "valid_loss \n", + "time \n", + "\n", + " \n", " \n", "0 \n", + "2.042219 \n", + "1.635136 \n", + "00:22 \n", "\n", " \n", " \n", @@ -1257,10 +1335,10 @@ " get_items=get_image_files, get_y=parent_label,\n", " item_tfms=Resize(128),\n", " batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])\n", - " dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)\n", + " dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)\n", "\n", " learn = Learner(dls, resnet34(num_classes=dls.c), \n", - " cbs=CutMixUpAugment(cutmix_uniform=False, cutmixup_augs=aug_transforms())).to_channelslast()\n", + " cbs=CutMixUpAugment(cutmix_uniform=True, cutmixup_augs=aug_transforms())).to_channelslast()\n", " learn.fit_one_cycle(1, 3e-3)\n", " free_gpu_memory(learn, dls)" ] @@ -1285,9 +1363,9 @@ " \n", "0 \n", - "2.277530 \n", - "1.957705 \n", + "2.275790 \n", + "1.934458 \n", "00:15 \n", "\n", " \n", " \n", "" @@ -1312,7 +1390,7 @@ " get_items=get_image_files, get_y=parent_label,\n", " item_tfms=Resize(128),\n", " batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])\n", - " dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)\n", + " dls = dblock.dataloaders(imagenette, bs=128, num_workers=num_cpus(), pin_memory=True)\n", "\n", " learn = Learner(dls, resnet34(num_classes=dls.c), cbs=CutMixUpAugment(cutmix_uniform=False, element=False)).to_channelslast()\n", " learn.fit_one_cycle(1, 3e-3)\n", diff --git a/settings.ini b/settings.ini index dfbd0a7..918ecd6 100644 --- a/settings.ini +++ b/settings.ini @@ -8,7 +8,7 @@ author = Benjamin Warner author_email = me@benjaminwarner.dev copyright = Benjamin Warner branch = main -version = 0.1.0 +version = 0.0.10 min_python = 3.7 audience = Developers language = English0 \n", - "1.956891 \n", - "1.655091 \n", - "00:13 \n", + "1.937060 \n", + "1.594086 \n", + "00:17 \n", "