From 4111d3d8e1bf21a6f93162f8f391553ec504972b Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Wed, 29 Nov 2023 22:05:16 +0800 Subject: [PATCH] add the unit test for parallel_cross_entropy --- .../nlp/gpt/auto/pretrain_gpt_base.yaml | 2 +- scripts/distribute/ci_case_auto.sh | 70 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/model_zoo/gpt-3/ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_base.yaml b/model_zoo/gpt-3/ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_base.yaml index c8a5f9abf76c..a6d2d6103823 100644 --- a/model_zoo/gpt-3/ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_base.yaml +++ b/model_zoo/gpt-3/ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_base.yaml @@ -20,7 +20,7 @@ Engine: dtype: "float16" level: "o2" scale_loss: 32768.0 - custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"] + custom_black_list: ["reduce_sum", "softmax_with_cross_entropy", "elementwise_div"] custom_white_list: ["lookup_table", "lookup_table_v2"] save_load: output_dir: ./output diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 9bf12703ccb4..6eb70d8f04fb 100644 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -820,6 +820,76 @@ function gpt_auto_sp_acc_check() { check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } + +function gpt_auto_parallel_cross_entropy_acc_check() { + echo "=========== $FUNCNAME run begin ===========" + export PYTHONPATH=/workspace/PaddleNLP/:$PYTHONPATH + export FLAGS_infer_spmd_enable=true + export FLAGS_call_stack_level=2 + mp_degree=4 + dp_degree=2 + pp_degree=1 + local_batch_size=8 + + # parallel cross entropy off + ori_log_dir=./ori_mp${mp_degree}_dp${dp_degree}_pp${pp_degree} + rm -rf $ori_log_dir + PARALLEL_CROSS_ENTROPY=false python -m paddle.distributed.launch --log_dir=$ori_log_dir --devices=$devices --rank 0 tools/auto.py \ + -c ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_345M_single_card.yaml \ + -o Model.hidden_dropout_prob=0 \ + -o Model.attention_probs_dropout_prob=0 \ + -o Model.use_recompute=True \ + -o Global.local_batch_size=$(($local_batch_size / $dp_degree)) \ + -o Global.micro_batch_size=$(($local_batch_size / $dp_degree)) \ + -o Distributed.dp_degree=${dp_degree} \ + -o Distributed.mp_degree=${mp_degree} \ + -o Distributed.pp_degree=${pp_degree} \ + -o Distributed.sharding.sharding_degree=1 \ + -o Distributed.sharding.sharding_stage=1 \ + -o Distributed.schedule_mode=FThenB \ + -o Engine.mix_precision.enable=False \ + -o Engine.mix_precision.level=o0 \ + -o Engine.max_steps=10 \ + -o Engine.eval_freq=100000 \ + -o Engine.verbose=3 \ + -o Engine.logging_freq=1 \ + -o Engine.save_load.output_dir="" 2>&1 | tee log.txt + + # parallel cross entropy on + par_log_dir=./par_mp${mp_degree}_dp${dp_degree}_pp${pp_degree} + rm -rf $par_log_dir + PARALLEL_CROSS_ENTROPY=true python -m paddle.distributed.launch --log_dir=$par_log_dir --devices=$devices --rank 0 tools/auto.py \ + -c ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_345M_single_card.yaml \ + -o Model.hidden_dropout_prob=0 \ + -o Model.attention_probs_dropout_prob=0 \ + -o Model.use_recompute=True \ + -o Global.local_batch_size=$(($local_batch_size / $dp_degree)) \ + -o Global.micro_batch_size=$(($local_batch_size / $dp_degree)) \ + -o Distributed.dp_degree=${dp_degree} \ + -o Distributed.mp_degree=${mp_degree} \ + -o Distributed.pp_degree=${pp_degree} \ + -o Distributed.sharding.sharding_degree=1 \ + -o Distributed.sharding.sharding_stage=1 \ + -o Distributed.schedule_mode=FThenB \ + -o Engine.mix_precision.enable=False \ + -o Engine.mix_precision.level=o0 \ + -o Engine.max_steps=10 \ + -o Engine.eval_freq=100000 \ + -o Engine.verbose=3 \ + -o Engine.logging_freq=1 \ + -o Engine.save_load.output_dir="" 2>&1 | tee log.txt + + # loss diff + loss=`cat ${par_log_dir}/workerlog.7 | grep '10/10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + ips=-1 + mem=-1 + loss_base=`cat ${ori_log_dir}/workerlog.7 | grep '10/10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + ips_base=-1 + mem_base=-1 + echo "result: loss_spTrue=$loss loss_spFasle=$loss_base" + check_result $FUNCNAME ${loss_base:0:8} ${loss:0:8} ${ips_base} ${ips} ${mem_base} ${mem} + echo "=========== $FUNCNAME run end ===========" +} ############ case end ############ function check_result() {