From b2a42b4cceb459c6bf5c7cb9c90c008d56581df3 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 3 Dec 2024 11:50:27 -0800 Subject: [PATCH] Pyre fix Differential Revision: D66717161 --- torchrec/distributed/model_parallel.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 0f60362b6..5cbd2429b 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -746,6 +746,7 @@ def sync(self, include_optimizer_state: bool = True) -> None: all_weights = [ w for emb_kernel in self._modules_to_sync + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. for w in emb_kernel.split_embedding_weights() ] handle = self._replica_pg.allreduce_coalesced(all_weights, opts=opts) @@ -755,6 +756,7 @@ def sync(self, include_optimizer_state: bool = True) -> None: # Sync accumulated square of grad of local optimizer shards optim_list = [] for emb_kernel in self._modules_to_sync: + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. all_optimizer_states = emb_kernel.get_optimizer_state() momentum1 = [optim["sum"] for optim in all_optimizer_states] optim_list.extend(momentum1) @@ -864,6 +866,8 @@ def _find_sharded_modules( if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen): sharded_modules.append(module) if hasattr(module, "_lookups"): + # pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is + # not a function. for lookup in module._lookups: _find_sharded_modules(lookup) return