Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient accumulation logic to SupervisedTrainer #6101

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from monai.config import IgniteInfo
from monai.engines import SupervisedTrainer
from monai.transforms import apply_transform
from monai.utils import ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys, GanKeys
Expand All @@ -39,6 +40,7 @@
"default_make_latent",
"engine_apply_transform",
"default_metric_cmp_fn",
"GradientAccumulationSupervisedTrainingStep",
]


Expand Down Expand Up @@ -286,3 +288,61 @@ def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:

"""
return current_metric > prev_best


class GradientAccumulationSupervisedTrainingStep():
"""Factory function for supervised training.

Args:
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
Returns:
Callable: update function.
"""

def __init__(self, gradient_accumulation_steps: int = 1) -> None:
if gradient_accumulation_steps <= 0:
raise ValueError("Gradient_accumulation_steps must be strictly positive. "
"No gradient accumulation if the value set to one (default).")
self.gradient_accumulation_steps = gradient_accumulation_steps

def __call__(self, engine: SupervisedTrainer, batchdata: Sequence[torch.Tensor]) -> Any | tuple[torch.Tensor]:
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
if len(batch) == 2:
inputs, targets = batch
args: tuple = ()
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch
# put iteration outputs into engine.state
engine.state.output = {CommonKeys.IMAGE: inputs, CommonKeys.LABEL: targets}

def _compute_pred_loss():
engine.state.output[CommonKeys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.state.output[CommonKeys.LOSS] = engine.loss_function(engine.state.output[CommonKeys.PRED], targets).mean()
engine.fire_event(IterationEvents.LOSS_COMPLETED)

engine.network.train()
if (engine.state.iteration - 1) % self.gradient_accumulation_steps == 0:
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)

if engine.amp and engine.scaler is not None:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
_compute_pred_loss()
engine.scaler.scale(engine.state.output[CommonKeys.LOSS] / self.gradient_accumulation_steps).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.scaler.step(engine.optimizer)
engine.scaler.update()
else:
_compute_pred_loss()
(engine.state.output[CommonKeys.LOSS] / self.gradient_accumulation_steps).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.optimizer.step()
engine.fire_event(IterationEvents.MODEL_COMPLETED)

return engine.state.output