-
Notifications
You must be signed in to change notification settings - Fork 15
[Feature] Model Freezing ❄️ #189
base: develop
Are you sure you want to change the base?
Conversation
…removed metadata saving in checkpoints due to corruption error on big models, fixed logging to work in the transfer leanring setting
…y after changing it
…arning' into feature/transfer-learning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @icedoom888!
Thanks for adding this, I was working on something similar, so might have a refactor coming sometime next year. Until then, I think there are some improvements we should make to the implementation.
- Pytorch Lightning provides recursive freezing capability.
- We should make these config entries optional.
Additionally, I had some concerns around unused parameters in the model. During my research I have found some hints that the training strategy may have to be adjusted to still work correctly with unused parameters.
The Pytorch one is here and implements:
https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel
find_unused_parameters
, have you noticed any changes to your (computational) graph as opposed to non-frozen training?
module (nn.Module): The parent module to search in. | ||
target_name (str): The name of the submodule to freeze. | ||
""" | ||
for name, child in module.named_children(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually necessary?
As far as I know, we can use Pytorch Lightning .freeze()
that also does the recursion for us (and the checking of edge cases).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked into Pytorch Lightning freeze().
The function can only be called on a LightningModule
class. Most submodules are torch.nn.Module classes.
Solves: ecmwf/anemoi-core#10
📚 Documentation preview 📚: https://anemoi-training--189.org.readthedocs.build/en/189/