Skip to content

Commit

Permalink
add mean pooling divisor to cuda stream (#1863)
Browse files Browse the repository at this point in the history
Summary:

Initial mean pooling implementation was not attending to the appropriate CUDA stream properly with respect to train pipeline. We now register the divisor tensor into the CUDA stream in context.

The key insight:
Tensors used on a different stream than their origin, the memory allocator may reuse the memory unexpectedly.

We also split the callback function into two (create divisor, apply mean pooling). Change the context from holding a callable to divisor tensor instead. This is because recording non tensors into a CUDA stream is non trivial, whereas recording a tensor into a CUDA stream is easily supported. This has no perf regressions from the original implementation nor lack of clarity.

Reviewed By: joshuadeng

Differential Revision: D55945969
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Apr 10, 2024
1 parent 3eec2fc commit 2ca3a1b
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import copy
from collections import defaultdict, OrderedDict
from dataclasses import dataclass, field
from functools import partial
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -386,14 +387,16 @@ class EmbeddingBagCollectionContext(Multistreamable):
)
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None
variable_batch_per_feature: bool = False
mean_pooling_callback: Optional[Callable[[KeyedTensor], KeyedTensor]] = None
divisor: Optional[torch.Tensor] = None

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for ctx in self.sharding_contexts:
if ctx:
ctx.record_stream(stream)
if self.inverse_indices is not None:
self.inverse_indices[1].record_stream(stream)
if self.divisor is not None:
self.divisor.record_stream(stream)


class ShardedEmbeddingBagCollection(
Expand Down Expand Up @@ -859,7 +862,7 @@ def input_dist(
self._features_order_tensor,
)
if self._has_mean_pooling_callback:
ctx.mean_pooling_callback = _create_mean_pooling_callback(
ctx.divisor = _create_mean_pooling_divisor(
lengths=features.lengths(),
stride=features.stride(),
keys=features.keys(),
Expand Down Expand Up @@ -939,7 +942,9 @@ def output_dist(

# register callback if there are features that need mean pooling
if self._has_mean_pooling_callback:
awaitable.callbacks.append(ctx.mean_pooling_callback)
awaitable.callbacks.append(
partial(_apply_mean_pooling, divisor=ctx.divisor)
)

return awaitable

Expand Down Expand Up @@ -984,7 +989,9 @@ def compute_and_output_dist(

# register callback if there are features that need mean pooling
if self._has_mean_pooling_callback:
awaitable.callbacks.append(ctx.mean_pooling_callback)
awaitable.callbacks.append(
partial(_apply_mean_pooling, divisor=ctx.divisor)
)

return awaitable

Expand Down Expand Up @@ -1260,7 +1267,7 @@ def module_type(self) -> Type[nn.EmbeddingBag]:
return nn.EmbeddingBag


def _create_mean_pooling_callback(
def _create_mean_pooling_divisor(
lengths: torch.Tensor,
keys: List[str],
stride: int,
Expand All @@ -1274,7 +1281,7 @@ def _create_mean_pooling_callback(
kjt_key_indices: Dict[str, int],
kt_key_ordering: torch.Tensor,
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
) -> Callable[[KeyedTensor], KeyedTensor]:
) -> torch.Tensor:
with record_function("## ebc create mean pooling callback ##"):
batch_size = (
none_throws(inverse_indices)[1].size(dim=1)
Expand Down Expand Up @@ -1321,22 +1328,23 @@ def _create_mean_pooling_callback(
)
eps = 1e-6 # used to safe guard against 0 division
divisor = divisor + eps
return divisor.detach()

# pyre-ignore[53]
def _apply_mean_pooling(keyed_tensor: KeyedTensor) -> KeyedTensor:
"""
Apply mean pooling to pooled embeddings in RW/TWRW sharding schemes.
This function is applied as a callback to the awaitable
"""
with record_function("## ebc apply mean pooling ##"):
mean_pooled_values = (
keyed_tensor.values() / divisor
) # [batch size, num_features * embedding dim]
return KeyedTensor(
keys=keyed_tensor.keys(),
values=mean_pooled_values,
length_per_key=keyed_tensor.length_per_key(),
key_dim=1,
)

return _apply_mean_pooling
def _apply_mean_pooling(
keyed_tensor: KeyedTensor, divisor: torch.Tensor
) -> KeyedTensor:
"""
Apply mean pooling to pooled embeddings in RW/TWRW sharding schemes.
This function is applied as a callback to the awaitable
"""
with record_function("## ebc apply mean pooling ##"):
mean_pooled_values = (
keyed_tensor.values() / divisor
) # [batch size, num_features * embedding dim]
return KeyedTensor(
keys=keyed_tensor.keys(),
values=mean_pooled_values,
length_per_key=keyed_tensor.length_per_key(),
key_dim=1,
)

0 comments on commit 2ca3a1b

Please sign in to comment.