Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Profiler] Add group_info output #206

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion megatron/megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

"""Model and data parallel groups."""


import json
import pickle
import socket

import os
import warnings
from datetime import timedelta
Expand Down Expand Up @@ -269,7 +274,10 @@ def __init__(
}
self.order = order
order = order.lower()


self.parallelism_to_groups = {}
self.rank_to_parallelism_to_group_id = {}

if 'ep' in order:
if 'ep-dp' not in order and 'dp-ep' not in order:
raise RuntimeError(f"The ep and dp must be adjacent in order ({self.order}).")
Expand Down Expand Up @@ -334,8 +342,35 @@ def get_ranks(self, token, independent_ep=False):
for rank_group in ranks:
for i in range(len(rank_group)):
rank_group[i] += self.rank_offset

self.parallelism_to_groups[token] = ranks
group_id = 0
for group in ranks:
for rank in group:
if rank not in self.rank_to_parallelism_to_group_id:
self.rank_to_parallelism_to_group_id[rank] = {}
self.rank_to_parallelism_to_group_id[rank][token] = group_id
group_id = group_id + 1

return ranks

def print_ranks(self, print_path, generator_type):
if print_path == "stdout":
print(generator_type + ": parallelism_to_groups", self.parallelism_to_groups)
print(generator_type + ": rank_to_parallelism_to_group_id", self.rank_to_parallelism_to_group_id)
elif print_path != "stdout":
print_path = print_path + "/" + generator_type
try:
os.makedirs(print_path, exist_ok=True)
except OSError as e:
raise RuntimeError(f"Failed to create path '{print_path}'. Error: {e}")
parallelism_to_groups_file = print_path + "/parallelism_to_groups.json"
with open(parallelism_to_groups_file, 'w') as file:
json.dump(self.parallelism_to_groups, file, ensure_ascii=False, indent=4)
rank_to_parallelism_to_group_id_file = print_path + "/rank_to_parallelism_to_group_id.json"
with open(rank_to_parallelism_to_group_id_file, 'w') as file:
json.dump(self.rank_to_parallelism_to_group_id, file, ensure_ascii=False, indent=4)


def default_embedding_ranks(pp_ranks, split_rank=None):
"""Return the default ranks that constitute the stages on which the word embeddings live.
Expand All @@ -361,6 +396,59 @@ def default_position_embedding_ranks(pp_ranks, split_rank=None):
return [pp_ranks[0]]


def print_ranks(print_path:str):
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()

world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()

rank_to_host_and_device = {}
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
device_id = None
device_name = None
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(device_id)
rank_to_host_and_device[rank] = {"host_ip": host_ip, "host_name:": host_name, "device_id": device_id, "device_name": device_name}

serialized_data = pickle.dumps(rank_to_host_and_device)
serialized_tensor = torch.ByteTensor(list(serialized_data)).cuda()

local_length = torch.tensor([len(serialized_tensor)], dtype=torch.int).cuda()
all_lengths = [torch.tensor([0], dtype=torch.int).cuda() for _ in range(world_size)]
torch.distributed.all_gather(all_lengths, local_length)

max_length = max(length.item() for length in all_lengths)
padded_tensor = torch.zeros(max_length, dtype=torch.uint8).cuda()
padded_tensor[:local_length.item()] = serialized_tensor

gathered_tensors = [torch.zeros(max_length, dtype=torch.uint8).cuda() for _ in range(world_size)]
torch.distributed.all_gather(gathered_tensors, padded_tensor)

gathered_tensors = [torch.zeros_like(serialized_tensor) for _ in range(world_size)]
torch.distributed.all_gather(gathered_tensors, serialized_tensor)
if rank == 0 and print_path != None:
rank_to_host_and_device = {}
for i, tensor in enumerate(gathered_tensors):
serialized_data = tensor[:all_lengths[i].item()].tolist()
rank_dict = pickle.loads(bytes(serialized_data))
rank_to_host_and_device.update(rank_dict)
if print_path == "stdout":
print("rank_to_host_and_device", rank_to_host_and_device)
elif print_path != "stdout":
try:
os.makedirs(print_path, exist_ok=True)
except OSError as e:
raise RuntimeError(f"Failed to create path '{print_path}'. Error: {e}")
rank_to_host_and_device_file = print_path + "/rank_to_host_and_device.json"
with open(rank_to_host_and_device_file, 'w') as file:
json.dump(rank_to_host_and_device, file, ensure_ascii=False, indent=4)
else:
print("rank_to_host_and_device", rank_to_host_and_device)


def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
Expand All @@ -377,6 +465,7 @@ def initialize_model_parallel(
encoder_pipeline_model_parallel_size: Optional[int] = 0,
get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
analyze_save_path:str = "stdout",
) -> None:
"""Initialize model data parallel groups.

Expand Down Expand Up @@ -933,6 +1022,12 @@ def generator_wrapper(group_type, **kwargs):
# we could stick it there
_set_global_memory_buffer()

if encoder_rank_generator:
encoder_rank_generator.print_ranks(analyze_save_path, "encoder")
if decoder_rank_generator:
decoder_rank_generator.print_ranks(analyze_save_path, "decoder")
print_ranks(analyze_save_path)


def is_initialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
Expand Down
10 changes: 10 additions & 0 deletions megatron/megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_customized_device_args(parser)
parser = _add_hetero_args(parser)
parser = _add_auto_tuner_args(parser)
parser = _add_analyze_args(parser)

# Custom arguments.
if extra_args_provider is not None:
Expand Down Expand Up @@ -2225,3 +2226,12 @@ def _add_auto_tuner_args(parser):
help='use auto tuner')

return parser


def _add_analyze_args(parser):
group = parser.add_argument_group(title="analyze")

group.add_argument('--analyze-save-dir', type=str, default=None,
help='The dir used to save analysis information. This path will include grouping information files and other files.')

return parser
1 change: 1 addition & 0 deletions megatron/megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
encoder_pipeline_model_parallel_size=args.encoder_pipeline_model_parallel_size,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
analyze_save_path=args.analyze_save_dir if args.analyze_save_dir else "stdout",
)
if args.rank == 0:
print(
Expand Down
Loading