Skip to content

Commit

Permalink
Add gradient accumulation logic to SupervisedTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
jak0bw committed Mar 3, 2023
1 parent 10faf46 commit 6bc4652
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ class SupervisedTrainer(Trainer):
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
"""

def __init__(
Expand Down Expand Up @@ -157,6 +158,7 @@ def __init__(
optim_set_to_none: bool = False,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
gradient_accumulation_steps: int = 1,
) -> None:
super().__init__(
device=device,
Expand Down Expand Up @@ -185,7 +187,12 @@ def __init__(
self.inferer = SimpleInferer() if inferer is None else inferer
self.optim_set_to_none = optim_set_to_none

def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict:
if gradient_accumulation_steps <= 0:
raise ValueError("Gradient_accumulation_steps must be strictly positive. "
"No gradient accumulation if the value set to one (default).")
self.gradient_accumulation_steps = gradient_accumulation_steps

def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]):
"""
Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
Return below items in a dictionary:
Expand Down Expand Up @@ -217,24 +224,30 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso
def _compute_pred_loss():
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()
loss = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps
engine.state.output[Keys.LOSS] = loss
engine.fire_event(IterationEvents.LOSS_COMPLETED)

engine.network.train()
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
if (engine.state.iteration - 1) % self.gradient_accumulation_steps == 0:
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)

if engine.amp and engine.scaler is not None:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
_compute_pred_loss()
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
engine.scaler.step(engine.optimizer)
engine.scaler.update()
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.scaler.step(engine.optimizer)
engine.scaler.update()
else:
_compute_pred_loss()
engine.state.output[Keys.LOSS].backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
engine.optimizer.step()
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.optimizer.step()
engine.fire_event(IterationEvents.MODEL_COMPLETED)

return engine.state.output
Expand Down

0 comments on commit 6bc4652

Please sign in to comment.