diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index d412a4d97f..e77dbc62a6 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -37,7 +37,6 @@ ) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import ( - AutoModel, PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase, @@ -132,14 +131,12 @@ def _maybe_get_license_filename( def _log_model_with_multi_process( mlflow_logger: MLFlowLogger, python_logging_level: int, - transformers_model: str, - tokenizer: PreTrainedTokenizerBase, + transformers_model: Union[dict[str, Any], str], artifact_path: str, pretrained_model_name: str, registered_model_name: Optional[str], await_registration_for: int, mlflow_logging_config: dict[str, Any], - is_peft: bool = False, ): """Call MLFlowLogger.log_model. @@ -204,17 +201,9 @@ def save_model_patch(*args: Any, **kwargs: Any): if mlflow_logger.model_registry_uri is not None: mlflow.set_registry_uri(mlflow_logger.model_registry_uri) - if is_peft: - transformers_in_memory_model = { - 'model': AutoModel.from_pretrained(transformers_model), - 'tokenizer': tokenizer, - } - else: - transformers_in_memory_model = None - register_model_path = f'{mlflow_logger.model_registry_prefix}.{registered_model_name}' if mlflow_logger.model_registry_prefix and registered_model_name else registered_model_name mlflow_logger.log_model( - transformers_model=transformers_in_memory_model or transformers_model, + transformers_model=transformers_model, flavor='transformers', artifact_path=artifact_path, registered_model_name=register_model_path, @@ -731,8 +720,10 @@ def _register_hf_model( mlflow_logger, 'python_logging_level': logging.getLogger('llmfoundry').level, - 'transformers_model': - register_save_dir, + 'transformers_model': { + 'model': new_model_instance, + 'tokenizer': original_tokenizer + } if self.using_peft else register_save_dir, 'artifact_path': 'final_model_checkpoint', 'pretrained_model_name': @@ -743,10 +734,6 @@ def _register_hf_model( 3600, 'mlflow_logging_config': self.mlflow_logging_config, - 'is_peft': - self.using_peft, - 'tokenizer': - original_tokenizer }, )