Skip to content

Commit

Permalink
Use scaled to gradient accumulation loss only for backwards computati…
Browse files Browse the repository at this point in the history
…on and save the correct value

Signed-off-by: Jakob Weigand <[email protected]>
  • Loading branch information
jak0bw committed Mar 3, 2023
1 parent 892b482 commit 68177ac
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,7 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso
def _compute_pred_loss():
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
loss = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps
engine.state.output[Keys.LOSS] = loss
engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()
engine.fire_event(IterationEvents.LOSS_COMPLETED)

engine.network.train()
Expand All @@ -237,14 +234,14 @@ def _compute_pred_loss():
if engine.amp and engine.scaler is not None:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
_compute_pred_loss()
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
engine.scaler.scale(engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.scaler.step(engine.optimizer)
engine.scaler.update()
else:
_compute_pred_loss()
engine.state.output[Keys.LOSS].backward()
(engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.optimizer.step()
Expand Down

0 comments on commit 68177ac

Please sign in to comment.