diff --git a/megatron/arguments.py b/megatron/arguments.py index 4ab086b..0f8ee6a 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -401,6 +401,10 @@ def _add_network_size_args(parser): 'attention. This is set to ' ' args.hidden_size // args.num_attention_heads ' 'if not provided.') + group.add_argument('--attention-head-type', type=str, default='multihead', + choices=['multihead', 'multiquery'], + help='Type of attention heads. `multihead` is the standard multi-head attention.' + '`multiquery` shares the values and keys across attention heads') group.add_argument('--max-position-embeddings', type=int, default=None, help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') @@ -477,6 +481,9 @@ def _add_logging_args(parser): help="Name of wandb entity for reporting") group.add_argument('--wandb-project-name', type=str, default=None, help="Name of wandb project") + group.add_argument('--transformer-timers', action='store_true', + help="If set, activate the timers within the transformer layers." + "Only for debugging, as this slows down the model.") return parser diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index f011770..57d6d9f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -27,7 +27,7 @@ from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl -from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu +from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_linear_layer from .glu_activations import GLU_ACTIVATIONS @@ -233,6 +233,7 @@ def forward(self, query_layer, key_layer, # =================================== # Raw attention scores. [b, np, s, s] # =================================== + np = query_layer.size(2) # [b, np, sq, sk] output_size = (query_layer.size(1), @@ -253,6 +254,7 @@ def forward(self, query_layer, key_layer, (output_size[0]*output_size[1], output_size[2], output_size[3]), query_layer.dtype, "mpu") else: + # alibi: (batch_size * num_attention_heads, 1, max_seq_len) matmul_input_buffer = alibi[:output_size[0]*output_size[1], :, :output_size[3]] # Raw attention scores. [b * np, sq, sk] @@ -307,7 +309,7 @@ def forward(self, query_layer, key_layer, # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), - value_layer.size(2), + np, query_layer.size(0), value_layer.size(3)) @@ -336,6 +338,127 @@ def forward(self, query_layer, key_layer, return context_layer +class MultiQueryCoreAttention(CoreAttention): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi): + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + sq = query_layer.size(0) + bs = query_layer.size(1) + np = query_layer.size(2) + + sk = key_layer.size(0) + # Only one head for key and values + assert key_layer.size(2) == 1 and value_layer.size(2) == 1 + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [b, np * sq, hn] + query_layer = query_layer.permute([1, 2, 0, 3]).reshape(bs, np * sq, -1) + # [sk, b, 1, hn] -> [b, hn, sk] + key_layer = key_layer.squeeze(2).permute(1, 2, 0) + # [sk, b, 1, hn] -> [sk, b * np, hn] + # key_layer = key_layer.expand(output_size[3], output_size[0], np, -1) + # key_layer = key_layer.reshape(output_size[3], output_size[0] * np, -1) + + if alibi is None: + # preallocting input tensor: [b, np * sq, sk] + matmul_input_buffer = get_global_memory_buffer().get_tensor( + (bs, np * sq, sk), + query_layer.dtype, "mpu") + else: + # alibi: (batch_size * num_attention_heads, 1, max_seq_len) + # TODO: ideally, alibi would have the shape: (1, num_heads * sq, sk) + matmul_input_buffer = alibi[:bs * np, :, :sk].view(bs, np, sk) + matmul_input_buffer = matmul_input_buffer.repeat(1, sq, 1) # [b, np * sq, sk] + + if alibi is None: + # Raw attention scores. [b, np * sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, # [b, np * sq, hn] + key_layer, # [b, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + else: + if not hasattr(self, "logged_alibi"): + print("Using Alibi.") + self.logged_alibi = True + + if self.apply_query_key_layer_scaling: + beta = 1.0 / self.layer_number + else: + beta = 1.0 + + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, + key_layer, + beta=beta, alpha=(1.0 / self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(bs, np, sq, sk) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + np, + query_layer.size(0), + value_layer.size(3)) + + # [sk, b, 1, hn] -> [b, sk, hn] + value_layer = value_layer.squeeze(2).transpose(0, 1) + + # change view [b, np * sq, sk] + attention_probs = attention_probs.view(bs, np * sq, -1) + + # matmul: [b, np * sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(bs, np, sq, -1) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + class ParallelAttention(MegatronModule): """Parallel self-attention layer abstract class. @@ -353,6 +476,7 @@ def __init__(self, init_method, self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype + self.attention_head_type = args.attention_head_type projection_size = args.kv_channels * args.num_attention_heads @@ -364,13 +488,28 @@ def __init__(self, init_method, args.num_attention_heads, world_size) # Strided linear layer. - if attention_type == AttnType.self_attn: + if attention_type == AttnType.self_attn and self.attention_head_type == 'multihead': self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method) - else: + elif attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery': + # TODO: Find a way to merge the query and key-value computations? + self.query = mpu.ColumnParallelLinear( + args.hidden_size, + projection_size, + gather_output=False, + init_method=init_method) + # In MultiQuery attention, keys and values are shared across heads + # Use args.kv_channels instead of projection_size + # No `.fork()` so the rng tracker is shared across tensor-parallel processes. + # with mpu.get_cuda_rng_tracker(): + self.key_value = get_linear_layer( + args.hidden_size, + 2 * args.kv_channels, + init_method=init_method) + elif attention_type == AttnType.cross_attn and self.attention_head_type == 'multihead': assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear( args.hidden_size, @@ -383,9 +522,14 @@ def __init__(self, init_method, 2 * projection_size, gather_output=False, init_method=init_method) + else: + raise NotImplementedError("Multiquery attention not implemented for cross-attention.") - self.core_attention = CoreAttention(self.layer_number, - self.attn_mask_type) + if self.attention_head_type == 'multihead': + self.core_attention = CoreAttention(self.layer_number, + self.attn_mask_type) + else: + self.core_attention = MultiQueryCoreAttention(self.layer_number, self.attn_mask_type) self.checkpoint_core_attention = args.recompute_granularity == 'selective' # Output. @@ -419,15 +563,15 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, - self.num_attention_heads_per_partition, + self.num_attention_heads_per_partition if self.attention_head_type == "multihead" else 1, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) + def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None, alibi=None): # hidden_states: [sq, b, h] - # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= @@ -449,7 +593,7 @@ def forward(self, hidden_states, attention_mask, # Query, Key, and Value # ===================== - if self.attention_type == AttnType.self_attn: + if self.attention_type == AttnType.self_attn and self.attention_head_type == 'multihead': # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) @@ -463,6 +607,35 @@ def forward(self, hidden_states, attention_mask, (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) + elif self.attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery': + # Attention heads [sq, b, h] --> [sq, b, (2 * hn)] + mixed_kv_layer = self.key_value(hidden_states) + + # [sq, b, (2 * hn)] --> [sq, b, np (expanded), 2 * hn] + # new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + # (self.num_attention_heads_per_partition, + # 2 * self.hidden_size_per_attention_head) + # mixed_kv_layer = mixed_kv_layer.unsqueeze(2).expand(*new_tensor_shape) + + # [sq, b, (2 * hn)] --> [sq, b, 1, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + (1, + 2 * self.hidden_size_per_attention_head) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn] + (key_layer, + value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2) + + # Attention head [sq, b, h] --> [sq, b, np * hn] + query_layer, _ = self.query(hidden_states) + # [sq, b, np * hn] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_tensor_shape) + + # [sq, b, np, hn] -> [b, np * sq, hn] else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) @@ -489,6 +662,7 @@ def forward(self, hidden_states, attention_mask, # Adjust key and value for inference # ================================== + if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) @@ -520,7 +694,6 @@ def forward(self, hidden_states, attention_mask, # ================= # Output. [sq, b, h] # ================= - output, bias = self.dense(context_layer) return output, bias @@ -963,6 +1136,10 @@ def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # hidden_states: [s, b, h] + timers = get_timers() + args = get_args() + + if args.transformer_timers: timers("Transformer forward").start() # Checks. if inference_params: @@ -1020,4 +1197,6 @@ def forward(self, hidden_states, attention_mask, if self.post_process and self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) + if args.transformer_timers: timers("Transformer forward").stop() + return hidden_states diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 3ee9db2..ac78c3a 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -264,7 +264,8 @@ def backward(ctx, grad_output): handle.wait() # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], + # TODO: Is the reshape preventing us from getting a speedup here? + grad_output = grad_output.reshape(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index b265145..6e83e65 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -265,6 +265,19 @@ def allreduce_embedding_grads(self, args): """All-reduce both word and position embeddings.""" self.allreduce_word_embedding_grads(args) self.allreduce_position_embedding_grads(args) + + def allreduce_key_value_grads(self, args): + # TODO: models[0] ? + unwrapped_model = self.models[0] + unwrapped_model = unwrap_model( + unwrapped_model, (torchDDP, LocalDDP, Float16Module)) + for layer in unwrapped_model.language_model.encoder.layers: + kv_weight = layer.self_attention.key_value.weight + if args.DDP_impl == 'local': + grad = kv_weight.main_grad + else: + grad = kv_weight.grad + torch.distributed.all_reduce(grad, group=mpu.get_tensor_model_parallel_group()) def allreduce_layernorm_grads(self, args): @@ -310,6 +323,13 @@ def reduce_model_grads(self, args, timers): self.allreduce_embedding_grads(args) timers('backward-embedding-all-reduce').stop() + # All-reduce key-value grads if needed. + if args.attention_head_type == "multiquery": + timers('backward-key-value-all-reduce').start() + self.allreduce_key_value_grads(args) + timers('backward-key-value-all-reduce').stop() + + class MixedPrecisionOptimizer(MegatronOptimizer): """Base class for both the float-16 and the distributed optimizer. diff --git a/tools/text_generation_benchmark.py b/tools/text_generation_benchmark.py new file mode 100644 index 0000000..ee458f3 --- /dev/null +++ b/tools/text_generation_benchmark.py @@ -0,0 +1,163 @@ + +"""Sample Generate GPT""" +import os +import sys +import re +sys.path.append(os.path.abspath(os.path.join( + os.getcwd(), + "Megatron-LM", +))) +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import mpu +from megatron.checkpointing import load_checkpoint +from megatron.initialize import initialize_megatron +from megatron.model import GPTModel +from megatron.training import get_model +from megatron.text_generation import generate_and_post_process +import torch +from human_eval.data import write_jsonl, read_problems +from tqdm import tqdm + + +GENERATE_NUM = 0 + +# End on unindented code +# EOF_STRINGS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"] + + +BATCH_SIZE = 512 +TOKENS_TO_GENERATE = 128 +PROMPT_LENGTH = 128 +NUM_BATCHES = 8 + + +# NUM_SAMPLES_PER_TASK = 5 +# # Number of human-eval tasks +# NUM_TASKS = 200 + +def send_do_generate(): + choice = torch.cuda.LongTensor([GENERATE_NUM]) + torch.distributed.broadcast(choice, 0) + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) + + return model + +def get_batches(prompts, batch_size): + for start_idx in tqdm(range(0, len(prompts), batch_size)): + actual_batch_size = min(batch_size, len(prompts) - start_idx) + yield prompts[start_idx: start_idx + actual_batch_size] + + +def unbatch(d: dict): + return [dict(zip(d.keys(), t)) for t in zip(*d.values())] + + +# Use fixed-length prompts +def load_evaluation_data(args): + # HumanEval data + # problems = read_problems() + + # batches = get_batches( + # [ + # problems[task_id]["prompt"] + # for task_id in problems + # for _ in range(5) + # ], + # BATCH_SIZE + # ) + # return batches + + prompt = " ".join(["one"] * PROMPT_LENGTH) + prompts = [prompt] * (BATCH_SIZE * NUM_BATCHES) + + batches = get_batches(prompts, BATCH_SIZE) + return batches + + +if __name__ == "__main__": + # Initialize Megatron + initialize_megatron(extra_args_provider=None, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer', + 'no_load_rng': True, + 'no_load_optim': True}) + + args = get_args() + timers = get_timers() + + if args.num_layers_per_virtual_pipeline_stage is not None: + print("Interleaved pipeline schedule is not yet supported for text generation.") + exit() + # Setup model and load checkpoint + model = get_model(model_provider, wrap_with_ddp=False) + + if args.load is not None: + iteration = load_checkpoint(model, None, None, iteration=None) + else: + iteration = None + + assert len(model) == 1 + model = model[0] + + def generate(prompts): + response, response_seg, response_logprobs, tokens = \ + generate_and_post_process( + model, + prompts=prompts, + tokens_to_generate=TOKENS_TO_GENERATE, + return_output_log_probs=True, + use_eod_token_for_early_termination=False) + + assert all([r.startswith(p) for r, p in zip(response, prompts)]) + result = { + "response": response, + "response_seg": response_seg, + "raw_completion": [r[len(p):] for r, p in zip(response, prompts)] + } + # The "completion" field contains the string that is actually going to be evaluated by the HumanEval script + # result["completion"] = [post_process_completion(c) for c in result["raw_completion"]] + # Return a list of dicts + return unbatch(result) + + # if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: + # server = MegatronServer(model) + # server.run("0.0.0.0") + + # while True: + # choice = torch.cuda.LongTensor(1) + # torch.distributed.broadcast(choice, 0) + # if choice[0].item() == 0: + # generate_and_post_process(model) + + + # Evaluation data iterator + batches = load_evaluation_data(args) + + timers('generate').start() + # Generate + samples = [ + generate_dict + for batch in batches + for generate_dict in generate(batch) + ] + timers('generate').stop() + + elapsed = timers.timers['generate'].elapsed(reset=False) + num_tokens = TOKENS_TO_GENERATE * NUM_BATCHES * BATCH_SIZE + print(f"{elapsed * 1000 / (num_tokens)} ms per token") + timers.log(['generate']) + if args.transformer_timers: + timers.log(["Transformer forward"]) + print("DONE") + + # Write results to file + # if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: + # write_jsonl(args.output_file.format(iteration), samples) +