Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient accumulation functionality to SupervisedTrainer #6100

Open
jak0bw opened this issue Mar 3, 2023 · 1 comment
Open

Add gradient accumulation functionality to SupervisedTrainer #6100

jak0bw opened this issue Mar 3, 2023 · 1 comment

Comments

@jak0bw
Copy link
Contributor

jak0bw commented Mar 3, 2023

I am sorry if I missed any existing functionality or documentation on this topic but I could not find anything.

Is your feature request related to a problem? Please describe.
SupervisedTrainer is missing built-in gradient accumulation functionality

Describe the solution you'd like
Add gradient accumulation functionality

Describe alternatives you've considered
ignite.supervised_training_step as iteration_update parameter
(https://pytorch.org/ignite/generated/ignite.engine.supervised_training_step.html)

trainer = SupervisedTrainer(
      device=device,
      max_epochs=max_epochs,
      train_data_loader=train_loader,
      network=net,
      optimizer=optimizer,
      loss_function=loss,
      inferer=SimpleInferer(),
      key_train_metric=None,
      train_handlers=train_handlers,
      iteration_update=supervised_training_step(device=device,
                                                optimizer=optimizer,
                                                loss_fn=loss,
                                                model=net,
                                                gradient_accumulation_steps=4,
                                                prepare_batch=default_prepare_batch),
      amp=False,
      postprocessing=None,
  )

This works kind of but it does not feel like it should be the way to do it in monai as it does not fire the correct events during update and therefore ignores set handlers during training.

@jak0bw
Copy link
Contributor Author

jak0bw commented Mar 3, 2023

#6101 Would be a first example of what I imagine the solution could look like.

(Source code is strongly (and shamelessly) influenced by https://pytorch.org/ignite/generated/ignite.engine.supervised_training_step.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants