Skip to content

Commit

Permalink
use DTensor API for embedding matrix (#23)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 authored Sep 25, 2024
1 parent 0103bfb commit 95e54d3
Showing 1 changed file with 3 additions and 24 deletions.
27 changes: 3 additions & 24 deletions dolomite_engine/hf_models/modeling_utils_TP/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
from torch.distributed._tensor.placement_types import Replicate, Shard

from ...utils import ProcessGroupManager
from ..modeling_utils import ParameterizedEmbedding
Expand Down Expand Up @@ -50,31 +49,11 @@ def __init__(
self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.tensor_parallel_word_embeddings:
input_mask = (input < self.vocab_start_index) | (input >= self.vocab_end_index)
input = input - self.vocab_start_index
input[input_mask] = 0

input = F.embedding(input, self.weight.to_local())

input[input_mask, :] = 0
input = tensor_to_dtensor(input, current_placement=Partial())
else:
input = F.embedding(input, self.weight.to_local())
input = tensor_to_dtensor(input, current_placement=Replicate())

input = tensor_to_dtensor(input, current_placement=Replicate())
input = super().forward(input)
input = dtensor_to_tensor(input, desired_placement=self.output_placement)

return input

# FIXME sadly this code is not working when we have 2 embedding matrices (absolute embeddings)
# my guess is that PyTorch is saving the mask globaly and wpe sees the mask of wte
# def forward(self, input: torch.Tensor) -> torch.Tensor:
# input = tensor_to_dtensor(input, current_placement=Replicate())
# input = super().forward(input)
# input = dtensor_to_tensor(input, desired_placement=self.output_placement)
# return input


def get_tensor_parallel_vocab_info(vocab_size: int, make_vocab_size_divisible_by: int = 64) -> tuple[int, int, int]:
tp_rank = ProcessGroupManager.get_tensor_parallel_rank()
Expand Down

0 comments on commit 95e54d3

Please sign in to comment.