From e1d8ad31a349545ea2e06fa159d74e0370b907c7 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 20 Sep 2022 08:17:59 +0500 Subject: [PATCH 1/2] Encode reshaping guide --- megatron/checkpointing.py | 54 ++++++++- tools/convert_checkpoint/ds_to_universal.py | 118 ++++++-------------- 2 files changed, 89 insertions(+), 83 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index dacbec7dc..cad801407 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -31,6 +31,19 @@ get_tokenizer) from megatron.enums import PositionEmbeddingType + +from deepspeed.checkpoint import ( + ORIGINAL_VOCAB_SIZE, + PADDED_VOCAB_SIZE, + UNIVERSAL_CHECKPOINT_INFO, + UNIVERSAL_CHECKPOINT_VERSION_KEY, + UNIVERSAL_CHECKPOINT_VERSION_VALUE, + VOCABULARY_PARAMETERS_PATTERN, + PIPELINE_REPLICATED_PARAMETERS_PATTERN, + PARAMETERS_TO_AVERAGE_PATTERN, + PARAMETERS_WITH_ROW_PARALLELISM_PATTERN, +) + _CHECKPOINT_VERSION = None def set_checkpoint_version(value): @@ -133,6 +146,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): state_dict['iteration'] = iteration state_dict['tokens'] = args.consumed_train_tokens state_dict['checkpoint_info'] = _checkpoint_info() + state_dict[UNIVERSAL_CHECKPOINT_INFO] = _universal_checkpoint_info() # DeepSpeed saves the model/optimizer/scheduler if not args.deepspeed: @@ -480,4 +494,42 @@ def _checkpoint_info(): return { "padded_vocab_size": args.padded_vocab_size, "original_vocab_size": tokenizer.vocab_size, - } \ No newline at end of file + } + +def _universal_checkpoint_info(): + args = get_args() + tokenizer = get_tokenizer() + + info = dict() + info[UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE + info[ORIGINAL_VOCAB_SIZE] = tokenizer.vocab_size + info[PADDED_VOCAB_SIZE] = args.padded_vocab_size + + # Vocabulary parameters (embeddings) that require special handling due to padding. + info[VOCABULARY_PARAMETERS_PATTERN] = ["word_embeddings.weight"] + + # Replicated (shared) parameters on the pipeline dimension + info[PIPELINE_REPLICATED_PARAMETERS_PATTERN] = ["word_embeddings.weight"] + + # Parameter slices that should be averaged not concatenated. + info[PARAMETERS_TO_AVERAGE_PATTERN] = [ + r"tied_modules.embed.word_embeddings.norm.weight", + r"tied_modules.embed.word_embeddings.norm.bias", + r"\d+.input_layernorm.weight", + r"\d+.input_layernorm.bias", + r"\d+.post_attention_layernorm.weight", + r"\d+.post_attention_layernorm.bias", + r"\d+.self_attention.dense.bias", + r"\d+.mlp.dense_4h_to_h.bias", + r"\d+.weight", + r"\d+.bias", + ] + + # Parameter that are sliced on the row dimension + info[PARAMETERS_WITH_ROW_PARALLELISM_PATTERN] = [ + "dense_4h_to_h.weight", + "self_attention.dense.weight", + ] + + return info + diff --git a/tools/convert_checkpoint/ds_to_universal.py b/tools/convert_checkpoint/ds_to_universal.py index 9a5dd1154..0a282cab9 100755 --- a/tools/convert_checkpoint/ds_to_universal.py +++ b/tools/convert_checkpoint/ds_to_universal.py @@ -23,21 +23,23 @@ if root_repo_path not in sys.path: sys.path.insert(0, root_repo_path) - from deepspeed.checkpoint import DeepSpeedCheckpoint - -MODEL_KEY = 'model' -ARGS_KEY = 'args' -LANGUGAGE_MODEL_KEY = 'language_model' -EMBEDDING_KEY = 'embedding' -ENCODER_KEY = 'encoder' -WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head' -WORD_EMBEDDINGS_KEY = 'word_embeddings' -FINAL_LAYER_NORM_KEY = 'final_layernorm' -CHECKPOINT_VERSION_KEY = 'checkpoint_version' -CHECKPOINT_VERSION_VALUE = 3.0 -ITERATION_KEY = 'iteration' - +from deepspeed.checkpoint import ( + OPTIMIZER_STATE_DICT, + BASE_OPTIMIZER_STATE, + SINGLE_PARTITION_OF_FP32_GROUPS, + PARAM_SLICE_MAPPINGS, + PARAM_SHAPES, + PARAM, + CAT_DIM, + VOCAB_DIVISIBILITY_PADDING_TENSOR, + ORIGINAL_VOCAB_SIZE, + UNIVERSAL_CHECKPOINT_INFO, + VOCABULARY_PARAMETERS_PATTERN, + PIPELINE_REPLICATED_PARAMETERS_PATTERN, + PARAMETERS_TO_AVERAGE_PATTERN, + PARAMETERS_WITH_ROW_PARALLELISM_PATTERN, +) def parse_arguments(): parser = argparse.ArgumentParser() @@ -72,16 +74,6 @@ def parse_arguments(): return args -def _convert_ds_transformer_state(sd_list): - new_sd = OrderedDict() - for i, sd in enumerate(sd_list): - for key, value in sd.items(): - new_key = f'layers.{i}.{key}' - new_sd[new_key] = value - - return new_sd - - def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): path_list = [] iter_folder = f'iter_{iteration:07d}' @@ -96,17 +88,6 @@ def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): return path_list -def _create_megatron_dict(): - language_model_dict = {EMBEDDING_KEY: {}, ENCODER_KEY: {}} - megatron_dict = { - MODEL_KEY: { - LANGUGAGE_MODEL_KEY: language_model_dict - }, - CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE - } - return megatron_dict - - def _save_checkpoint(file_path, chkpt_sd): dir, _ = os.path.split(file_path) os.makedirs(dir, exist_ok=True) @@ -123,13 +104,14 @@ def extract_zero_shards(dir, slice_shapes, ds_checkpoint, indices_3D): #pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}") - optim_sd = sd["optimizer_state_dict"] - param_slice_mappings = optim_sd["param_slice_mappings"] - + optim_sd = sd[OPTIMIZER_STATE_DICT] + param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS] + universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) + pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETERS_PATTERN, []) # dict - state_groups = optim_sd["base_optimizer_state"]["state"] + state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] # list - fp32_groups = optim_sd["single_partition_of_fp32_groups"] + fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] param_groups_cnt = len(state_groups) for param_group_id in range(param_groups_cnt): @@ -141,7 +123,7 @@ def extract_zero_shards(dir, slice_shapes, ds_checkpoint, indices_3D): ) for name,fragment_mapping in param_slice_mappings[param_group_id].items(): - if "word_embeddings.weight" in name and pp_index > 0: + if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params): # Skip tied weights that are replicated in first and last pp stages continue @@ -176,7 +158,6 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): for tp_index in range(tp_degree): prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") paths = sorted(list(glob.glob(f"{prefix_path}.0*"))) - #print(paths) shards = [torch.load(p) for p in paths] slice = torch.cat(shards, dim=0).reshape(slice_shape) slices.append(slice) @@ -184,33 +165,6 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): return slices -ORIGINAL_VOCAB_SIZE = 'original_vocab_size' -def _strip_vocab_padding(ds_checkpoint, padded_vocab_tensor): - checkpoint_info = ds_checkpoint.get_checkpoint_info() - padding_tensor = padded_vocab_tensor.narrow(0, checkpoint_info[ORIGINAL_VOCAB_SIZE], padded_vocab_tensor.shape[0]-checkpoint_info[ORIGINAL_VOCAB_SIZE]) - #print(f'{padded_vocab_tensor[checkpoint_info[ORIGINAL_VOCAB_SIZE]-3:,:]=}') - return padded_vocab_tensor.narrow(0, 0, checkpoint_info[ORIGINAL_VOCAB_SIZE]) - - -WEIGHTS_TO_AVERAGE_PATTERNS = [ - r"tied_modules.embed.word_embeddings.norm.weight", - r"tied_modules.embed.word_embeddings.norm.bias", - r"\d+.input_layernorm.weight", - r"\d+.input_layernorm.bias", - r"\d+.post_attention_layernorm.weight", - r"\d+.post_attention_layernorm.bias", - r"\d+.self_attention.dense.bias", - r"\d+.mlp.dense_4h_to_h.bias", - r"\d+.weight", - r"\d+.bias", -] - -WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [ - "dense_4h_to_h.weight", - "self_attention.dense.weight", -] - - def _get_vocab_divisibility_padding_tensor(ds_checkpoint, padded_vocab_tensor): checkpoint_info = ds_checkpoint.get_checkpoint_info() if padded_vocab_tensor.shape[0] > checkpoint_info[ORIGINAL_VOCAB_SIZE]: @@ -223,6 +177,10 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): slice_base_path = os.path.join(slice_dir, name) param_base_path = os.path.join(dir, name) + universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) + parameters_to_average = universal_checkpoint_info.get(PARAMETERS_TO_AVERAGE_PATTERN, []) + parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETERS_WITH_ROW_PARALLELISM_PATTERN, []) + vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETERS_PATTERN, []) for state in ("fp32", "exp_avg", "exp_avg_sq"): slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) final_path = os.path.join(param_base_path, f"{state}.pt") @@ -230,30 +188,27 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): #print(f"Expected shape: {shape}") #print(f"Fragment sizes:", list(frag.shape for frag in slices)) ckpt_dict = {} - if any(re.match(pattern, name) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS): + if any(re.match(pattern, name) for pattern in parameters_to_average): param = sum(slices) / len(slices) else: - cat_dim = 1 if any(text in name for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0 + cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0 #print(f"CAT DIM: {cat_dim}") param = torch.cat(slices, dim=cat_dim) - ckpt_dict['cat_dim'] = cat_dim + ckpt_dict[CAT_DIM] = cat_dim - if "word_embeddings.weight" in name: + if any(re.match(pattern, name) for pattern in vocabulary_parameters): #print(f"Before {param.shape=}") # strip padding #param = _strip_vocab_padding(ds_checkpoint, param) - ckpt_dict['vocab_divisibility_padding_tensor'] = _get_vocab_divisibility_padding_tensor(ds_checkpoint, param) + ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor(ds_checkpoint, param) #print(f"After {param.shape=}") #print(f"Final shape: {param.shape}") - ckpt_dict['param'] = param + ckpt_dict[PARAM] = param _save_checkpoint(final_path, ckpt_dict) - - - def _get_chunks(l, n): for i in range(0, len(l), n): yield l[i:i + n] @@ -268,9 +223,9 @@ def _do_parallel_work(do_work, work_chunks, num_workers): def _extract_zero_shard_files(args, ds_checkpoint, slice_shapes, temp_dir): _3d_range_list = list(itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree))) - #pprint(_3d_range_list) + #pprint(f'{_3d_range_list=}') work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) - #pprint(work_chunks) + #pprint(f'{work_chunks=}') do_work = partial(extract_zero_shards, temp_dir, slice_shapes, ds_checkpoint) _do_parallel_work(do_work, work_chunks, args.num_extract_workers) @@ -295,7 +250,6 @@ def main(): ) ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)#, 1, 2) # args.target_tp, args.target_pp) - iteration = ds_checkpoint.get_iteration() #_create_latest_file(args.output_folder, iteration) checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, @@ -305,7 +259,7 @@ def main(): slice_shapes = [] for mp_rank_file in ds_checkpoint.mp_rank_files: mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu')) - slice_shapes += mp_sd["param_shapes"] + slice_shapes += mp_sd[PARAM_SHAPES] # fix back to normal flat dict, merge duplicates for tp>1 slice_shapes = dict((k,v) for d in slice_shapes for k,v in d.items() ) From 06a039736d850aabee0a8c4b300f5987cda9c78f Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 20 Sep 2022 17:59:52 +0500 Subject: [PATCH 2/2] Naming fixes --- megatron/checkpointing.py | 16 +++--- requirements.txt | 2 +- run_bf16.sh | 22 ++++---- run_fp16.sh | 58 +++++++++++++-------- run_universal_bf16.sh | 26 ++++----- tools/convert_checkpoint/ds_to_universal.py | 16 +++--- 6 files changed, 78 insertions(+), 62 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index cad801407..5c3de9d3a 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -38,10 +38,10 @@ UNIVERSAL_CHECKPOINT_INFO, UNIVERSAL_CHECKPOINT_VERSION_KEY, UNIVERSAL_CHECKPOINT_VERSION_VALUE, - VOCABULARY_PARAMETERS_PATTERN, - PIPELINE_REPLICATED_PARAMETERS_PATTERN, - PARAMETERS_TO_AVERAGE_PATTERN, - PARAMETERS_WITH_ROW_PARALLELISM_PATTERN, + VOCABULARY_PARAMETER_PATTERNS, + PIPELINE_REPLICATED_PARAMETER_PATTERNS, + PARAMETER_TO_AVERAGE_PATTERNS, + PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, ) _CHECKPOINT_VERSION = None @@ -506,13 +506,13 @@ def _universal_checkpoint_info(): info[PADDED_VOCAB_SIZE] = args.padded_vocab_size # Vocabulary parameters (embeddings) that require special handling due to padding. - info[VOCABULARY_PARAMETERS_PATTERN] = ["word_embeddings.weight"] + info[VOCABULARY_PARAMETER_PATTERNS] = ["word_embeddings.weight"] # Replicated (shared) parameters on the pipeline dimension - info[PIPELINE_REPLICATED_PARAMETERS_PATTERN] = ["word_embeddings.weight"] + info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = ["word_embeddings.weight"] # Parameter slices that should be averaged not concatenated. - info[PARAMETERS_TO_AVERAGE_PATTERN] = [ + info[PARAMETER_TO_AVERAGE_PATTERNS] = [ r"tied_modules.embed.word_embeddings.norm.weight", r"tied_modules.embed.word_embeddings.norm.bias", r"\d+.input_layernorm.weight", @@ -526,7 +526,7 @@ def _universal_checkpoint_info(): ] # Parameter that are sliced on the row dimension - info[PARAMETERS_WITH_ROW_PARALLELISM_PATTERN] = [ + info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = [ "dense_4h_to_h.weight", "self_attention.dense.weight", ] diff --git a/requirements.txt b/requirements.txt index da76b5e44..aba50f72f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ six tensorboard torch>=1.7 transformers -DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git +#DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git # versions from HF transformers black==21.4b0 isort>=5.5.4 diff --git a/run_bf16.sh b/run_bf16.sh index fc884d4af..36af9d8df 100755 --- a/run_bf16.sh +++ b/run_bf16.sh @@ -30,40 +30,42 @@ CONFIG_JSON="$script_dir/ds_config.json" USE_DEEPSPEED=1 ZERO_STAGE=0 +DTYPE="bf16" #TP=4 #PP=4 # Debug -DEBUG_MODE=0 +DEBUG_MODE=1 if [[ $DEBUG_MODE == 1 ]]; then LAYERS=4 HIDDEN=512 SEQ=512 - EXIT_INTERVAL=3 + EXIT_INTERVAL=100 + RUN_TAG="toy" else HIDDEN=1024 LAYERS=24 SEQ=1024 - EXIT_INTERVAL=10 + EXIT_INTERVAL=100 + RUN_TAG="big" fi TP=2 PP=2 -DP=4 +DP=2 WORLD_SIZE=$((TP*PP*DP)) GLOBAL_BATCH=4 MICRO_BATCH=1 TRAIN_ITERS=100000 -CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP} -LOAD_CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP} +CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP}_$RUN_TAG +LOAD_CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP}_$RUN_TAG LR=6.0e-4 MIN_LR=6.0e-5 -DTYPE="bf16" -EXP_DIR=${HOME}/experiments/results/ckpt_reshape -LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_cont" +EXP_DIR=${HOME}/experiments/results/uni_ckpt +LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_cont_$RUN_TAG" mkdir -p $LOG_DIR while [[ $# -gt 0 ]] @@ -166,7 +168,7 @@ cat < $CONFIG_JSON } EOT -#WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" +WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" #WORKER_STR="-i worker-0:0,1,2,3" #run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}" #run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}" diff --git a/run_fp16.sh b/run_fp16.sh index dcb2a0143..6d96c9fdf 100755 --- a/run_fp16.sh +++ b/run_fp16.sh @@ -12,7 +12,7 @@ DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` #DATASET_3="" #DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}" -BASE_DATA_PATH=/data/Megatron-LM/data +BASE_DATA_PATH=/vc_data/Megatron-LM/data DATASET=${BASE_DATA_PATH}/indexed_datasets/megatron VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt @@ -20,39 +20,46 @@ MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt script_path=$(realpath $0) script_dir=$(dirname $script_path) -#CONFIG_JSON="$script_dir/ds_config.json" -CONFIG_JSON="/tmp/ds_config.json" +CONFIG_JSON="$script_dir/ds_config.json" +#CONFIG_JSON="/tmp/ds_config.json" USE_DEEPSPEED=1 -ZERO_STAGE=0 - +ZERO_STAGE=2 +DTYPE="fp16" -# Debug #TP=4 #PP=4 -#LAYERS=8 -#HIDDEN=512 -#SEQ=1024 -#GLOBAL_BATCH=128 -#WORKER_STR="-i worker-0" +# Debug +DEBUG_MODE=1 +if [[ $DEBUG_MODE == 1 ]]; then + LAYERS=4 + HIDDEN=512 + SEQ=512 + EXIT_INTERVAL=100 + RUN_TAG="toy" +else + HIDDEN=1024 + LAYERS=24 + SEQ=1024 + EXIT_INTERVAL=100 + RUN_TAG="big" +fi TP=1 PP=1 -DP=2 +DP=1 WORLD_SIZE=$((TP*PP*DP)) -HIDDEN=1024 -LAYERS=24 -SEQ=1024 -GLOBAL_BATCH=1 -WORKER_STR="" +GLOBAL_BATCH=4 MICRO_BATCH=1 +TRAIN_ITERS=100000 +CHECKPOINT_PATH=checkpoints/gpt2/z${ZERO_STAGE}/$DTYPE/tp${TP}_pp${PP}_dp${DP}_$RUN_TAG +LOAD_CHECKPOINT_PATH=checkpoints/gpt2/z${ZERO_STAGE}/$DTYPE/tp${TP}_pp${PP}_dp${DP}_$RUN_TAG LR=6.0e-4 MIN_LR=6.0e-5 -DTYPE="fp16" -EXP_DIR=${HOME}/experiments/results/bf16 -LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_fix3" +EXP_DIR="${HOME}/experiments/results/z${ZERO_STAGE}_uni_ckpt" +LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_cont_$RUN_TAG" mkdir -p $LOG_DIR while [[ $# -gt 0 ]] @@ -88,7 +95,7 @@ options=" \ --max-position-embeddings $SEQ \ --micro-batch-size $MICRO_BATCH \ --global-batch-size $GLOBAL_BATCH \ - --train-iters 1000 \ + --train-iters $TRAIN_ITERS \ --lr $LR \ --min-lr $MIN_LR \ --lr-decay-style cosine \ @@ -98,7 +105,7 @@ options=" \ --data-path ${DATASET} \ --vocab-file ${VOCAB_PATH} \ --merge-file ${MERGE_PATH} \ - --save-interval 10000 \ + --save-interval 1000 \ --split 98,2,0 \ --clip-grad 1.0 \ --weight-decay 0.1 \ @@ -107,7 +114,12 @@ options=" \ --init-method-std 0.006 \ --${DTYPE} \ --checkpoint-activations \ - --exit-interval 10000 \ + --exit-interval ${EXIT_INTERVAL} \ + --save ${CHECKPOINT_PATH} \ + --load ${LOAD_CHECKPOINT_PATH} \ + --position-embedding-type alibi \ + --override-lr-scheduler \ + --embed-layernorm \ --tensorboard-dir $LOG_DIR " diff --git a/run_universal_bf16.sh b/run_universal_bf16.sh index 7a60c34c1..52fa256ad 100755 --- a/run_universal_bf16.sh +++ b/run_universal_bf16.sh @@ -30,40 +30,42 @@ CONFIG_JSON="$script_dir/ds_config.json" USE_DEEPSPEED=1 ZERO_STAGE=0 +DTYPE="bf16" #TP=4 #PP=4 # Debug -DEBUG_MODE=0 +DEBUG_MODE=1 if [[ $DEBUG_MODE == 1 ]]; then LAYERS=4 HIDDEN=512 SEQ=512 - EXIT_INTERVAL=3 + EXIT_INTERVAL=100 + RUN_TAG="toy" else HIDDEN=1024 LAYERS=24 SEQ=1024 - EXIT_INTERVAL=10 + EXIT_INTERVAL=100 + RUN_TAG="big" fi -TP=2 -PP=2 -DP=4 +TP=1 +PP=1 +DP=2 WORLD_SIZE=$((TP*PP*DP)) GLOBAL_BATCH=4 MICRO_BATCH=1 TRAIN_ITERS=100000 -CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP} -LOAD_CHECKPOINT_PATH=checkpoints/gpt2/tp2_pp2_dp4 +CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP}_$RUN_TAG +LOAD_CHECKPOINT_PATH=checkpoints/gpt2/tp2_pp2_dp2_$RUN_TAG LR=6.0e-4 MIN_LR=6.0e-5 -DTYPE="bf16" -EXP_DIR=${HOME}/experiments/results/ckpt_reshape -LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_uni" +EXP_DIR=${HOME}/experiments/results/uni_ckpt +LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_ref_uni_$RUN_TAG" mkdir -p $LOG_DIR while [[ $# -gt 0 ]] @@ -167,7 +169,7 @@ cat < $CONFIG_JSON } EOT -#WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" +WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE" #WORKER_STR="-i worker-0:0,1,2,3" #run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}" #run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}" diff --git a/tools/convert_checkpoint/ds_to_universal.py b/tools/convert_checkpoint/ds_to_universal.py index 0a282cab9..eb4383e22 100755 --- a/tools/convert_checkpoint/ds_to_universal.py +++ b/tools/convert_checkpoint/ds_to_universal.py @@ -35,10 +35,10 @@ VOCAB_DIVISIBILITY_PADDING_TENSOR, ORIGINAL_VOCAB_SIZE, UNIVERSAL_CHECKPOINT_INFO, - VOCABULARY_PARAMETERS_PATTERN, - PIPELINE_REPLICATED_PARAMETERS_PATTERN, - PARAMETERS_TO_AVERAGE_PATTERN, - PARAMETERS_WITH_ROW_PARALLELISM_PATTERN, + VOCABULARY_PARAMETER_PATTERNS, + PIPELINE_REPLICATED_PARAMETER_PATTERNS, + PARAMETER_TO_AVERAGE_PATTERNS, + PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, ) def parse_arguments(): @@ -107,7 +107,7 @@ def extract_zero_shards(dir, slice_shapes, ds_checkpoint, indices_3D): optim_sd = sd[OPTIMIZER_STATE_DICT] param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS] universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) - pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETERS_PATTERN, []) + pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, []) # dict state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] # list @@ -178,9 +178,9 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): param_base_path = os.path.join(dir, name) universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) - parameters_to_average = universal_checkpoint_info.get(PARAMETERS_TO_AVERAGE_PATTERN, []) - parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETERS_WITH_ROW_PARALLELISM_PATTERN, []) - vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETERS_PATTERN, []) + parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, []) + parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, []) + vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, []) for state in ("fp32", "exp_avg", "exp_avg_sq"): slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) final_path = os.path.join(param_base_path, f"{state}.pt")