Skip to content

Commit

Permalink
add bag reverse mapping for block_bucketize kernel inference (#2477)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#1845

Pull Request resolved: #2477

* we add the per indices bucketize bag reverse index return for inference batching needs.
* to avoid training kernel bc breakage, we branch the new oe as the infernce kernel.

Reviewed By: xing-liu

Differential Revision: D55492793

fbshipit-source-id: ffc4c8be4dcb8de735e21e7442241e8179c31636
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Apr 23, 2024
1 parent 1cb1dd1 commit b70dcd8
Show file tree
Hide file tree
Showing 5 changed files with 587 additions and 333 deletions.
43 changes: 43 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,49 @@ block_bucketize_sparse_features_cpu(
const int64_t max_batch_size,
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos);

std::tuple<
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>>
///@ingroup sparse-data-cuda
block_bucketize_sparse_features_inference_cuda(
const at::Tensor& lengths,
const at::Tensor& indices,
const bool bucketize_pos,
const bool sequence,
const at::Tensor& block_sizes,
const int64_t my_size,
const c10::optional<at::Tensor>& weights,
const c10::optional<at::Tensor>& batch_size_per_feature,
const int64_t max_batch_size,
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping);

std::tuple<
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>>

///@ingroup sparse-data-cpu
block_bucketize_sparse_features_inference_cpu(
const at::Tensor& lengths,
const at::Tensor& indices,
const bool bucketize_pos,
const bool sequence,
const at::Tensor& block_sizes,
const int64_t my_size,
const c10::optional<at::Tensor>& weights,
const c10::optional<at::Tensor>& batch_size_per_feature,
const int64_t max_batch_size,
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping);

std::tuple<
at::Tensor,
at::Tensor,
Expand Down
Loading

0 comments on commit b70dcd8

Please sign in to comment.