diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 6fa35c4980..30d4b294eb 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -38,6 +38,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import ( AutoModel, + pipeline, PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase, @@ -204,8 +205,10 @@ def save_model_patch(*args: Any, **kwargs: Any): mlflow.set_registry_uri(mlflow_logger.model_registry_uri) if is_peft: - transformers_in_memory_model = AutoModel.from_pretrained( - transformers_model, + transformers_in_memory_model = pipeline( + model=AutoModel.from_pretrained( + transformers_model, + ) ) else: transformers_in_memory_model = None @@ -619,7 +622,9 @@ def tensor_hook( hooks = [] for _, module in state_dict_model.named_modules(): - hooks.append(module._register_state_dict_hook(tensor_hook),) + hooks.append( + module._register_state_dict_hook(tensor_hook), + ) state_dict = get_model_state_dict( state_dict_model,