diff --git a/README.md b/README.md index 006f5964f..95aaf35bd 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg * Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, monitor experiments via [WandB](https://wandb.ai/site)/[Comet](https://www.comet.com/site/)/TensorBoard, and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness). ## News +**[10/9/2024]** We now support [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) integration + **[9/9/2024]** We now support preference learning via [DPO](https://arxiv.org/abs/2305.18290), [KTO](https://arxiv.org/abs/2402.01306), and reward modeling **[9/9/2024]** We now support integration with [Comet ML](https://www.comet.com/site/), a machine learning monitoring platform @@ -60,6 +62,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA * [Environment and Dependencies](#environment-and-dependencies) + [Host Setup](#host-setup) + [Flash Attention](#flash-attention) + + [Transformer Engine](#transformer-engine) + [Multi-Node Launching](#multi-node-launching) + [Containerized Setup](#containerized-setup) * [Usage](#usage) @@ -130,7 +133,20 @@ This will automatically adapts building process over different GPU vendors (AMD, ### Flash Attention -To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details. +To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). Then set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details. + +### Transformer Engine + +To use [Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine), install the additional dependencies in `./requirements/requirements-transformer-engine.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). See [this config](https://github.com/EleutherAI/gpt-neox/configs/1-3B-transformer-engine.yml) for an example of using TE on a 1.3B model. This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere and Hopper GPUs; see the repository for more details. + + +TE provides very efficient kernels for both A100 and H100 GPUs. We've run some sample ablations on A100: + + + +and H100: + + ### Multi-Node Launching diff --git a/configs/1-3B-transformer-engine.yml b/configs/1-3B-transformer-engine.yml new file mode 100644 index 000000000..079a5c31d --- /dev/null +++ b/configs/1-3B-transformer-engine.yml @@ -0,0 +1,105 @@ +# GPT-2 pretraining setup +{ + # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages + # across the node boundaries ) + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + # model settings + "num_layers": 24, + "hidden_size": 2048, + "num_attention_heads": 16, + "seq_length": 2048, + "max_position_embeddings": 2048, + "norm": "layernorm", + "pos_emb": "rotary", + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + + # Transformer Engine settings + "te_columnparallel": false, + "te_rowparallel": false, + "te_layernorm_mlp": true, + "te_mha": true, + "te_fp8_format": "hybrid", + "te_fp8_wgrad": true, + "te_fp8_amax_history_len": 1, + "te_fp8_amax_compute_algo": "most_recent", + "te_fp8_margin": 0, + "te_fp8_mha": false, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00002, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "fp16": { + "fp16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 100, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, +} diff --git a/configs/eleutherai_cluster.yml b/configs/eleutherai_cluster.yml index 36e75d8b3..3cf5bb007 100644 --- a/configs/eleutherai_cluster.yml +++ b/configs/eleutherai_cluster.yml @@ -24,6 +24,7 @@ "tensorboard_dir": "/mnt/ssd-1/tensorboard", "log_dir": "/mnt/ssd-1/logs", "wandb_team": "eleutherai", + #"wandb_run_name": "experiment" "wandb_project": "neox", "wandb_group": "example" } diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index c08b60151..a73cf2a68 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -24,6 +24,7 @@ from megatron.data.blendable_dataset import BlendableDataset from megatron.data.gpt2_dataset import GPT2Dataset from megatron.data.pairwise_dataset import PairwiseDataset +from megatron.data.online_dataset import OnlineDataset from megatron.data.samplers import DistributedBatchSampler @@ -532,7 +533,56 @@ def build_train_valid_test_data_loaders(neox_args): pipe_load = True # Data loader only on rank 0 of each model parallel group. - if mpu.get_model_parallel_rank() == 0 and pipe_load: + if ( + pipe_load + and (neox_args.dataset_impl == "online") + and (mpu.get_model_parallel_rank() == 0) + ): + # Can skip most of the work... + train_iters = neox_args.train_iters + eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters + test_iters = neox_args.eval_iters + # Build datasets... + print( + f"train_iters: {train_iters}, eval_iters: {eval_iters}, test_iters: {test_iters}" + ) + train_datasets = OnlineDataset( + leave_one_out=neox_args.reinforce_leave_one_out, + data_split="train", + num_samples=train_iters * neox_args.train_batch_size, + seq_length=neox_args.seq_length, + dataserver_ips=neox_args.online_dataserver_ips, + dataserver_ports=neox_args.online_dataserver_ports, + ) + valid_datasets = OnlineDataset( + leave_one_out=neox_args.reinforce_leave_one_out, + data_split="valid", + num_samples=eval_iters * neox_args.train_batch_size, + seq_length=neox_args.seq_length, + dataserver_ips=neox_args.online_dataserver_ips, + dataserver_ports=neox_args.online_dataserver_ports, + ) + test_datasets = OnlineDataset( + leave_one_out=neox_args.reinforce_leave_one_out, + data_split="test", + num_samples=test_iters * neox_args.train_batch_size, + seq_length=neox_args.seq_length, + dataserver_ips=neox_args.online_dataserver_ips, + dataserver_ports=neox_args.online_dataserver_ports, + ) + # print length of datasets + # Build dataloders. + train_dataloader = make_data_loader(train_datasets, neox_args=neox_args) + valid_dataloader = make_data_loader(valid_datasets, neox_args=neox_args) + test_dataloader = make_data_loader(test_datasets, neox_args=neox_args) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and neox_args.train_iters > 0 + do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 + do_test = test_dataloader is not None and neox_args.eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. + flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) + elif mpu.get_model_parallel_rank() == 0 and pipe_load: # Number of train/valid/test samples. if neox_args.train_iters is not None: train_iters = neox_args.train_iters diff --git a/megatron/data/online_dataset.py b/megatron/data/online_dataset.py new file mode 100644 index 000000000..9a12c1875 --- /dev/null +++ b/megatron/data/online_dataset.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Online dataset.""" +from typing import Union, List + +import numpy as np +import torch +import torch.utils.data +import socket +import pickle +from megatron.mpu.initialize import get_data_parallel_rank + + +class OnlineDataset(torch.utils.data.Dataset): + def __init__( + self, + num_samples, + seq_length, + leave_one_out=False, + data_split="train", + dataserver_ips: Union[str, List[str]] = "localhost", + dataserver_ports: Union[int, List[int]] = 10000, + ): + self.num_samples = num_samples + self.global_rank = get_data_parallel_rank() + self.leave_one_out = leave_one_out + self.reward_buffer = [] + self.online_batching_data = [] + self.data_split = data_split + self.seq_length = seq_length + self.dataserver_ips = dataserver_ips + self.dataserver_ports = dataserver_ports + + def __len__(self): + # dummy value since it's decided by the Online Trainer + return self.num_samples + + def update_online_batches(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if isinstance(self.dataserver_ips, str): + ipaddr = self.dataserver_ips + else: + ipaddr = self.dataserver_ips[self.global_rank] + if isinstance(self.dataserver_ports, int): + # simply add over the global rank + port = self.dataserver_ports + else: + # in case we want to use different ports for different ranks, e.g. per machine sampling + port = self.dataserver_ports[self.global_rank] + print(f"Connecting to {ipaddr}:{port}") + s.connect((ipaddr, port)) + s.send(self.data_split.encode()) + data = b"" + while True: + chunk = s.recv(4096) + if not chunk: + break + data += chunk + batch_data = pickle.loads(data) + s.close() + print(f"Received {len(batch_data)} samples from the server.") + for data in batch_data: + if self.leave_one_out: + rewards = list() + for i in range(len(data["rewards"])): + rewards.append( + data["rewards"][i] + - np.mean( + [ + data["rewards"][j] + for j in range(len(data["rewards"])) + if j != i + ] + ) + ) + data["raw_rewards"] = data["rewards"] + data["rewards"] = rewards + else: + moving_average = 0 + if len(self.reward_buffer) > 0: + moving_average = np.mean(self.reward_buffer) + self.reward_buffer.append(np.mean(data["rewards"])) + if len(self.reward_buffer) > 100: + self.reward_buffer.pop(0) + # For metrics... + data["raw_rewards"] = data["rewards"] + data["rewards"] = [r - moving_average for r in data["rewards"]] + for i in range(len(data["completions"])): + self.online_batching_data.append( + [ + data["prefix"], + data["completions"][i], + data["rewards"][i], + data["raw_rewards"][i], + ] + ) + + def __getitem__(self, idx): + if len(self.online_batching_data) == 0: + self.update_online_batches() + batch = self.online_batching_data.pop(0) + text = batch[0] + batch[1] + label = [-100 for _ in batch[0]] + batch[1] + # +1 because of causal masking + if len(text) <= self.seq_length: + text = text + [0] * ((self.seq_length + 1) - len(text)) + label = label + [-100] * ((self.seq_length + 1) - len(label)) + return { + "text": np.array(text, dtype=np.int64), + "label": np.array(label, dtype=np.int64), + "reward": np.array([batch[2]], dtype=np.float32), + "raw_reward": np.array([batch[3]], dtype=np.float32), + } diff --git a/megatron/logging.py b/megatron/logging.py index af8a41fe5..48481c047 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -80,8 +80,12 @@ def human_readable_flops(num) -> str: return "%.1f%s" % (num, "Yi") -def get_flops(neox_args, iter_time_s) -> float: +def get_actual_flops(neox_args, iter_time_s) -> float: """ + This function finds the actual FLOPs achieved accounting for implementation and hardware details. Also used for HFU. + + For more detail on flop calculations, see https://github.com/EleutherAI/cookbook/tree/main/calc and https://github.com/Zyphra/zcookbook/tree/main/calc + Use FLOPS calculation from Megatron-DeepSpeed: https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253 They get it from https://arxiv.org/pdf/2104.04473.pdf @@ -156,6 +160,83 @@ def get_flops(neox_args, iter_time_s) -> float: return flops_per_iteration / (iter_time_s * world_size) +def get_forward_backward_flops(neox_args, iter_time_s) -> float: + """ + This function finds the estimated FLOPs required by a single forward+backward pass without accounting for implementation and hardware details. Also used for MFU. + + Mostly duplicated from get_actual_flops with just a change in activation checkpointing for now, but these may diverge over time as implementation details accumulate so I think 2 separate functions are appropriate. + + For more detail on flop calculations, see https://github.com/EleutherAI/cookbook/tree/main/calc and https://github.com/Zyphra/zcookbook/tree/main/calc + + Use FLOPS calculation from Megatron-DeepSpeed: + https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253 + They get it from https://arxiv.org/pdf/2104.04473.pdf + """ + world_size = torch.distributed.get_world_size() + vocab_size = neox_args.padded_vocab_size + batch_size = neox_args.train_batch_size + seq_len = neox_args.seq_length + hidden_size = neox_args.hidden_size + num_layers = neox_args.num_layers + fwd_bwd_factor = 3 # 1 for fwd, 2 for bwd and weight update + if "rwkv" in neox_args.attention_config: + num_heads = neox_args.num_attention_heads + + flops_per_iteration = ( + batch_size + * seq_len + * ( + 78 * hidden_size * hidden_size * num_layers + + 84 * hidden_size * num_layers + + 16 * hidden_size + + 12 * hidden_size * vocab_size + + 18 * hidden_size * hidden_size * num_layers / num_heads + ) + ) + elif "mamba" in neox_args.attention_config: + # from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py + if neox_args.expansion_factor: + d_inner = neox_args.hidden_size * neox_args.expansion_factor + elif neox_args.intermediate_size: + d_inner = neox_args.intermediate_size + else: + d_inner = neox_args.hidden_size * 2 # default expansion factor + d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here + conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here + dt_rank = math.ceil(neox_args.hidden_size / 16) + ssm_flops = ( + fwd_bwd_factor + * d_inner + * seq_len + * batch_size + * (11 * d_state + 4 * dt_rank + 1) + ) + mamba_projectors_flops = ( + fwd_bwd_factor * seq_len * batch_size * 6 * d_inner * hidden_size + ) + mamba_conv_flops = ( + fwd_bwd_factor * seq_len * batch_size * 2 * d_inner * conv_dimension + ) + mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops + embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size + flops_per_iteration = mamba_flops * num_layers + embedding_flops + else: + flops_per_iteration = ( + 24 + * fwd_bwd_factor + * batch_size + * seq_len + * num_layers + * (hidden_size**2) + * ( + 1.0 + + (seq_len / (6.0 * hidden_size)) + + (vocab_size / (16.0 * num_layers * hidden_size)) + ) + ) + return flops_per_iteration / (iter_time_s * world_size) + + def training_log( neox_args, timers, @@ -350,6 +431,8 @@ def add_to_logging(name): elapsed_time = timers("interval time").elapsed() iteration_time = elapsed_time / neox_args.log_interval samples_per_sec = neox_args.train_batch_size / iteration_time + steps_per_sec = 1 / iteration_time + tokens_per_sec = samples_per_sec * neox_args.seq_length log_string = " samples/sec: {:.3f} |".format(samples_per_sec) tb_wandb_log( "runtime/samples_per_sec", @@ -367,6 +450,22 @@ def add_to_logging(name): tensorboard_writer=neox_args.tensorboard_writer, comet_experiment=neox_args.comet_experiment, ) + tb_wandb_log( + "runtime/steps_per_sec", + steps_per_sec, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + comet_experiment=neox_args.comet_experiment, + ) + tb_wandb_log( + "runtime/tokens_per_sec", + tokens_per_sec, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + comet_experiment=neox_args.comet_experiment, + ) log_string += " iteration {:8d}/{:8d} |".format( iteration, neox_args.train_iters ) @@ -390,7 +489,7 @@ def add_to_logging(name): ) # log tflop / gpu - flops_per_s_per_gpu = get_flops(neox_args, iteration_time) + flops_per_s_per_gpu = get_actual_flops(neox_args, iteration_time) log_string += ( f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |" @@ -404,6 +503,39 @@ def add_to_logging(name): comet_experiment=neox_args.comet_experiment, ) + if neox_args.peak_theoretical_tflops: + # Convert peak theoretical TFLOPS to FLOPS for consistent units + peak_theoretical_flops = neox_args.peak_theoretical_tflops * (10**12) + + # Calculate MFU and HFU as percentages + mfu = ( + get_forward_backward_flops(neox_args, iteration_time) + / peak_theoretical_flops + ) * 100 + hfu = (flops_per_s_per_gpu / peak_theoretical_flops) * 100 + + # Add to log string + log_string += f" MFU: {mfu:.2f}% | HFU: {hfu:.2f}% |" + + # Log to tracking systems + tb_wandb_log( + "runtime/model_flops_utilization", + mfu, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + comet_experiment=neox_args.comet_experiment, + ) + + tb_wandb_log( + "runtime/hardware_flops_utilization", + hfu, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + comet_experiment=neox_args.comet_experiment, + ) + for key in total_loss_dict: if key not in [skipped_iters_key, got_nan_key]: v = ( diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index fcded9e96..072aad8b4 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -67,6 +67,8 @@ def _prepare_cache(self, seq_len, precision, base): freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) + self.emb = emb.reshape(emb.size(0), 1, 1, emb.size(1)) + cos_cached = emb.cos()[:, None, None, :] sin_cached = emb.sin()[:, None, None, :] @@ -76,6 +78,9 @@ def _prepare_cache(self, seq_len, precision, base): inv_freq.to(precision), ) + def get_emb(self): + return self.emb.to(self.precision).cuda() + def forward(self, x, seq_dim=0, seq_len=None): if seq_len is None: seq_len = x.shape[seq_dim] diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c670fd4bf..e60fbbe41 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -55,6 +55,8 @@ except ImportError: swiglu = None +from .utils import get_parallel_linear + # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) @@ -114,6 +116,8 @@ def __init__( self.bias_gelu_fusion = neox_args.bias_gelu_fusion self.multiple_of = multiple_of + ColumnParallelLinear, RowParallelLinear = get_parallel_linear(neox_args) + if neox_args.intermediate_size: ffn_dim = neox_args.intermediate_size elif neox_args.expansion_factor: @@ -141,8 +145,7 @@ def __init__( ffn_dim_in = int( self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) ) - - self.linear1 = mpu.ColumnParallelLinear( + self.linear1 = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ffn_dim, @@ -154,7 +157,7 @@ def __init__( bias=neox_args.use_bias_in_mlp, ) # Project back to h. - self.linear2 = mpu.RowParallelLinear( + self.linear2 = RowParallelLinear( neox_args=neox_args, input_size=ffn_dim_in, output_size=neox_args.hidden_size, @@ -170,7 +173,6 @@ def __init__( def forward(self, hidden_states): # [s, b, intermediate_size] intermediate_parallel, bias_parallel = self.linear1(hidden_states) - if self.is_gated or (self.activation_type == "gelu" and self.bias_gelu_fusion): intermediate_parallel = self.activation_func( intermediate_parallel, bias_parallel @@ -217,10 +219,13 @@ def __init__( is_last_layer=False, ): super().__init__() + + ColumnParallelLinear, RowParallelLinear = get_parallel_linear(neox_args) + self.is_rm = neox_args.train_impl == "rm" parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" if parallelism == "column": - self.final_linear = mpu.ColumnParallelLinear( + self.final_linear = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.padded_vocab_size, @@ -249,7 +254,7 @@ def __init__( # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here # ) else: # Not using cross entropy loss for RMs - self.rm_linear = mpu.RowParallelLinear( + self.rm_linear = RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=1, @@ -335,6 +340,8 @@ def __init__( ): super().__init__() + ColumnParallelLinear, RowParallelLinear = get_parallel_linear(neox_args) + self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" self.attention_mask_func = attention_mask_func @@ -388,7 +395,7 @@ def __init__( if not self.gqa: # Strided linear layer. - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=3 * neox_args.hidden_size, @@ -398,7 +405,7 @@ def __init__( ) else: # QKV proj is smaller if we are using GQA / MQA - self.query_key_value = mpu.ColumnParallelLinear( + self.query_key_value = ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size + 2 * self.kv_hidden_size, @@ -506,7 +513,7 @@ def __init__( self.attention_dropout = nn.Dropout(self.dropout_p) # Output. - self.dense = mpu.RowParallelLinear( + self.dense = RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=neox_args.hidden_size, @@ -838,13 +845,11 @@ def forward(self, hidden_states, attention_mask, layer_past=None): # ===================== # Query, Key, and Value # ===================== - if not self.gqa: # QKV projection for MHA. # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, @@ -864,7 +869,6 @@ def forward(self, hidden_states, attention_mask, layer_past=None): query_layer, key_layer, value_layer = self.gqa_project( hidden_states, attention_mask, layer_past=layer_past ) - # QK Normalization https://arxiv.org/abs/2302.05442 if self.use_qk_layernorm: query_layer = self.qk_layernorm(query_layer) @@ -1002,17 +1006,33 @@ def __init__( ) # Self attention. - self.attention = ParallelSelfAttention( - neox_args=neox_args, - attention_mask_func=attention_mask_func, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - layer_number=layer_number, - rpe=rpe, - use_cache=self.use_cache, - rotary=rotary, - parallel_output=self.gpt_j_residual, - ) + if neox_args.te_mha or neox_args.te_fp8_mha: + from megatron.model.transformer_engine import TEMultiheadAttention + + self.attention = TEMultiheadAttention( + neox_args=neox_args, + attention_mask_func=attention_mask_func, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + rpe=rpe, + use_cache=self.use_cache, + rotary=rotary, + parallel_output=self.gpt_j_residual, + ) + + else: + self.attention = ParallelSelfAttention( + neox_args=neox_args, + attention_mask_func=attention_mask_func, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + rpe=rpe, + use_cache=self.use_cache, + rotary=rotary, + parallel_output=self.gpt_j_residual, + ) # Layernorm on the output of the attention layer. # If GPT-J residuals are used, this is surpurfulous but leaving it in @@ -1030,6 +1050,18 @@ def get_mlp(**kw): **kw, ) + def get_te_lnmlp(**kw): + from megatron.model.transformer_engine import TELayerNormMLP + + return TELayerNormMLP( + neox_args=neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=self.gpt_j_residual, + multiple_of=neox_args.mlp_multiple_of, + **kw, + ) + self.num_experts = ( neox_args.moe_num_experts if layer_number % neox_args.expert_interval == 0 @@ -1037,7 +1069,10 @@ def get_mlp(**kw): ) args = neox_args if self.num_experts <= 1: - self.mlp = get_mlp() + if neox_args.te_layernorm_mlp: + self.mlp = get_te_lnmlp() + else: + self.mlp = get_mlp() else: from torch import distributed as dist @@ -1146,123 +1181,144 @@ def forward(self, x, attention_mask, layer_past=None): bias_dropout_fn = self._get_bias_dropout() moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) # x: [b, s, h] - if self.gpt_j_residual: - # pseudocode: - # x = x + attn(ln(x)) + mlp(ln(x)) - # this means we can avoid doing the allreduce in the attn / mlp outputs - # to save communication time (we can do a single allreduce after we add mlp / attn outputs). - # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but - # we preserve the functionality for backwards compatibility - - residual = x - # applies the correct normalization depending on if the norms are tied - if self.gpt_j_tied: - x = self.input_layernorm(x) - x1, x2 = x, x - else: - x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) - # attention operator - attention_output, attention_bias = self.attention( - x1, attention_mask, layer_past=layer_past - ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - - if attention_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(attention_output), - residual=None, - prob=self.hidden_dropout, - ) - - # mlp operator - mlp_output, mlp_bias = self.mlp(x2) - if mlp_bias is not None: - with torch.enable_grad() if not self.eval else nullcontext(): - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(mlp_output), - residual=attention_output, - prob=self.hidden_dropout, - ) - else: - output = mlp_output + # Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer. + if self.neox_args.te_fp8_mha: + from megatron.model.transformer_engine import TEDelayedScaling - # output = (x + attn(ln(x)) + mlp(ln(x)) - output = residual + self.reduce(output) + fp8_recipe = TEDelayedScaling(neox_args=self.neox_args) + fp8_context = fp8_recipe.get_context() else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) + from contextlib import nullcontext + + fp8_context = nullcontext() + + with fp8_context: + if self.gpt_j_residual: + # pseudocode: + # x = x + attn(ln(x)) + mlp(ln(x)) + # this means we can avoid doing the allreduce in the attn / mlp outputs + # to save communication time (we can do a single allreduce after we add mlp / attn outputs). + # due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but + # we preserve the functionality for backwards compatibility + + residual = x + # applies the correct normalization depending on if the norms are tied + if self.gpt_j_tied and not self.neox_args.te_layernorm_mlp: + x = self.input_layernorm(x) + x1, x2 = x, x + elif self.gpt_j_tied and self.neox_args.te_layernorm_mlp: + x2 = x + x = self.input_layernorm(x) + x1 = x + elif self.neox_args.te_layernorm_mlp: + x1, x2 = self.input_layernorm(x), x + else: + x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x) - residual = x + # attention operator + attention_output, attention_bias = self.attention( + x1, attention_mask, layer_past=layer_past + ) + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents - # x = x + attn(ln1(x)) - attention_output, attention_bias = self.attention( - self.input_layernorm(x), attention_mask, layer_past=layer_past - ) - if self.use_cache: - attention_output, presents = attention_output - self.layer_past = presents - with torch.enable_grad() if not self.eval else nullcontext(): if attention_bias is not None: - # Use special bias_dropout_fn if we have a bias term from the above attention layer - attention_output = bias_dropout_fn( - attention_output, - bias=attention_bias.expand_as(residual), - residual=residual, - prob=self.hidden_dropout, - ) - else: - # Otherwise just apply dropout + residual - attention_output = ( - torch.nn.functional.dropout( + with torch.enable_grad() if not self.eval else nullcontext(): + attention_output = bias_dropout_fn( attention_output, - p=self.hidden_dropout, - training=self.training, + bias=attention_bias.expand_as(attention_output), + residual=None, + prob=self.hidden_dropout, ) - + residual - ) - # output = x + mlp(ln2(x)) - layernorm_output = self.post_attention_layernorm(attention_output) - mlp_bias = torch.tensor( - 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype - ) + # mlp operator + mlp_output, mlp_bias = self.mlp(x2) + if mlp_bias is not None: + with torch.enable_grad() if not self.eval else nullcontext(): + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(mlp_output), + residual=attention_output, + prob=self.hidden_dropout, + ) + else: + output = mlp_output - if self.num_experts == 1: - mlp_output, mlp_bias = self.mlp(layernorm_output) + # output = (x + attn(ln(x)) + mlp(ln(x)) + output = residual + self.reduce(output) else: - if self.moe_type == "deepspeed": - mlp_output, moe_loss, _ = self.mlp(layernorm_output) - mlp_bias = ( - None # deepspeed.moe.layer.MoE.forward ignores the bias term - ) - elif self.moe_type == "megablocks": - mlp_output, mlp_bias = self.mlp(layernorm_output) + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + + residual = x + + # x = x + attn(ln1(x)) + attention_output, attention_bias = self.attention( + self.input_layernorm(x), attention_mask, layer_past=layer_past + ) + + if self.use_cache: + attention_output, presents = attention_output + self.layer_past = presents + with torch.enable_grad() if not self.eval else nullcontext(): + if attention_bias is not None: + # Use special bias_dropout_fn if we have a bias term from the above attention layer + attention_output = bias_dropout_fn( + attention_output, + bias=attention_bias.expand_as(residual), + residual=residual, + prob=self.hidden_dropout, + ) + else: + # Otherwise just apply dropout + residual + attention_output = ( + torch.nn.functional.dropout( + attention_output, + p=self.hidden_dropout, + training=self.training, + ) + + residual + ) + + # output = x + mlp(ln2(x)) + if self.neox_args.te_layernorm_mlp: + layernorm_output = attention_output else: - raise KeyError(self.moe_type) - - with torch.enable_grad() if not self.eval else nullcontext(): - if mlp_bias == None or ( - self.num_experts > 1 and self.moe_type == "deepspeed" - ): - # No dropout either - assert mlp_bias is None - output = mlp_output + attention_output + layernorm_output = self.post_attention_layernorm(attention_output) + mlp_bias = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype + ) + + if self.num_experts == 1: + mlp_output, mlp_bias = self.mlp(layernorm_output) else: - output = bias_dropout_fn( - mlp_output, - bias=mlp_bias.expand_as(attention_output), - residual=attention_output, - prob=self.hidden_dropout, - ) + if self.moe_type == "deepspeed": + mlp_output, moe_loss, _ = self.mlp(layernorm_output) + mlp_bias = None # deepspeed.moe.layer.MoE.forward ignores the bias term + elif self.moe_type == "megablocks": + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + raise KeyError(self.moe_type) + + with torch.enable_grad() if not self.eval else nullcontext(): + if mlp_bias == None or ( + self.num_experts > 1 and self.moe_type == "deepspeed" + ): + # No dropout either + assert mlp_bias is None + output = mlp_output + attention_output + else: + output = bias_dropout_fn( + mlp_output, + bias=mlp_bias.expand_as(attention_output), + residual=attention_output, + prob=self.hidden_dropout, + ) - return output, moe_loss + return output, moe_loss class ParallelTransformerLayerPipe(ParallelTransformerLayer): diff --git a/megatron/model/transformer_engine.py b/megatron/model/transformer_engine.py index 338513a97..e67071f88 100644 --- a/megatron/model/transformer_engine.py +++ b/megatron/model/transformer_engine.py @@ -1,4 +1,45 @@ +# Copyright (c) 2024, EleutherAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + +from megatron.model.transformer import Gated_Activation +from megatron.model.activations import get_activation +from megatron.mpu.initialize import get_model_parallel_rank +from megatron.mpu.initialize import get_model_parallel_world_size +from megatron.mpu.initialize import get_tensor_model_parallel_group +from megatron.mpu.mappings import copy_to_model_parallel_region +from megatron.mpu.mappings import gather_from_model_parallel_region +from megatron.mpu.mappings import reduce_from_model_parallel_region +from megatron.mpu.mappings import scatter_to_model_parallel_region +from megatron.mpu.mappings import reduce_scatter_to_sequence_parallel_region +from megatron.mpu.mappings import gather_from_sequence_parallel_region +from megatron.mpu.layers import ( + _initialize_affine_weight_gpu, + _initialize_affine_weight_cpu, +) +from megatron.mpu.random import get_cuda_rng_tracker +from megatron.mpu.utils import divide +from megatron.mpu.utils import VocabUtility +from functools import partial +from megatron.model.positional_embeddings import RotaryEmbedding +from megatron import mpu try: import transformer_engine as te @@ -58,73 +99,546 @@ class TELinear(te.pytorch.Linear): Wrapper for the Transformer-Engine's `Linear` layer. """ - def __init__(self): - # TODO - return + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + init_method=init.xavier_normal_, + stride=1, + skip_bias_add=False, + mup_rescale_parameters=False, + seq_dim=0, + ): + self.input_size = input_size + self.output_size = output_size + + self.skip_bias_add = skip_bias_add + self.use_bias = bias + + self.sequence_parallel = neox_args.sequence_parallel + self.seq_dim = seq_dim + + self.init_method = init_method + self.stride = stride + self.mup_rescale_parameters = mup_rescale_parameters + self.use_mup = neox_args.use_mup + self.params_dtype = neox_args.params_dtype + + super(TELinear, self).__init__( + in_features=self.input_size, + out_features=self.output_size, + bias=self.use_bias, + init_method=self.init_method, + get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), + return_bias=self.skip_bias_add, + params_dtype=self.params_dtype, + ) - def forward(self, x): - # TODO - return + def forward(self, inp, **kwargs): + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() + + output = super(TELinear, self).forward(inp, **kwargs) + + if self.skip_bias_add: + return output + else: + return output, None -class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): +class TELayerNormMLP(te.pytorch.LayerNormMLP): """ - Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines - layernorm and linear layers + Wrapper for the Transformer-Engine's `LayerNormMLP` layer that combines + layernorm and followed by the MLP module, consisting of 2 successive + linear transformations, separated by the GeLU activation. """ - def __init__(self): - # TODO - return + def __init__( + self, + neox_args, + init_method, + output_layer_init_method, + parallel_output=False, + multiple_of=256, + MOE=False, + MoE_mp_size=1, + bias=True, + ): + self.activation_func, self.is_gated = get_activation(neox_args) + self.activation_type = neox_args.activation + self.multiple_of = multiple_of + self.bias = bias + self.init_method = init_method + self.output_layer_init_method = output_layer_init_method + + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.sequence_parallel = neox_args.sequence_parallel + self.seq_len = neox_args.seq_length + self.micro_batch_size = neox_args.train_micro_batch_size_per_gpu + self.params_dtype = neox_args.params_dtype + self.set_parallel_mode = False + if world_size > 1: + self.set_parallel_mode = True + + if neox_args.intermediate_size: + ffn_dim = neox_args.intermediate_size + elif neox_args.expansion_factor: + ffn_dim = int(neox_args.expansion_factor * neox_args.hidden_size) + else: + # 4h is default for ffn_dim + ffn_dim = 4 * neox_args.hidden_size + ffn_dim_in = ffn_dim + if self.is_gated: + # set activation function to be gated implementation + self.activation_func = Gated_Activation(self.activation_func) + # auto scale so gated activations has equal parameters + ffn_dim = int(ffn_dim * 2 / 3) + ffn_dim_in = ffn_dim // 2 + # set multiple + ffn_dim = int( + (2 * self.multiple_of) + * ((ffn_dim + (2 * multiple_of) - 1) // (2 * multiple_of)) + ) + ffn_dim_in = int( + self.multiple_of * ((ffn_dim_in + multiple_of - 1) // multiple_of) + ) - def forward(self, x): - # TODO - return + if neox_args.norm in ["layernorm", "te_layernorm"]: + self.eps = 1.0e-5 + self.normalization = "LayerNorm" + elif neox_args.norm in ["rmsnorm", "te_rmsnorm"]: + self.eps = 1.0e-8 + self.normalization = "RMSNorm" + else: + raise ValueError( + "Only LayerNorm and RMSNorm are supported with TransformerEngine" + ) + + if self.activation_type not in [ + "gelu", + "geglu", + "relu", + "reglu", + "squared_relu", + "swiglu", + "qgelu", + "srelu", + ]: + raise ValueError( + "Only gelu, geglu, relu, reglu, squared_relu, swiglu, qgelu, and srelu are supported with TransformerEngine" + ) + + super(TELayerNormMLP, self).__init__( + hidden_size=neox_args.hidden_size, + ffn_hidden_size=ffn_dim, + eps=self.eps, + bias=self.bias, + normalization=self.normalization, + activation=self.activation_type, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + device=torch.cuda.current_device(), + set_parallel_mode=self.set_parallel_mode, + sequence_parallel=self.sequence_parallel, + tp_group=self.tp_group, + tp_size=self.world_size, + return_bias=True, + params_dtype=self.params_dtype, + seq_length=self.seq_len, + get_rng_state_tracker=get_cuda_rng_tracker, + micro_batch_size=self.micro_batch_size, + ) -class TEColumnParallelLinear(TELinear): +class TEColumnParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `ColumnParallelLinear` layer. - """ - def __init__(self): - # TODO - return + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ - def forward(self, x): - # TODO - return + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + gather_output=True, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + MOE=False, + MoE_mp_size=1, + mup_rescale_parameters=False, + seq_dim=0, + ): + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + self.use_bias = bias + + self.sequence_parallel = neox_args.sequence_parallel + self.seq_dim = seq_dim + + self.init_method = init_method + self.stride = stride + self.mup_rescale_parameters = mup_rescale_parameters + self.use_mup = neox_args.use_mup + self.params_dtype = neox_args.params_dtype + self.parallel_mode = "column" + + super(TEColumnParallelLinear, self).__init__( + in_features=self.input_size, + out_features=self.output_size, + bias=self.use_bias, + init_method=self.init_method, + get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), + sequence_parallel=self.sequence_parallel, + tp_group=self.tp_group, + tp_size=self.world_size, + parallel_mode=self.parallel_mode, + return_bias=self.skip_bias_add, + params_dtype=self.params_dtype, + ) + # Copied from Mup + def width_mult(self): + assert hasattr(self.weight, "infshape"), ( + "Please call set_base_shapes(...). If using torch.nn.DataParallel, " + "switch to distributed training with " + "torch.nn.parallel.DistributedDataParallel instead" + ) + return self.weight.infshape.width_mult() -class TERowParallelLinear(TELinear): + # Copied from Mup + def _rescale_parameters(self): + """Rescale parameters to convert SP initialization to μP initialization. + Warning: This method is NOT idempotent and should be called only once + unless you know what you are doing. + """ + if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: + raise RuntimeError( + "`_rescale_parameters` has been called once before already. " + "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" + "If you called `set_base_shapes` on a model loaded from a checkpoint, " + "or just want to re-set the base shapes of an existing model, " + "make sure to set the flag `rescale_params=False`.\n" + "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." + ) + if self.bias is not None: + self.bias.data *= self.width_mult() ** 0.5 + self.weight.data *= self.width_mult() ** 0.5 + self._has_rescaled_params = True + + def mup_reinitialize_weights(self, neox_args): + if neox_args.use_cpu_initialization: + self.master_weight = _initialize_affine_weight_cpu( + neox_args, + self.weight, + self.output_size, + self.input_size, + self.input_size_per_partition, + 1, + partial(self.init_method, use_mup=True), + stride=self.stride, + return_master_weight=self.keep_master_weight_for_test, + ) + else: + _initialize_affine_weight_gpu( + self.weight, + partial(self.init_method, use_mup=True), + partition_dim=1, + stride=self.stride, + ) + + def forward(self, inp, **kwargs): + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() + + output = super(TEColumnParallelLinear, self).forward(inp, **kwargs) + if self.skip_bias_add: + return output + else: + return output, None + + +class TERowParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer but specialized similar to megatron's `RowParallelLinear` layer. - """ - def __init__(self): - # TODO - return + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ - def forward(self, x): - # TODO - return + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + MOE=False, + MoE_mp_size=1, + parallel_output=False, + mup_rescale_parameters=False, + ): + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + # Divide the weight matrix along the last dimension. + world_size = MoE_mp_size if MOE else get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + self.use_bias = bias + self.input_is_parallel = input_is_parallel + self.sequence_parallel = neox_args.sequence_parallel + + self.init_method = init_method + self.stride = stride + self.mup_rescale_parameters = mup_rescale_parameters + self.use_mup = neox_args.use_mup + self.params_dtype = neox_args.params_dtype + self.parallel_mode = "row" + + super(TERowParallelLinear, self).__init__( + in_features=self.input_size, + out_features=self.output_size, + bias=self.use_bias, + init_method=self.init_method, + get_rng_state_tracker=get_cuda_rng_tracker, + device=torch.cuda.current_device(), + sequence_parallel=self.sequence_parallel, + tp_group=self.tp_group, + tp_size=self.world_size, + parallel_mode=self.parallel_mode, + return_bias=self.skip_bias_add, + params_dtype=self.params_dtype, + ) + # Copied from Mup + def width_mult(self): + assert hasattr(self.weight, "infshape"), ( + "Please call set_base_shapes(...). If using torch.nn.DataParallel, " + "switch to distributed training with " + "torch.nn.parallel.DistributedDataParallel instead" + ) + return self.weight.infshape.width_mult() -class TEDotProductAttention(te.pytorch.DotProductAttention): + # Copied from Mup + def _rescale_parameters(self): + """Rescale parameters to convert SP initialization to μP initialization. + Warning: This method is NOT idempotent and should be called only once + unless you know what you are doing. + """ + if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params: + raise RuntimeError( + "`_rescale_parameters` has been called once before already. " + "Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n" + "If you called `set_base_shapes` on a model loaded from a checkpoint, " + "or just want to re-set the base shapes of an existing model, " + "make sure to set the flag `rescale_params=False`.\n" + "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." + ) + if self.bias is not None: + self.bias.data *= self.width_mult() ** 0.5 + self.weight.data *= self.width_mult() ** 0.5 + self._has_rescaled_params = True + + def mup_reinitialize_weights(self, neox_args): + if neox_args.use_cpu_initialization: + self.master_weight = _initialize_affine_weight_cpu( + neox_args, + self.weight, + self.output_size, + self.input_size, + self.input_size_per_partition, + 1, + partial(self.init_method, use_mup=True), + stride=self.stride, + return_master_weight=self.keep_master_weight_for_test, + ) + else: + _initialize_affine_weight_gpu( + self.weight, + partial(self.init_method, use_mup=True), + partition_dim=1, + stride=self.stride, + ) + + def forward(self, inp, **kwargs): + if self.use_mup and self.mup_rescale_parameters: + input_ /= self.width_mult() + + output = super(TERowParallelLinear, self).forward(inp, **kwargs) + + if self.skip_bias_add: + return output + else: + return output, None + + +class TEMultiheadAttention(te.pytorch.MultiheadAttention): """ - Wrapper for the Transformer-Engine's `DotProductAttention` layer that also + Wrapper for the Transformer-Engine's `MultiheadAttention` layer that also has "flash attention" enabled. """ - def __init__(self): - # TODO - return + def __init__( + self, + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=None, + rotary=False, + use_cache=False, + parallel_output=False, + ): + + self.neox_args = neox_args + self.attention_mask_func = attention_mask_func + self.init_method = init_method + self.output_layer_init_method = output_layer_init_method + self.layer_number = layer_number + 1 + + world_size = get_model_parallel_world_size() + self.world_size = world_size + self.tp_group = get_tensor_model_parallel_group() + self.sequence_parallel = neox_args.sequence_parallel + self.seq_len = neox_args.seq_length + self.micro_batch_size = neox_args.train_micro_batch_size_per_gpu + self.params_dtype = neox_args.params_dtype + self.set_parallel_mode = False + if world_size > 1: + self.set_parallel_mode = True + + if neox_args.norm in ["layernorm", "te_layernorm"]: + self.eps = 1.0e-5 + self.normalization = "LayerNorm" + elif neox_args.norm == ["rmsnorm", "te_rmsnorm"]: + self.eps = 1.0e-8 + self.normalization = "RMSNorm" + + if ( + not neox_args.num_kv_heads + or neox_args.num_kv_heads == neox_args.num_attention_heads + ): + self.gqa = False + self.num_kv_heads = None + else: + self.gqa = True + self.num_kv_heads = neox_args.num_kv_heads + + super(TEMultiheadAttention, self).__init__( + hidden_size=neox_args.hidden_size, + num_attention_heads=neox_args.num_attention_heads, + attention_dropout=neox_args.attention_dropout, + layernorm_epsilon=self.eps, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + layer_number=self.layer_number, + window_size=neox_args.sliding_window_width, + num_gqa_groups=self.num_kv_heads, + input_layernorm=False, + normalization=self.normalization, + bias=True, + device=torch.cuda.current_device(), + get_rng_state_tracker=get_cuda_rng_tracker, + set_parallel_mode=self.set_parallel_mode, + sequence_parallel=self.sequence_parallel, + tp_group=self.tp_group, + tp_size=self.world_size, + params_dtype=self.params_dtype, + return_bias=True, + qkv_format="sbhd", + fuse_qkv_params=True, + ) - def forward(self, x): - # TODO - return + if neox_args.pos_emb == "rotary": + self.hidden_size_per_attention_head = mpu.divide( + neox_args.hidden_size, neox_args.num_attention_heads + ) + + if neox_args.rotary_pct == 1: + self.rotary_ndims = None + else: + assert neox_args.rotary_pct < 1 + self.rotary_ndims = int( + self.hidden_size_per_attention_head * neox_args.rotary_pct + ) + dim = ( + self.rotary_ndims + if self.rotary_ndims is not None + else self.hidden_size_per_attention_head + ) + self.rotary_embeddings = RotaryEmbedding( + dim, + base=neox_args.rotary_emb_base, + max_seq_len=neox_args.seq_length, + precision=neox_args.params_dtype, + save_inv_freqs=neox_args.rotary_save_freqs_buffer, + ) + self.rope_emb = self.rotary_embeddings.get_emb() + + def forward( + self, hidden_states, attention_mask, layer_past=None, rope_emb=None, **kwargs + ): + output = super(TEMultiheadAttention, self).forward( + hidden_states, attention_mask, rotary_pos_emb=self.rope_emb, **kwargs + ) + return output class TEDelayedScaling(te.common.recipe.DelayedScaling): @@ -132,6 +646,36 @@ class TEDelayedScaling(te.common.recipe.DelayedScaling): Wrapper for the Transformer-Engine's `DelayedScaling` layer. """ - def __init__(self): - # TODO - return + ##TODO Test with H100 + def __init__(self, neox_args): + + self.neox_args = neox_args + self.tp_group = get_tensor_model_parallel_group() + + if neox_args.te_fp8_format == "e4m3": + fp8_format = te.common.recipe.Format.E4M3 + elif neox_args.te_fp8_format == "hybrid": + fp8_format = te.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + override_linear_precision = (False, False, not neox_args.te_fp8_wgrad) + + super().__init__( + margin=neox_args.fp8_margin, + fp8_format=te_fp8_format, + amax_compute_algo=neox_args.te_fp8_amax_compute_algo, + amax_history_len=neox_args.te_fp8_amax_history_len, + override_linear_precision=override_linear_precision, + fp8_mha=neox_args.te_fp8_mha, + ) + + def fp8_context(self): + fp8_group = None + if self.tp_group: + fp8_group = self.tp_group + fp8_context = te.pytorch.fp8_autocast( + enabled=True, fp8_recipe=self, fp8_group=fp8_group + ) + + return get_context diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 8176f1f7a..5515c41f5 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -402,3 +402,20 @@ def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): for name, param in module_.named_parameters(): if param.requires_grad: param.register_hook(reduce_weight_grads_from_model_parallel_region) + + +def get_parallel_linear(neox_args): + if neox_args.te_columnparallel: + from megatron.model.transformer_engine import ( + TEColumnParallelLinear as ColumnParallelLinear, + ) + else: + from megatron.mpu import ColumnParallelLinear + if neox_args.te_rowparallel: + from megatron.model.transformer_engine import ( + TERowParallelLinear as RowParallelLinear, + ) + else: + from megatron.mpu import RowParallelLinear + + return ColumnParallelLinear, RowParallelLinear diff --git a/megatron/model/weight_server.py b/megatron/model/weight_server.py new file mode 100644 index 000000000..987db3434 --- /dev/null +++ b/megatron/model/weight_server.py @@ -0,0 +1,64 @@ +from typing import Union, List + +import torch +import socket +import pickle + + +def send_tensor(state_dict_key, data, sock, end: bool): + storage = data.storage() + ( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_cuda_() + sock.send( + pickle.dumps( + { + "state_dict_key": state_dict_key, + "dtype": data.dtype, + "tensor_size": data.shape, + "tensor_stride": data.stride(), + "tensor_offset": data.storage_offset(), # !Not sure about this one. + "storage_cls": type(storage), + "storage_device": storage_device, + "storage_handle": storage_handle, + "storage_size_bytes": storage_size_bytes, + "storage_offset_bytes": storage_offset_bytes, + "requires_grad": False, + "ref_counter_handle": ref_counter_handle, + "ref_counter_offset": ref_counter_offset, + "event_handle": event_handle, + "event_sync_required": event_sync_required, + "end": end, + } + ) + ) + + +def send_state_dict(state_dict, sock): + for i, key in enumerate(state_dict.keys()): + print(key) + end = i == len(state_dict.keys()) - 1 + send_tensor(key, state_dict[key], sock, end) + sock.recv(4096) + + +def start_server(model, ports: Union[int, List[int]] = 6000): + global_rank = torch.distributed.get_rank() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if type(ports) == int: + port = ports + global_rank + else: + port = ports[global_rank] + s.bind(("localhost", port)) + s.listen(1) + conn, addr = s.accept() + state_dict = model.state_dict() + send_state_dict(state_dict, conn) + conn.close() diff --git a/megatron/neox_arguments/__init__.py b/megatron/neox_arguments/__init__.py index 025464cbf..087dbe6b7 100644 --- a/megatron/neox_arguments/__init__.py +++ b/megatron/neox_arguments/__init__.py @@ -18,7 +18,7 @@ * NeoXArgs.from_ymls(["path_to_yaml1", "path_to_yaml2", ...]): load yaml configuration files and instantiate with the values provided; checks for duplications and unknown arguments are performed * NeoXArgs.from_dict({"num_layers": 12, ...}): load attribute values from dict; checks unknown arguments are performed -* NeoXArgs.consume_deepy_args(): entry point for deepy.py configuring and consuming command line arguments (i.e. user_script, conf_dir, conf_file, wandb_group, wandb_team); neox_args.get_deepspeed_main_args() produces a list of command line arguments to feed to deepspeed.launcher.runner.main +* NeoXArgs.consume_deepy_args(): entry point for deepy.py configuring and consuming command line arguments (i.e. user_script, conf_dir, conf_file, wandb_group, wandb_run_name, wandb_team); neox_args.get_deepspeed_main_args() produces a list of command line arguments to feed to deepspeed.launcher.runner.main * NeoXArgs.consume_neox_args(): In the call stack deepy.py -> deepspeed -> pretrain_gpt2.py; arguments are passed to pretrain_gpt2.py by neox_args.get_deepspeed_main_args(). So produced arguments can be read with consume_neox_args() to instantiate a NeoXArgs instance. diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 3b49cea32..c0a33d4a4 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -335,6 +335,12 @@ def consume_deepy_args(cls, input_args=None): default=None, help='Weights & Biases group name - used to group together "runs".', ) + group.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="Weights & Biases run name for the current experiment.", + ) group.add_argument( "--wandb_team", type=str, diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index c64f67d32..4846a5718 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -502,11 +502,86 @@ class NeoXArgsModel(NeoXArgsTemplate): # Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905) output_layer_parallelism: Literal["column"] = "column" - """ Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ + serve_model_weights: bool = False + """ + If true, serve model weight pointers over a socket connection + """ + + weight_server_port: Union[int, List[int]] = 6000 + """ + Port(s) to serve model weights over + If an integer is provided, the port for each GPU will be 6000 + global rank + If a list is provided, the ports will be used in order, e.g. rank0 will be weight_server_port[0] + """ + + online_dataserver_ips: Union[str, List[str]] = "localhost" + """ + ip addresses to connect to for online data serving, defaults to localhost + """ + + online_dataserver_ports: Union[int, List[int]] = 10000 + """ + Port(s) to connect to for online data serving, defaults to 10000 + """ + + te_columnparallel: bool = False + """ + Use TransformerEngine for RowParallelLinear layer. + """ + + te_rowparallel: bool = False + """ + Use TransformerEngine for ColumnParallelLinear layer. + """ + + te_layernorm_mlp: bool = False + """ + Use TransformerEngine for LayerNormMLP layer. + """ + + te_mha: bool = False + """ + Use TransformerEngine for MultiheadAttention layer. + """ + + te_fp8_format: Literal["e4m3", "hybrid"] = "hybrid" + """ + Controls the FP8 data format used during forward and backward pass by TransformerEngine. + Hybrid uses E4M3 during forward pass, E5M2 during backward pass. + """ + + te_fp8_wgrad: bool = True + """ + When set to False, override FP8 config options and do the wgrad computation + in higher precision. + """ + + te_fp8_amax_history_len: int = 1 + """ + The length of the amax history window used for scaling factor computation. + """ + + te_fp8_amax_compute_algo: str = "most_recent" + """ + Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2 + predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent` + always chooses the most recently seen value. + """ + + te_fp8_margin: int = 0 + """ + Margin for the scaling factor computation. + """ + + te_fp8_mha: bool = False + """ + When set to True, use the FP8 implementation of Multi Head Attention. + """ + dim_att: int = None """ Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size. @@ -629,12 +704,16 @@ class NeoXArgsLogging(NeoXArgsTemplate): Logging Arguments """ + ### BEGIN WANDB ARGS ### use_wandb: bool = None """Flag indicating if wandb is to be used.""" wandb_group: str = None """Weights and Biases group name - used to group together "runs".""" + wandb_run_name: str = None + """Weights and Biases run name for the current experiment""" + wandb_team: str = None """Team name for Weights and Biases.""" @@ -646,6 +725,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): wandb_init_all_ranks: bool = False """Initialize wandb on all ranks.""" + ### END WANDB ARGS ### git_hash: str = get_git_commit_hash() """current git hash of repository""" @@ -655,6 +735,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): Directory to save logs to. """ + ### BEGIN TENSORBOARD ARGS ### tensorboard_writer = None """ initialized tensorboard writer @@ -664,7 +745,9 @@ class NeoXArgsLogging(NeoXArgsTemplate): """ Write TensorBoard logs to this directory. """ + ### END TENSORBOARD ARGS ### + ### BEGIN COMET ARGS ### use_comet: bool = None """Flag indicating if comet is to be used.""" @@ -697,6 +780,12 @@ class NeoXArgsLogging(NeoXArgsTemplate): """ Initialized comet experiment object used to log data """ + ### END COMET ARGS ### + + peak_theoretical_tflops: float = None + """ + The peak hardware flops with which to compute MFU and HFU, in units of teraflops. Automatic detection is more trouble than it's worth, so this is left to the user. Helpful table listed at https://github.com/stas00/ml-engineering/tree/master/compute/accelerator#tflops-comparison-table + """ log_interval: int = 100 """ @@ -716,8 +805,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): log_grad_norm: bool = False """ Log the frob norm of the gradients to wandb / tensorboard (useful for debugging). - (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because - deepspeed.) + (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because deepspeed.) """ log_optimizer_states: bool = False @@ -740,6 +828,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): Whether to offload the buffered gradients to cpu when measuring gradient noise scale. """ + ### BEGIN PROFILING ARGS memory_profiling: bool = False """ Whether to take a memory snapshot of the model. Useful for debugging memory issues. @@ -772,6 +861,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): """ Step to stop profiling at. """ + ### END PROFILING ARGS ### @dataclass @@ -1068,14 +1158,14 @@ class NeoXArgsTraining(NeoXArgsTemplate): warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets """ - dataset_impl: Literal["gpt2", "pairwise"] = "gpt2" + dataset_impl: Literal["gpt2", "pairwise", "online"] = "gpt2" """ - Dataset implementation, can be one of "gpt2" or "pairwise" + Dataset implementation, can be one of "gpt2", "pairwise", or "online" """ - train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal" + train_impl: Literal["normal", "dpo", "rm", "kto", "reinforce"] = "normal" """ - Training implementation, can be one of "normal", "dpo", "kto", or "rm" + Training implementation, can be one of "normal", "dpo", "kto", "reinforce", or "rm" """ dpo_fp32: bool = True @@ -1120,6 +1210,27 @@ class NeoXArgsTraining(NeoXArgsTemplate): Beta value for KTO """ + fp32_reinforce: bool = True + """ + Whether to cast logits to fp32 for Reinforce loss calculation. + """ + + kl_impl: Literal["abs", "mse", "kl", "full"] = "mse" + """ + KL divergence implementation, can be one of "abs", "mse", "kl", or "full" + """ + + kl_div_beta: float = 0.1 + """ + Beta value for KL divergence in Reinforce loss calculation. + """ + + reinforce_leave_one_out: bool = False + """ + Whether to use reinforce leave one out for training + (from https://arxiv.org/abs/2402.14740 and https://api.semanticscholar.org/CorpusID:198489118) + """ + allow_chopped: bool = True """ WARNING: if your packing impl is packed, this is ignored. diff --git a/megatron/training.py b/megatron/training.py index 1965faea8..3def74860 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -62,6 +62,7 @@ get_total_params, CharCounter, ) +from megatron.model.weight_server import start_server from megatron.model.gpt2_model import cross_entropy from megatron.mpu import vocab_parallel_cross_entropy @@ -253,6 +254,13 @@ def pretrain(neox_args): ) timers("model and optimizer").stop() + if neox_args.serve_model_weights: + start_server(model) + # sync... + torch.distributed.barrier() + + # Start data stuff: + # Make and configure iterators timers("train/valid/test data iterators").start() ( @@ -382,7 +390,7 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - if neox_args.train_impl in ["normal", "kto"]: + if neox_args.train_impl in ["normal", "kto", "reinforce"]: keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] elif neox_args.train_impl in ["dpo", "rm"]: keys = ( @@ -427,6 +435,20 @@ def get_batch(neox_args, data_iterator): else None ) return tup + (rw_data, ref_data) + elif neox_args.train_impl == "reinforce": + + tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys, + data=data, + datatype=datatype, + ) + rw_data = mpu.broadcast_data(["reward"], data, torch.float)["reward"] + raw_rw_data = mpu.broadcast_data(["raw_reward"], data, torch.float)[ + "raw_reward" + ] + return tup + (rw_data, raw_rw_data) elif neox_args.train_impl in ["dpo", "rm"]: pos_tup = _get_batch( neox_args=neox_args, @@ -604,6 +626,16 @@ def forward_step( rewards, ref_logp, ) = get_batch(neox_args=neox_args, data_iterator=data_iterator) + elif neox_args.train_impl == "reinforce": + ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + rewards, + raw_rewards, + ) = get_batch(neox_args=neox_args, data_iterator=data_iterator) if neox_args.train_impl in ["dpo", "rm"]: tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( neox_args=neox_args, data_iterator=data_iterator @@ -841,6 +873,70 @@ def forward_step( # print(loss.shape) loss = loss.mean() # print(loss.shape) + elif neox_args.train_impl == "reinforce": + if reference_model is not None: + with torch.no_grad(): + ref_outputs = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_outputs) is tuple: + ref_outputs, _ = ref_outputs + ref_outputs = ref_outputs + if neox_args.kl_impl == "full": + # Have to do the loss over all tokens... + ref_outputs = gather_from_model_parallel_region(ref_outputs) + if neox_args.fp32_reinforce: + ref_outputs = ref_outputs.float() + ref_logp = ref_outputs.log_softmax(dim=-1).detach() + ref_per_token_logp = torch.gather( + ref_logp.clone(), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + else: + ref_per_token_logp = get_logp( + ref_outputs, labels, neox_args.fp32_reinforce + ) + metrics["ref_logp"] = ref_per_token_logp.clone().detach().mean() + outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args) + if type(outputs) is tuple: + outputs, _ = outputs + if neox_args.kl_impl == "full": + # Have to do the loss over all tokens... + outputs = gather_from_model_parallel_region(outputs) + if neox_args.fp32_reinforce: + outputs = outputs.float() + logp = outputs.log_softmax(dim=-1) + per_token_logp = torch.gather( + logp.clone(), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + else: + per_token_logp = get_logp(outputs, labels, neox_args.fp32_reinforce) + with torch.no_grad(): + metrics["logp"] = per_token_logp.clone().detach().mean() + metrics["reward"] = raw_rewards.clone().detach().mean() + metrics["reward_std"] = raw_rewards.clone().detach().std() + loss_mask_sum = loss_mask.sum() + if reference_model is not None: + if neox_args.kl_impl == "full": + # Following along with + # https://github.com/huggingface/trl/blob/104a02d207b63a4a062882aaff68f2d275493399/trl/trainer/ppo_trainer.py#L1109 + kl = F.kl_div(ref_logp, logp, log_target=True, reduction="none").sum(-1) + else: + kl = per_token_logp - ref_per_token_logp + if neox_args.kl_impl == "abs": + kl = kl.abs() + elif neox_args.kl_impl == "mse": + kl = 0.5 * (kl).square() + elif neox_args.kl_impl == "kl": + pass + with torch.no_grad(): + metrics["kl"] = kl.clone().detach().mean() + loss = (-per_token_logp * rewards) + (neox_args.kl_div_beta * kl) + loss = (loss * loss_mask).sum(-1) / loss_mask_sum + loss = loss.mean() + else: + loss = -(rewards * per_token_logp) + loss = (loss * loss_mask).sum(-1) / loss_mask_sum + loss = loss.mean() if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: @@ -1146,10 +1242,17 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): """Setup model and optimizer.""" needs_reference_model = ( - (neox_args.train_impl == "dpo") - and (neox_args.precompute_model_name is None) - and (not neox_args.dpo_reference_free) - ) or ((neox_args.train_impl == "kto") and (neox_args.precompute_model_name is None)) + ( + (neox_args.train_impl == "dpo") + and (neox_args.precompute_model_name is None) + and (not neox_args.dpo_reference_free) + ) + or ( + (neox_args.train_impl == "kto") + and (neox_args.precompute_model_name is None) + ) + or ((neox_args.train_impl == "reinforce") and (neox_args.kl_div_beta > 0.0)) + ) model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: reference_model = get_model(neox_args=neox_args, use_cache=use_cache) @@ -1281,7 +1384,6 @@ def train_step( reference_model=None, ): """Single training step.""" - # Pipeline parallelism schedules forward/backward/step if neox_args.is_pipe_parallel: reduced_loss = train_step_pipe( diff --git a/megatron/utils.py b/megatron/utils.py index 507c44179..fc2f80dad 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -166,12 +166,12 @@ def init_wandb(neox_args): neox_args.update_value("use_wandb", use_wandb) if neox_args.use_wandb: group_name = neox_args.wandb_group - name = f"{socket.gethostname()}-{local_rank()}" if group_name else None + run_name = neox_args.wandb_run_name try: wandb.init( project=neox_args.wandb_project, group=group_name, - name=name, + name=run_name, save_code=False, force=False, entity=neox_args.wandb_team, diff --git a/post-training/README.md b/post-training/README.md index fb7ac8eb4..940cef428 100644 --- a/post-training/README.md +++ b/post-training/README.md @@ -2,6 +2,8 @@ Examples for running post-training with ultrafeedback data for SFT/DPO/RM training. +For [REINFORCE](https://arxiv.org/abs/2402.14740) style training, see [Online Training](OnlineTraining.MD). + ```bash python tools/ckpts/convert_hf_llama_to_neox.py --tp 4 --model meta-llama/Meta-Llama-3-8B-Instruct --model_path checkpoints/neox_converted/llama3-8b-instruct ``` diff --git a/post-training/configs/llama3-8b-reinforce.yml b/post-training/configs/llama3-8b-reinforce.yml new file mode 100644 index 000000000..8d8e04462 --- /dev/null +++ b/post-training/configs/llama3-8b-reinforce.yml @@ -0,0 +1,119 @@ +{ + "pipe_parallel_size": 0, + "model_parallel_size": 4, + "make_vocab_size_divisible_by": 1, + + # model settings + "num_layers": 32, + "hidden_size": 4096, + "num_attention_heads": 32, + "num_kv_heads": 8, + # llama3 supports more than this but this is just for testing. + "seq_length": 1024, + "max_position_embeddings": 1024, + "pos_emb": "rotary", + "rotary_pct": 1, + "rotary_emb_base": 500000, + "rope_fusion": true, + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + + "attention_config": [[["flash"], 32]], + + "scaled_upper_triang_masked_softmax_fusion": true, + "bias_gelu_fusion": false, + "use_bias_in_norms": false, + "use_bias_in_attn_linear": false, + "use_bias_in_mlp": false, + "use_flashattn_swiglu": true, + "activation": "swiglu", + "intermediate_size": 14336, + "mlp_multiple_of": 14336, + + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00001, + "betas": [0.9, 0.95], + "eps": 1.0e-8 + } + }, + "min_lr": 0.000001, + + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 1260000000, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 1260000000, + "contiguous_gradients": true, + "cpu_offload": false + }, + + "train_impl": "reinforce", + "dataset_impl": "online", + "reinforce_leave_one_out": true, + "fp32_reinforce": true, + "kl_impl": "abs", + "online_dataserver_ports": [10000, 10001], + "serve_model_weights": true, + "train_label_data_paths": [ "data/sft/llama3_train_messages_label_document" ], + "test_label_data_paths": [ "data/sft/llama3_test_messages_label_document" ], + "valid_label_data_paths": [ "data/sft/llama3_train_messages_label_document" ], + "train_data_paths": [ "data/sft/llama3_train_messages_document" ], + "test_data_paths": [ "data/sft/llama3_test_messages_document" ], + "valid_data_paths": [ "data/sft/llama3_train_messages_document" ], + + "train_micro_batch_size_per_gpu": 8, + "gradient_accumulation_steps": 4, + "data_impl": "mmap", + "pack_impl": "unpacked", + "num_workers": 1, + + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + "precision": "bfloat16", + "fp32_allreduce": true, + "bf16": { + "enabled": true + }, + "data_types": { + "grad_accum_dtype": "fp32" + }, + + "train_iters": 477, + "lr_decay_iters": 477, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.1, + "checkpoint_factor": 1000, + "eval_interval": 100, + "eval_iters": 10, + + "log_interval": 1, + "steps_per_print": 1, + "wall_clock_breakdown": true, + + + "save": "checkpoints/reinforce/llama3/llama3-8b-instruct", + #"load": "", # once run is started, to restart from intermediate ckpt use "load" = "save" + "load": "checkpoints/neox_converted/llama3-8b-instruct", + "vocab-file": "checkpoints/neox_converted/llama3-8b-instruct/tokenizer/tokenizer.json", + "use_wandb": true, + "wandb_group": "llama3-8b-instruct", + "wandb_project": "reinforce-test", + "finetune": true, # set to false once resuming from intermediate finetuning step + "tokenizer_type": "HFTokenizer", +} diff --git a/post-training/online_data_example_llama3.py b/post-training/online_data_example_llama3.py new file mode 100644 index 000000000..bdd902512 --- /dev/null +++ b/post-training/online_data_example_llama3.py @@ -0,0 +1,177 @@ +import socket +import threading +import datasets +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +import requests +import pickle +from collections import defaultdict +import time + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def http_bot(url, pload): + for i in range(10): + try: + headers = {"User-Agent": "vLLM Client"} + response = requests.post(url, headers=headers, json=pload, stream=True) + data = response.json() + return data + except Exception as e: + # give it a few seconds to recover + time.sleep(5) + print(e) + continue + raise Exception("Failed to connect to server") + + +def threaded_data_gatherer( + prefix, + max_completion_len, + tokenizer, + model_name, + num_completions, + i, + dp_idx, + data_to_send, + rm_pipeline, +): + pload = { + "temperature": 1.0, + "max_tokens": 0, + "stop": "<|eot_id|>", + "stream": False, + "model": model_name, + "prompt": "", + "n": num_completions, + } + # Grab tokens... + prefix_tokens = tokenizer.encode(prefix) + prompt = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": "Please write a mildly negative movie review starting with " + + prefix, + } + ], + add_generation_prompt=True, + tokenize=False, + ) + prompt_tokens = tokenizer.encode(prompt) + pload["max_tokens"] = max_completion_len - len(prefix_tokens) + pload["prompt"] = prompt + prefix + completions = http_bot(f"http://localhost:{8000+dp_idx}/v1/completions", pload) + completions = [completion["text"].strip() for completion in completions["choices"]] + + def reward_fn(samples, **kwargs): + sentiments = list(map(get_positive_score, rm_pipeline(samples))) + return sentiments + + rewards = reward_fn([prefix + " " + completion for completion in completions]) + if i == 0 and dp_idx == 0: + print(completions) + completions = [ + tokenizer.encode(completion + "<|eot_id|>") for completion in completions + ] + data_to_send.append( + {"prefix": prompt_tokens, "completions": completions, "rewards": rewards} + ) + + +def data_generator( + bs_per_dp, + dataset, + tokenizer, + model_name, + max_prefix_len, + max_completion_len, + num_completions, + dp_idx, + dp_size, + tp_size, + rm_pipeline, +): + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind( + ("localhost", 10000 + dp_idx) + ) # only one data loader per data parallel group + split_counter = defaultdict(lambda: dp_idx) + while True: + server.listen(1) + conn, addr = server.accept() + split = conn.recv(4096).decode() + if split == "valid": + split = "unsupervised" + data_to_send = list() + threads = list() + for i in range(bs_per_dp): + prefix = " ".join( + dataset[split][split_counter[split]]["text"].split()[:5] + ) # grab a few words to prompt it... + split_counter[split] = (split_counter[split] + dp_size) % len( + dataset[split] + ) + threads.append( + threading.Thread( + target=threaded_data_gatherer, + args=( + prefix, + max_completion_len, + tokenizer, + model_name, + num_completions, + i, + dp_idx, + data_to_send, + rm_pipeline, + ), + ) + ) + threads[-1].start() + for thread in threads: + thread.join() + conn.send(pickle.dumps(data_to_send)) + conn.close() + print( + f"Sent data to {dp_idx} for {split} split at iter {split_counter[split]}..." + ) + + +if __name__ == "__main__": + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device="cpu", + ) + dataset = datasets.load_dataset("imdb") + threads = list() + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + for i in range(2): + threads.append( + threading.Thread( + target=data_generator, + args=( + 64, # bs_per_dp + dataset, # dataset + tokenizer, # tokenizer + "meta-llama/Meta-Llama-3-8B-Instruct", # model_name + 128, # max_prefix_len + 256, # max_completion_len + 4, # num_completions + i, # dp_idx + 2, # dp_size + 4, # tp_size + sentiment_fn, # rm_pipeline + ), + ) + ) + threads[-1].start() + for thread in threads: + thread.join() diff --git a/post-training/online_example.sh b/post-training/online_example.sh new file mode 100644 index 000000000..abe601faa --- /dev/null +++ b/post-training/online_example.sh @@ -0,0 +1,7 @@ +# Launch vllm +CUDA_VISIBLE_DEVICES=0,1,2,3 conda run --no-capture-output -n vllm python -m vllm.entrypoints.openai.api_server --model=meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --from-remote-program --tensor-parallel-size=4 --enforce-eager --gpu-memory-utilization=0.2 --port 8000 --max-model-len=1024 --max-num-seqs=512 & + +CUDA_VISIBLE_DEVICES=4,5,6,7 conda run --no-capture-output -n vllm python -m vllm.entrypoints.openai.api_server --model=meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --from-remote-program --tensor-parallel-size=4 --enforce-eager --gpu-memory-utilization=0.2 --port 8001 --max-model-len=1024 --max-num-seqs=512 & + +# Launch training +conda run --no-capture-output -n neox python deepy.py train.py post-training/configs/llama3-8b-reinforce.yml diff --git a/post-training/online_training.md b/post-training/online_training.md new file mode 100644 index 000000000..28f45c7cd --- /dev/null +++ b/post-training/online_training.md @@ -0,0 +1,56 @@ +# Online Training + +## Prerequisites +Want to use [REINFORCE](https://arxiv.org/abs/2402.14740) to train your model? First you'll need to build a custom vllm package. + +[synth-vllm](https://github.com/SynthLabsAI/synth-vllm) is a fork of [vllm](https://github.com/vllm-project/vllm) maintained by [SynthLabs](https://www.synthlabs.ai/) +that has been modified to support using the weights in GPT-NeoX by sharing the GPU memory location of the model weights. + +It currently supports Llama and Pythia models. + +### Building the package + +Here is a reference on how the package has been built before, using conda: +(Note this should be taken as a reference, and may not work as is due to your system configuration) + +```bash +# cd to the synth vllm directory... +conda create -n vllm python=3.10 +conda deactivate +conda activate vllm +conda install -y pytorch pytorch-cuda=12.1 -c pytorch -c nvidia +conda install -y nvidia/label/cuda-12.1.0::cuda-toolkit +conda install -y nvidia/label/cuda-12.1.0::cuda-cudart +conda install -y nvidia/label/cuda-12.1.0::cuda-compiler +conda install -y nvidia/label/cuda-12.1.0::cuda-nvcc +conda install -y nvidia/label/cuda-12.1.0::cuda-profiler-api +conda install -y nvidia/label/cuda-12.1.0::cuda-cudarty +conda install -y -c nvidia cuda-nvprof=12.1 +conda install -y conda-forge::cuda-version=12.1 +conda install -y gcc_linux-64=12.3.0 +conda install -y -c conda-forge gxx_linux-64=12.3.0 +pip install -e . +``` + +## Training + +If you haven't already, run this command to generate a copy of the Llama-3 weights in GPT-NeoX format: +```bash +python tools/ckpts/convert_hf_llama_to_neox.py --tp 4 --model meta-llama/Meta-Llama-3-8B-Instruct --model_path checkpoints/neox_converted/llama3-8b-instruct +``` + +[online_example.sh](online_example.sh), [online_data_example_llama3.py](online_data_example_llama3.py) is an example of +how to train a model using the synth-vllm package on a single node. + +This assumes you are using a conda environment with GPT-NeoX installed under the name `neox`. + +To run the example, execute the following commands: + +```bash +# It may be preferable to run these in two separate terminals +python post-training/online_data_example_llama3.py & +bash post-training/online_example.sh +``` + +This will train a model using the synth-vllm package on the llama3-8b-instruct model. It will optimize a positive reward +from a sentiment classifier. diff --git a/requirements/requirements-transformerengine.txt b/requirements/requirements-transformerengine.txt index 2050d7566..10a1f3b82 100644 --- a/requirements/requirements-transformerengine.txt +++ b/requirements/requirements-transformerengine.txt @@ -1 +1 @@ -pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable +transformer-engine[pytorch]