Skip to content

Commit

Permalink
Add gradient accumulation supervised trainer update step function
Browse files Browse the repository at this point in the history
Signed-off-by: Jakob Weigand <[email protected]>
  • Loading branch information
jak0bw committed Mar 28, 2023
1 parent be3d138 commit 1d8415e
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from monai.config import IgniteInfo
from monai.engines import Trainer
from monai.transforms import apply_transform
from monai.utils import ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys, GanKeys
Expand All @@ -39,6 +40,7 @@
"default_make_latent",
"engine_apply_transform",
"default_metric_cmp_fn",
"GradientAccumulationSupervisedTrainingStep",
]


Expand Down Expand Up @@ -286,3 +288,61 @@ def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:
"""
return current_metric > prev_best


class GradientAccumulationSupervisedTrainingStep():
"""Factory function for supervised training.
Args:
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
Returns:
Callable: update function.
"""

def __init__(self, gradient_accumulation_steps: int = 1) -> None:
if gradient_accumulation_steps <= 0:
raise ValueError("Gradient_accumulation_steps must be strictly positive. "
"No gradient accumulation if the value set to one (default).")
self.gradient_accumulation_steps = gradient_accumulation_steps

def __call__(self, engine: Trainer, batchdata: Sequence[torch.Tensor]) -> Any | tuple[torch.Tensor]:
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
if len(batch) == 2:
inputs, targets = batch
args: tuple = ()
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch
# put iteration outputs into engine.state
engine.state.output = {CommonKeys.IMAGE: inputs, CommonKeys.LABEL: targets}

def _compute_pred_loss():
engine.state.output[CommonKeys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.state.output[CommonKeys.LOSS] = engine.loss_function(engine.state.output[CommonKeys.PRED], targets).mean()
engine.fire_event(IterationEvents.LOSS_COMPLETED)

engine.network.train()
if (engine.state.iteration - 1) % self.gradient_accumulation_steps == 0:
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)

if engine.amp and engine.scaler is not None:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
_compute_pred_loss()
engine.scaler.scale(engine.state.output[CommonKeys.LOSS] / self.gradient_accumulation_steps).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.scaler.step(engine.optimizer)
engine.scaler.update()
else:
_compute_pred_loss()
(engine.state.output[CommonKeys.LOSS] / self.gradient_accumulation_steps).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.optimizer.step()
engine.fire_event(IterationEvents.MODEL_COMPLETED)

return engine.state.output

0 comments on commit 1d8415e

Please sign in to comment.