-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Support FSDP2 #3231
base: main
Are you sure you want to change the base?
[RFC] Support FSDP2 #3231
Conversation
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
@ByronHsu FYI - thoughts? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@kmehant thanks for starting this PR! I was looking at FSDP2 support in
Yes please! It looks like FSDP2 will be in a public API in the next torch release (2.6): pytorch/pytorch@d815efc , so maybe things are somewhat stable ? But many of the older config parameters (like
Hmm if the new API had supported most of the V1 configurations, I would think having only a feature flag would be enough -i.e something like
It looks like |
cc @muellerzr curious to know if this already in the pipeline internally from HF! |
# auto_wrap_policy is not yet supported by FSDP2 | ||
# therefore manual wrapping has to be done like below | ||
####### | ||
for layer in model.model.layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one doesn't seem to apply to general use case.
Feels like it should be something like below that checks and apply fully_shard from bottom up.
stack = [model]
ordered_modules = []
while stack:
current_modules = stack.pop()
for _, attr in current_module.__dict__.items():
if isinstance(attr, torch.nn.Module):
stack.append(attr)
ordered_modules.append(current_module)
for each in ordered_modules[::-1]:
fully_shard(each, **fsdp2_kwargs)
What does this PR do?
Prototype implementation for porting from FSDP V1 to FSDP V2. There are couple of open questions in this PR that would need comments and discussion.
Preliminary run of this PR and results
The current version of the PR has been tested for basic functionality (full shard) and compared with previous FSDP V1 implementation.
Memory
Loss Parity
Throughput
TODO
Fixes #2873
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr