From 95e54d36fe558d37f5718e98b76d78e8bd9f3ec0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra <32954280+mayank31398@users.noreply.github.com> Date: Wed, 25 Sep 2024 01:12:48 -0400 Subject: [PATCH] use DTensor API for embedding matrix (#23) Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils_TP/embedding.py | 27 +++---------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/dolomite_engine/hf_models/modeling_utils_TP/embedding.py b/dolomite_engine/hf_models/modeling_utils_TP/embedding.py index 8cb85ee1..b57fe29b 100644 --- a/dolomite_engine/hf_models/modeling_utils_TP/embedding.py +++ b/dolomite_engine/hf_models/modeling_utils_TP/embedding.py @@ -2,9 +2,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.placement_types import Partial, Replicate, Shard +from torch.distributed._tensor.placement_types import Replicate, Shard from ...utils import ProcessGroupManager from ..modeling_utils import ParameterizedEmbedding @@ -50,31 +49,11 @@ def __init__( self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.tensor_parallel_word_embeddings: - input_mask = (input < self.vocab_start_index) | (input >= self.vocab_end_index) - input = input - self.vocab_start_index - input[input_mask] = 0 - - input = F.embedding(input, self.weight.to_local()) - - input[input_mask, :] = 0 - input = tensor_to_dtensor(input, current_placement=Partial()) - else: - input = F.embedding(input, self.weight.to_local()) - input = tensor_to_dtensor(input, current_placement=Replicate()) - + input = tensor_to_dtensor(input, current_placement=Replicate()) + input = super().forward(input) input = dtensor_to_tensor(input, desired_placement=self.output_placement) - return input - # FIXME sadly this code is not working when we have 2 embedding matrices (absolute embeddings) - # my guess is that PyTorch is saving the mask globaly and wpe sees the mask of wte - # def forward(self, input: torch.Tensor) -> torch.Tensor: - # input = tensor_to_dtensor(input, current_placement=Replicate()) - # input = super().forward(input) - # input = dtensor_to_tensor(input, desired_placement=self.output_placement) - # return input - def get_tensor_parallel_vocab_info(vocab_size: int, make_vocab_size_divisible_by: int = 64) -> tuple[int, int, int]: tp_rank = ProcessGroupManager.get_tensor_parallel_rank()