Skip to content

Commit

Permalink
use pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Jan 24, 2025
1 parent ec56132 commit 8d200b8
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
AutoModel,
pipeline,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Expand Down Expand Up @@ -204,8 +205,10 @@ def save_model_patch(*args: Any, **kwargs: Any):
mlflow.set_registry_uri(mlflow_logger.model_registry_uri)

if is_peft:
transformers_in_memory_model = AutoModel.from_pretrained(
transformers_model,
transformers_in_memory_model = pipeline(
model=AutoModel.from_pretrained(
transformers_model,
)
)
else:
transformers_in_memory_model = None
Expand Down Expand Up @@ -619,7 +622,9 @@ def tensor_hook(

hooks = []
for _, module in state_dict_model.named_modules():
hooks.append(module._register_state_dict_hook(tensor_hook),)
hooks.append(
module._register_state_dict_hook(tensor_hook),
)

state_dict = get_model_state_dict(
state_dict_model,
Expand Down

0 comments on commit 8d200b8

Please sign in to comment.