From 849a24f9d6d9ea9accc2332295c709674c6a91fb Mon Sep 17 00:00:00 2001 From: Wilson Hong Date: Thu, 4 Apr 2024 16:37:44 -0700 Subject: [PATCH] improve granularity of `PooledEmbeddingArchAwaitable` (#1843) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1843 # Context This is try to reland D47272219 (which got reverted in S355730) after fixing QPS regression. This diff is necessary in using PT2 IR for training for APS SNN models, otherwise training will get stuck due to disordered collective communication, see post: https://fb.workplace.com/groups/319878845696681/permalink/1130124208005470/ # Original summary from D47272219 Today, if I do a `__getitem__` call on `PooledEmbeddingArchAwaitable`, it triggers the wait. We'd like to defer that further to when the result of `__getitem__` is actually used. So instead, have `__getitem__` return another `LazyAwaitable` which represents the pooled embedding. Usage of that value in the context of a torchfunction will trigger the wait as desired. This ends up being important for PT2 IR integration, which eagerly dumps a bunch of `__getitem__` calls right after the sparse arch because PT2 IR prefers to operate on "flat" values. With improved granularity, we still get the desired lazy behavior. For pure eager users, this should be a no-op (we generally only call `__getitem__` right before use, so this doesn't reorder anything). The laziness affects the ordering of comms/compute, which is important in two ways: 1. PEA design means that the per-rank feature processing behavior causes the specific order of execution to be load-bearing. Without the laziness, the execution order of ranks with vs. without feature processing will diverge, causing training hangs. 2. getting comms/compute overlapping for the all to all comms vs. dense compute is likely to be a performance improvement, although it is hard to make a direct comparison because of issue #1. Further details can be found in: https://fb.workplace.com/groups/319878845696681/posts/1017888535895705 Reviewed By: IvanKobzarev Differential Revision: D54879753 fbshipit-source-id: 97c6b7d891d22c2280d93bda9319da3f61325aeb --- torchrec/distributed/embeddingbag.py | 5 ++- .../distributed/tests/test_lazy_awaitable.py | 32 +++++++++++++- torchrec/distributed/types.py | 42 +++++++++++++++++++ 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 54e703205..d2bf2594b 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -55,6 +55,7 @@ EmbeddingModuleShardingPlan, EnumerableShardingSpec, LazyAwaitable, + LazyGetItemMixin, NullShardedModuleContext, ParameterSharding, QuantizedCommCodecs, @@ -353,7 +354,9 @@ def _wait_impl(self) -> KeyedTensor: ) -class EmbeddingBagCollectionAwaitable(LazyAwaitable[KeyedTensor]): +class EmbeddingBagCollectionAwaitable( + LazyGetItemMixin[str, Tensor], LazyAwaitable[KeyedTensor] +): def __init__( self, awaitables: List[Awaitable[torch.Tensor]], diff --git a/torchrec/distributed/tests/test_lazy_awaitable.py b/torchrec/distributed/tests/test_lazy_awaitable.py index 8fd9b51eb..bc97a3b6d 100644 --- a/torchrec/distributed/tests/test_lazy_awaitable.py +++ b/torchrec/distributed/tests/test_lazy_awaitable.py @@ -12,7 +12,7 @@ import torch import torch.fx -from torchrec.distributed.types import LazyAwaitable +from torchrec.distributed.types import LazyAwaitable, LazyGetItemMixin class NeedWait(LazyAwaitable[torch.Tensor]): @@ -256,3 +256,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertTrue(torch.equal(ref_res, 17 * torch.ones(3, 4))) tempFile.close() + + def test_lazy_getitem_mixin(self) -> None: + class LazyGetItemAwaitable( + LazyGetItemMixin[str, torch.Tensor], LazyAwaitable[Dict[str, torch.Tensor]] + ): + def __init__(self, actual_value: Dict[str, torch.Tensor]): + super().__init__() + self.actual_value = actual_value + + def _wait_impl(self) -> Dict[str, torch.Tensor]: + for v in self.actual_value.values(): + v *= 3 + return self.actual_value + + actual_value = {"foo": torch.tensor(1), "bar": torch.tensor(2)} + a = LazyGetItemAwaitable(actual_value) + lazy_foo = a["foo"] + lazy_bar = a["bar"] + # The returned value should be lazy + self.assertIsInstance(lazy_foo, LazyAwaitable) + self.assertIsInstance(lazy_bar, LazyAwaitable) + + # Our lazy values should not have been waited yet + self.assertIsNone(lazy_foo._result) + self.assertIsNone(lazy_bar._result) + self.assertIsNone(a._result) + + # The use of a torch op should trigger exactly one wait on the parent object. + result = torch.add(lazy_foo, lazy_bar) + self.assertEqual(result, torch.tensor(1 * 3 + 2 * 3)) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index fe34bed79..00515814a 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -402,6 +402,48 @@ def _wait_impl(self) -> W: return self._obj +KT = TypeVar("KT") +VT_co = TypeVar("VT_co") +ParentW = TypeVar("ParentW") + + +class LazyGetItemMixin(Generic[KT, VT_co]): + """Augments the base LazyAwaitable with a lazy __getitem__ method. + + Instead of triggering a wait() on a __getitem__ call, KeyedLazyAwaitable + will return another awaitable. This can achieve better + communication/computation overlap by deferring the wait() until the + tensor data is actually needed. + + This is intended for Awaitables that model keyed collections, like + dictionaries or EmbeddingBagCollectionAwaitable. + + NOTE: if using this mixin, please include it before LazyAwaitable in the + inheritance list, so that Python MRO can properly select this __getitem__ + implementation. + """ + + def __getitem__(self, key: KT) -> LazyAwaitable[VT_co]: + return GetItemLazyAwaitable(self, key) + + +class GetItemLazyAwaitable(LazyAwaitable[W], Generic[W, ParentW, KT]): + """The LazyAwaitable returned from a __getitem__ call on `LazyGetItemMixin`. + + When the actual value of this awaitable is requested, wait on the parent and + then call __getitem__ on the result. + """ + + def __init__(self, parent: LazyAwaitable[ParentW], key: KT) -> None: + super().__init__() + self._parent = parent + self._key = key + + def _wait_impl(self) -> W: + kt = LazyAwaitable._wait_async(self._parent) + return kt[self._key] + + # install magic methods for orig_method_name in torch.fx.graph.magic_methods: as_magic = f"__{orig_method_name}__"