Skip to content

Commit

Permalink
Add util function to recursively got device info from model (#1870)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1870

Add this util function so we could get related device info from model for debugging purpose

Reviewed By: IvanKobzarev

Differential Revision: D55982349

fbshipit-source-id: 9cb84126e48523c5b0fa67e55d41c6e96bd813b5
  • Loading branch information
gnahzg authored and facebook-github-bot committed Apr 17, 2024
1 parent c589d38 commit b5d67e4
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 2 deletions.
67 changes: 66 additions & 1 deletion torchrec/distributed/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Tuple
from typing import List, Optional, Set, Tuple

import torch

Expand Down Expand Up @@ -55,3 +55,68 @@ def get_tbe_specs_from_sharded_module(
)
)
return tbe_specs


def get_path_device_tuples(
module: object, ignore_list: Optional[List[str]] = None
) -> List[Tuple[str, str]]:
path_device_tuples: List[Tuple[str, str]] = []
visited_path: Set[str] = set()

cur_ignore_list: List[str] = ignore_list if ignore_list else ["embedding_tables"]

def recursive_find_device(
module: object, cur_depth: int, path: str = "", max_depth: int = 50
) -> None:
nonlocal path_device_tuples
nonlocal visited_path

if cur_depth > max_depth:
return

if path in visited_path:
return

visited_path.add(path)
if (
isinstance(module, (int, float, str, bool, torch.Tensor))
or type(module).__name__ in ["method", "function", "Proxy"]
or module is None
):
return

device_attrs = ("device", "_device", "_device_str", "_device_type")

for name in dir(module):
if name in cur_ignore_list:
continue
child = getattr(module, name)
if name.startswith("__"):
continue
if name in device_attrs:
device = getattr(module, name)
path_device_tuples.append((path + "." + name, str(device)))
elif isinstance(child, list):
for idx, child_ in enumerate(child):
recursive_find_device(
child_,
cur_depth + 1,
f"{path}.{name}[{idx}]",
max_depth=max_depth,
)
elif isinstance(child, dict):
for key, child_ in child.items():
recursive_find_device(
child_,
cur_depth + 1,
f"{path}.{name}[{key}]",
max_depth=max_depth,
)
else:
recursive_find_device(
child, cur_depth + 1, f"{path}.{name}", max_depth=max_depth
)

recursive_find_device(module, 0, "")

return path_device_tuples
11 changes: 10 additions & 1 deletion torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
KeyedJaggedTensor,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.infer_utils import get_tbes_from_sharded_module
from torchrec.distributed.infer_utils import (
get_path_device_tuples,
get_tbes_from_sharded_module,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
from torchrec.distributed.planner.shard_estimators import (
Expand Down Expand Up @@ -974,6 +977,12 @@ def test_rw_sequence_uneven(self, weight_dtype: torch.dtype, device: str) -> Non
for tbe in tbes:
self.assertTrue(tbe.weight_initialized)

get_path_device_tuples(sharded_model)

# TODO : enable this after device propagation fix
# for path_device in path_device_lists:
# assert device in path_device[1]

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
Expand Down

0 comments on commit b5d67e4

Please sign in to comment.