Skip to content

Commit

Permalink
sparse.mm backward: performance improvements (pytorch#94991)
Browse files Browse the repository at this point in the history
`torch.sparse.mm` - faster and without syncs in "most" cases.

Pull Request resolved: pytorch#94991
Approved by: https://github.com/Skylion007, https://github.com/pearu, https://github.com/cpuhrsch
  • Loading branch information
nikitaved authored and pytorchmergebot committed Jun 12, 2023
1 parent d083d44 commit 056d92e
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 61 deletions.
32 changes: 27 additions & 5 deletions aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ struct RhsProjOp {
}
};

struct LhsProjOp {
template <typename scalar_t>
static FUNCAPI scalar_t apply(scalar_t a, scalar_t b) {
return a;
}
};

template <int nt, int vt, typename loop_t>
C10_LAUNCH_BOUNDS_2(nt, vt)
__global__ void apply_kernel(int n, loop_t loop) {
Expand Down Expand Up @@ -70,11 +77,12 @@ void binary_op_intersection_kernel(
TensorIterator& iter,
int64_t lhs_nnz_stride,
int64_t rhs_nnz_stride,
const Tensor& argsort) {
const Tensor& argsort,
const bool accumulate_matches) {
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort);
sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches);
}
return;
}
Expand Down Expand Up @@ -106,7 +114,8 @@ void binary_op_intersection_kernel(
accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
accscalar_t rhs_values;
index_t rhs_sorted_nnz_idx;
for (int64_t c = 0; c < count; ++c) {
const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1);
for (int64_t c = 0; c < match_count; ++c) {
rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
res_values += binary_op_t::apply(lhs_values, rhs_values);
Expand All @@ -126,7 +135,8 @@ struct CUDAValueSelectionIntersectionKernel {
const Tensor& rhs_values,
const Tensor& rhs_select_idx,
const Tensor& intersection_counts,
const Tensor& argsort) {
const Tensor& argsort,
const bool accumulate_matches) {
auto iter = make_value_selection_intersection_iter(
lhs_values,
lhs_select_idx,
Expand All @@ -150,7 +160,7 @@ struct CUDAValueSelectionIntersectionKernel {
// COO indices are only 64-bit for now.
using index_t = int64_t;
binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
iter, lhs_nnz_stride, rhs_nnz_stride, argsort);
iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches);
});

return res_values;
Expand Down Expand Up @@ -180,9 +190,21 @@ void sparse_mask_intersection_out_cuda_kernel(
);
}

void sparse_mask_projection_out_cuda_kernel(
Tensor& result,
const Tensor& x,
const Tensor& y,
const OptTensor& x_hash_opt = c10::nullopt) {
using CUDAValueLhsProjKernel = CUDAValueSelectionIntersectionKernel<LhsProjOp>;
_sparse_binary_op_intersection_kernel_out<CUDAKernelLauncher, CUDAValueLhsProjKernel>(
result, x, y, x_hash_opt, c10::nullopt, /*accumulate_matches=*/false
);
}

}

REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel);
REGISTER_CUDA_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cuda_kernel);
REGISTER_CUDA_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cuda_kernel);

} // namespace at::native
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6801,6 +6801,12 @@
SparseCsrCPU, SparseCsrCUDA: sparse_mask_sparse_csr
autogen: sparse_mask.out

- func: _sparse_mask_projection(Tensor self, Tensor mask) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_mask_projection
autogen: _sparse_mask_projection.out

- func: _to_cpu(Tensor[] tensors) -> Tensor[]
variants: function

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ void _sparse_binary_op_intersection_kernel_impl(
const std::vector<int64_t> broadcasted_shape,
const c10::optional<Tensor>& x_hash_opt_ = c10::nullopt,
const c10::optional<Tensor>& y_hash_opt_ = c10::nullopt,
const bool accumulate_matches = true,
const bool distributive_with_sum = true
) {
// The common dtype check is relevant when op is done in-place.
Expand Down Expand Up @@ -403,7 +404,8 @@ void _sparse_binary_op_intersection_kernel_impl(
probably_coalesced._values().to(binary_op_res_dtype),
intersection_first_idx.to(nnz_arange.scalar_type()),
intersection_count,
argsort_hash).to(res.scalar_type());
argsort_hash,
accumulate_matches).to(res.scalar_type());
const auto res_sparse_dim = source.sparse_dim();
const auto res_dense_dim = source.dense_dim();
const auto& res_shape = broadcasted_shape;
Expand Down
32 changes: 29 additions & 3 deletions aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ struct RhsProjOp {
}
};

struct LhsProjOp {
template <typename scalar_t>
static scalar_t apply(scalar_t a, scalar_t b) {
return a;
}
};

template <typename binary_op_t>
struct CPUValueSelectionIntersectionKernel {
static Tensor apply(
Expand All @@ -43,7 +50,8 @@ struct CPUValueSelectionIntersectionKernel {
const Tensor& rhs_values,
const Tensor& rhs_select_idx,
const Tensor& intersection_counts,
const Tensor& argsort) {
const Tensor& argsort,
const bool accumulate_matches) {
auto iter = make_value_selection_intersection_iter(
lhs_values,
lhs_select_idx,
Expand Down Expand Up @@ -86,7 +94,8 @@ struct CPUValueSelectionIntersectionKernel {
accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
accscalar_t rhs_values;
index_t rhs_sorted_nnz_idx;
for (int64_t c = 0; c < count; ++c) {
const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1);
for (int64_t c = 0; c < match_count; ++c) {
rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
res_values += binary_op_t::apply(lhs_values, rhs_values);
Expand Down Expand Up @@ -132,6 +141,17 @@ void sparse_mask_intersection_out_cpu_kernel(
);
}

void sparse_mask_projection_out_cpu_kernel(
Tensor& result,
const Tensor& x,
const Tensor& y,
const OptTensor& x_hash_opt = c10::nullopt) {
using CPUValueLhsProjKernel = CPUValueSelectionIntersectionKernel<LhsProjOp>;
_sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueLhsProjKernel>(
result, x, y, x_hash_opt, c10::nullopt, /*accumulate_matches=*/false
);
}

}

REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel);
Expand All @@ -145,4 +165,10 @@ REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_interse
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
} // namespace at::native

REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel);
REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
}
3 changes: 3 additions & 0 deletions aten/src/ATen/native/sparse/SparseStubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ DECLARE_DISPATCH(mul_sparse_sparse_out_fn, mul_sparse_sparse_out_stub);
using sparse_mask_intersection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const c10::optional<Tensor>& x_hash_opt);
DECLARE_DISPATCH(sparse_mask_intersection_out_fn, sparse_mask_intersection_out_stub);

using sparse_mask_projection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const c10::optional<Tensor>& x_hash_opt);
DECLARE_DISPATCH(sparse_mask_projection_out_fn, sparse_mask_projection_out_stub);

using flatten_indices_fn = Tensor (*)(const Tensor& indices, IntArrayRef size);
DECLARE_DISPATCH(flatten_indices_fn, flatten_indices_stub);

Expand Down
149 changes: 105 additions & 44 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like_native.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/index_select.h>
#include <ATen/ops/indices_native.h>
#include <ATen/ops/is_coalesced_native.h>
Expand All @@ -55,6 +56,7 @@
#include <ATen/ops/sparse_coo_tensor_native.h>
#include <ATen/ops/sparse_dim_native.h>
#include <ATen/ops/sparse_mask_native.h>
#include <ATen/ops/_sparse_mask_projection_native.h>
#include <ATen/ops/sparse_resize_and_clear_native.h>
#include <ATen/ops/sparse_resize_native.h>
#include <ATen/ops/to_dense_native.h>
Expand Down Expand Up @@ -744,6 +746,72 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
}

DEFINE_DISPATCH(sparse_mask_intersection_out_stub);
DEFINE_DISPATCH(sparse_mask_projection_out_stub);

using OptTensor = c10::optional<Tensor>;

std::tuple<Tensor, Tensor, OptTensor> sparse_mask_like_prepare_sparse_inputs(
const std::string& method_name,
const Tensor& t,
const Tensor& mask) {
// This is a helper function for operations that implement "sparse_mask"-like
// functionality, namely, projection of values of one tensor onto the other.
// These operations mostly rely on COO intersection primitives that heavily
// exploit coalesced inputs to avoid any syncs and calls to sort. The problem
// is that these primitives might project first argument onto second one or
// the other way around depending on which arguments are coalesced and which are
// larger. This function prepares inputs for `sparse_mask` such that `t` is
// projected onto `mask` by sorting `t` if uncoalesced and artifically marking it
// as coalesced all while `mask` is set to uncoalesced.
// The result of this projectionk is going to be uncoalesced, so it is up to the
// user to set the corresponding flag correctly with respect to the operations'
// semantics.

// We already assume that t.sizes() == mask.sizes()
TORCH_CHECK(t.sparse_dim() == mask.sparse_dim(),
method_name, "(): the number of sparse dimensions in `self` ",
"should match that of the `mask`. ",
"Got `self.sparse_dim() == ", t.sparse_dim(), "` != ",
"`mask.sparse_dim() == ", mask.sparse_dim(), "`.");

const auto wrapped_tensor = [](const Tensor& t,
const OptTensor& indices = c10::nullopt,
const OptTensor& values = c10::nullopt) -> Tensor {
auto res = at::empty({0}, t.options());
auto* res_sparse_impl = get_sparse_impl(res);
res_sparse_impl->raw_resize_(t.sparse_dim(), t.dense_dim(), t.sizes());
const auto res_indices = indices.has_value() ? *indices : t._indices();
const auto res_values = values.has_value() ? *values : t._values();
res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values);
res_sparse_impl->set_nnz_and_narrow(t._nnz());
res._coalesced_(false);
return res;
};

Tensor lhs;
OptTensor lhs_hash_opt;

std::tie(lhs, lhs_hash_opt) = [&]() -> auto {
if (t.is_coalesced()) {
return std::make_tuple(t, static_cast<OptTensor>(c10::nullopt));
} else {
const auto indices_hash = at::sparse::flatten_indices(t._indices(), t.sizes());
const auto argsort_indices_hash = std::get<1>(indices_hash.sort(0));
// Probably worth having a dedicated kernel for.
const auto res_indices = t._indices().index_select(1, argsort_indices_hash);
const auto res_values = t._values().index_select(0, argsort_indices_hash);
const auto indices_hash_sorted = indices_hash.index_select(0, argsort_indices_hash);
// NOTE: res is not necessariy coalesced, but it is sorted.
// We mark it as "coalesced" to skip sorting in the intersection kernel.
auto res = wrapped_tensor(t, res_indices, res_values)._coalesced_(true);
return std::make_tuple(res, static_cast<OptTensor>(indices_hash_sorted));
}
}();

const auto rhs = mask.is_coalesced() ? wrapped_tensor(mask) : mask;

return std::make_tuple(lhs, rhs, lhs_hash_opt);
}

SparseTensor sparse_mask(const Tensor& t, const SparseTensor& mask) {
TORCH_CHECK(
Expand All @@ -753,57 +821,25 @@ SparseTensor sparse_mask(const Tensor& t, const SparseTensor& mask) {
" but mask has size ",
mask.sizes());

if (!mask.numel()) {
if (t.is_same(mask)) {
return t;
}

if (!mask.numel() || !mask._nnz()) {
return mask.clone().to(t.device(), t.scalar_type());
}

if (t.layout() == at::kSparse) {
TORCH_CHECK(t.sparse_dim() == mask.sparse_dim(),
"sparse_mask(): the number of sparse dimensions in `self` ",
"should match that of the `mask`. ",
"Got `self.sparse_dim() == ", t.sparse_dim(), "` != ",
"`mask.sparse_dim() == ", mask.sparse_dim(), "`.");

using OptTensor = c10::optional<Tensor>;

const auto wrapped_tensor = [](const Tensor& t,
const OptTensor& indices = c10::nullopt,
const OptTensor& values = c10::nullopt) -> Tensor {
auto res = at::empty({0}, t.options());
auto* res_sparse_impl = get_sparse_impl(res);
res_sparse_impl->raw_resize_(t.sparse_dim(), t.dense_dim(), t.sizes());
const auto res_indices = indices.has_value() ? *indices : t._indices();
const auto res_values = values.has_value() ? *values : t._values();
res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values);
res_sparse_impl->set_nnz_and_narrow(t._nnz());
res._coalesced_(false);
if (!t._nnz()) {
auto res = mask.clone().to(t.device(), t.scalar_type());
res._values().zero_();
return res;
};

using OptTensor = c10::optional<Tensor>;
Tensor lhs;
OptTensor lhs_hash_opt;

std::tie(lhs, lhs_hash_opt) = [&]() -> auto {
if (t.is_coalesced()) {
return std::make_tuple(t, static_cast<OptTensor>(c10::nullopt));
} else {
const auto indices_hash = at::sparse::flatten_indices(t._indices(), t.sizes());
const auto argsort_indices_hash = std::get<1>(indices_hash.sort(0));
// Probably worth having a dedicated kernel for.
const auto res_indices = t._indices().index_select(1, argsort_indices_hash);
const auto res_values = t._values().index_select(0, argsort_indices_hash);
const auto indices_hash_sorted = indices_hash.index_select(0, argsort_indices_hash);
// NOTE: res is not necessariy coalesced, but it is sorted.
// We mark it as "coalesced" to skip sorting in the intersection kernel.
auto res = wrapped_tensor(t, res_indices, res_values)._coalesced_(true);
return std::make_tuple(res, static_cast<OptTensor>(indices_hash_sorted));
}
}();

const auto rhs = mask.is_coalesced() ? wrapped_tensor(mask) : mask;
}

auto res = at::empty({0}, t.options());
Tensor lhs, rhs;
OptTensor lhs_hash_opt;
std::tie(lhs, rhs, lhs_hash_opt) = sparse_mask_like_prepare_sparse_inputs("sparse_mask", t, mask);
sparse_mask_intersection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt);
return res._coalesced_(mask.is_coalesced());
}
Expand All @@ -816,6 +852,31 @@ SparseTensor sparse_mask(const Tensor& t, const SparseTensor& mask) {
return t.mul(mask_template).to(t.scalar_type());
}

Tensor sparse_mask_projection(const Tensor& t, const Tensor& mask) {
TORCH_INTERNAL_ASSERT(t.is_sparse());
TORCH_INTERNAL_ASSERT(mask.is_sparse());

TORCH_CHECK(
mask.sizes().equals(t.sizes()),
"_sparse_mask_projection(): operands have incompatible sizes; self has size ",
t.sizes(),
" but mask has size ",
mask.sizes());

if (!t.numel() || !t._nnz() || !mask._nnz()) {
auto res = t.clone();
res._values().zero_();
return res;
}

auto res = at::empty({0}, t.options());
Tensor lhs, rhs;
OptTensor lhs_hash_opt;
std::tie(lhs, rhs, lhs_hash_opt) = sparse_mask_like_prepare_sparse_inputs("_sparse_mask_projection", mask, t);
sparse_mask_projection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt);
return res._coalesced_(t.is_coalesced());
}

Tensor empty_like_sparse_coo(
const Tensor& self,
c10::optional<ScalarType> dtype,
Expand Down
2 changes: 2 additions & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ aten::_sparse_log_softmax
aten::_sparse_log_softmax.out
aten::_sparse_log_softmax_backward_data
aten::_sparse_log_softmax_backward_data.out
aten::_sparse_mask_projection
aten::_sparse_mask_projection.out
aten::_sparse_mm_reduce_impl
aten::_sparse_mm_reduce_impl_backward
aten::_sparse_softmax
Expand Down
Loading

0 comments on commit 056d92e

Please sign in to comment.