Skip to content

Commit

Permalink
Support custom data in Topology
Browse files Browse the repository at this point in the history
Summary:
Add a new field custom_topology_data, so we could have non-homogenous setting for  ddr_cap, hbm_cap.
Following diffs would get these custom data to support uneven sharding within planner

Differential Revision: D55524525
  • Loading branch information
gnahzg authored and facebook-github-bot committed Apr 3, 2024
1 parent a15b5eb commit 90ae582
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,34 @@ class DeviceHardware:
perf: Perf


class CustomTopologyData:
"""
Custom device data for individual device in a topology.
"""

supported_fields = ["ddr_cap", "hbm_cap"]

def __init__(
self,
data: Dict[str, List[int]],
world_size: int,
) -> None:
assert all(
key in self.supported_fields for key in data.keys()
), f"{data.keys()} not supported in CustomTopologyData"
assert all(
len(v) == world_size for v in data.values()
), f"{data.values()} must be positive"
self._data = data
self._world_size = world_size

def get_data(self, key: str) -> List[int]:
assert (
key in self.supported_fields
), f"{key} not supported in CustomTopologyData"
return self._data[key]


class Topology:
def __init__(
self,
Expand All @@ -154,6 +182,7 @@ def __init__(
intra_host_bw: float = INTRA_NODE_BANDWIDTH,
inter_host_bw: float = CROSS_NODE_BANDWIDTH,
bwd_compute_multiplier: float = BWD_COMPUTE_MULTIPLIER,
custom_topology_data: Optional[CustomTopologyData] = None,
) -> None:
"""
Representation of a network of devices in a cluster.
Expand Down

0 comments on commit 90ae582

Please sign in to comment.