Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

[Feature] Model Freezing ❄️ #189

Open
wants to merge 59 commits into
base: develop
Choose a base branch
from

Conversation

icedoom888
Copy link
Contributor

@icedoom888 icedoom888 commented Dec 6, 2024

icedoom888 and others added 30 commits October 9, 2024 15:25
…removed metadata saving in checkpoints due to corruption error on big models, fixed logging to work in the transfer leanring setting
@icedoom888 icedoom888 added enhancement New feature or request contributor labels Dec 6, 2024
@icedoom888 icedoom888 self-assigned this Dec 6, 2024
@icedoom888 icedoom888 requested a review from gabrieloks December 6, 2024 10:26
@icedoom888 icedoom888 changed the title Feature/model freezing ❄️ Feature/Model Freezing ❄️ Dec 6, 2024
Copy link
Member

@JesperDramsch JesperDramsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @icedoom888!

Thanks for adding this, I was working on something similar, so might have a refactor coming sometime next year. Until then, I think there are some improvements we should make to the implementation.

  1. Pytorch Lightning provides recursive freezing capability.
  2. We should make these config entries optional.

Additionally, I had some concerns around unused parameters in the model. During my research I have found some hints that the training strategy may have to be adjusted to still work correctly with unused parameters.

The Pytorch one is here and implements:
https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel
find_unused_parameters, have you noticed any changes to your (computational) graph as opposed to non-frozen training?

CHANGELOG.md Outdated Show resolved Hide resolved
src/anemoi/training/utils/checkpoint.py Outdated Show resolved Hide resolved
src/anemoi/training/train/train.py Outdated Show resolved Hide resolved
src/anemoi/training/train/train.py Outdated Show resolved Hide resolved
module (nn.Module): The parent module to search in.
target_name (str): The name of the submodule to freeze.
"""
for name, child in module.named_children():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually necessary?

As far as I know, we can use Pytorch Lightning .freeze() that also does the recursion for us (and the checking of edge cases).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into Pytorch Lightning freeze().
The function can only be called on a LightningModule class. Most submodules are torch.nn.Module classes.

@icedoom888 icedoom888 changed the title Feature/Model Freezing ❄️ [Feature] Model Freezing ❄️ Dec 17, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
contributor enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature] Model Freezing ❄️
2 participants