From 68177ac43780f6dbc63d3fa49f56862d1700c23a Mon Sep 17 00:00:00 2001 From: Jakob Weigand Date: Fri, 3 Mar 2023 20:33:39 +0100 Subject: [PATCH] Use scaled to gradient accumulation loss only for backwards computation and save the correct value Signed-off-by: Jakob Weigand --- monai/engines/trainer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index c49e0577e82..6ba5ad74ae1 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -224,10 +224,7 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso def _compute_pred_loss(): engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) - loss = engine.loss_function(engine.state.output[Keys.PRED], targets).mean() - if self.gradient_accumulation_steps > 1: - loss = loss / self.gradient_accumulation_steps - engine.state.output[Keys.LOSS] = loss + engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) engine.network.train() @@ -237,14 +234,14 @@ def _compute_pred_loss(): if engine.amp and engine.scaler is not None: with torch.cuda.amp.autocast(**engine.amp_kwargs): _compute_pred_loss() - engine.scaler.scale(engine.state.output[Keys.LOSS]).backward() - engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + engine.scaler.scale(engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) if engine.state.iteration % self.gradient_accumulation_steps == 0: engine.scaler.step(engine.optimizer) engine.scaler.update() else: _compute_pred_loss() - engine.state.output[Keys.LOSS].backward() + (engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) if engine.state.iteration % self.gradient_accumulation_steps == 0: engine.optimizer.step()