Skip to content

Commit

Permalink
Small batch size support in CutMix and CutMixAugment
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Aug 8, 2022
1 parent 4965981 commit a2bb552
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 82 deletions.
6 changes: 3 additions & 3 deletions docs/callback.cutmixup.html
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ <h2 id="MixUp" class="doc_header"><code>class</code> <code>MixUp</code><a href="


<div class="output_markdown rendered_html output_subarea ">
<h2 id="CutMix" class="doc_header"><code>class</code> <code>CutMix</code><a href="https://github.com/warner-benjamin/fastxtend/tree/main/fastxtend/callback/cutmixup.py#L52" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>CutMix</code>(<strong><code>alpha</code></strong>:<code>float</code>=<em><code>1.0</code></em>, <strong><code>uniform</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>interp_label</code></strong>:<code>bool | None</code>=<em><code>None</code></em>) :: <a href="/multiloss.html#MixHandlerX"><code>MixHandlerX</code></a></p>
<h2 id="CutMix" class="doc_header"><code>class</code> <code>CutMix</code><a href="https://github.com/warner-benjamin/fastxtend/tree/main/fastxtend/callback/cutmixup.py#L54" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>CutMix</code>(<strong><code>alpha</code></strong>:<code>float</code>=<em><code>1.0</code></em>, <strong><code>uniform</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>interp_label</code></strong>:<code>bool | None</code>=<em><code>None</code></em>) :: <a href="/multiloss.html#MixHandlerX"><code>MixHandlerX</code></a></p>
</blockquote>
<p>Implementation of <a href="https://arxiv.org/abs/1905.04899">https://arxiv.org/abs/1905.04899</a>. Supports <a href="/multiloss.html#MultiLoss"><code>MultiLoss</code></a></p>
<table>
Expand Down Expand Up @@ -163,7 +163,7 @@ <h2 id="CutMix" class="doc_header"><code>class</code> <code>CutMix</code><a href


<div class="output_markdown rendered_html output_subarea ">
<h2 id="CutMixUp" class="doc_header"><code>class</code> <code>CutMixUp</code><a href="https://github.com/warner-benjamin/fastxtend/tree/main/fastxtend/callback/cutmixup.py#L115" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>CutMixUp</code>(<strong><code>mix_alpha</code></strong>:<code>float</code>=<em><code>0.4</code></em>, <strong><code>cut_alpha</code></strong>:<code>float</code>=<em><code>1.0</code></em>, <strong><code>mixup_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>cutmix_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>cutmix_uniform</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>element</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>interp_label</code></strong>:<code>bool | None</code>=<em><code>None</code></em>) :: <a href="/callback.cutmixup.html#MixUp"><code>MixUp</code></a></p>
<h2 id="CutMixUp" class="doc_header"><code>class</code> <code>CutMixUp</code><a href="https://github.com/warner-benjamin/fastxtend/tree/main/fastxtend/callback/cutmixup.py#L117" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>CutMixUp</code>(<strong><code>mix_alpha</code></strong>:<code>float</code>=<em><code>0.4</code></em>, <strong><code>cut_alpha</code></strong>:<code>float</code>=<em><code>1.0</code></em>, <strong><code>mixup_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>cutmix_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>cutmix_uniform</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>element</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>interp_label</code></strong>:<code>bool | None</code>=<em><code>None</code></em>) :: <a href="/callback.cutmixup.html#MixUp"><code>MixUp</code></a></p>
</blockquote>
<p>Combo implementation of <a href="https://arxiv.org/abs/1710.09412">https://arxiv.org/abs/1710.09412</a> and <a href="https://arxiv.org/abs/1905.04899">https://arxiv.org/abs/1905.04899</a>"</p>
<p>Supports element-wise application of MixUp and CutMix on a batch.</p>
Expand Down Expand Up @@ -249,7 +249,7 @@ <h2 id="CutMixUp" class="doc_header"><code>class</code> <code>CutMixUp</code><a


<div class="output_markdown rendered_html output_subarea ">
<h2 id="CutMixUpAugment" class="doc_header"><code>class</code> <code>CutMixUpAugment</code><a href="https://github.com/warner-benjamin/fastxtend/tree/main/fastxtend/callback/cutmixup.py#L179" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>CutMixUpAugment</code>(<strong><code>mix_alpha</code></strong>:<code>float</code>=<em><code>0.4</code></em>, <strong><code>cut_alpha</code></strong>:<code>float</code>=<em><code>1.0</code></em>, <strong><code>mixup_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>cutmix_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>augment_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>augment_finetune</code></strong>:<code>Number | None</code>=<em><code>None</code></em>, <strong><code>cutmix_uniform</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>cutmixup_augs</code></strong>:<code>listified[Transform | Callable[..., Transform]] | None</code>=<em><code>None</code></em>, <strong><code>element</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>interp_label</code></strong>:<code>bool | None</code>=<em><code>None</code></em>) :: <a href="/callback.cutmixup.html#MixUp"><code>MixUp</code></a></p>
<h2 id="CutMixUpAugment" class="doc_header"><code>class</code> <code>CutMixUpAugment</code><a href="https://github.com/warner-benjamin/fastxtend/tree/main/fastxtend/callback/cutmixup.py#L187" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>CutMixUpAugment</code>(<strong><code>mix_alpha</code></strong>:<code>float</code>=<em><code>0.4</code></em>, <strong><code>cut_alpha</code></strong>:<code>float</code>=<em><code>1.0</code></em>, <strong><code>mixup_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>cutmix_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>augment_ratio</code></strong>:<code>Number</code>=<em><code>1</code></em>, <strong><code>augment_finetune</code></strong>:<code>Number | None</code>=<em><code>None</code></em>, <strong><code>cutmix_uniform</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>cutmixup_augs</code></strong>:<code>listified[Transform | Callable[..., Transform]] | None</code>=<em><code>None</code></em>, <strong><code>element</code></strong>:<code>bool</code>=<em><code>True</code></em>, <strong><code>interp_label</code></strong>:<code>bool | None</code>=<em><code>None</code></em>) :: <a href="/callback.cutmixup.html#MixUp"><code>MixUp</code></a></p>
</blockquote>
<p>Combo implementation of <a href="https://arxiv.org/abs/1710.09412">https://arxiv.org/abs/1710.09412</a> and <a href="https://arxiv.org/abs/1905.04899">https://arxiv.org/abs/1905.04899</a> plus Augmentation.</p>
<p>Supports element-wise application of MixUp, CutMix, and Augmentation on a batch.</p>
Expand Down
2 changes: 1 addition & 1 deletion fastxtend/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.0.10"
77 changes: 47 additions & 30 deletions fastxtend/callback/cutmixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Cell
#nbdev_comment from __future__ import annotations

from torch.distributions import Bernoulli, Categorical
from torch.distributions import Categorical
from torch.distributions.beta import Beta

from fastcore.transform import Pipeline, Transform
Expand Down Expand Up @@ -44,7 +44,9 @@ def before_batch(self):
self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))

def _mixup(self, bs):
lam = self.distrib.sample((bs,)).squeeze().to(self.x.device)
lam = self.distrib.sample((bs,)).to(self.x.device)
if len(lam.shape) > 1:
lam = lam.squeeze()
lam = torch.stack([lam, 1-lam], 1)
return lam.max(1)[0]

Expand Down Expand Up @@ -77,8 +79,8 @@ def before_batch(self):

def _uniform_cutmix(self, xb, xb1, H, W):
"Add uniform patches and blend labels from another random item in batch"
self.lam = self.distrib.sample((1,)).to(self.x.device)
x1, y1, x2, y2 = self.rand_bbox(W, H, self.lam)
lam = self.distrib.sample((1,)).to(self.x.device)
x1, y1, x2, y2 = self.rand_bbox(W, H, lam)
xb[..., y1:y2, x1:x2] = xb1[..., y1:y2, x1:x2]
lam = (1 - ((x2-x1)*(y2-y1))/float(W*H))
return xb, lam
Expand Down Expand Up @@ -149,17 +151,19 @@ def before_batch(self):
xb1, self.yb1 = xb[shuffle], self.yb1[shuffle]

# Apply MixUp
self.distrib = self.mix_distrib
self.lam[aug_type==0] = MixUp._mixup(self, xb[aug_type==0].shape[0])
xb[aug_type==0] = torch.lerp(xb1[aug_type==0], xb[aug_type==0], weight=unsqueeze(self.lam[aug_type==0], n=3))
if (aug_type==0).sum() > 0:
self.distrib = self.mix_distrib
self.lam[aug_type==0] = MixUp._mixup(self, xb[aug_type==0].shape[0])
xb[aug_type==0] = torch.lerp(xb1[aug_type==0], xb[aug_type==0], weight=unsqueeze(self.lam[aug_type==0], n=3))

# Apply CutMix
bs, _, H, W = xb[aug_type==1].size()
self.distrib = self.cut_distrib
if self.cutmix_uniform:
xb[aug_type==1], self.lam[aug_type==1] = CutMix._uniform_cutmix(self, xb[aug_type==1], xb1[aug_type==1], H, W)
else:
xb[aug_type==1], self.lam[aug_type==1] = CutMix._multi_cutmix(self, xb[aug_type==1], xb1[aug_type==1], H, W, bs)
if bs > 0:
self.distrib = self.cut_distrib
if self.cutmix_uniform:
xb[aug_type==1], self.lam[aug_type==1] = CutMix._uniform_cutmix(self, xb[aug_type==1], xb1[aug_type==1], H, W)
else:
xb[aug_type==1], self.lam[aug_type==1] = CutMix._multi_cutmix(self, xb[aug_type==1], xb1[aug_type==1], H, W, bs)

self.learn.xb = (xb,)
if not self.stack_y:
Expand All @@ -175,6 +179,10 @@ def before_batch(self):
self.distrib = self.cut_distrib
CutMix.before_batch(self)

# Internal Cell
def _do_cutmixaug(t:Tensor):
return t.sum().item() > 0

# Cell
class CutMixUpAugment(MixUp, CutMix):
"""
Expand Down Expand Up @@ -263,8 +271,10 @@ def before_batch(self):
bs, C, H, W = xb.size()
self.lam = torch.zeros(bs, device=xb.device)
aug_type = self.categorical.sample((bs,))
shuffle = torch.randperm(xb[aug_type<2].shape[0]).to(xb.device)
self.yb1[aug_type<2] = self.yb1[aug_type<2][shuffle]
do_mix, do_cut, do_aug = _do_cutmixaug(aug_type==0), _do_cutmixaug(aug_type==1), _do_cutmixaug(aug_type==2)
if do_mix or do_cut:
shuffle = torch.randperm(xb[aug_type<2].shape[0]).to(xb.device)
self.yb1[aug_type<2] = self.yb1[aug_type<2][shuffle]

# Apply IntToFloat to all samples
xb = self._inttofloat_pipe(xb)
Expand All @@ -273,33 +283,40 @@ def before_batch(self):
xb2 = torch.zeros([bs, C, self._size[0], self._size[1]], dtype=xb.dtype, device=xb.device) if self._size is not None else torch.zeros_like(xb)

# Apply MixUp/CutMix Augmentations to MixUp and CutMix samples
if self._docutmixaug:
xb2[aug_type<2] = self._cutmixaugs_pipe(xb[aug_type<2])
else:
xb2[aug_type<2] = xb[aug_type<2]
if do_mix or do_cut:
if self._docutmixaug:
xb2[aug_type<2] = self._cutmixaugs_pipe(xb[aug_type<2])
else:
xb2[aug_type<2] = xb[aug_type<2]

# Original Augmentations
xb2[aug_type==2] = self._aug_pipe(xb[aug_type==2])
if do_aug:
xb2[aug_type==2] = self._aug_pipe(xb[aug_type==2])

# Possibly Resized xb and shuffled xb1
xb = xb2
xb1 = xb[aug_type<2][shuffle]
if do_mix or do_cut:
xb1 = xb[aug_type<2][shuffle]

# Apply MixUp
self.distrib = self.mix_distrib
self.lam[aug_type==0] = MixUp._mixup(self, xb[aug_type==0].shape[0])
xb[aug_type==0] = torch.lerp(xb1[aug_type[aug_type<2]==0], xb[aug_type==0], weight=unsqueeze(self.lam[aug_type==0], n=3))
if do_mix:
self.distrib = self.mix_distrib
self.lam[aug_type==0] = MixUp._mixup(self, xb[aug_type==0].shape[0])
xb[aug_type==0] = torch.lerp(xb1[aug_type[aug_type<2]==0], xb[aug_type==0], weight=unsqueeze(self.lam[aug_type==0], n=3))

# Apply CutMix
bs, _, H, W = xb[aug_type==1].size()
self.distrib = self.cut_distrib
if self.cutmix_uniform:
xb[aug_type==1], self.lam[aug_type==1] = CutMix._uniform_cutmix(self, xb[aug_type==1], xb1[aug_type[aug_type<2]==1], H, W)
else:
xb[aug_type==1], self.lam[aug_type==1] = CutMix._multi_cutmix(self, xb[aug_type==1], xb1[aug_type[aug_type<2]==1], H, W, bs)
if do_cut:
bs, _, H, W = xb[aug_type==1].size()
self.distrib = self.cut_distrib
if self.cutmix_uniform:
xb[aug_type==1], lam = CutMix._uniform_cutmix(self, xb[aug_type==1], xb1[aug_type[aug_type<2]==1], H, W)
self.lam[aug_type==1] = lam.expand(bs)
else:
xb[aug_type==1], self.lam[aug_type==1] = CutMix._multi_cutmix(self, xb[aug_type==1], xb1[aug_type[aug_type<2]==1], H, W, bs)

# Normalize MixUp and CutMix
xb[aug_type<2] = self._norm_pipe(xb[aug_type<2])
if do_mix or do_cut:
xb[aug_type<2] = self._norm_pipe(xb[aug_type<2])

self.learn.xb = (xb,)
if not self.stack_y:
Expand Down
Loading

0 comments on commit a2bb552

Please sign in to comment.