diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index c49e0577e82..6ba5ad74ae1 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -224,10 +224,7 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso def _compute_pred_loss(): engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) - loss = engine.loss_function(engine.state.output[Keys.PRED], targets).mean() - if self.gradient_accumulation_steps > 1: - loss = loss / self.gradient_accumulation_steps - engine.state.output[Keys.LOSS] = loss + engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) engine.network.train() @@ -237,14 +234,14 @@ def _compute_pred_loss(): if engine.amp and engine.scaler is not None: with torch.cuda.amp.autocast(**engine.amp_kwargs): _compute_pred_loss() - engine.scaler.scale(engine.state.output[Keys.LOSS]).backward() - engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + engine.scaler.scale(engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) if engine.state.iteration % self.gradient_accumulation_steps == 0: engine.scaler.step(engine.optimizer) engine.scaler.update() else: _compute_pred_loss() - engine.state.output[Keys.LOSS].backward() + (engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) if engine.state.iteration % self.gradient_accumulation_steps == 0: engine.optimizer.step()