Skip to content

Commit

Permalink
[UnifiedCheckpoint] Add tied_weight_keys for pipeline model (PaddlePa…
Browse files Browse the repository at this point in the history
…ddle#9663)

* update trainer.py

* fix tied_weights_keys

* update
  • Loading branch information
DesmonDay authored Dec 23, 2024
1 parent 97ae9ad commit 3374e7f
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
12 changes: 11 additions & 1 deletion paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,13 @@ def unified_optimizer_into_shards(
static2struct_name_mappings = {}
state_dict = get_expected_state_dict(model)
fp32_weight = {}

extra_save_keys = {}
for k, v in state_dict.items():
static2struct_name_mappings[v.name] = k
if v.name not in static2struct_name_mappings:
static2struct_name_mappings[v.name] = k
else:
extra_save_keys[v.name] = k
if master_weights is not None and v.dtype == paddle.float32:
if args.dataset_rank > 0: # deal with different dataset rank.
continue
Expand All @@ -599,10 +604,15 @@ def unified_optimizer_into_shards(
static_name, type_name = generate_base_static_name(key)
new_name = static2struct_name_mappings[static_name] + "/" + type_name
optim_state_dict[new_name] = optim_state_dict.pop(key)
if static_name in extra_save_keys:
extra_new_name = extra_save_keys[static_name] + "/" + type_name
optim_state_dict[extra_new_name] = optim_state_dict[new_name]

if master_weights is not None:
for key in list(master_weights.keys()):
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key)
if key in extra_save_keys:
master_weights[extra_save_keys[key]] = master_weights[static2struct_name_mappings[key]]
master_weights.update(fp32_weight)

# filter optimizer param
Expand Down
9 changes: 0 additions & 9 deletions paddlenlp/trainer/unified_checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,6 @@ def get_expected_state_dict(model_to_save, **kwargs):

if isinstance(model_to_save, PretrainedModel):
state_dict = model_to_save.state_dict()
if (
hasattr(model_to_save.config, "tie_word_embeddings")
and model_to_save.config.tie_word_embeddings
and hasattr(model_to_save, "_tied_weights_keys")
and model_to_save._tied_weights_keys is not None
):
for key in model_to_save._tied_weights_keys:
if key in state_dict:
state_dict.pop(key)
elif isinstance(model_to_save, LoRAModel):
concat_additional_adapter = kwargs.get("concat_additional_adapter", False)
concat_init_lora = model_to_save.lora_config.loraga and concat_additional_adapter
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
_get_fuse_or_split_param_mappings = LlamaPretrainedModel._get_fuse_or_split_param_mappings
_init_weights = LlamaPretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = LlamaPretrainedModel._keys_to_ignore_on_load_unexpected
_tied_weights_keys = ["lm_head.weight"]

# DONOT Add base_model_prefix !!!!

Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/qwen2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class Qwen2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
_get_tensor_parallel_mappings = Qwen2PretrainedModel._get_tensor_parallel_mappings
_init_weights = Qwen2PretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = Qwen2PretrainedModel._keys_to_ignore_on_load_unexpected
_tied_weights_keys = ["lm_head.weight"]

# DONOT Add base_model_prefix !!!!

Expand Down

0 comments on commit 3374e7f

Please sign in to comment.