diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 730322ef7..01fa577bb 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -900,13 +900,14 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract( lengths: torch.Tensor, offsets: torch.Tensor, indices: torch.Tensor, - batch_size: int, + batch_size: torch.SymInt, weights: Optional[torch.Tensor] = None, - selected_lengths_sum: Optional[int] = None, + selected_lengths_sum: Optional[torch.SymInt] = None, ) -> List[torch.Tensor]: - num_batches = len(lengths) // batch_size - torch._check(len(lengths) + 1 == len(offsets)) - torch._check(len(lengths) % batch_size == 0) + num_batches = lengths.size(0) // batch_size + torch._check(lengths.size(0) + 1 == offsets.size(0)) + # pyre-ignore + torch._check(lengths.size(0) % batch_size == 0) if weights is not None: # weights must have the same shape as values