From e60fb87932fa410b2b430f6c1b3e6986666f68b5 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 15 Oct 2024 16:00:22 -0700 Subject: [PATCH 01/20] Simplify Tensor Parallel implementation with PyTorch TP --- src/transformers/modeling_utils.py | 25 +++++ .../models/llama/modeling_llama.py | 95 +++++++------------ 2 files changed, 61 insertions(+), 59 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cb0d743b0a90ae..9f7fc2800d69bf 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4979,6 +4979,31 @@ def _is_quantized_training_enabled(self): return self.hf_quantizer.is_trainable + def tensor_parallel(self, device_mesh): + """ + Tensor parallelize the model across the given device mesh. + + Args: + device_mesh (`torch.distributed.DeviceMesh`): + The device mesh to use for tensor parallelism. + """ + # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. + # No op if `_tp_plan` attribute does not exist under the module. + # This is a helper function to be used with `model.apply` to recursively + # parallelize a model. + def tplize(mod: torch.nn.Module) -> None: + tp_plan = getattr(mod, "_tp_plan", None) + if tp_plan: + torch.distributed.tensor.parallel.parallelize_module( + mod, + device_mesh=device_mesh, + parallelize_plan=tp_plan, + ) + + # `apply` is a native method of `nn.Module` that recursively applies a + # function to every submodule. + self.apply(tplize) + PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) if PreTrainedModel.push_to_hub.__doc__ is not None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dde017bbb92797..eee8dc437593d9 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -24,6 +24,11 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn +from torch.distributed.tensor import Replicate +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, +) from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN @@ -226,6 +231,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class LlamaMLP(nn.Module): + _tp_plan = { + "gate_proj": ColwiseParallel(), + "up_proj": ColwiseParallel(), + "down_proj": RowwiseParallel(), + } + def __init__(self, config): super().__init__() self.config = config @@ -237,25 +248,7 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -274,6 +267,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" + _tp_plan = { + "q_proj": ColwiseParallel(), + "k_proj": ColwiseParallel(), + "v_proj": ColwiseParallel(), + "o_proj": RowwiseParallel(), + } + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config @@ -317,31 +317,14 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -383,12 +366,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None @@ -559,9 +537,10 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -1102,6 +1081,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = { + "lm_head": ColwiseParallel(output_layouts=Replicate()), + } def __init__(self, config): super().__init__(config) @@ -1198,13 +1180,8 @@ def forward( ) hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: From fd7f7c721ca6a8c59e9ed0c507daee95f0f4d713 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 22 Oct 2024 19:29:49 -0700 Subject: [PATCH 02/20] Move tp_plan to config --- src/transformers/modeling_utils.py | 12 +++++++++ .../models/llama/configuration_llama.py | 10 +++++++ .../models/llama/modeling_llama.py | 25 +++--------------- src/transformers/pytorch_utils.py | 26 +++++++++++++++++++ 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9f7fc2800d69bf..9155063fd05080 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -54,6 +54,7 @@ prune_conv1d_layer, prune_layer, prune_linear_layer, + translate_to_torch_parallel_style, ) from .quantizers import AutoHfQuantizer, HfQuantizer from .quantizers.quantizers_utils import get_module_from_name @@ -4994,6 +4995,17 @@ def tensor_parallel(self, device_mesh): def tplize(mod: torch.nn.Module) -> None: tp_plan = getattr(mod, "_tp_plan", None) if tp_plan: + logger.debug( + f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}" + ) + # In model configs, we use a neutral type (string) to specify + # parallel styles, here we translate them into torch TP types. + # Using tree_map because `tp_plan` is a dict. + tp_plan = torch.utils._pytree.tree_map( + translate_to_torch_parallel_style, + tp_plan, + ) + # Apply TP to current module. torch.distributed.tensor.parallel.parallelize_module( mod, device_mesh=device_mesh, diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index a3667e06534564..0f831a376d1024 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig): model_type = "llama" keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `LlamaModel` + _base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } def __init__( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index eee8dc437593d9..94316ca67d4095 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -24,11 +24,6 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.distributed.tensor import Replicate -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - RowwiseParallel, -) from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN @@ -231,12 +226,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class LlamaMLP(nn.Module): - _tp_plan = { - "gate_proj": ColwiseParallel(), - "up_proj": ColwiseParallel(), - "down_proj": RowwiseParallel(), - } - def __init__(self, config): super().__init__() self.config = config @@ -267,13 +256,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - _tp_plan = { - "q_proj": ColwiseParallel(), - "k_proj": ColwiseParallel(), - "v_proj": ColwiseParallel(), - "o_proj": RowwiseParallel(), - } - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config @@ -824,8 +806,9 @@ def __init__(self, config: LlamaConfig): ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False + self.gradient_checkpointing = False + self._tp_plan = config._base_model_tp_plan # Initialize weights and apply final processing self.post_init() @@ -1081,9 +1064,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] - _tp_plan = { - "lm_head": ColwiseParallel(output_layouts=Replicate()), - } + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index f3663c09902f52..d0a794403da862 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -20,6 +20,11 @@ from packaging import version from safetensors.torch import storage_ptr, storage_size from torch import nn +from torch.distributed.tensor import Replicate +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, +) from .utils import is_torch_xla_available, logging @@ -326,3 +331,24 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) else: # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045 return torch.isin(elements, test_elements) + + +def translate_to_torch_parallel_style(style: str): + """ + In model configurations, we use a neutral type (string) to specify parallel + styles, here we translate them into torch.distributed tensor-parallel + types. + """ + if not isinstance(style, str): + raise ValueError( + f"Unsupported parallel style type {type(style)}, expected str" + ) + + if style == "colwise": + return ColwiseParallel() + elif style == "rowwise": + return RowwiseParallel() + elif style == "colwise_rep": + return ColwiseParallel(output_layouts=Replicate()) + else: + raise ValueError(f"Unsupported parallel style value: {style}") From 79cc524ea7ee2b962c97893dcc3f85223053f80d Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 29 Oct 2024 23:17:11 -0700 Subject: [PATCH 03/20] Lint --- src/transformers/models/llama/modeling_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f78eae6b2a2970..46b34b4df82dad 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn From a2934b3362af3daf364b650a0f8c8883bfcfd453 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 30 Oct 2024 00:56:14 -0700 Subject: [PATCH 04/20] Format and warning --- src/transformers/modeling_utils.py | 5 ++--- src/transformers/models/llama/modeling_llama.py | 2 ++ src/transformers/pytorch_utils.py | 4 +--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 98de259c16ece5..95e2bc71a43fd3 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5022,6 +5022,7 @@ def tensor_parallel(self, device_mesh): device_mesh (`torch.distributed.DeviceMesh`): The device mesh to use for tensor parallelism. """ + # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. # No op if `_tp_plan` attribute does not exist under the module. # This is a helper function to be used with `model.apply` to recursively @@ -5029,9 +5030,7 @@ def tensor_parallel(self, device_mesh): def tplize(mod: torch.nn.Module) -> None: tp_plan = getattr(mod, "_tp_plan", None) if tp_plan: - logger.debug( - f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}" - ) + logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}") # In model configs, we use a neutral type (string) to specify # parallel styles, here we translate them into torch TP types. # Using tree_map because `tp_plan` is a dict. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 46b34b4df82dad..0381ba4720db4a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -813,6 +813,8 @@ def __init__(self, config: LlamaConfig): self.gradient_checkpointing = False self._tp_plan = config._base_model_tp_plan + if config.pretraining_tp != 1: + logger.warn("`pretraining_tp` is deprecated, please use `tensor_parallel` method instead.") # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index d0a794403da862..6ef6a52997d08f 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -340,9 +340,7 @@ def translate_to_torch_parallel_style(style: str): types. """ if not isinstance(style, str): - raise ValueError( - f"Unsupported parallel style type {type(style)}, expected str" - ) + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") if style == "colwise": return ColwiseParallel() From a8fc418c9dd8e7df36ebb9c540b6ac002533dd83 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 30 Oct 2024 01:27:53 -0700 Subject: [PATCH 05/20] Disable copy-from check --- src/transformers/models/cohere/modeling_cohere.py | 2 +- src/transformers/models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmoe/modeling_olmoe.py | 2 +- .../models/recurrent_gemma/modeling_recurrent_gemma.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index b215fb6561bf81..0261f997da1110 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1068,7 +1068,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index d4eb348260c1a4..8de6bc90ea3fec 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -980,7 +980,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 60225d4759c6ab..d865c51e50578e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1020,7 +1020,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index cbb8db0f59dd02..47cb0964eca8b6 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -888,7 +888,7 @@ def _init_weights(self, module): "The bare Olmoe Model outputting raw hidden-states without any specific head on top.", OLMOE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe class OlmoeModel(OlmoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`] diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index d3164b17fe130c..2b3cf7eb0cb82e 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -775,7 +775,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): return causal_mask -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma +# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] From e84a3889d984dc6e3b67cafbd2af4567de9eaea7 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 31 Oct 2024 09:55:50 -0700 Subject: [PATCH 06/20] Conditionally get attr from config --- src/transformers/models/llama/modeling_llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0381ba4720db4a..7f042d27d04277 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -812,9 +812,10 @@ def __init__(self, config: LlamaConfig): self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False - self._tp_plan = config._base_model_tp_plan - if config.pretraining_tp != 1: - logger.warn("`pretraining_tp` is deprecated, please use `tensor_parallel` method instead.") + self._tp_plan = getattr(config, "_base_model_tp_plan", None) + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + # Initialize weights and apply final processing self.post_init() From 396d158c9d8053447d3df4728260babb3f12bd5d Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 31 Oct 2024 10:11:05 -0700 Subject: [PATCH 07/20] make fix-copies --- src/transformers/models/gemma/modeling_gemma.py | 5 +++++ src/transformers/models/gemma2/modeling_gemma2.py | 5 +++++ src/transformers/models/glm/modeling_glm.py | 4 ++++ 3 files changed, 14 insertions(+) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9a4de1022c57e9..f794c687e2619f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -724,7 +724,11 @@ def __init__(self, config: GemmaConfig): [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self._tp_plan = getattr(config, "_base_model_tp_plan", None) + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -986,6 +990,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6d61c47619f304..f6d086fe53e42c 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -738,7 +738,11 @@ def __init__(self, config: Gemma2Config): [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self._tp_plan = getattr(config, "_base_model_tp_plan", None) + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -958,6 +962,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 5f8eaf89ed9353..0ba04f0d674f0e 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -712,6 +712,9 @@ def __init__(self, config: GlmConfig): dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta ) self.gradient_checkpointing = False + self._tp_plan = getattr(config, "_base_model_tp_plan", None) + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -971,6 +974,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config: GlmConfig): super().__init__(config) From 7b346b552617771065a8bc3408a8c5fbe3684530 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 31 Oct 2024 11:09:28 -0700 Subject: [PATCH 08/20] Move base_model_tp_plan to PretrainedConfig --- src/transformers/configuration_utils.py | 3 +++ src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/llama/configuration_llama.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- 6 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 1d892c49a231fc..de3ac84ce002c2 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -71,6 +71,8 @@ class PretrainedConfig(PushToHubMixin): outputs of the model during inference. - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized naming of attributes. + - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor + parallel plan applied to the sub-module when `model.tensor_parallel` is called. Common attributes (present in all subclasses): @@ -192,6 +194,7 @@ class PretrainedConfig(PushToHubMixin): model_type: str = "" is_composition: bool = False attribute_map: Dict[str, str] = {} + base_model_tp_plan: Optional[Dict[str, Any]] = None _auto_class: Optional[str] = None def __setattr__(self, key, value): diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index f794c687e2619f..3d3c034b6f7875 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -726,7 +726,7 @@ def __init__(self, config: GemmaConfig): self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False - self._tp_plan = getattr(config, "_base_model_tp_plan", None) + self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index f6d086fe53e42c..0d1cf96b0ae0df 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -740,7 +740,7 @@ def __init__(self, config: Gemma2Config): self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False - self._tp_plan = getattr(config, "_base_model_tp_plan", None) + self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 0ba04f0d674f0e..8af8938ef4763d 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -712,7 +712,7 @@ def __init__(self, config: GlmConfig): dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta ) self.gradient_checkpointing = False - self._tp_plan = getattr(config, "_base_model_tp_plan", None) + self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 0f831a376d1024..98d5ecdd2a4fdb 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -142,7 +142,7 @@ class LlamaConfig(PretrainedConfig): model_type = "llama" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `LlamaModel` - _base_model_tp_plan = { + base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7f042d27d04277..a967cfe685ec5a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -812,7 +812,7 @@ def __init__(self, config: LlamaConfig): self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False - self._tp_plan = getattr(config, "_base_model_tp_plan", None) + self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") From 12fbbe70ddbac9b4aae153092ff672fc4bfec7de Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 6 Nov 2024 19:32:19 -0800 Subject: [PATCH 09/20] Move TP into from_pretrained --- src/transformers/modeling_utils.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 95e2bc71a43fd3..468de904e34d01 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3473,6 +3473,12 @@ def from_pretrained( # Cache path to the GGUF file gguf_path = None + tp_plan = kwargs.pop("tp_plan", None) + if tp_plan is not None and tp_plan != "auto": + raise ValueError( + f"tp_plan supports 'auto' only for now but got {tp_plan}." + ) + if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -4074,6 +4080,7 @@ def from_pretrained( # Instantiate model. init_contexts = [no_init_weights(_enable=_fast_init)] + tp_device = None if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed @@ -4086,6 +4093,19 @@ def from_pretrained( f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" ) init_contexts.append(init_empty_weights()) + elif tp_plan is not None: + if not torch.distributed.is_initialized(): + raise ValueError( + "Tensor Parallel requires torch.distributed to be initialized first." + ) + + # Get device type (e.g. "cuda") + device_type = torch.distributed.distributed_c10d._device_capability()[0] + # Get torch device module (e.g. torch.cuda) based on device type + device_module = torch.get_device_module(device_type) + # Get device with index assuming equal number of devices per host + tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count()) + init_contexts.append(tp_device) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. if not getattr(config, "_attn_implementation_autoset", False): @@ -4325,6 +4345,14 @@ def from_pretrained( } return model, loading_info + if tp_plan is not None: + assert tp_device is not None, "tp_device not set!" + # Assuming sharding the model onto the world + world_size = torch.distributed.get_world_size() + device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) + # Apply Tensor Parallelism + model.tensor_parallel(device_mesh) + return model @classmethod From 02c8c39399830b8c8855f4c999f742abfb733761 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 6 Nov 2024 19:45:45 -0800 Subject: [PATCH 10/20] Add device context for load --- src/transformers/modeling_utils.py | 58 ++++++++++++++++-------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 468de904e34d01..ee2406e747a2e0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4236,32 +4236,38 @@ def from_pretrained( if dtype_orig is not None: torch.set_default_dtype(dtype_orig) - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - offload_index, - error_msgs, - ) = cls._load_pretrained_model( - model, - state_dict, - loaded_state_dict_keys, # XXX: rename? - resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - sharded_metadata=sharded_metadata, - _fast_init=_fast_init, - low_cpu_mem_usage=low_cpu_mem_usage, - device_map=device_map, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - gguf_path=gguf_path, - weights_only=weights_only, - ) + load_contexts = [] + # Make sure we load onto targeted device + if tp_device is not None: + load_contexts.append(tp_device) + + with ContextManagers(load_contexts): + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + gguf_path=gguf_path, + weights_only=weights_only, + ) # make sure token embedding weights are still tied if needed model.tie_weights() From 073c521dc883aaba0e3d009c5abc64dc670a9df9 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 7 Nov 2024 12:20:21 -0800 Subject: [PATCH 11/20] Do not serialize --- src/transformers/configuration_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index de3ac84ce002c2..9fda18cfbac1e3 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -838,6 +838,9 @@ def to_diff_dict(self) -> Dict[str, Any]: if "_attn_implementation_internal" in serializable_config_dict: del serializable_config_dict["_attn_implementation_internal"] + # Do not serialize `base_model_tp_plan` for now + if "base_model_tp_plan" in serializable_config_dict: + del serializable_config_dict["base_model_tp_plan"] return serializable_config_dict @@ -857,6 +860,9 @@ def to_dict(self) -> Dict[str, Any]: del output["_commit_hash"] if "_attn_implementation_internal" in output: del output["_attn_implementation_internal"] + # Do not serialize `base_model_tp_plan` for now + if "base_model_tp_plan" in output: + del output["base_model_tp_plan"] # Transformers version when serializing the model output["transformers_version"] = __version__ From db6e5eebe0b5e0b116e1c74f16bf299bb502c9cb Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 7 Nov 2024 13:29:20 -0800 Subject: [PATCH 12/20] Move _tp_plan setting to post_init --- src/transformers/modeling_utils.py | 49 +++++++++++-------- .../models/gemma/modeling_gemma.py | 1 - .../models/gemma2/modeling_gemma2.py | 1 - src/transformers/models/glm/modeling_glm.py | 1 - .../models/llama/modeling_llama.py | 1 - 5 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ee2406e747a2e0..bde262843d80dd 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1399,6 +1399,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Has support for a `QuantoQuantizedCache` instance as `past_key_values` _supports_quantized_cache = False + # A tensor parallel plan to be applied to the model when TP is enabled. For + # top-level models, this attribute is currently defined in respective model + # code. For base models, this attribute comes from + # `config.base_model_tp_plan` during `post_init`. + _tp_plan = None + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ @@ -1443,6 +1449,9 @@ def post_init(self): """ self.init_weights() self._backward_compatibility_gradient_checkpointing() + # If current model is a base model, attach `base_model_tp_plan` from config + if self.base_model is self: + self._tp_plan = self.config.base_model_tp_plan def dequantize(self): """ @@ -3475,9 +3484,8 @@ def from_pretrained( tp_plan = kwargs.pop("tp_plan", None) if tp_plan is not None and tp_plan != "auto": - raise ValueError( - f"tp_plan supports 'auto' only for now but got {tp_plan}." - ) + # TODO: we can relax this check when we support taking tp_plan from a json file, for example. + raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.") if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -4095,9 +4103,7 @@ def from_pretrained( init_contexts.append(init_empty_weights()) elif tp_plan is not None: if not torch.distributed.is_initialized(): - raise ValueError( - "Tensor Parallel requires torch.distributed to be initialized first." - ) + raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") # Get device type (e.g. "cuda") device_type = torch.distributed.distributed_c10d._device_capability()[0] @@ -5063,21 +5069,22 @@ def tensor_parallel(self, device_mesh): # parallelize a model. def tplize(mod: torch.nn.Module) -> None: tp_plan = getattr(mod, "_tp_plan", None) - if tp_plan: - logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}") - # In model configs, we use a neutral type (string) to specify - # parallel styles, here we translate them into torch TP types. - # Using tree_map because `tp_plan` is a dict. - tp_plan = torch.utils._pytree.tree_map( - translate_to_torch_parallel_style, - tp_plan, - ) - # Apply TP to current module. - torch.distributed.tensor.parallel.parallelize_module( - mod, - device_mesh=device_mesh, - parallelize_plan=tp_plan, - ) + if tp_plan is None: + return + logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}") + # In model configs, we use a neutral type (string) to specify + # parallel styles, here we translate them into torch TP types. + # Using tree_map because `tp_plan` is a dict. + tp_plan = torch.utils._pytree.tree_map( + translate_to_torch_parallel_style, + tp_plan, + ) + # Apply TP to current module. + torch.distributed.tensor.parallel.parallelize_module( + mod, + device_mesh=device_mesh, + parallelize_plan=tp_plan, + ) # `apply` is a native method of `nn.Module` that recursively applies a # function to every submodule. diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 650e57de6d93d3..6fead73eced704 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -722,7 +722,6 @@ def __init__(self, config: GemmaConfig): self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False - self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 2e6168bfb70e45..6a3d8f27fb177d 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -742,7 +742,6 @@ def __init__(self, config: Gemma2Config): self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False - self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index d108e645bf562b..58a89d90b44ff5 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -708,7 +708,6 @@ def __init__(self, config: GlmConfig): dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta ) self.gradient_checkpointing = False - self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a967cfe685ec5a..679296648a9135 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -812,7 +812,6 @@ def __init__(self, config: LlamaConfig): self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False - self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") From 5bb294ecc33abd419eb6a83fa959cfc70358a705 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 14 Nov 2024 13:19:28 -0800 Subject: [PATCH 13/20] Add has_tp_plan --- src/transformers/modeling_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bde262843d80dd..5d9e69e04026ca 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4359,6 +4359,8 @@ def from_pretrained( if tp_plan is not None: assert tp_device is not None, "tp_device not set!" + if not model.has_tp_plan: + raise NotImplementedError("This model does not have a tensor parallel plan.") # Assuming sharding the model onto the world world_size = torch.distributed.get_world_size() device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) @@ -5054,6 +5056,18 @@ def _is_quantized_training_enabled(self): return self.hf_quantizer.is_trainable + @property + def has_tp_plan(self): + """ + Returns whether the model has a tensor parallelism plan. + """ + if self._tp_plan is not None: + return True + # Check if base model has a TP plan + if getattr(self.base_model, "_tp_plan", None) is not None: + return True + return False + def tensor_parallel(self, device_mesh): """ Tensor parallelize the model across the given device mesh. From 290a7f18dc96cd6b2cd0594aa14e947b54bdb04e Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 14 Nov 2024 20:33:07 -0800 Subject: [PATCH 14/20] Add test_tp --- tests/tp/test_tp.py | 91 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 tests/tp/test_tp.py diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py new file mode 100644 index 00000000000000..14e01e3f8f22d0 --- /dev/null +++ b/tests/tp/test_tp.py @@ -0,0 +1,91 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from transformers import is_torch_available +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel +from transformers.testing_utils import ( + TestCasePlus, + execute_subprocess_async, + get_torch_dist_unique_port, + require_torch_multi_gpu, +) + + +if is_torch_available(): + import torch + + +class TestTensorParallel(TestCasePlus): + @require_torch_multi_gpu + def test_tp(self): + distributed_args = f"""--nproc_per_node={torch.cuda.device_count()} + --master_port={get_torch_dist_unique_port()} + {self.test_file_dir}/test_tp.py + """.split() + output_dir = self.get_auto_remove_tmp_dir() + args = f"--output_dir {output_dir} --report_to none".split() + cmd = ["torchrun"] + distributed_args + args + print(cmd) + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + +if __name__ == "__main__": + # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs: + # CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py + # or + # PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py + + if not is_torch_available(): + exit(0) + + # Test settings + model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + bs = 4 + seqlen = 64 + + # Get distributed settings + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # Initialize distributed + device = torch.device(f"cuda:{rank}") + torch.distributed.init_process_group("nccl", device_id=device) + device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,)) + + # Get model config + config = LlamaConfig.from_pretrained(model_id) + # Shrink model size + config.num_hidden_layers //= 8 + config.vocab_size //= 8 + + # Instantiate model + with device: + model = LlamaModel(config) + + model.eval() + + # Tensor Parallel + if world_size > 1: + model.tensor_parallel(device_mesh) + + # Run model + inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device) + with torch.no_grad(): + out = model(inputs) + + assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size]) From bd2e89c1d44c412d99313ed98e0433d2286aa828 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 14 Nov 2024 21:54:48 -0800 Subject: [PATCH 15/20] Add 'Multi-gpu inference' doc --- docs/source/en/_toctree.yml | 2 + docs/source/en/perf_infer_gpu_multi.md | 68 ++++++++++++++++++++++++++ docs/source/en/performance.md | 2 +- 3 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/perf_infer_gpu_multi.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a7806059afaa59..4e44ca3038cb2f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -218,6 +218,8 @@ title: CPU inference - local: perf_infer_gpu_one title: GPU inference + - local: perf_infer_gpu_multi + title: Multi-GPU inference title: Optimizing inference - local: big_models title: Instantiate a big model diff --git a/docs/source/en/perf_infer_gpu_multi.md b/docs/source/en/perf_infer_gpu_multi.md new file mode 100644 index 00000000000000..55a804fac2e03a --- /dev/null +++ b/docs/source/en/perf_infer_gpu_multi.md @@ -0,0 +1,68 @@ + + +# Multi-GPU inference + +Built-in Tensor Parallelism (TP) is now available with certain models using PyTorch. Tensor parallelism shards a model onto multiple GPUs, enabling larger model sizes, and parallelizes computations such as matrix multiplication. + +To enable tensor parallel, pass the argument `tp_plan="auto"` to [`~AutoModelForCausalLM.from_pretrained`]: + +```python +import os +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + +# Initialize distributed +rank = int(os.environ["RANK"]) +device = torch.device(f"cuda:{rank}") +torch.distributed.init_process_group("nccl", device_id=device) + +# Retrieve tensor parallel model +model = AutoModelForCausalLM.from_pretrained( + model_id, + tp_plan="auto", +) + +# Prepare input tokens +tokenizer = AutoTokenizer.from_pretrained(model_id) +prompt = "Can I help" +inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + +# Distributed run +outputs = model(inputs) +``` + +You can use `torchrun` to launch the above script with multiple processes, each mapping to a GPU: + +``` +torchrun --nproc-per-node 4 demo.py +``` + +PyTorch tensor parallel is currently supported for the following models: +* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) + +You can request to add tensor parallel support for another model by opening a GitHub Issue or Pull Request. + +### Expected speedups + +You can benefit from considerable speedups for inference, especially for inputs with large batch size or long sequences. + +For a single forward pass on [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) with a sequence length of 512 and various batch sizes, the expected speedup is as follows: + +
+ +
diff --git a/docs/source/en/performance.md b/docs/source/en/performance.md index 94e756cf33ada6..b9176be04ec206 100644 --- a/docs/source/en/performance.md +++ b/docs/source/en/performance.md @@ -53,7 +53,7 @@ sections we go through the steps to run inference on CPU and single/multi-GPU se * [Inference on a single CPU](perf_infer_cpu) * [Inference on a single GPU](perf_infer_gpu_one) -* [Multi-GPU inference](perf_infer_gpu_one) +* [Multi-GPU inference](perf_infer_gpu_multi) * [XLA Integration for TensorFlow Models](tf_xla) From 9648f31629230abb0b99dac028d9f942bcb4043e Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 14 Nov 2024 22:18:34 -0800 Subject: [PATCH 16/20] Add backward support for device type identification --- src/transformers/modeling_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 94d8f408b9c1ee..5100da5844f52b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4116,7 +4116,16 @@ def from_pretrained( raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") # Get device type (e.g. "cuda") - device_type = torch.distributed.distributed_c10d._device_capability()[0] + try: + # torch 2.6 API + device_type = torch.distributed.distributed_c10d._device_capability()[0] + except AttributeError: + if torch.cuda.is_available(): + device_type = "cuda" + else: + raise RuntimeError( + "Device type unknown. Please run model.tensor_parallel with an explicit DeviceMesh." + ) # Get torch device module (e.g. torch.cuda) based on device type device_module = torch.get_device_module(device_type) # Get device with index assuming equal number of devices per host From 93ba28355b1200491d4c6f1971d654f0847dda87 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 15 Nov 2024 19:43:37 -0800 Subject: [PATCH 17/20] Auto-detect accelerator --- src/transformers/modeling_utils.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5100da5844f52b..cee499f37875f8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4115,18 +4115,8 @@ def from_pretrained( if not torch.distributed.is_initialized(): raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") - # Get device type (e.g. "cuda") - try: - # torch 2.6 API - device_type = torch.distributed.distributed_c10d._device_capability()[0] - except AttributeError: - if torch.cuda.is_available(): - device_type = "cuda" - else: - raise RuntimeError( - "Device type unknown. Please run model.tensor_parallel with an explicit DeviceMesh." - ) - # Get torch device module (e.g. torch.cuda) based on device type + # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. + device_type = torch._C._get_accelerator().type device_module = torch.get_device_module(device_type) # Get device with index assuming equal number of devices per host tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count()) From 73524c901df81fc088cb672dc1507dec6578dccf Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 15 Nov 2024 19:45:07 -0800 Subject: [PATCH 18/20] supports_tp_plan --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cee499f37875f8..d0dacc4802438d 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4371,7 +4371,7 @@ def from_pretrained( if tp_plan is not None: assert tp_device is not None, "tp_device not set!" - if not model.has_tp_plan: + if not model.supports_tp_plan: raise NotImplementedError("This model does not have a tensor parallel plan.") # Assuming sharding the model onto the world world_size = torch.distributed.get_world_size() @@ -5069,7 +5069,7 @@ def _is_quantized_training_enabled(self): return self.hf_quantizer.is_trainable @property - def has_tp_plan(self): + def supports_tp_plan(self): """ Returns whether the model has a tensor parallelism plan. """ From f312e5513645510debcedd6ca1d3529913a32ecc Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 15 Nov 2024 19:58:49 -0800 Subject: [PATCH 19/20] copyright year --- docs/source/en/perf_infer_gpu_multi.md | 2 +- tests/tp/test_tp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_multi.md b/docs/source/en/perf_infer_gpu_multi.md index 55a804fac2e03a..9975094411527a 100644 --- a/docs/source/en/perf_infer_gpu_multi.md +++ b/docs/source/en/perf_infer_gpu_multi.md @@ -1,4 +1,4 @@ -