Skip to content

Commit

Permalink
improve granularity of PooledEmbeddingArchAwaitable (#1843)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1843

# Context
This is try to reland D47272219 (which got reverted in S355730) after fixing QPS regression.  This diff is necessary in using PT2 IR for training for APS SNN models, otherwise training will get stuck due to disordered collective communication, see post: https://fb.workplace.com/groups/319878845696681/permalink/1130124208005470/

# Original summary from D47272219

Today, if I do a `__getitem__` call on `PooledEmbeddingArchAwaitable`, it triggers the wait. We'd like to defer that further to when the result of `__getitem__` is actually used.

So instead, have `__getitem__` return another `LazyAwaitable` which represents the pooled embedding. Usage of that value in the context of a torchfunction will trigger the wait as desired.

This ends up being important for PT2 IR integration, which eagerly dumps a bunch of `__getitem__` calls right after the sparse arch because PT2 IR prefers to operate on "flat" values. With improved granularity, we still get the desired lazy behavior.

For pure eager users, this should be a no-op (we generally only call `__getitem__` right before use, so this doesn't reorder anything).

The laziness affects the ordering of comms/compute, which is important in two ways:
1. PEA design means that the per-rank feature processing behavior causes the specific order of execution to be load-bearing. Without the laziness, the execution order of ranks with vs. without feature processing will diverge, causing training hangs.
2. getting comms/compute overlapping for the all to all comms vs. dense compute is likely to be a performance improvement, although it is hard to make a direct comparison because of issue #1.

Further details can be found in: https://fb.workplace.com/groups/319878845696681/posts/1017888535895705

Reviewed By: IvanKobzarev

Differential Revision: D54879753

fbshipit-source-id: 97c6b7d891d22c2280d93bda9319da3f61325aeb
  • Loading branch information
wilson100hong authored and facebook-github-bot committed Apr 4, 2024
1 parent 030c694 commit 849a24f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 2 deletions.
5 changes: 4 additions & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
EmbeddingModuleShardingPlan,
EnumerableShardingSpec,
LazyAwaitable,
LazyGetItemMixin,
NullShardedModuleContext,
ParameterSharding,
QuantizedCommCodecs,
Expand Down Expand Up @@ -353,7 +354,9 @@ def _wait_impl(self) -> KeyedTensor:
)


class EmbeddingBagCollectionAwaitable(LazyAwaitable[KeyedTensor]):
class EmbeddingBagCollectionAwaitable(
LazyGetItemMixin[str, Tensor], LazyAwaitable[KeyedTensor]
):
def __init__(
self,
awaitables: List[Awaitable[torch.Tensor]],
Expand Down
32 changes: 31 additions & 1 deletion torchrec/distributed/tests/test_lazy_awaitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
import torch.fx
from torchrec.distributed.types import LazyAwaitable
from torchrec.distributed.types import LazyAwaitable, LazyGetItemMixin


class NeedWait(LazyAwaitable[torch.Tensor]):
Expand Down Expand Up @@ -256,3 +256,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.assertTrue(torch.equal(ref_res, 17 * torch.ones(3, 4)))

tempFile.close()

def test_lazy_getitem_mixin(self) -> None:
class LazyGetItemAwaitable(
LazyGetItemMixin[str, torch.Tensor], LazyAwaitable[Dict[str, torch.Tensor]]
):
def __init__(self, actual_value: Dict[str, torch.Tensor]):
super().__init__()
self.actual_value = actual_value

def _wait_impl(self) -> Dict[str, torch.Tensor]:
for v in self.actual_value.values():
v *= 3
return self.actual_value

actual_value = {"foo": torch.tensor(1), "bar": torch.tensor(2)}
a = LazyGetItemAwaitable(actual_value)
lazy_foo = a["foo"]
lazy_bar = a["bar"]
# The returned value should be lazy
self.assertIsInstance(lazy_foo, LazyAwaitable)
self.assertIsInstance(lazy_bar, LazyAwaitable)

# Our lazy values should not have been waited yet
self.assertIsNone(lazy_foo._result)
self.assertIsNone(lazy_bar._result)
self.assertIsNone(a._result)

# The use of a torch op should trigger exactly one wait on the parent object.
result = torch.add(lazy_foo, lazy_bar)
self.assertEqual(result, torch.tensor(1 * 3 + 2 * 3))
42 changes: 42 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,48 @@ def _wait_impl(self) -> W:
return self._obj


KT = TypeVar("KT")
VT_co = TypeVar("VT_co")
ParentW = TypeVar("ParentW")


class LazyGetItemMixin(Generic[KT, VT_co]):
"""Augments the base LazyAwaitable with a lazy __getitem__ method.
Instead of triggering a wait() on a __getitem__ call, KeyedLazyAwaitable
will return another awaitable. This can achieve better
communication/computation overlap by deferring the wait() until the
tensor data is actually needed.
This is intended for Awaitables that model keyed collections, like
dictionaries or EmbeddingBagCollectionAwaitable.
NOTE: if using this mixin, please include it before LazyAwaitable in the
inheritance list, so that Python MRO can properly select this __getitem__
implementation.
"""

def __getitem__(self, key: KT) -> LazyAwaitable[VT_co]:
return GetItemLazyAwaitable(self, key)


class GetItemLazyAwaitable(LazyAwaitable[W], Generic[W, ParentW, KT]):
"""The LazyAwaitable returned from a __getitem__ call on `LazyGetItemMixin`.
When the actual value of this awaitable is requested, wait on the parent and
then call __getitem__ on the result.
"""

def __init__(self, parent: LazyAwaitable[ParentW], key: KT) -> None:
super().__init__()
self._parent = parent
self._key = key

def _wait_impl(self) -> W:
kt = LazyAwaitable._wait_async(self._parent)
return kt[self._key]


# install magic methods
for orig_method_name in torch.fx.graph.magic_methods:
as_magic = f"__{orig_method_name}__"
Expand Down

0 comments on commit 849a24f

Please sign in to comment.