Skip to content

Commit

Permalink
Fix specailization issue in keyed_jagged_index_select_dim1_forward_cu…
Browse files Browse the repository at this point in the history
…da (pytorch#3578)

Summary:
X-link: facebookresearch/FBGEMM#664

Pull Request resolved: pytorch#3578

`lengths` is a tensor with symbolic shapes. Calling `len` on it will force specialization on it which will cause data dependent failure as shown below:
 {F1974383976}

tlparse: https://fburl.com/74rjmr8e

The fix is to replace `len` with equivalent operations which support symbolic shapes.

Reviewed By: TroyGarden

Differential Revision: D67491452

fbshipit-source-id: ed2207b310697d774a284f296c8d34ca2da61adc
  • Loading branch information
Microve authored and facebook-github-bot committed Jan 17, 2025
1 parent 3e0db25 commit 21d1260
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,13 +900,14 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
lengths: torch.Tensor,
offsets: torch.Tensor,
indices: torch.Tensor,
batch_size: int,
batch_size: torch.SymInt,
weights: Optional[torch.Tensor] = None,
selected_lengths_sum: Optional[int] = None,
selected_lengths_sum: Optional[torch.SymInt] = None,
) -> List[torch.Tensor]:
num_batches = len(lengths) // batch_size
torch._check(len(lengths) + 1 == len(offsets))
torch._check(len(lengths) % batch_size == 0)
num_batches = lengths.size(0) // batch_size
torch._check(lengths.size(0) + 1 == offsets.size(0))
# pyre-ignore
torch._check(lengths.size(0) % batch_size == 0)

if weights is not None:
# weights must have the same shape as values
Expand Down

0 comments on commit 21d1260

Please sign in to comment.