diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index f541ba9f0..253d6e615 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -348,6 +348,16 @@ def feedback( f"EmbeddingOffloadScaleupProposer - proposed size={round(bytes_to_gb(hbm_used_previously), 2)} GB, score={perf_rating}" ) + if not partitionable: + # Focus our search on smaller plans by assuming plans larger than this + # proposal will also fail to partition. + starting_size = sum( + sharding_option.total_storage.hbm + for sharding_option in self.starting_proposal + ) + new_budget = hbm_used_previously - starting_size + self.search.shrink_right(new_budget) # pyre-ignore + assert self.search is not None # keep pyre happy budget = self.search.next(perf_rating or 1e99) if budget is not None: diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index e3beb917e..145c3d078 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -536,6 +536,77 @@ def test_scaleup(self) -> None: }, ) + def test_budget_shrink(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=2_000_000, + embedding_dim=10000, + name="table_0", + feature_names=["feature_0"], + ) + ] + constraints = { + "table_0": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], + cache_params=CacheParams( + load_factor=0.1, + stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2), + ), + ), + } + + GB = 1024 * 1024 * 1024 + storage_constraint = Topology( + world_size=1, compute_device="cuda", hbm_cap=100 * GB, ddr_cap=1000 * GB + ) + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + enumerator = EmbeddingEnumerator( + topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints + ) + search_space = enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + proposer = EmbeddingOffloadScaleupProposer() + proposer.load(search_space, enumerator=enumerator) + + proposal = proposer.propose() + best_plan = None + best_perf = 1e99 + proposals = -1 + initial_mem = None + while proposal is not None: + proposals += 1 + mem = sum(so.total_storage.hbm for so in proposal) + if initial_mem is None: + initial_mem = mem + # Budget given constraints: + # cache scale up budget=92.53 GB, exploring [7.47, 100.0] GB + # + # Simple perf model, assume partitioner gives a lowest score at 7.9GB, and + # anything larger than 8GB fails to partition. This is very hard to hit when + # exploring the larger [7.47, 100] range with limited iterations without + # shrinkage. + perf = abs(mem - (7.9 * GB)) + partitionable = mem < 8 * GB + if perf < best_perf: + best_plan = mem + best_perf = perf + proposer.feedback( + partitionable=partitionable, + plan=proposal, + perf_rating=perf if partitionable else None, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + + self.assertEqual(proposals, 16) + self.assertNotEqual(initial_mem, best_plan, "couldn't find a better plan") + # goal is 7.9, we get very close + self.assertEqual(best_plan, 7.960684550926089 * GB) + def test_proposers_to_proposals_list(self) -> None: def make_mock_proposal(name: str) -> List[ShardingOption]: return [ diff --git a/torchrec/distributed/planner/utils.py b/torchrec/distributed/planner/utils.py index c5d7d6e12..7ad237553 100644 --- a/torchrec/distributed/planner/utils.py +++ b/torchrec/distributed/planner/utils.py @@ -199,6 +199,13 @@ def __init__( self.fright: Optional[float] = None self.d: float = self.right - self.left + def shrink_right(self, B: float) -> None: + "Shrink right boundary given [B,infinity) -> infinity" + self.right = B + self.fright = math.inf + self.d = self.right - self.left + self.x = self.clamp(self.x) + def clamp(self, x: float) -> float: "Clamp x into range [left, right]" if x < self.left: