Skip to content

Commit

Permalink
Use scaled to gradient accumulation loss only for backwards computati…
Browse files Browse the repository at this point in the history
…on and save the correct value

Signed-off-by: [email protected]
  • Loading branch information
jak0bw committed Mar 3, 2023
1 parent de2884a commit 3c25c7c
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,7 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso
def _compute_pred_loss():
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
loss = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps
engine.state.output[Keys.LOSS] = loss
engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()
engine.fire_event(IterationEvents.LOSS_COMPLETED)

engine.network.train()
Expand All @@ -237,14 +234,14 @@ def _compute_pred_loss():
if engine.amp and engine.scaler is not None:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
_compute_pred_loss()
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
engine.scaler.scale(engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.scaler.step(engine.optimizer)
engine.scaler.update()
else:
_compute_pred_loss()
engine.state.output[Keys.LOSS].backward()
(engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.optimizer.step()
Expand Down

0 comments on commit 3c25c7c

Please sign in to comment.