diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 3dd7f7181..167b3c177 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -1009,11 +1009,7 @@ def reduce_scatter_v_pooled( [ip_split if d == 0 else input_size[d] for d in range(len(input_size))] for ip_split in input_splits ] - - equal_splits = False - if not torch.compiler.is_dynamo_compiling(): - # We can not check during tracing equality of splits -> fallback on general - equal_splits = all(ip_split == input_splits[0] for ip_split in input_splits) + equal_splits = all(ip_split == input_splits[0] for ip_split in input_splits) rsvi = ReduceScatterVInfo( input_sizes=input_sizes, diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 54576650b..5354eedb4 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -104,17 +104,10 @@ def _get_recat( for j in feature_order: # range(num_splits): recat.append(i + j * local_split) - vb_condition: bool = batch_size_per_rank is not None - if not torch.compiler.is_dynamo_compiling(): - vb_condition = vb_condition and any( - # pyre-ignore - bs != batch_size_per_rank[0] - # pyre-ignore - for bs in batch_size_per_rank - ) - # variable batch size - if vb_condition: + if batch_size_per_rank is not None and any( + bs != batch_size_per_rank[0] for bs in batch_size_per_rank + ): batch_size_per_feature = list( itertools.chain.from_iterable( itertools.repeat(x, local_split) for x in batch_size_per_rank @@ -245,8 +238,6 @@ def __init__( self._device: torch.device = device self._input = input self._splits = splits - self._input_splits_list = input_splits - self._output_splits_list = output_splits self._input_splits: Dict[str, List[int]] = dict(zip(labels, input_splits)) self._output_splits: Dict[str, List[int]] = dict(zip(labels, output_splits)) self._keys = keys @@ -264,7 +255,6 @@ def __init__( self._output_tensors: List[torch.Tensor] = [] self._awaitables: List[dist.Work] = [] - self._world_size: int = self._pg.size() for input_split, output_split, input_tensor, label in zip( input_splits, @@ -406,20 +396,6 @@ def _wait_impl(self) -> KJTAllToAllTensorsAwaitable: self._output_splits = output_list[:-1] self._stride_per_rank = output_list[-1] - if torch.compiler.is_dynamo_compiling(): - rank: int = self._pg.rank() - for i in range(len(self._output_splits)): - for j in range(len(self._output_splits[i])): - torch._check_is_size(self._output_splits[i][j]) - torch._check( - self._output_splits[i][rank] == self._input_splits[i][rank] - ) - if self._stride_per_rank is not None: - # pyre-ignore - for i in range(len(self._stride_per_rank)): - # pyre-ignore - torch._check_is_size(self._stride_per_rank[i]) - return KJTAllToAllTensorsAwaitable( pg=self._pg, input=self._input, @@ -498,7 +474,7 @@ def __init__( stagger: int = 1, ) -> None: super().__init__() - torch._check(len(splits) == pg.size()) + assert len(splits) == pg.size() self._pg: dist.ProcessGroup = pg self._splits = splits self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits)) @@ -1050,25 +1026,14 @@ def forward( PooledEmbeddingsAwaitable: awaitable of pooled embeddings of tensor of shape [batch_size, dimension]. """ - # Dynamo can not trace through data dependent condition: len(set(input_splits)) > 1 - if torch.compiler.is_dynamo_compiling(): - if input_splits is not None: - tensor_awaitable = reduce_scatter_v_pooled( - local_embs, input_splits, self._pg, codecs=self._codecs - ) - else: - tensor_awaitable = reduce_scatter_base_pooled( - local_embs, self._pg, codecs=self._codecs - ) + if input_splits and len(set(input_splits)) > 1: + tensor_awaitable = reduce_scatter_v_pooled( + local_embs, input_splits, self._pg, codecs=self._codecs + ) else: - if input_splits and len(set(input_splits)) > 1: - tensor_awaitable = reduce_scatter_v_pooled( - local_embs, input_splits, self._pg, codecs=self._codecs - ) - else: - tensor_awaitable = reduce_scatter_base_pooled( - local_embs, self._pg, codecs=self._codecs - ) + tensor_awaitable = reduce_scatter_base_pooled( + local_embs, self._pg, codecs=self._codecs + ) return PooledEmbeddingsAwaitable(tensor_awaitable=tensor_awaitable) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index c1474e77c..74d6ca672 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -14,9 +14,6 @@ import torch from torch.autograd.profiler import record_function from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec - -# pyre-ignore -from torch.fx.experimental.symbolic_shapes import guard_size_oblivious from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node from torchrec.streamable import Pipelineable @@ -676,10 +673,6 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten def _assert_tensor_has_no_elements_or_has_integers( tensor: torch.Tensor, tensor_name: str ) -> None: - if torch.compiler.is_dynamo_compiling(): - # Skipping assert on tensor.numel() == 0 for dynamo to avoid DataDependentError - return - assert tensor.numel() == 0 or tensor.dtype in [ torch.long, torch.int, @@ -796,27 +789,15 @@ def _maybe_compute_length_per_key( else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist() ) elif len(keys) and lengths is not None: - _length: List[int] = [] - if variable_stride_per_key: - _length = _length_per_key_from_stride_per_key(lengths, stride_per_key) - else: - cond: bool = False - if ( - torch.compiler.is_dynamo_compiling() - and not torch.jit.is_scripting() - ): - # pyre-ignore - cond = guard_size_oblivious(lengths.numel() != 0) - else: - cond = lengths.numel() != 0 - - _length = ( - torch.jit.annotate( - List[int], torch.sum(lengths.view(-1, stride), dim=1).tolist() - ) - if cond + _length: List[int] = ( + _length_per_key_from_stride_per_key(lengths, stride_per_key) + if variable_stride_per_key + else ( + torch.sum(lengths.view(-1, stride), dim=1).tolist() + if lengths.numel() != 0 else [0] * len(keys) ) + ) else: _length: List[int] = [] length_per_key = _length @@ -1324,7 +1305,7 @@ def __init__( self._stride_per_key_per_rank = stride_per_key_per_rank self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank] self._variable_stride_per_key = True - if stride_per_key_per_rank is not None: + if not stride_per_key_per_rank: self._stride = 0 elif all(s == self.stride_per_key()[0] for s in self.stride_per_key()): self._stride = self.stride_per_key()[0] @@ -2183,20 +2164,17 @@ def dist_init( cumsum_lengths[strides_cumsum[1:]] - cumsum_lengths[strides_cumsum[:-1]] ) with record_function("## all2all_data:recat_values ##"): - recat_cond: bool = recat is not None - if recat_cond and not is_torchdynamo_compiling(): - recat_cond = torch.jit._unwrap_optional(recat).numel() > 0 - if recat_cond: + if recat is not None and recat.numel() > 0: lengths, _ = _permute_tensor_by_segments( lengths, stride_per_rank_per_key, - torch.jit._unwrap_optional(recat), + recat, None, ) values, weights = _permute_tensor_by_segments( values, length_per_key, - torch.jit._unwrap_optional(recat), + recat, weights, ) if not stride_per_key_per_rank: @@ -2223,23 +2201,16 @@ def dist_init( else: assert stride_per_rank is not None with record_function("## all2all_data:recat_values ##"): - recat_cond: bool = recat is not None - if recat_cond and not is_torchdynamo_compiling(): - recat_cond = torch.jit._unwrap_optional(recat).numel() > 0 - - if recat_cond: + if recat is not None and recat.numel() > 0: stride = stride_per_rank[0] # dynamo don't handle generators well # so had to unroll the original generator into # this for loop. - single_batch_per_rank = False - if not is_torchdynamo_compiling(): - # Dynamo symbolic shapes could not pass through s != stride condition without hints => Dynamo always use VB path - single_batch_per_rank = True - for s in stride_per_rank: - if s != stride: - single_batch_per_rank = False + single_batch_per_rank = True + for s in stride_per_rank: + if s != stride: + single_batch_per_rank = False if single_batch_per_rank: ( @@ -2247,7 +2218,7 @@ def dist_init( values, weights, ) = torch.ops.fbgemm.permute_2D_sparse_data( - torch.jit._unwrap_optional(recat), + recat, lengths.view(-1, stride), values, weights, @@ -2260,7 +2231,7 @@ def dist_init( values, weights, ) = torch.ops.fbgemm.permute_1D_sparse_data( - torch.jit._unwrap_optional(recat), + recat, lengths.view(-1), values, weights,