Skip to content

Commit

Permalink
TransformerEngine Integration (#1282)
Browse files Browse the repository at this point in the history
* Implemented ColumnParallelLinear with Transformer-Engine

* Implemented RowParallelLinear with Transformer-Engine

* Implemented LayerNormMLP with Transformer-Engine

* Implemented MultiheadAttention with Transformer-Engine

* Cleaned up transformer.py

* Cleaned up neox_args

* Cleaned up neox_args

* - Fixed TE_MHA and added rope support
- Implemented delayed scaling

* Fixed mixed files.

* Implemented ColumnParallelLinear with Transformer-Engine

* Implemented RowParallelLinear with Transformer-Engine

* Implemented LayerNormMLP with Transformer-Engine

* Implemented MultiheadAttention with Transformer-Engine

* Cleaned up transformer.py

* Cleaned up neox_args

* Cleaned up neox_args

* - Fixed TE_MHA and added rope support
- Implemented delayed scaling

* Fixed mixed files.

* Changed get_linear name

* Added rng tracker to lnmlp and placed rope in te_mha init instead of forward

* Updated fp8 arguments to te_fp8

* Added EAI copyright

* precommit

* add sample TE config

* add te to readme

* remove pip install prefix from reqs file

* Force TE pytorch in requirements file

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
aurelion-source and Quentin-Anthony authored Dec 19, 2024
1 parent 29080c3 commit 8900d05
Show file tree
Hide file tree
Showing 8 changed files with 968 additions and 171 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg
* Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, monitor experiments via [WandB](https://wandb.ai/site)/[Comet](https://www.comet.com/site/)/TensorBoard, and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).

## News
**[10/9/2024]** We now support [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) integration

**[9/9/2024]** We now support preference learning via [DPO](https://arxiv.org/abs/2305.18290), [KTO](https://arxiv.org/abs/2402.01306), and reward modeling

**[9/9/2024]** We now support integration with [Comet ML](https://www.comet.com/site/), a machine learning monitoring platform
Expand Down Expand Up @@ -60,6 +62,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA
* [Environment and Dependencies](#environment-and-dependencies)
+ [Host Setup](#host-setup)
+ [Flash Attention](#flash-attention)
+ [Transformer Engine](#transformer-engine)
+ [Multi-Node Launching](#multi-node-launching)
+ [Containerized Setup](#containerized-setup)
* [Usage](#usage)
Expand Down Expand Up @@ -130,7 +133,20 @@ This will automatically adapts building process over different GPU vendors (AMD,

### Flash Attention

To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). Then set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.

### Transformer Engine

To use [Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine), install the additional dependencies in `./requirements/requirements-transformer-engine.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). See [this config](https://github.com/EleutherAI/gpt-neox/configs/1-3B-transformer-engine.yml) for an example of using TE on a 1.3B model. This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere and Hopper GPUs; see the repository for more details.


TE provides very efficient kernels for both A100 and H100 GPUs. We've run some sample ablations on A100:



and H100:




### Multi-Node Launching
Expand Down
105 changes: 105 additions & 0 deletions configs/1-3B-transformer-engine.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# GPT-2 pretraining setup
{
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe_parallel_size": 1,
"model_parallel_size": 1,

# model settings
"num_layers": 24,
"hidden_size": 2048,
"num_attention_heads": 16,
"seq_length": 2048,
"max_position_embeddings": 2048,
"norm": "layernorm",
"pos_emb": "rotary",
"no_weight_tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",

# Transformer Engine settings
"te_columnparallel": false,
"te_rowparallel": false,
"te_layernorm_mlp": true,
"te_mha": true,
"te_fp8_format": "hybrid",
"te_fp8_wgrad": true,
"te_fp8_amax_history_len": 1,
"te_fp8_amax_compute_algo": "most_recent",
"te_fp8_margin": 0,
"te_fp8_mha": false,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0002,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00002,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"fp16": {
"fp16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 100,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
}
5 changes: 5 additions & 0 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def _prepare_cache(self, seq_len, precision, base):
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)

self.emb = emb.reshape(emb.size(0), 1, 1, emb.size(1))

cos_cached = emb.cos()[:, None, None, :]
sin_cached = emb.sin()[:, None, None, :]

Expand All @@ -76,6 +78,9 @@ def _prepare_cache(self, seq_len, precision, base):
inv_freq.to(precision),
)

def get_emb(self):
return self.emb.to(self.precision).cuda()

def forward(self, x, seq_dim=0, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
Expand Down
Loading

0 comments on commit 8900d05

Please sign in to comment.