-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TransformerEngine Integration (#1282)
* Implemented ColumnParallelLinear with Transformer-Engine * Implemented RowParallelLinear with Transformer-Engine * Implemented LayerNormMLP with Transformer-Engine * Implemented MultiheadAttention with Transformer-Engine * Cleaned up transformer.py * Cleaned up neox_args * Cleaned up neox_args * - Fixed TE_MHA and added rope support - Implemented delayed scaling * Fixed mixed files. * Implemented ColumnParallelLinear with Transformer-Engine * Implemented RowParallelLinear with Transformer-Engine * Implemented LayerNormMLP with Transformer-Engine * Implemented MultiheadAttention with Transformer-Engine * Cleaned up transformer.py * Cleaned up neox_args * Cleaned up neox_args * - Fixed TE_MHA and added rope support - Implemented delayed scaling * Fixed mixed files. * Changed get_linear name * Added rng tracker to lnmlp and placed rope in te_mha init instead of forward * Updated fp8 arguments to te_fp8 * Added EAI copyright * precommit * add sample TE config * add te to readme * remove pip install prefix from reqs file * Force TE pytorch in requirements file --------- Co-authored-by: Quentin Anthony <[email protected]>
- Loading branch information
1 parent
29080c3
commit 8900d05
Showing
8 changed files
with
968 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# GPT-2 pretraining setup | ||
{ | ||
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages | ||
# across the node boundaries ) | ||
"pipe_parallel_size": 1, | ||
"model_parallel_size": 1, | ||
|
||
# model settings | ||
"num_layers": 24, | ||
"hidden_size": 2048, | ||
"num_attention_heads": 16, | ||
"seq_length": 2048, | ||
"max_position_embeddings": 2048, | ||
"norm": "layernorm", | ||
"pos_emb": "rotary", | ||
"no_weight_tying": true, | ||
"gpt_j_residual": false, | ||
"output_layer_parallelism": "column", | ||
|
||
# Transformer Engine settings | ||
"te_columnparallel": false, | ||
"te_rowparallel": false, | ||
"te_layernorm_mlp": true, | ||
"te_mha": true, | ||
"te_fp8_format": "hybrid", | ||
"te_fp8_wgrad": true, | ||
"te_fp8_amax_history_len": 1, | ||
"te_fp8_amax_compute_algo": "most_recent", | ||
"te_fp8_margin": 0, | ||
"te_fp8_mha": false, | ||
|
||
# these should provide some speedup but takes a while to build, set to true if desired | ||
"scaled_upper_triang_masked_softmax_fusion": false, | ||
"bias_gelu_fusion": false, | ||
"rope_fusion": false, | ||
"layernorm_fusion": false, | ||
|
||
# init methods | ||
"init_method": "small_init", | ||
"output_layer_init_method": "wang_init", | ||
|
||
# optimizer settings | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.0002, | ||
"betas": [0.9, 0.95], | ||
"eps": 1.0e-8, | ||
} | ||
}, | ||
"min_lr": 0.00002, | ||
|
||
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training | ||
"zero_optimization": { | ||
"stage": 1, | ||
"allgather_partitions": True, | ||
"allgather_bucket_size": 500000000, | ||
"overlap_comm": True, | ||
"reduce_scatter": True, | ||
"reduce_bucket_size": 500000000, | ||
"contiguous_gradients": True, | ||
}, | ||
|
||
# batch / data settings | ||
"train_micro_batch_size_per_gpu": 4, | ||
"data_impl": "mmap", | ||
|
||
# activation checkpointing | ||
"checkpoint_activations": true, | ||
"checkpoint_num_layers": 1, | ||
"partition_activations": true, | ||
"synchronize_each_layer": true, | ||
|
||
# regularization | ||
"gradient_clipping": 1.0, | ||
"weight_decay": 0.1, | ||
"hidden_dropout": 0, | ||
"attention_dropout": 0, | ||
|
||
# precision settings | ||
"fp16": { | ||
"fp16": true, | ||
"enabled": true, | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1 | ||
}, | ||
|
||
# misc. training settings | ||
"train_iters": 320000, | ||
"lr_decay_iters": 320000, | ||
"distributed_backend": "nccl", | ||
"lr_decay_style": "cosine", | ||
"warmup": 0.01, | ||
"checkpoint_factor": 10000, | ||
"eval_interval": 1000, | ||
"eval_iters": 10, | ||
|
||
# logging | ||
"log_interval": 100, | ||
"steps_per_print": 10, | ||
"keep_last_n_checkpoints": 4, | ||
"wall_clock_breakdown": true, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.