Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
improve granularity of
PooledEmbeddingArchAwaitable
(#1843)
Summary: Pull Request resolved: #1843 # Context This is try to reland D47272219 (which got reverted in S355730) after fixing QPS regression. This diff is necessary in using PT2 IR for training for APS SNN models, otherwise training will get stuck due to disordered collective communication, see post: https://fb.workplace.com/groups/319878845696681/permalink/1130124208005470/ # Original summary from D47272219 Today, if I do a `__getitem__` call on `PooledEmbeddingArchAwaitable`, it triggers the wait. We'd like to defer that further to when the result of `__getitem__` is actually used. So instead, have `__getitem__` return another `LazyAwaitable` which represents the pooled embedding. Usage of that value in the context of a torchfunction will trigger the wait as desired. This ends up being important for PT2 IR integration, which eagerly dumps a bunch of `__getitem__` calls right after the sparse arch because PT2 IR prefers to operate on "flat" values. With improved granularity, we still get the desired lazy behavior. For pure eager users, this should be a no-op (we generally only call `__getitem__` right before use, so this doesn't reorder anything). The laziness affects the ordering of comms/compute, which is important in two ways: 1. PEA design means that the per-rank feature processing behavior causes the specific order of execution to be load-bearing. Without the laziness, the execution order of ranks with vs. without feature processing will diverge, causing training hangs. 2. getting comms/compute overlapping for the all to all comms vs. dense compute is likely to be a performance improvement, although it is hard to make a direct comparison because of issue #1. Further details can be found in: https://fb.workplace.com/groups/319878845696681/posts/1017888535895705 Reviewed By: IvanKobzarev Differential Revision: D54879753 fbshipit-source-id: 97c6b7d891d22c2280d93bda9319da3f61325aeb
- Loading branch information