Skip to content

Commit

Permalink
Allow empty list of stats, and add Noop Stats (pytorch#1628)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1628

Unblock D52651099 by allowing an empty list for Stats, and adding a Noop Stats.

Reviewed By: gnahzg

Differential Revision: D52716401

fbshipit-source-id: 483f1620bbc5b8378beee1c1a98041800877da2e
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 13, 2024
1 parent aa9c9ba commit 431fb6c
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
performance_model if performance_model else NoopPerfModel(topology=topology)
)

if stats:
if stats is not None:
self._stats: List[Stats] = [stats] if not isinstance(stats, list) else stats
else:
self._stats = [EmbeddingStats()]
Expand Down
22 changes: 22 additions & 0 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,25 @@ def _reduce_int_list(input_list: List[int]) -> str:
reduced.append(str(prev_num))

return ", ".join(reduced)


class NoopEmbeddingStats(Stats):
"""
Noop Stats for a sharding planner execution.
"""

def log(
self,
sharding_plan: ShardingPlan,
topology: Topology,
batch_size: int,
storage_reservation: StorageReservation,
num_proposals: int,
num_plans: int,
run_time: float,
best_plan: List[ShardingOption],
constraints: Optional[Dict[str, ParameterConstraints]] = None,
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
debug: bool = True,
) -> None:
pass
68 changes: 68 additions & 0 deletions torchrec/distributed/planner/tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import List

import torch
from torch import nn
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
from torchrec.distributed.planner.stats import NoopEmbeddingStats
from torchrec.distributed.planner.types import Topology
from torchrec.distributed.test_utils.test_model import TestSparseNN
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig


class TWvsRWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]):
def sharding_types(self, compute_device_type: str) -> List[str]:
return [ShardingType.ROW_WISE.value, ShardingType.TABLE_WISE.value]

def compute_kernels(
self, sharding_type: str, compute_device_type: str
) -> List[str]:
return [EmbeddingComputeKernel.FUSED.value]


class TestEmbeddingStats(unittest.TestCase):
def setUp(self) -> None:
compute_device = "cuda"
self.topology = Topology(
world_size=2, hbm_cap=1024 * 1024 * 2, compute_device=compute_device
)
tables = [
EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=64,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(4)
]
self.model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))

def test_embedding_stats_runs(self) -> None:
planner = EmbeddingShardingPlanner(topology=self.topology)
_ = planner.plan(module=self.model, sharders=[TWvsRWSharder()])
self.assertEqual(len(planner._stats), 1)
stats: List[str] = planner._stats[0]._stats_table # pyre-ignore[16]
self.assertTrue(isinstance(stats, list))
self.assertTrue(stats[0].startswith("####"))

def test_empty_embedding_stats_runs(self) -> None:
planner = EmbeddingShardingPlanner(topology=self.topology, stats=[])
_ = planner.plan(module=self.model, sharders=[TWvsRWSharder()])
self.assertEqual(len(planner._stats), 0)

def test_noop_embedding_stats_runs(self) -> None:
planner = EmbeddingShardingPlanner(
topology=self.topology, stats=NoopEmbeddingStats()
)
_ = planner.plan(module=self.model, sharders=[TWvsRWSharder()])
self.assertEqual(len(planner._stats), 1)

0 comments on commit 431fb6c

Please sign in to comment.