diff --git a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py index 3c7c1e746916..0cb38bec94bb 100644 --- a/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py +++ b/paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py @@ -587,8 +587,13 @@ def unified_optimizer_into_shards( static2struct_name_mappings = {} state_dict = get_expected_state_dict(model) fp32_weight = {} + + extra_save_keys = {} for k, v in state_dict.items(): - static2struct_name_mappings[v.name] = k + if v.name not in static2struct_name_mappings: + static2struct_name_mappings[v.name] = k + else: + extra_save_keys[v.name] = k if master_weights is not None and v.dtype == paddle.float32: if args.dataset_rank > 0: # deal with different dataset rank. continue @@ -599,10 +604,15 @@ def unified_optimizer_into_shards( static_name, type_name = generate_base_static_name(key) new_name = static2struct_name_mappings[static_name] + "/" + type_name optim_state_dict[new_name] = optim_state_dict.pop(key) + if static_name in extra_save_keys: + extra_new_name = extra_save_keys[static_name] + "/" + type_name + optim_state_dict[extra_new_name] = optim_state_dict[new_name] if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + if key in extra_save_keys: + master_weights[extra_save_keys[key]] = master_weights[static2struct_name_mappings[key]] master_weights.update(fp32_weight) # filter optimizer param diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index 6cc57c148a08..bbb49ae14820 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -210,15 +210,6 @@ def get_expected_state_dict(model_to_save, **kwargs): if isinstance(model_to_save, PretrainedModel): state_dict = model_to_save.state_dict() - if ( - hasattr(model_to_save.config, "tie_word_embeddings") - and model_to_save.config.tie_word_embeddings - and hasattr(model_to_save, "_tied_weights_keys") - and model_to_save._tied_weights_keys is not None - ): - for key in model_to_save._tied_weights_keys: - if key in state_dict: - state_dict.pop(key) elif isinstance(model_to_save, LoRAModel): concat_additional_adapter = kwargs.get("concat_additional_adapter", False) concat_init_lora = model_to_save.lora_config.loraga and concat_additional_adapter diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index 1ec4c027a72a..948ec9857338 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -314,6 +314,7 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): _get_fuse_or_split_param_mappings = LlamaPretrainedModel._get_fuse_or_split_param_mappings _init_weights = LlamaPretrainedModel._init_weights _keys_to_ignore_on_load_unexpected = LlamaPretrainedModel._keys_to_ignore_on_load_unexpected + _tied_weights_keys = ["lm_head.weight"] # DONOT Add base_model_prefix !!!! diff --git a/paddlenlp/transformers/qwen2/modeling_pp.py b/paddlenlp/transformers/qwen2/modeling_pp.py index 916baad328ce..aa2d125e3e6e 100644 --- a/paddlenlp/transformers/qwen2/modeling_pp.py +++ b/paddlenlp/transformers/qwen2/modeling_pp.py @@ -233,6 +233,7 @@ class Qwen2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): _get_tensor_parallel_mappings = Qwen2PretrainedModel._get_tensor_parallel_mappings _init_weights = Qwen2PretrainedModel._init_weights _keys_to_ignore_on_load_unexpected = Qwen2PretrainedModel._keys_to_ignore_on_load_unexpected + _tied_weights_keys = ["lm_head.weight"] # DONOT Add base_model_prefix !!!!