Skip to content

Commit

Permalink
Revert "Auxiliary commit to revert individual files from 7bd873b"
Browse files Browse the repository at this point in the history
This reverts commit 232d2f0967853204e85d6fb02d7a5463acab2c0e.
  • Loading branch information
jak0bw committed Mar 20, 2023
1 parent af54b53 commit d48d18d
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ class SupervisedTrainer(Trainer):
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
"""

def __init__(
Expand Down Expand Up @@ -158,7 +157,6 @@ def __init__(
optim_set_to_none: bool = False,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
gradient_accumulation_steps: int = 1,
) -> None:
super().__init__(
device=device,
Expand Down Expand Up @@ -187,12 +185,7 @@ def __init__(
self.inferer = SimpleInferer() if inferer is None else inferer
self.optim_set_to_none = optim_set_to_none

if gradient_accumulation_steps <= 0:
raise ValueError("Gradient_accumulation_steps must be strictly positive. "
"No gradient accumulation if the value set to one (default).")
self.gradient_accumulation_steps = gradient_accumulation_steps

def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]):
def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict:
"""
Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
Return below items in a dictionary:
Expand Down Expand Up @@ -228,23 +221,20 @@ def _compute_pred_loss():
engine.fire_event(IterationEvents.LOSS_COMPLETED)

engine.network.train()
if (engine.state.iteration - 1) % self.gradient_accumulation_steps == 0:
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)

if engine.amp and engine.scaler is not None:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
_compute_pred_loss()
engine.scaler.scale(engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward()
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.scaler.step(engine.optimizer)
engine.scaler.update()
engine.scaler.step(engine.optimizer)
engine.scaler.update()
else:
_compute_pred_loss()
(engine.state.output[Keys.LOSS] / self.gradient_accumulation_steps).backward()
engine.state.output[Keys.LOSS].backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
if engine.state.iteration % self.gradient_accumulation_steps == 0:
engine.optimizer.step()
engine.optimizer.step()
engine.fire_event(IterationEvents.MODEL_COMPLETED)

return engine.state.output
Expand Down

0 comments on commit d48d18d

Please sign in to comment.