From d091d11ec3b24a05a7a822e3d1ea4fc014d6cfef Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Tue, 9 Apr 2024 18:06:51 -0700 Subject: [PATCH] add mean pooling divisor to cuda stream (#1863) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1863 Initial mean pooling implementation was not attending to the appropriate CUDA stream properly with respect to train pipeline. We now register the divisor tensor into the CUDA stream in context. The key insight: Tensors used on a different stream than their origin, the memory allocator may reuse the memory unexpectedly. We also split the callback function into two (create divisor, apply mean pooling). Change the context from holding a callable to divisor tensor instead. This is because recording non tensors into a CUDA stream is non trivial, whereas recording a tensor into a CUDA stream is easily supported. This has no perf regressions from the original implementation nor lack of clarity. Differential Revision: D55945969 --- torchrec/distributed/embeddingbag.py | 54 ++++++++++++++++------------ 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 822a7bbd1..46d5fccad 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -10,6 +10,7 @@ import copy from collections import defaultdict, OrderedDict from dataclasses import dataclass, field +from functools import partial from typing import ( Any, Callable, @@ -386,7 +387,7 @@ class EmbeddingBagCollectionContext(Multistreamable): ) inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None variable_batch_per_feature: bool = False - mean_pooling_callback: Optional[Callable[[KeyedTensor], KeyedTensor]] = None + divisor: Optional[torch.Tensor] = None def record_stream(self, stream: torch.cuda.streams.Stream) -> None: for ctx in self.sharding_contexts: @@ -394,6 +395,8 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None: ctx.record_stream(stream) if self.inverse_indices is not None: self.inverse_indices[1].record_stream(stream) + if self.divisor is not None: + self.divisor.record_stream(stream) class ShardedEmbeddingBagCollection( @@ -859,7 +862,7 @@ def input_dist( self._features_order_tensor, ) if self._has_mean_pooling_callback: - ctx.mean_pooling_callback = _create_mean_pooling_callback( + ctx.divisor = _create_mean_pooling_divisor( lengths=features.lengths(), stride=features.stride(), keys=features.keys(), @@ -939,7 +942,9 @@ def output_dist( # register callback if there are features that need mean pooling if self._has_mean_pooling_callback: - awaitable.callbacks.append(ctx.mean_pooling_callback) + awaitable.callbacks.append( + partial(_apply_mean_pooling, divisor=ctx.divisor) + ) return awaitable @@ -984,7 +989,9 @@ def compute_and_output_dist( # register callback if there are features that need mean pooling if self._has_mean_pooling_callback: - awaitable.callbacks.append(ctx.mean_pooling_callback) + awaitable.callbacks.append( + partial(_apply_mean_pooling, divisor=ctx.divisor) + ) return awaitable @@ -1260,7 +1267,7 @@ def module_type(self) -> Type[nn.EmbeddingBag]: return nn.EmbeddingBag -def _create_mean_pooling_callback( +def _create_mean_pooling_divisor( lengths: torch.Tensor, keys: List[str], stride: int, @@ -1274,7 +1281,7 @@ def _create_mean_pooling_callback( kjt_key_indices: Dict[str, int], kt_key_ordering: torch.Tensor, inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, -) -> Callable[[KeyedTensor], KeyedTensor]: +) -> torch.Tensor: with record_function("## ebc create mean pooling callback ##"): batch_size = ( none_throws(inverse_indices)[1].size(dim=1) @@ -1321,22 +1328,23 @@ def _create_mean_pooling_callback( ) eps = 1e-6 # used to safe guard against 0 division divisor = divisor + eps + return divisor.detach() - # pyre-ignore[53] - def _apply_mean_pooling(keyed_tensor: KeyedTensor) -> KeyedTensor: - """ - Apply mean pooling to pooled embeddings in RW/TWRW sharding schemes. - This function is applied as a callback to the awaitable - """ - with record_function("## ebc apply mean pooling ##"): - mean_pooled_values = ( - keyed_tensor.values() / divisor - ) # [batch size, num_features * embedding dim] - return KeyedTensor( - keys=keyed_tensor.keys(), - values=mean_pooled_values, - length_per_key=keyed_tensor.length_per_key(), - key_dim=1, - ) - return _apply_mean_pooling +def _apply_mean_pooling( + keyed_tensor: KeyedTensor, divisor: torch.Tensor +) -> KeyedTensor: + """ + Apply mean pooling to pooled embeddings in RW/TWRW sharding schemes. + This function is applied as a callback to the awaitable + """ + with record_function("## ebc apply mean pooling ##"): + mean_pooled_values = ( + keyed_tensor.values() / divisor + ) # [batch size, num_features * embedding dim] + return KeyedTensor( + keys=keyed_tensor.keys(), + values=mean_pooled_values, + length_per_key=keyed_tensor.length_per_key(), + key_dim=1, + )