Skip to content

Commit

Permalink
Add gradient accumulation supervised trainer update step function
Browse files Browse the repository at this point in the history
Signed-off-by: Jakob Weigand <[email protected]>
  • Loading branch information
jak0bw committed Mar 28, 2023
1 parent 1d8415e commit dfe24ef
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from monai.config import IgniteInfo
from monai.engines import Trainer
from monai.engines import SupervisedTrainer
from monai.transforms import apply_transform
from monai.utils import ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys, GanKeys
Expand Down Expand Up @@ -306,7 +306,7 @@ def __init__(self, gradient_accumulation_steps: int = 1) -> None:
"No gradient accumulation if the value set to one (default).")
self.gradient_accumulation_steps = gradient_accumulation_steps

def __call__(self, engine: Trainer, batchdata: Sequence[torch.Tensor]) -> Any | tuple[torch.Tensor]:
def __call__(self, engine: SupervisedTrainer, batchdata: Sequence[torch.Tensor]) -> Any | tuple[torch.Tensor]:
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
Expand Down

0 comments on commit dfe24ef

Please sign in to comment.