From 843ecd4602625d19d5ad8fc3441c7c90e5ea187c Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Fri, 4 Oct 2024 04:00:56 -0400 Subject: [PATCH] Scattermoe TP and SP (#29) Signed-off-by: Shawn Tan Signed-off-by: Mayank Mishra Co-authored-by: Mayank Mishra --- Makefile | 5 +- dolomite_engine/checkpointing.py | 2 +- dolomite_engine/hf_models/__init__.py | 10 +- dolomite_engine/hf_models/mixins/__init__.py | 3 +- .../hf_models/mixins/moe/__init__.py | 2 +- dolomite_engine/hf_models/mixins/moe/base.py | 57 +++++- dolomite_engine/hf_models/mixins/moe/main.py | 8 + .../hf_models/mixins/moe_TP/__init__.py | 2 + .../hf_models/mixins/moe_TP/base.py | 75 ++++++++ .../hf_models/mixins/moe_TP/main.py | 5 + .../hf_models/modeling_utils_TP/TP.py | 9 +- dolomite_engine/hf_models/models/__init__.py | 13 +- .../models/gpt_dolomite_TP/__init__.py | 2 +- .../gpt_dolomite_TP/weights/__init__.py | 2 +- .../models/gpt_dolomite_TP/weights/unshard.py | 4 +- .../hf_models/models/moe_dolomite/base.py | 58 +----- .../hf_models/models/moe_dolomite/main.py | 8 - .../models/moe_dolomite_TP/__init__.py | 3 + .../hf_models/models/moe_dolomite_TP/base.py | 12 ++ .../hf_models/models/moe_dolomite_TP/layer.py | 59 ++++++ .../hf_models/models/moe_dolomite_TP/main.py | 8 + .../models/moe_dolomite_TP/moe_TP/scatter.py | 8 +- .../moe_dolomite_TP/weights/__init__.py | 2 + .../models/moe_dolomite_TP/weights/shard.py | 115 ++++++++++++ .../models/moe_dolomite_TP/weights/unshard.py | 177 ++++++++++++++++++ .../hf_models/models/rnn_dolomite/base.py | 10 +- .../hf_models/models/rnn_dolomite/layer.py | 6 +- dolomite_engine/hf_models/register_hf.py | 18 +- dolomite_engine/hf_models/unshard.py | 51 +++++ dolomite_engine/model_wrapper/base.py | 6 +- tests/hf_models/multi_gpu/dcp/dcp_test.py | 4 + .../tensor_parallel/tensor_parallel_test.py | 1 + .../multi_gpu/unsharding/unsharding.py | 102 +++++++--- .../multi_gpu/unsharding/unsharding_test.py | 24 ++- tests/hf_models/test_common.py | 11 +- .../params_group/params_group_test.py | 6 + .../training_configs/rnn_dolomite_config.yml | 1 + 37 files changed, 755 insertions(+), 134 deletions(-) create mode 100644 dolomite_engine/hf_models/mixins/moe_TP/__init__.py create mode 100644 dolomite_engine/hf_models/mixins/moe_TP/base.py create mode 100644 dolomite_engine/hf_models/mixins/moe_TP/main.py create mode 100644 dolomite_engine/hf_models/models/moe_dolomite_TP/__init__.py create mode 100644 dolomite_engine/hf_models/models/moe_dolomite_TP/base.py create mode 100644 dolomite_engine/hf_models/models/moe_dolomite_TP/layer.py create mode 100644 dolomite_engine/hf_models/models/moe_dolomite_TP/main.py create mode 100644 dolomite_engine/hf_models/models/moe_dolomite_TP/weights/__init__.py create mode 100644 dolomite_engine/hf_models/models/moe_dolomite_TP/weights/shard.py create mode 100644 dolomite_engine/hf_models/models/moe_dolomite_TP/weights/unshard.py create mode 100644 dolomite_engine/hf_models/unshard.py diff --git a/Makefile b/Makefile index b0691cd2..ce3b1cab 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,10 @@ install-dev: cd .. test: - pytest tests + RUN_SLOW=True pytest tests + +test-fast: + RUN_SLOW=False pytest tests update-precommit: pre-commit autoupdate diff --git a/dolomite_engine/checkpointing.py b/dolomite_engine/checkpointing.py index bf642ed8..c5621062 100644 --- a/dolomite_engine/checkpointing.py +++ b/dolomite_engine/checkpointing.py @@ -28,7 +28,7 @@ from .arguments import InferenceArgs, TrainingArgs, UnshardingArgs from .data import ResumableDataLoader from .enums import DistributedBackend, Mode, TuningMethod -from .hf_models.models.gpt_dolomite_TP import fix_unsharded_state_dict +from .hf_models import fix_unsharded_state_dict from .model_wrapper import ModelWrapper, get_model from .optimization import get_scheduler from .utils import ExperimentsTracker, ProcessGroupManager, load_yaml, log_rank_0, run_rank_n, string_to_torch_dtype diff --git a/dolomite_engine/hf_models/__init__.py b/dolomite_engine/hf_models/__init__.py index d61dc38f..feadfa40 100644 --- a/dolomite_engine/hf_models/__init__.py +++ b/dolomite_engine/hf_models/__init__.py @@ -11,18 +11,16 @@ GPTDolomiteModel_TP, MoEDolomiteConfig, MoEDolomiteForCausalLM, + MoEDolomiteForCausalLM_TP, MoEDolomiteModel, + MoEDolomiteModel_TP, RNNDolomiteConfig, RNNDolomiteForCausalLM, RNNDolomiteModel, convert_gpt_dolomite_to_gpt_crosslayer, ) -from .register_hf import ( - get_tensor_parallel_class, - is_custom_model, - is_tensor_parallel_compatible_model, - register_model_classes, -) +from .register_hf import get_tensor_parallel_class, is_custom_model, register_model_classes +from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts from .utils import convert_padding_free_lists_to_tensors diff --git a/dolomite_engine/hf_models/mixins/__init__.py b/dolomite_engine/hf_models/mixins/__init__.py index 06b4a213..42292e55 100644 --- a/dolomite_engine/hf_models/mixins/__init__.py +++ b/dolomite_engine/hf_models/mixins/__init__.py @@ -1,3 +1,4 @@ from .dense import BaseModelMixin, CausalLMModelMixin, PreTrainedModelMixin from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP -from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin +from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin, PreTrainedMoEModelMixin +from .moe_TP import BaseMoEModelMixin_TP, CausalLMMoEModelMixin_TP, PreTrainedMoEModelMixin_TP diff --git a/dolomite_engine/hf_models/mixins/moe/__init__.py b/dolomite_engine/hf_models/mixins/moe/__init__.py index 4d93fbd6..b7a62fe0 100644 --- a/dolomite_engine/hf_models/mixins/moe/__init__.py +++ b/dolomite_engine/hf_models/mixins/moe/__init__.py @@ -1,2 +1,2 @@ -from .base import BaseMoEModelMixin +from .base import BaseMoEModelMixin, PreTrainedMoEModelMixin from .main import CausalLMMoEModelMixin diff --git a/dolomite_engine/hf_models/mixins/moe/base.py b/dolomite_engine/hf_models/mixins/moe/base.py index 87465528..54ed982a 100644 --- a/dolomite_engine/hf_models/mixins/moe/base.py +++ b/dolomite_engine/hf_models/mixins/moe/base.py @@ -1,10 +1,14 @@ from dataclasses import dataclass import torch +import torch.nn as nn from transformers import DynamicCache from transformers.modeling_outputs import MoeModelOutputWithPast -from ..dense import BaseModelMixin +from ...config import CommonConfig +from ...enums import AttentionHeadType, PositionEmbeddingType +from ...modeling_utils import ParameterizedEmbedding, get_normalization_function +from ..dense import BaseModelMixin, PreTrainedModelMixin @dataclass @@ -12,7 +16,58 @@ class MoeModelOutputWithPastAndAuxLoss(MoeModelOutputWithPast): aux_loss: torch.Tensor | None = None +class PreTrainedMoEModelMixin(PreTrainedModelMixin): + def __init__(self, config: CommonConfig, *args, **kwargs) -> None: + self.moe_implementation = kwargs.get("moe_implementation", "eager") + assert self.moe_implementation in ["eager", "scattermoe"] + + super().__init__(config, *args, **kwargs) + + class BaseMoEModelMixin(BaseModelMixin): + def _init_model(self, config: CommonConfig, **kwargs) -> None: + self.attention_head_type = AttentionHeadType(config.attention_head_type) + self.embed_dim = config.n_embd + self.num_heads = config.n_head + self.m_emb = config.m_emb + self.initializer_range = config.initializer_range + self.mask_value = None + + assert ( + self.embed_dim % self.num_heads == 0 + ), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})" + + self.head_dim = self.embed_dim // self.num_heads + + self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) + + self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList( + [ + self.layer_class( + config, + normalization_implementation=self.normalization_implementation, + attention_implementation=self.attention_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + moe_implementation=self.moe_implementation, + layer_idx=i, + ) + for i in range(config.n_layer) + ] + ) + self.ln_f = get_normalization_function( + config.normalization_function, + self.embed_dim, + eps=config.layer_norm_epsilon, + normalization_implementation=self.normalization_implementation, + ) + + self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self._setup_positional_encoding() + + # Initialize weights and apply final processing + self.post_init() + def forward( self, input_ids: torch.Tensor | None = None, diff --git a/dolomite_engine/hf_models/mixins/moe/main.py b/dolomite_engine/hf_models/mixins/moe/main.py index 9c49e3cc..9f52c4d1 100644 --- a/dolomite_engine/hf_models/mixins/moe/main.py +++ b/dolomite_engine/hf_models/mixins/moe/main.py @@ -1,11 +1,19 @@ import torch from transformers.modeling_outputs import MoeCausalLMOutputWithPast +from ...config import CommonConfig from ..dense import CausalLMModelMixin from .base import MoeModelOutputWithPastAndAuxLoss class CausalLMMoEModelMixin(CausalLMModelMixin): + def __init__(self, config: CommonConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + def forward( self, input_ids: torch.Tensor | list[list[int]] | None = None, diff --git a/dolomite_engine/hf_models/mixins/moe_TP/__init__.py b/dolomite_engine/hf_models/mixins/moe_TP/__init__.py new file mode 100644 index 00000000..e4e90ab7 --- /dev/null +++ b/dolomite_engine/hf_models/mixins/moe_TP/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseMoEModelMixin_TP, PreTrainedMoEModelMixin_TP +from .main import CausalLMMoEModelMixin_TP diff --git a/dolomite_engine/hf_models/mixins/moe_TP/base.py b/dolomite_engine/hf_models/mixins/moe_TP/base.py new file mode 100644 index 00000000..55b09ded --- /dev/null +++ b/dolomite_engine/hf_models/mixins/moe_TP/base.py @@ -0,0 +1,75 @@ +import torch.nn as nn + +from ....utils import ProcessGroupManager +from ...config import CommonConfig +from ...enums import AttentionHeadType, PositionEmbeddingType +from ...modeling_utils_TP import Dropout_TP, Embedding_TP, get_normalization_function_TP +from ..dense_TP import BaseModelMixin_TP, PreTrainedModelMixin_TP +from ..moe import BaseMoEModelMixin, PreTrainedMoEModelMixin + + +class PreTrainedMoEModelMixin_TP(PreTrainedMoEModelMixin, PreTrainedModelMixin_TP): + def __init__(self, config: CommonConfig, *args, **kwargs): + self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False) + self.sequence_parallel = kwargs.get("sequence_parallel", False) + + super().__init__(config, *args, **kwargs) + + +class BaseMoEModelMixin_TP(BaseMoEModelMixin, BaseModelMixin_TP): + def _init_model(self, config: CommonConfig, **kwargs) -> None: + self.attention_head_type = AttentionHeadType(config.attention_head_type) + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.m_emb = config.m_emb + self.initializer_range = config.initializer_range + self.head_dim = self.embed_dim // self.num_heads + + self.tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + self.wte = Embedding_TP( + config.vocab_size, + self.embed_dim, + std=self.initializer_range, + tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + + self.drop = ( + nn.Identity() + if config.embd_pdrop == 0 + else Dropout_TP( + config.embd_pdrop, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + ) + self.h = nn.ModuleList( + [ + self.layer_class( + config, + normalization_implementation=self.normalization_implementation, + attention_implementation=self.attention_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + moe_implementation=self.moe_implementation, + layer_idx=i, + sequence_parallel=self.sequence_parallel, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = get_normalization_function_TP( + config.normalization_function, + self.embed_dim, + eps=config.layer_norm_epsilon, + normalization_implementation=self.normalization_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + + self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self._setup_positional_encoding() + + # Initialize weights and apply final processing + self.post_init() diff --git a/dolomite_engine/hf_models/mixins/moe_TP/main.py b/dolomite_engine/hf_models/mixins/moe_TP/main.py new file mode 100644 index 00000000..3ccc622d --- /dev/null +++ b/dolomite_engine/hf_models/mixins/moe_TP/main.py @@ -0,0 +1,5 @@ +from ..dense_TP import CausalLMModelMixin_TP +from ..moe import CausalLMMoEModelMixin + + +class CausalLMMoEModelMixin_TP(CausalLMMoEModelMixin, CausalLMModelMixin_TP): ... diff --git a/dolomite_engine/hf_models/modeling_utils_TP/TP.py b/dolomite_engine/hf_models/modeling_utils_TP/TP.py index 01b07f70..24126be1 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/TP.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/TP.py @@ -70,13 +70,16 @@ def dtensor_to_tensor( @torch.no_grad() def modify_state_dict_to_dtensor_dict(module: nn.Module, state_dict: dict, prefix: str, strip_keys: bool) -> dict: + module_state_dict = module.state_dict() + result = {} for key, tensor in state_dict.items(): if key.startswith(prefix): - striped_key = key.split(prefix)[1] if strip_keys else key + stripped_key = key.split(prefix)[1] if strip_keys and prefix != "" else key - device_mesh = getattr(module, striped_key).device_mesh - placements = getattr(module, striped_key).placements + param = module_state_dict[stripped_key] + device_mesh = param.device_mesh + placements = param.placements result[key] = DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements) return result diff --git a/dolomite_engine/hf_models/models/__init__.py b/dolomite_engine/hf_models/models/__init__.py index db7fa832..7a6bac78 100644 --- a/dolomite_engine/hf_models/models/__init__.py +++ b/dolomite_engine/hf_models/models/__init__.py @@ -5,6 +5,17 @@ convert_gpt_dolomite_to_gpt_crosslayer, ) from .gpt_dolomite import GPTDolomiteConfig, GPTDolomiteForCausalLM, GPTDolomiteModel -from .gpt_dolomite_TP import GPTDolomiteForCausalLM_TP, GPTDolomiteModel_TP +from .gpt_dolomite_TP import ( + GPTDolomiteForCausalLM_TP, + GPTDolomiteModel_TP, + fix_gpt_dolomite_unsharded_state_dict, + unshard_gpt_dolomite_tensor_parallel_state_dicts, +) from .moe_dolomite import MoEDolomiteConfig, MoEDolomiteForCausalLM, MoEDolomiteModel +from .moe_dolomite_TP import ( + MoEDolomiteForCausalLM_TP, + MoEDolomiteModel_TP, + fix_moe_dolomite_unsharded_state_dict, + unshard_moe_dolomite_tensor_parallel_state_dicts, +) from .rnn_dolomite import RNNDolomiteConfig, RNNDolomiteForCausalLM, RNNDolomiteModel diff --git a/dolomite_engine/hf_models/models/gpt_dolomite_TP/__init__.py b/dolomite_engine/hf_models/models/gpt_dolomite_TP/__init__.py index c151abd6..041a3c63 100644 --- a/dolomite_engine/hf_models/models/gpt_dolomite_TP/__init__.py +++ b/dolomite_engine/hf_models/models/gpt_dolomite_TP/__init__.py @@ -1,3 +1,3 @@ from .base import GPTDolomiteModel_TP from .main import GPTDolomiteForCausalLM_TP -from .weights import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts +from .weights import fix_gpt_dolomite_unsharded_state_dict, unshard_gpt_dolomite_tensor_parallel_state_dicts diff --git a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/__init__.py b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/__init__.py index 639e370c..893e2763 100644 --- a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/__init__.py +++ b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/__init__.py @@ -1,2 +1,2 @@ from .shard import get_gpt_dolomite_tp_state_dict -from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts +from .unshard import fix_gpt_dolomite_unsharded_state_dict, unshard_gpt_dolomite_tensor_parallel_state_dicts diff --git a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/unshard.py b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/unshard.py index 9edc651c..af35de32 100644 --- a/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/unshard.py +++ b/dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/unshard.py @@ -6,7 +6,7 @@ from ...gpt_dolomite import GPTDolomiteConfig -def unshard_tensor_parallel_state_dicts( +def unshard_gpt_dolomite_tensor_parallel_state_dicts( config: GPTDolomiteConfig, tensor_parallel_state_dicts: list[dict], tensor_parallel_word_embeddings: bool, @@ -106,7 +106,7 @@ def unshard_tensor_parallel_state_dicts( return output_state_dict -def fix_unsharded_state_dict( +def fix_gpt_dolomite_unsharded_state_dict( config: GPTDolomiteConfig, state_dict: dict, tensor_parallel_size: int, prefix: str = "" ) -> dict: state_dict[prefix + "transformer.wte.weight"] = state_dict[prefix + "transformer.wte.weight"][ diff --git a/dolomite_engine/hf_models/models/moe_dolomite/base.py b/dolomite_engine/hf_models/models/moe_dolomite/base.py index 3062d555..7bdfd2f8 100644 --- a/dolomite_engine/hf_models/models/moe_dolomite/base.py +++ b/dolomite_engine/hf_models/models/moe_dolomite/base.py @@ -1,64 +1,12 @@ -import torch.nn as nn - -from ...enums import AttentionHeadType, PositionEmbeddingType -from ...mixins import BaseMoEModelMixin, PreTrainedModelMixin -from ...modeling_utils import ParameterizedEmbedding, get_normalization_function +from ...mixins import BaseMoEModelMixin, PreTrainedMoEModelMixin from .config import MoEDolomiteConfig from .layer import SparseMoEBlock -class MoEDolomitePreTrainedModel(PreTrainedModelMixin): +class MoEDolomitePreTrainedModel(PreTrainedMoEModelMixin): config_class = MoEDolomiteConfig layer_class = SparseMoEBlock _no_split_modules = ["SparseMoEBlock"] - def __init__(self, config: MoEDolomiteConfig, *args, **kwargs) -> None: - self.moe_implementation = kwargs.get("moe_implementation", "eager") - assert self.moe_implementation in ["eager", "scattermoe"] - - super().__init__(config, *args, **kwargs) - - -class MoEDolomiteModel(MoEDolomitePreTrainedModel, BaseMoEModelMixin): - def _init_model(self, config: MoEDolomiteConfig, **kwargs) -> None: - self.attention_head_type = AttentionHeadType(config.attention_head_type) - self.embed_dim = config.n_embd - self.num_heads = config.n_head - self.m_emb = config.m_emb - self.initializer_range = config.initializer_range - self.mask_value = None - - assert ( - self.embed_dim % self.num_heads == 0 - ), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})" - - self.head_dim = self.embed_dim // self.num_heads - - self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) - - self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) - self.h = nn.ModuleList( - [ - self.layer_class( - config, - normalization_implementation=self.normalization_implementation, - attention_implementation=self.attention_implementation, - use_padding_free_transformer=self._use_padding_free_transformer, - moe_implementation=self.moe_implementation, - layer_idx=i, - ) - for i in range(config.n_layer) - ] - ) - self.ln_f = get_normalization_function( - config.normalization_function, - self.embed_dim, - eps=config.layer_norm_epsilon, - normalization_implementation=self.normalization_implementation, - ) - - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) - self._setup_positional_encoding() - # Initialize weights and apply final processing - self.post_init() +class MoEDolomiteModel(MoEDolomitePreTrainedModel, BaseMoEModelMixin): ... diff --git a/dolomite_engine/hf_models/models/moe_dolomite/main.py b/dolomite_engine/hf_models/models/moe_dolomite/main.py index cf8a3d18..480c27bd 100644 --- a/dolomite_engine/hf_models/models/moe_dolomite/main.py +++ b/dolomite_engine/hf_models/models/moe_dolomite/main.py @@ -1,14 +1,6 @@ from ...mixins import CausalLMMoEModelMixin from .base import MoEDolomiteModel, MoEDolomitePreTrainedModel -from .config import MoEDolomiteConfig class MoEDolomiteForCausalLM(MoEDolomitePreTrainedModel, CausalLMMoEModelMixin): base_model_class = MoEDolomiteModel - - def __init__(self, config: MoEDolomiteConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - - self.router_aux_loss_coef = config.router_aux_loss_coef - self.num_experts = config.num_experts - self.num_experts_per_tok = config.num_experts_per_tok diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/__init__.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/__init__.py new file mode 100644 index 00000000..cea12f14 --- /dev/null +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/__init__.py @@ -0,0 +1,3 @@ +from .base import MoEDolomiteModel_TP +from .main import MoEDolomiteForCausalLM_TP +from .weights import fix_moe_dolomite_unsharded_state_dict, unshard_moe_dolomite_tensor_parallel_state_dicts diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/base.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/base.py new file mode 100644 index 00000000..4300cfee --- /dev/null +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/base.py @@ -0,0 +1,12 @@ +from ...mixins import BaseMoEModelMixin_TP, PreTrainedMoEModelMixin_TP +from ..moe_dolomite import MoEDolomiteConfig +from .layer import SparseMoEBlock_TP + + +class MoEDolomitePreTrainedModel_TP(PreTrainedMoEModelMixin_TP): + config_class = MoEDolomiteConfig + layer_class = SparseMoEBlock_TP + _no_split_modules = ["SparseMoEBlock_TP"] + + +class MoEDolomiteModel_TP(MoEDolomitePreTrainedModel_TP, BaseMoEModelMixin_TP): ... diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/layer.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/layer.py new file mode 100644 index 00000000..1f60b70b --- /dev/null +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/layer.py @@ -0,0 +1,59 @@ +import torch.nn as nn + +from ...enums import AttentionHeadType +from ...modeling_utils_TP import get_attention_module_TP, get_normalization_function_TP +from ..moe_dolomite import MoEDolomiteConfig +from ..moe_dolomite.layer import SparseMoEBlock +from .moe_TP.scatter import ScatterMoE_TP + + +class SparseMoEBlock_TP(SparseMoEBlock): + def __init__( + self, + config: MoEDolomiteConfig, + normalization_implementation: str, + attention_implementation: str, + use_padding_free_transformer: bool, + moe_implementation: str, + layer_idx: int | None = None, + sequence_parallel: bool = False, + ) -> None: + nn.Module.__init__(self) + + hidden_size = config.hidden_size + self.inner_dim = config.n_inner + self.attention_head_type = AttentionHeadType(config.attention_head_type) + self.layer_idx = layer_idx + self.m_residual = config.m_residual + + self.ln_1 = get_normalization_function_TP( + config.normalization_function, + hidden_size, + eps=config.layer_norm_epsilon, + normalization_implementation=normalization_implementation, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) + self.attn = get_attention_module_TP( + config, + True, + attention_implementation=attention_implementation, + use_padding_free_transformer=use_padding_free_transformer, + layer_idx=layer_idx, + sequence_parallel=sequence_parallel, + ) + self.ln_2 = get_normalization_function_TP( + config.normalization_function, + hidden_size, + eps=config.layer_norm_epsilon, + normalization_implementation=normalization_implementation, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) + + assert moe_implementation == "scattermoe", "TP for MoE is only implemented with scattermoe" + self.moe = ScatterMoE_TP( + config, + use_padding_free_transformer=use_padding_free_transformer, + layer_idx=layer_idx, + ) diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/main.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/main.py new file mode 100644 index 00000000..9be09d50 --- /dev/null +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/main.py @@ -0,0 +1,8 @@ +from ...mixins import CausalLMMoEModelMixin_TP +from .base import MoEDolomiteModel_TP, MoEDolomitePreTrainedModel_TP +from .weights import get_moe_dolomite_tp_state_dict + + +class MoEDolomiteForCausalLM_TP(MoEDolomitePreTrainedModel_TP, CausalLMMoEModelMixin_TP): + base_model_class = MoEDolomiteModel_TP + tensor_parallel_state_dict_function = get_moe_dolomite_tp_state_dict diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py index f9b97004..ce785e20 100644 --- a/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py @@ -23,7 +23,7 @@ if is_kernel_hyperdrive_available(): - from khd.scattermoe.triton_implementation import padded_block_indices, scattered_experts + from khd.kernels.scattermoe.triton_implementation import scattered_experts class ColumnParallelScatteredExperts(ParameterizedScatteredExperts, DTensorModule): @@ -61,7 +61,7 @@ def __init__( DTensor.from_local( self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), - placements=[Shard(1)], + placements=[Shard(0)], run_check=False, ) ) @@ -89,7 +89,7 @@ def forward( results = scattered_experts( inputs, - weight.permute(0, 2, 1), + weight.permute(1, 2, 0), k, sorted_expert_idxs, sorted_scattered_idxs, @@ -161,7 +161,7 @@ def forward( inputs = scattered_experts( inputs, - weight.permute(0, 2, 1), + weight.permute(1, 2, 0), k, sorted_expert_idxs, sorted_scattered_idxs, diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/__init__.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/__init__.py new file mode 100644 index 00000000..2b61b1b4 --- /dev/null +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/__init__.py @@ -0,0 +1,2 @@ +from .shard import get_moe_dolomite_tp_state_dict +from .unshard import fix_moe_dolomite_unsharded_state_dict, unshard_moe_dolomite_tensor_parallel_state_dicts diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/shard.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/shard.py new file mode 100644 index 00000000..2ebadaeb --- /dev/null +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/shard.py @@ -0,0 +1,115 @@ +import torch + +from .....utils import ProcessGroupManager, SafeTensorsWeightsManager +from ....enums import PositionEmbeddingType +from ....modeling_utils import is_glu +from ....utils import divide_if_divisible +from ...gpt_dolomite_TP.weights.shard import _get_attention_weights, _get_word_embedding_weights +from ...moe_dolomite import MoEDolomiteConfig + + +def get_moe_dolomite_tp_state_dict( + config: MoEDolomiteConfig, + safetensors_weights_manager: SafeTensorsWeightsManager, + tensor_parallel_word_embeddings: bool, +) -> dict: + # word embeddings + state_dict = _get_word_embedding_weights( + safetensors_weights_manager, + prefix="transformer.wte.", + vocab_size=config.vocab_size, + tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + ) + + # positional embeddings + if PositionEmbeddingType(config.position_embedding_type) == PositionEmbeddingType.learned_absolute: + state_dict.update( + _get_word_embedding_weights( + safetensors_weights_manager, + prefix="transformer.wpe.", + vocab_size=config.n_positions, + tensor_parallel_word_embeddings=False, + ) + ) + + for layer_idx in range(config.n_layer): + prefix = f"transformer.h.{layer_idx}." + + state_dict.update({prefix + "ln_1.weight": safetensors_weights_manager.get_tensor(prefix + "ln_1.weight")}) + if safetensors_weights_manager.has_tensor(prefix + "ln_1.bias"): + state_dict.update({prefix + "ln_1.bias": safetensors_weights_manager.get_tensor(prefix + "ln_1.bias")}) + + state_dict.update( + _get_attention_weights( + config=config, safetensors_weights_manager=safetensors_weights_manager, prefix=prefix + "attn." + ) + ) + + state_dict.update({prefix + "ln_2.weight": safetensors_weights_manager.get_tensor(prefix + "ln_2.weight")}) + if safetensors_weights_manager.has_tensor(prefix + "ln_2.bias"): + state_dict.update({prefix + "ln_2.bias": safetensors_weights_manager.get_tensor(prefix + "ln_2.bias")}) + + state_dict.update( + _get_moe_weights( + config=config, safetensors_weights_manager=safetensors_weights_manager, prefix=prefix + "moe." + ) + ) + + state_dict.update({"transformer.ln_f.weight": safetensors_weights_manager.get_tensor("transformer.ln_f.weight")}) + if safetensors_weights_manager.has_tensor("transformer.ln_f.bias"): + state_dict.update({"transformer.ln_f.bias": safetensors_weights_manager.get_tensor("transformer.ln_f.bias")}) + + if not config.tie_word_embeddings: + state_dict.update( + _get_word_embedding_weights( + safetensors_weights_manager=safetensors_weights_manager, + prefix="lm_head.", + vocab_size=config.vocab_size, + tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + ) + ) + + return state_dict + + +def _get_moe_weights( + config: MoEDolomiteConfig, + safetensors_weights_manager: SafeTensorsWeightsManager, + prefix: str, +) -> None: + # GLU is a special case and needs to be handled explicitely + state_dict = {prefix + "gate.weight": safetensors_weights_manager.get_tensor(prefix + "gate.weight")} + weight = safetensors_weights_manager.get_tensor(prefix + "c_fc.weight") + tp_rank = ProcessGroupManager.get_tensor_parallel_rank() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + if is_glu(config.activation_function): + # weight = safetensors_weights_manager.get_slice(prefix + "c_fc.weight") + shape = (config.n_inner * 2, config.num_experts, config.n_embd) + sharded_out_dim = divide_if_divisible( + shape[0], + tp_world_size * 2, + f"split dimension ({0}) is not divisible by 2 x tensor parallel world size (2 x {tp_world_size})", + ) + weight = weight.view(tp_world_size, sharded_out_dim, config.num_experts, config.n_embd) + # split weight tensors into gate and non-gate + weight_1 = weight[tp_rank] + weight_2 = weight[tp_world_size + tp_rank] + state_dict[prefix + "c_fc.weight"] = torch.cat([weight_1, weight_2], dim=1) + else: + shape = (config.n_inner, config.num_experts, config.n_embd) + sharded_out_dim = divide_if_divisible( + shape[0], + tp_world_size, + f"split dimension ({0}) is not divisible by tensor parallel world size ({tp_world_size})", + ) + weight = weight.view(tp_world_size, sharded_out_dim, config.num_experts, config.n_embd) + # split weight tensors into gate and non-gate + weight = weight[tp_rank] + state_dict[prefix + "c_fc.weight"] = weight + + weight = safetensors_weights_manager.get_tensor(prefix + "c_proj.weight") + sharded_in_dim = sharded_out_dim + weight = weight.view(config.n_embd, config.num_experts, tp_world_size, sharded_in_dim) + state_dict[prefix + "c_proj.weight"] = weight[:, :, tp_rank] + + return state_dict diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/unshard.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/unshard.py new file mode 100644 index 00000000..2f04641f --- /dev/null +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/weights/unshard.py @@ -0,0 +1,177 @@ +import torch +from tqdm import trange + +from ....enums import AttentionHeadType, PositionEmbeddingType +from ....modeling_utils import is_glu +from ...gpt_dolomite_TP.weights.unshard import ( + _concatenate_tensors_from_state_dicts, + _fix_attention_weights, + _get_attention, + _get_embeddings_or_lm_head, + _get_layernorm, + _get_once_from_state_dicts_with_check, +) +from ...moe_dolomite import MoEDolomiteConfig + + +def unshard_moe_dolomite_tensor_parallel_state_dicts( + config: MoEDolomiteConfig, + tensor_parallel_state_dicts: list[dict], + tensor_parallel_word_embeddings: bool, + prefix: str = "", + check_correctness: bool = True, +) -> dict: + attention_head_type = AttentionHeadType(config.attention_head_type) + position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + + # word embeddings + output_state_dict = _get_embeddings_or_lm_head( + tensor_parallel_state_dicts, + tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + prefix=prefix + "transformer.wte.weight", + vocab_size=config.vocab_size, + check_correctness=check_correctness, + ) + + # positional embeddings if using learned positional embeddings + if position_embedding_type == PositionEmbeddingType.learned_absolute: + output_state_dict.update( + _get_embeddings_or_lm_head( + tensor_parallel_state_dicts, + # TODO change this if we support tensor parallel position embeddings + tensor_parallel_word_embeddings=False, + prefix=prefix + "transformer.wpe.weight", + vocab_size=config.n_positions, + check_correctness=check_correctness, + ) + ) + + # layers + for layer_idx in trange(config.n_layer): + # first layernorm + output_state_dict.update( + _get_layernorm( + tensor_parallel_state_dicts, + prefix=prefix + f"transformer.h.{layer_idx}.ln_1.", + normalization_function=config.normalization_function, + check_correctness=check_correctness, + ) + ) + + # attention + output_state_dict.update( + _get_attention( + tensor_parallel_state_dicts, + attention_head_type=attention_head_type, + add_bias=config.add_bias, + prefix=prefix + f"transformer.h.{layer_idx}.attn.", + check_correctness=check_correctness, + ) + ) + + # second layernorm + output_state_dict.update( + _get_layernorm( + tensor_parallel_state_dicts, + prefix=prefix + f"transformer.h.{layer_idx}.ln_2.", + normalization_function=config.normalization_function, + check_correctness=check_correctness, + ) + ) + + # mlp + output_state_dict.update( + _get_moe( + tensor_parallel_state_dicts, + prefix=prefix + f"transformer.h.{layer_idx}.moe.", + config=config, + check_correctness=check_correctness, + ) + ) + + # final layernorm + output_state_dict.update( + _get_layernorm( + tensor_parallel_state_dicts, + prefix=prefix + f"transformer.ln_f.", + normalization_function=config.normalization_function, + check_correctness=check_correctness, + ) + ) + + if not config.tie_word_embeddings: + output_state_dict.update( + _get_embeddings_or_lm_head( + tensor_parallel_state_dicts, + tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + prefix=prefix + "lm_head.weight", + vocab_size=config.vocab_size, + check_correctness=check_correctness, + ) + ) + + return output_state_dict + + +def fix_moe_dolomite_unsharded_state_dict( + config: MoEDolomiteConfig, state_dict: dict, tensor_parallel_size: int, prefix: str = "" +) -> dict: + state_dict[prefix + "transformer.wte.weight"] = state_dict[prefix + "transformer.wte.weight"][ + : config.vocab_size, : + ] + state_dict = _fix_attention_weights(config, state_dict, prefix) + state_dict = _fix_moe_weights(config, state_dict, tensor_parallel_size, prefix) + return state_dict + + +def _concatenate_tensors_from_moe( + tensor_parallel_state_dicts: list[dict], + key: str, + dim: int, +) -> torch.Tensor: + tensor_list = [state_dict[key] for state_dict in tensor_parallel_state_dicts] + tensor = torch.cat(tensor_list, dim=dim) + return tensor + + +def _get_moe( + tensor_parallel_state_dicts: list[dict], config: MoEDolomiteConfig, prefix: str, check_correctness: bool +) -> dict: + assert not config.add_bias + + output = { + prefix + + "gate.weight": _get_once_from_state_dicts_with_check( + tensor_parallel_state_dicts, prefix + "gate.weight", True + ) + } + if is_glu(config.activation_function): + # per_rank_dim = config.n_inner // len(tensor_parallel_state_dicts) + weights = [state_dict[prefix + "c_fc.weight"].chunk(2, dim=0) for state_dict in tensor_parallel_state_dicts] + weights = (torch.cat([w[0] for w in weights], dim=0), torch.cat([w[1] for w in weights], dim=0)) + output[prefix + "c_fc.weight"] = torch.cat(weights, dim=0) + else: + output[prefix + "c_fc.weight"] = _concatenate_tensors_from_state_dicts( + tensor_parallel_state_dicts, prefix + "c_fc.weight", dim=0 + ) + + output[prefix + "c_proj.weight"] = _concatenate_tensors_from_moe( + tensor_parallel_state_dicts, prefix + "c_proj.weight", dim=2 + ) + return output + + +def _fix_moe_weights(config: MoEDolomiteConfig, state_dict: dict, tensor_parallel_size: int, prefix: str) -> dict: + assert not config.add_bias + + if is_glu(config.activation_function): + for layer_idx in range(config.n_layer): + key = f"{prefix}transformer.h.{layer_idx}.mlp.c_fc.weight" + weight = state_dict[key] + weight = weight.chunk(tensor_parallel_size, dim=0) + weight = [w.chunk(2, dim=0) for w in weight] + w0 = torch.cat([w[0] for w in weight]) + w1 = torch.cat([w[1] for w in weight]) + state_dict[key] = torch.cat([w0, w1], dim=0) + + return state_dict diff --git a/dolomite_engine/hf_models/models/rnn_dolomite/base.py b/dolomite_engine/hf_models/models/rnn_dolomite/base.py index 37e1fdf6..362a92ac 100644 --- a/dolomite_engine/hf_models/models/rnn_dolomite/base.py +++ b/dolomite_engine/hf_models/models/rnn_dolomite/base.py @@ -20,6 +20,8 @@ class RNNDolomitePreTrainedModel(PreTrainedModelMixin): config_class = RNNDolomiteConfig layer_class = RNNDolomiteBlock _no_split_modules = ["RNNDolomiteBlock"] + _supports_sdpa = False + _supports_flash_attn_2 = True def __init__(self, config: RNNDolomiteConfig, *args, **kwargs): super().__init__(config, *args, **kwargs) @@ -35,7 +37,7 @@ def _init_model(self, config: RNNDolomiteConfig, **kwargs) -> None: self.m_emb = config.m_emb self.initializer_range = config.initializer_range - self.attention_pattern = self.mapping_attention_pattern(config.attention_pattern) + self.attention_pattern = self.parse_attention_pattern(config.attention_pattern) self.head_dim = divide_if_divisible( self.embed_dim, @@ -71,13 +73,13 @@ def _init_model(self, config: RNNDolomiteConfig, **kwargs) -> None: # Initialize weights and apply final processing self.post_init() - def mapping_attention_pattern(self, attention_pattern: str) -> list[str]: + def parse_attention_pattern(self, attention_pattern: str) -> list[str]: attention_implementation_list = [] for pattern in attention_pattern: if pattern == "a": - attention_implementation_list.append(self.attention_implementation) + attention_implementation_list.append("flash_attention_2") elif pattern == "d": - attention_implementation_list.append("DeltaNet") + attention_implementation_list.append("deltanet") else: raise ValueError(f"Attention pattern {pattern} not supported") return attention_implementation_list diff --git a/dolomite_engine/hf_models/models/rnn_dolomite/layer.py b/dolomite_engine/hf_models/models/rnn_dolomite/layer.py index a6c4414a..c928d2d8 100644 --- a/dolomite_engine/hf_models/models/rnn_dolomite/layer.py +++ b/dolomite_engine/hf_models/models/rnn_dolomite/layer.py @@ -38,10 +38,10 @@ def __init__( normalization_implementation=normalization_implementation, ) - if attention_pattern == "DeltaNet": - self.attn = DeltaNet(config=config, layer_idx=layer_idx) - elif attention_pattern == "flash_attention_2": + if attention_pattern == "flash_attention_2": self.attn = RNNFlashAttention2(config, True, layer_idx) + elif attention_pattern == "deltanet": + self.attn = DeltaNet(config=config, layer_idx=layer_idx) else: raise ValueError(f"Attention pattern {attention_pattern} not supported.") diff --git a/dolomite_engine/hf_models/register_hf.py b/dolomite_engine/hf_models/register_hf.py index c84bcbcf..9d930fb0 100644 --- a/dolomite_engine/hf_models/register_hf.py +++ b/dolomite_engine/hf_models/register_hf.py @@ -10,6 +10,7 @@ GPTDolomiteModel, MoEDolomiteConfig, MoEDolomiteForCausalLM, + MoEDolomiteForCausalLM_TP, MoEDolomiteModel, RNNDolomiteConfig, RNNDolomiteForCausalLM, @@ -44,15 +45,14 @@ def is_custom_model(model_class: type[AutoModelForCausalLM] | type[AutoModelForS return model_class.__name__ in _CUSTOM_MODEL_CLASSES or model_type in _CUSTOM_MODEL_TYPES -def is_tensor_parallel_compatible_model( - model_class: type[AutoModelForCausalLM] | type[AutoModelForSeq2SeqLM], model_type: str -) -> bool: - return model_class.__name__ == "GPTDolomiteForCausalLM" or model_type == "gpt_dolomite" - - -_TENSOR_PARALLEL_CLASS_MAPPING = {"gpt_dolomite": GPTDolomiteForCausalLM_TP} +_TENSOR_PARALLEL_CLASS_MAPPING = { + GPTDolomiteConfig.model_type: GPTDolomiteForCausalLM_TP, + MoEDolomiteConfig.model_type: MoEDolomiteForCausalLM_TP, +} def get_tensor_parallel_class(model_type: str) -> AutoModelForCausalLM: - assert is_tensor_parallel_compatible_model(AutoModelForCausalLM, model_type) - return _TENSOR_PARALLEL_CLASS_MAPPING[model_type] + if model_type in _TENSOR_PARALLEL_CLASS_MAPPING: + return _TENSOR_PARALLEL_CLASS_MAPPING[model_type] + + raise ValueError(f"tensor parallel is not supported with `model_type` ({model_type})") diff --git a/dolomite_engine/hf_models/unshard.py b/dolomite_engine/hf_models/unshard.py new file mode 100644 index 00000000..f85ad426 --- /dev/null +++ b/dolomite_engine/hf_models/unshard.py @@ -0,0 +1,51 @@ +from .config import CommonConfig +from .models import ( + GPTDolomiteConfig, + MoEDolomiteConfig, + fix_gpt_dolomite_unsharded_state_dict, + fix_moe_dolomite_unsharded_state_dict, + unshard_gpt_dolomite_tensor_parallel_state_dicts, + unshard_moe_dolomite_tensor_parallel_state_dicts, +) + + +_UNSHARD_STATE_DICT_FUNCTIONS = { + GPTDolomiteConfig.model_type: unshard_gpt_dolomite_tensor_parallel_state_dicts, + MoEDolomiteConfig.model_type: unshard_moe_dolomite_tensor_parallel_state_dicts, +} + + +def unshard_tensor_parallel_state_dicts( + config: MoEDolomiteConfig, + tensor_parallel_state_dicts: list[dict], + tensor_parallel_word_embeddings: bool, + prefix: str = "", + check_correctness: bool = True, +) -> dict: + if config.model_type in _UNSHARD_STATE_DICT_FUNCTIONS: + return _UNSHARD_STATE_DICT_FUNCTIONS[config.model_type]( + config=config, + tensor_parallel_state_dicts=tensor_parallel_state_dicts, + tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + prefix=prefix, + check_correctness=check_correctness, + ) + + raise ValueError(f"unsupported `model_type` ({config.model_type})") + + +_FIX_UNSHARDED_STATE_DICT_FUNCTIONS = { + GPTDolomiteConfig.model_type: fix_gpt_dolomite_unsharded_state_dict, + MoEDolomiteConfig.model_type: fix_moe_dolomite_unsharded_state_dict, +} + + +def fix_unsharded_state_dict( + config: CommonConfig, state_dict: dict, tensor_parallel_size: int, prefix: str = "" +) -> dict: + if config.model_type in _FIX_UNSHARDED_STATE_DICT_FUNCTIONS: + return _FIX_UNSHARDED_STATE_DICT_FUNCTIONS[config.model_type]( + config=config, state_dict=state_dict, tensor_parallel_size=tensor_parallel_size, prefix=prefix + ) + + raise ValueError(f"unsupported `model_type` ({config.model_type})") diff --git a/dolomite_engine/model_wrapper/base.py b/dolomite_engine/model_wrapper/base.py index 9e4bf834..7df6b9f5 100644 --- a/dolomite_engine/model_wrapper/base.py +++ b/dolomite_engine/model_wrapper/base.py @@ -6,7 +6,7 @@ from transformers.integrations import HfDeepSpeedConfig from ..enums import AttentionImplementation, DistributedBackend, Mode, MoEImplementation -from ..hf_models import get_tensor_parallel_class, is_custom_model, is_tensor_parallel_compatible_model +from ..hf_models import get_tensor_parallel_class, is_custom_model from ..utils import ProcessGroupManager, SafeTensorsWeightsManager, log_rank_0, string_to_torch_dtype @@ -78,10 +78,6 @@ def __init__( if self.tp_world_size > 1: self.model_class = get_tensor_parallel_class(self.config.model_type) - assert is_tensor_parallel_compatible_model( - self.model_class, self.config.model_type - ), "tensor parallel is not supported with this model" - if self.use_padding_free_transformer: assert is_custom_model( self.model_class, self.config.model_type diff --git a/tests/hf_models/multi_gpu/dcp/dcp_test.py b/tests/hf_models/multi_gpu/dcp/dcp_test.py index cd0f125c..febee66c 100644 --- a/tests/hf_models/multi_gpu/dcp/dcp_test.py +++ b/tests/hf_models/multi_gpu/dcp/dcp_test.py @@ -15,7 +15,11 @@ class DCPTest(TestCommons): TestCommons.make_args_matrix( TestCommons.get_attention_head_types(), ["gelu", "geglu"], [False, True], [(3, 2, 2), (3, 1, 4), (0, 4, 1)] ) + + TestCommons.make_args_matrix( + [AttentionHeadType.gqa], ["gelu", "geglu"], [False], [(3, 2, 2), (3, 1, 4), (0, 4, 1)] + ) ) + @TestCommons.slow_test def test_dcp( self, attention_head_type: AttentionHeadType, diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_test.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_test.py index 6df36c5f..68af49eb 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_test.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_test.py @@ -22,6 +22,7 @@ class TensorParallelTest(TestCommons): [False, True], ) ) + @TestCommons.slow_test def test_tensor_parallel_forward( self, attention_head_type: AttentionHeadType, diff --git a/tests/hf_models/multi_gpu/unsharding/unsharding.py b/tests/hf_models/multi_gpu/unsharding/unsharding.py index 616e2843..aff2c662 100644 --- a/tests/hf_models/multi_gpu/unsharding/unsharding.py +++ b/tests/hf_models/multi_gpu/unsharding/unsharding.py @@ -3,9 +3,16 @@ import torch import torch.distributed - -from dolomite_engine.hf_models import AttentionHeadType, GPTDolomiteConfig, GPTDolomiteForCausalLM_TP -from dolomite_engine.hf_models.models.gpt_dolomite_TP import fix_unsharded_state_dict +from torch.distributed._tensor.api import DTensor + +from dolomite_engine.hf_models import ( + AttentionHeadType, + GPTDolomiteConfig, + MoEDolomiteConfig, + fix_unsharded_state_dict, + get_tensor_parallel_class, + unshard_tensor_parallel_state_dicts, +) from dolomite_engine.utils import ProcessGroupManager from ...test_common import TestCommons @@ -14,6 +21,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--attention-head-type", type=str) parser.add_argument("--activation-function", type=str) +parser.add_argument("--model-type", type=str) parser.add_argument("--tensor-parallel-word-embeddings", action="store_true") parser.add_argument("--tmp-path", type=str) args = parser.parse_args() @@ -28,15 +36,30 @@ if AttentionHeadType(args.attention_head_type) == AttentionHeadType.gqa: num_key_value_heads = 8 -config = GPTDolomiteConfig( - attention_head_type=args.attention_head_type, - n_layer=1, - position_embedding_type="learned_absolute", - num_key_value_heads=num_key_value_heads, - add_bias=False, - n_embd=128, - n_head=16, -) +kwargs = {} + +if args.model_type == GPTDolomiteConfig.model_type: + config = GPTDolomiteConfig( + attention_head_type=args.attention_head_type, + n_layer=1, + position_embedding_type="learned_absolute", + num_key_value_heads=num_key_value_heads, + add_bias=False, + n_embd=128, + n_head=16, + ) +elif args.model_type == MoEDolomiteConfig.model_type: + config = MoEDolomiteConfig( + attention_head_type=args.attention_head_type, + n_layer=1, + position_embedding_type="learned_absolute", + num_key_value_heads=num_key_value_heads, + add_bias=False, + n_embd=128, + n_head=16, + ) + kwargs["moe_implementation"] = "scattermoe" + if tp_rank == 0: model = TestCommons.from_config(None, config) @@ -44,21 +67,56 @@ torch.distributed.barrier() -model_tp = GPTDolomiteForCausalLM_TP.from_pretrained( - args.tmp_path, tensor_parallel_word_embeddings=args.tensor_parallel_word_embeddings +model_tp = get_tensor_parallel_class(args.model_type).from_pretrained( + args.tmp_path, tensor_parallel_word_embeddings=args.tensor_parallel_word_embeddings, **kwargs ) tp_state_dict = model_tp.state_dict() -tp_state_dict = {key: value.to("cpu").full_tensor() for key, value in tp_state_dict.items()} -tp_state_dict = fix_unsharded_state_dict(config, tp_state_dict, ProcessGroupManager.get_tensor_parallel_world_size()) -torch.distributed.barrier() -if tp_rank == 0: - original_state_dict = model.state_dict() +def run_check(fix: bool): + cpu_state_dict = {key: value.to("cpu") for key, value in tp_state_dict.items()} + + if fix: + tp_state_dict_unsharded = { + key: value.full_tensor() if isinstance(value, DTensor) else value for key, value in cpu_state_dict.items() + } + tp_state_dict_unsharded = fix_unsharded_state_dict( + config, tp_state_dict_unsharded, ProcessGroupManager.get_tensor_parallel_world_size() + ) + else: + cpu_state_dict = { + key: value.to_local() if isinstance(value, DTensor) else value for key, value in cpu_state_dict.items() + } + torch.save( + cpu_state_dict, os.path.join(args.tmp_path, f"tp-{ProcessGroupManager.get_tensor_parallel_rank()}.pt") + ) + del cpu_state_dict + + torch.distributed.barrier() + + tensor_parallel_state_dicts = [ + torch.load(os.path.join(args.tmp_path, f"tp-{i}.pt")) + for i in range(ProcessGroupManager.get_tensor_parallel_world_size()) + ] + + tp_state_dict_unsharded = unshard_tensor_parallel_state_dicts( + config, + tensor_parallel_state_dicts=tensor_parallel_state_dicts, + tensor_parallel_word_embeddings=args.tensor_parallel_word_embeddings, + ) + + torch.distributed.barrier() + + if tp_rank == 0: + original_state_dict = model.state_dict() + + assert tp_state_dict_unsharded.keys() == original_state_dict.keys() + for key in original_state_dict: + assert original_state_dict[key].equal(tp_state_dict_unsharded[key]) + - assert tp_state_dict.keys() == original_state_dict.keys() - for key in original_state_dict: - assert original_state_dict[key].equal(tp_state_dict[key]) +run_check(True) +run_check(False) ProcessGroupManager.destroy_process_groups() diff --git a/tests/hf_models/multi_gpu/unsharding/unsharding_test.py b/tests/hf_models/multi_gpu/unsharding/unsharding_test.py index 4ce560d4..ab6fa6a0 100644 --- a/tests/hf_models/multi_gpu/unsharding/unsharding_test.py +++ b/tests/hf_models/multi_gpu/unsharding/unsharding_test.py @@ -5,17 +5,33 @@ import torch.distributed from parameterized import parameterized -from dolomite_engine.hf_models import AttentionHeadType +from dolomite_engine.hf_models import AttentionHeadType, GPTDolomiteConfig, MoEDolomiteConfig from ...test_common import TestCommons class UnshardingTest(TestCommons): @parameterized.expand( - TestCommons.make_args_matrix(TestCommons.get_attention_head_types(), ["gelu", "geglu"], [False, True]) + TestCommons.make_args_matrix( + TestCommons.get_attention_head_types(), + ["gelu", "geglu"], + [False, True], + [GPTDolomiteConfig.model_type], + ) + + TestCommons.make_args_matrix( + [AttentionHeadType.gqa], + ["gelu", "geglu"], + [False], + [MoEDolomiteConfig.model_type], + ) ) + @TestCommons.slow_test def test_unsharding( - self, attention_head_type: AttentionHeadType, activation_function: str, tensor_parallel_word_embeddings: bool + self, + attention_head_type: AttentionHeadType, + activation_function: str, + tensor_parallel_word_embeddings: bool, + model_type: str, ) -> None: self.skip_test_if_device_unavailable(torch.device("cuda")) @@ -34,6 +50,8 @@ def test_unsharding( activation_function, "--tmp-path", tmp_path, + "--model-type", + model_type, ] if tensor_parallel_word_embeddings: diff --git a/tests/hf_models/test_common.py b/tests/hf_models/test_common.py index 299173e1..917e911d 100644 --- a/tests/hf_models/test_common.py +++ b/tests/hf_models/test_common.py @@ -2,8 +2,8 @@ import os import tempfile from itertools import product -from typing import Any -from unittest import TestCase +from typing import Any, Callable +from unittest import TestCase, skipUnless import torch from torch.testing import assert_close @@ -21,6 +21,9 @@ from dolomite_engine.hf_models.config import CommonConfig +_RUN_SLOW = True if os.getenv("RUN_SLOW", "False").lower() in ["1", "true"] else False + + class TestCommons(TestCase): @staticmethod def get_all_devices() -> list[torch.device]: @@ -292,3 +295,7 @@ def assert_equal_tensors( assert_close(x, y, rtol=rtol_bfloat16, atol=atol_bfloat16) else: raise ValueError(f"unexpected dtype ({dtype})") + + @staticmethod + def slow_test(func: Callable) -> Callable: + return skipUnless(_RUN_SLOW, "skipping slow test since RUN_SLOW=True is not set in the environment")(func) diff --git a/tests/training/params_group/params_group_test.py b/tests/training/params_group/params_group_test.py index 3f4b766f..b4f7dfe6 100644 --- a/tests/training/params_group/params_group_test.py +++ b/tests/training/params_group/params_group_test.py @@ -26,6 +26,9 @@ def test_mup_group(self, config_filename: str, expected_groups_filename: str) -> ) if "rnn_dolomite" in config_filename: + if not torch.cuda.is_available(): + self.skipTest("skipping test because CUDA is unavailable") + try: with ( torch.device("meta"), @@ -63,6 +66,9 @@ def test_normal_group(self, config_filename: str, expected_groups_filename: str) ) if "rnn_dolomite" in config_filename: + if not torch.cuda.is_available(): + self.skipTest("skipping test because CUDA is unavailable") + try: with ( torch.device("meta"), diff --git a/tests/training/params_group/training_configs/rnn_dolomite_config.yml b/tests/training/params_group/training_configs/rnn_dolomite_config.yml index 4b1df615..f9bb6249 100644 --- a/tests/training/params_group/training_configs/rnn_dolomite_config.yml +++ b/tests/training/params_group/training_configs/rnn_dolomite_config.yml @@ -50,6 +50,7 @@ model_args: init_method: mup tie_word_embeddings: true upcast_logits_for_loss: true + attention_implementation: flash_attention_2 tuning_args: tuning_method: pretraining