diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 7a75210b2..137bc7ad1 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -162,7 +162,7 @@ def __init__( performance_model if performance_model else NoopPerfModel(topology=topology) ) - if stats: + if stats is not None: self._stats: List[Stats] = [stats] if not isinstance(stats, list) else stats else: self._stats = [EmbeddingStats()] diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index a53928e9f..f988a316c 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -694,3 +694,25 @@ def _reduce_int_list(input_list: List[int]) -> str: reduced.append(str(prev_num)) return ", ".join(reduced) + + +class NoopEmbeddingStats(Stats): + """ + Noop Stats for a sharding planner execution. + """ + + def log( + self, + sharding_plan: ShardingPlan, + topology: Topology, + batch_size: int, + storage_reservation: StorageReservation, + num_proposals: int, + num_plans: int, + run_time: float, + best_plan: List[ShardingOption], + constraints: Optional[Dict[str, ParameterConstraints]] = None, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + debug: bool = True, + ) -> None: + pass diff --git a/torchrec/distributed/planner/tests/test_stats.py b/torchrec/distributed/planner/tests/test_stats.py new file mode 100644 index 000000000..0f789003c --- /dev/null +++ b/torchrec/distributed/planner/tests/test_stats.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import List + +import torch +from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner.planners import EmbeddingShardingPlanner +from torchrec.distributed.planner.stats import NoopEmbeddingStats +from torchrec.distributed.planner.types import Topology +from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +class TWvsRWSharder(EmbeddingBagCollectionSharder, ModuleSharder[nn.Module]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.ROW_WISE.value, ShardingType.TABLE_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.FUSED.value] + + +class TestEmbeddingStats(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology( + world_size=2, hbm_cap=1024 * 1024 * 2, compute_device=compute_device + ) + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=64, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + self.model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + + def test_embedding_stats_runs(self) -> None: + planner = EmbeddingShardingPlanner(topology=self.topology) + _ = planner.plan(module=self.model, sharders=[TWvsRWSharder()]) + self.assertEqual(len(planner._stats), 1) + stats: List[str] = planner._stats[0]._stats_table # pyre-ignore[16] + self.assertTrue(isinstance(stats, list)) + self.assertTrue(stats[0].startswith("####")) + + def test_empty_embedding_stats_runs(self) -> None: + planner = EmbeddingShardingPlanner(topology=self.topology, stats=[]) + _ = planner.plan(module=self.model, sharders=[TWvsRWSharder()]) + self.assertEqual(len(planner._stats), 0) + + def test_noop_embedding_stats_runs(self) -> None: + planner = EmbeddingShardingPlanner( + topology=self.topology, stats=NoopEmbeddingStats() + ) + _ = planner.plan(module=self.model, sharders=[TWvsRWSharder()]) + self.assertEqual(len(planner._stats), 1)