Skip to content

Commit

Permalink
apply remove/add_hooks fix from trl huggingface#2776
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan committed Feb 14, 2025
1 parent b8574f0 commit cb94d1f
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def setup_chat_format(

def remove_hooks(model: "DeepSpeedEngine") -> None:
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
# From https://github.com/huggingface/trl/pull/2776
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
Expand Down Expand Up @@ -164,6 +167,9 @@ def iter_params(module, recurse=False):

def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
# From https://github.com/huggingface/trl/pull/2776
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
Expand Down

0 comments on commit cb94d1f

Please sign in to comment.