-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add keep_torch_compile
param to unwrap_model
and extract_model_from_parallel
for distributed compiled model.
#3282
Add keep_torch_compile
param to unwrap_model
and extract_model_from_parallel
for distributed compiled model.
#3282
Conversation
fb3809e
to
f601b8c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ggoggam, this is actually intended behavior as you can see from this PR #1437.
What I can suggest modify either unwrap_model
or add a new argument to extract_model_from_parallel
. cc @muellerzr
extract_model_from_parallel
to fully unwrap compiled model.unwrap_model
for distributed compiled model.
I see. I think it would make more sense to modify |
Thanks for the PR. Regarding the implementation, |
My first commit actually fixes accelerate/tests/test_utils.py Line 252 in cb8b7c6
If I understand correctly, it should be assert compiled_model._orig_mod == compiled_model_unwrapped if |
Hmm, good question, I'll leave that to the others to answer, as I'm not sure. |
That's right. But the current behavior is that we don't unwrap the compiled model with cc @muellerzr what do you prefer, modify how |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Thanks for trying to tackle this. I think what I'd rather see instead is in extract_model_from_parallel
we should add a new param arg called keep_torch_compile
, which similar to keep_fp32_wrapper
should default to False
for a number of versions and then we likely should flip that to True after awhile, & modify the logic in extract_model_from_parallel
to reflect this change.
Can you tweak this PR to so? :)
Sure thing. I added I am also curious about what you think in the case of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this and the tests !
Also, note that we have to also upstream the modification done to extract_model_from_parallel
in transformers if needed.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…odel_from_parallel`.
71d8bf1
to
c2e9e5b
Compare
unwrap_model
for distributed compiled model.keep_torch_compile
param to unwrap_model
and extract_model_from_parallel
for distributed compiled model.
@SunMarc @muellerzr Please let me know if there are any changes needed for this PR! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
cc @muellerzr for final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work on this! Looks great to me
What does this PR do?
This PR fixes the unexpected behavior of
Accelerator.unwrap_model
(Issue #3281). Right now, if the model is wrapped in both distributed wrapper (e.g.DistributedDataParallel
orDeepSpeedEngine
) and compiled module (OptimizedModule
) it only unwraps the distributed module. This behavior arises from the following code in L80 ofutils/others.py
:accelerate/src/accelerate/utils/other.py
Line 80 in cb8b7c6
Instead of checking for compiled model both before and after unwrapping distributed wrapper, the current code only checks for compilation before unwrapping the distributed wrapper. If the model is wrapped in both,
is_compiled
will be set toFalse
and won't unwrap the model fully, resulting in unexpected behavior (users expect fully unwrapped model before saving, but getsOptimizedModule
instead, which may result in an error when loading the state dict due to key mismatch).Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Related Issue
Accelerate.unwrap
should fully unwrap the model. #3281Who can review?
@muellerzr @SunMarc