Skip to content

Commit

Permalink
fix(unified checkpoint): model weights load
Browse files Browse the repository at this point in the history
when skipping model weighs save and saving master weights as model weights, unified checkpoint needs choose the model weights to load into master weights.
  • Loading branch information
DrownFish19 committed Jan 8, 2024
1 parent 0f472dc commit 1b7625e
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
os.makedirs(save_directory, exist_ok=True)

# save model weights
if skip_save_model_weight:
if not skip_save_model_weight:
state_dict, shard_file, sharded_index = unified_checkpoint_into_shards(
args, model_to_save, safe_serialization=safe_serialization
)
Expand Down Expand Up @@ -1660,6 +1660,10 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali
index_filename_master_weights = (
PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME
)
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
index_filename_master_weights = (
PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
)
else:
has_master_weight = False
index_filename_master_weights = None
Expand Down

0 comments on commit 1b7625e

Please sign in to comment.