Skip to content

Commit

Permalink
pp
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Oct 25, 2024
1 parent 917c3c6 commit 15aa971
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 30 deletions.
1 change: 0 additions & 1 deletion dolomite_engine/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def _param_init(module: nn.Module) -> None:
if ProcessGroupManager.get_pipeline_parallel_world_size() > 1:
micro_batch_size = args.training_parameters.micro_batch_size
sequence_length = args.datasets[0].class_args.get("sequence_length")
args.model_args.pretrained_config.get("n_embd")

for model in model_container:
intermediate_dtype = string_to_torch_dtype(args.mixed_precision_args.dtype)
Expand Down
6 changes: 4 additions & 2 deletions dolomite_engine/hf_models/mixins/dense_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class CausalLMModelMixin_TP(PreTrainedModelMixin_TP, CausalLMModelMixin):
tensor_parallel_state_dict_function = None
model_parallel_state_dict_function = None

def _init_model(self, config: CommonConfig, **kwargs) -> None:
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -199,10 +199,12 @@ def load_from_safetensors_weights_manager(self, safetensors_weights_manager: Saf
elif position_embedding_type == PositionEmbeddingType.rope:
self.transformer.rope.reset_parameters()

state_dict = self.__class__.tensor_parallel_state_dict_function(
state_dict = self.__class__.model_parallel_state_dict_function(
config=self.config,
safetensors_weights_manager=safetensors_weights_manager,
tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings,
num_pipeline_stages=self.num_pipeline_stages,
pipeline_stage_id=self.pipeline_stage_id,
)

self.load_state_dict(state_dict)
Expand Down
4 changes: 2 additions & 2 deletions dolomite_engine/hf_models/models/gpt_dolomite_TP/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ...mixins import CausalLMModelMixin_TP
from .base import GPTDolomiteModel_TP, GPTDolomitePreTrainedModel_TP
from .weights import get_gpt_dolomite_tensor_parallel_state_dict
from .weights import get_gpt_dolomite_model_parallel_state_dict


class GPTDolomiteForCausalLM_TP(GPTDolomitePreTrainedModel_TP, CausalLMModelMixin_TP):
base_model_class = GPTDolomiteModel_TP
tensor_parallel_state_dict_function = get_gpt_dolomite_tensor_parallel_state_dict
model_parallel_state_dict_function = get_gpt_dolomite_model_parallel_state_dict
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .shard import get_gpt_dolomite_tensor_parallel_state_dict
from .shard import get_gpt_dolomite_model_parallel_state_dict
from .unshard import fix_gpt_dolomite_unsharded_state_dict, unshard_gpt_dolomite_tensor_parallel_state_dicts
64 changes: 41 additions & 23 deletions dolomite_engine/hf_models/models/gpt_dolomite_TP/weights/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,48 @@
from ...gpt_dolomite import GPTDolomiteConfig


def get_gpt_dolomite_tensor_parallel_state_dict(
def get_gpt_dolomite_model_parallel_state_dict(
config: GPTDolomiteConfig,
safetensors_weights_manager: SafeTensorsWeightsManager,
tensor_parallel_word_embeddings: bool,
num_pipeline_stages: int,
pipeline_stage_id: int,
) -> dict:
# word embeddings
state_dict = _get_embeddings_or_lm_head(
safetensors_weights_manager,
prefix="transformer.wte.",
vocab_size=config.vocab_size,
tensor_parallel_word_embeddings=tensor_parallel_word_embeddings,
is_first_pipeline_stage = pipeline_stage_id == 0
is_last_pipeline_stage = pipeline_stage_id == num_pipeline_stages - 1

layers_per_stage = divide_if_divisible(
config.n_layer, num_pipeline_stages, "layers should be divisible by num_pipeline_stages"
)

# positional embeddings
if PositionEmbeddingType(config.position_embedding_type) == PositionEmbeddingType.learned_absolute:
layer_start_id = layers_per_stage * pipeline_stage_id
layer_end_id = layers_per_stage * (pipeline_stage_id + 1)

state_dict = {}

if is_first_pipeline_stage:
# word embeddings
state_dict.update(
_get_embeddings_or_lm_head(
safetensors_weights_manager,
prefix="transformer.wpe.",
vocab_size=config.n_positions,
tensor_parallel_word_embeddings=False,
prefix="transformer.wte.",
vocab_size=config.vocab_size,
tensor_parallel_word_embeddings=tensor_parallel_word_embeddings,
)
)

for layer_idx in range(config.n_layer):
# positional embeddings
if PositionEmbeddingType(config.position_embedding_type) == PositionEmbeddingType.learned_absolute:
state_dict.update(
_get_embeddings_or_lm_head(
safetensors_weights_manager,
prefix="transformer.wpe.",
vocab_size=config.n_positions,
tensor_parallel_word_embeddings=False,
)
)

for layer_idx in range(layer_start_id, layer_end_id):
prefix = f"transformer.h.{layer_idx}."

state_dict.update(_get_layernorm(safetensors_weights_manager, prefix=prefix + "ln_1."))
Expand All @@ -58,17 +75,18 @@ def get_gpt_dolomite_tensor_parallel_state_dict(
)
)

state_dict.update(_get_layernorm(safetensors_weights_manager, prefix="transformer.ln_f."))

if not config.tie_word_embeddings:
state_dict.update(
_get_embeddings_or_lm_head(
safetensors_weights_manager=safetensors_weights_manager,
prefix="lm_head.",
vocab_size=config.vocab_size,
tensor_parallel_word_embeddings=tensor_parallel_word_embeddings,
if is_last_pipeline_stage:
state_dict.update(_get_layernorm(safetensors_weights_manager, prefix="transformer.ln_f."))

if not config.tie_word_embeddings:
state_dict.update(
_get_embeddings_or_lm_head(
safetensors_weights_manager=safetensors_weights_manager,
prefix="lm_head.",
vocab_size=config.vocab_size,
tensor_parallel_word_embeddings=tensor_parallel_word_embeddings,
)
)
)

return state_dict

Expand Down
2 changes: 1 addition & 1 deletion dolomite_engine/hf_models/models/moe_dolomite_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

class MoEDolomiteForCausalLM_TP(MoEDolomitePreTrainedModel_TP, CausalLMMoEModelMixin_TP):
base_model_class = MoEDolomiteModel_TP
tensor_parallel_state_dict_function = get_moe_dolomite_tensor_parallel_state_dict
model_parallel_state_dict_function = get_moe_dolomite_tensor_parallel_state_dict

0 comments on commit 15aa971

Please sign in to comment.