From 056d92e2a03b57ee0d783b095b7ddb674fedac0a Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Mon, 12 Jun 2023 09:37:05 +0000 Subject: [PATCH] sparse.mm backward: performance improvements (#94991) `torch.sparse.mm` - faster and without syncs in "most" cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94991 Approved by: https://github.com/Skylion007, https://github.com/pearu, https://github.com/cpuhrsch --- .../cuda/SparseBinaryOpIntersectionKernel.cu | 32 +++- aten/src/ATen/native/native_functions.yaml | 6 + .../sparse/SparseBinaryOpIntersectionCommon.h | 4 +- .../SparseBinaryOpIntersectionKernel.cpp | 32 +++- aten/src/ATen/native/sparse/SparseStubs.h | 3 + aten/src/ATen/native/sparse/SparseTensor.cpp | 149 ++++++++++++------ ...asDecompTest.test_has_decomposition.expect | 2 + torch/csrc/autograd/FunctionsManual.cpp | 34 +++- torch/overrides.py | 1 + torchgen/static_runtime/generator.py | 1 + 10 files changed, 203 insertions(+), 61 deletions(-) diff --git a/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu b/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu index 7888ac6a09fe7..67e28d8ba4493 100644 --- a/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu +++ b/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu @@ -36,6 +36,13 @@ struct RhsProjOp { } }; +struct LhsProjOp { + template + static FUNCAPI scalar_t apply(scalar_t a, scalar_t b) { + return a; + } +}; + template C10_LAUNCH_BOUNDS_2(nt, vt) __global__ void apply_kernel(int n, loop_t loop) { @@ -70,11 +77,12 @@ void binary_op_intersection_kernel( TensorIterator& iter, int64_t lhs_nnz_stride, int64_t rhs_nnz_stride, - const Tensor& argsort) { + const Tensor& argsort, + const bool accumulate_matches) { if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { binary_op_intersection_kernel( - sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort); + sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches); } return; } @@ -106,7 +114,8 @@ void binary_op_intersection_kernel( accscalar_t lhs_values = static_cast(*ptr_lhs_begin); accscalar_t rhs_values; index_t rhs_sorted_nnz_idx; - for (int64_t c = 0; c < count; ++c) { + const auto match_count = accumulate_matches ? count : std::min(count, 1); + for (int64_t c = 0; c < match_count; ++c) { rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++; rhs_values = static_cast(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride)); res_values += binary_op_t::apply(lhs_values, rhs_values); @@ -126,7 +135,8 @@ struct CUDAValueSelectionIntersectionKernel { const Tensor& rhs_values, const Tensor& rhs_select_idx, const Tensor& intersection_counts, - const Tensor& argsort) { + const Tensor& argsort, + const bool accumulate_matches) { auto iter = make_value_selection_intersection_iter( lhs_values, lhs_select_idx, @@ -150,7 +160,7 @@ struct CUDAValueSelectionIntersectionKernel { // COO indices are only 64-bit for now. using index_t = int64_t; binary_op_intersection_kernel( - iter, lhs_nnz_stride, rhs_nnz_stride, argsort); + iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches); }); return res_values; @@ -180,9 +190,21 @@ void sparse_mask_intersection_out_cuda_kernel( ); } +void sparse_mask_projection_out_cuda_kernel( + Tensor& result, + const Tensor& x, + const Tensor& y, + const OptTensor& x_hash_opt = c10::nullopt) { + using CUDAValueLhsProjKernel = CUDAValueSelectionIntersectionKernel; + _sparse_binary_op_intersection_kernel_out( + result, x, y, x_hash_opt, c10::nullopt, /*accumulate_matches=*/false + ); +} + } REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel); REGISTER_CUDA_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cuda_kernel); +REGISTER_CUDA_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cuda_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2e04e24d3730b..69968cd2d4866 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6801,6 +6801,12 @@ SparseCsrCPU, SparseCsrCUDA: sparse_mask_sparse_csr autogen: sparse_mask.out +- func: _sparse_mask_projection(Tensor self, Tensor mask) -> Tensor + variants: method + dispatch: + SparseCPU, SparseCUDA: sparse_mask_projection + autogen: _sparse_mask_projection.out + - func: _to_cpu(Tensor[] tensors) -> Tensor[] variants: function diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h index 0e1b96f12a6d6..94faadf2002f0 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h @@ -136,6 +136,7 @@ void _sparse_binary_op_intersection_kernel_impl( const std::vector broadcasted_shape, const c10::optional& x_hash_opt_ = c10::nullopt, const c10::optional& y_hash_opt_ = c10::nullopt, + const bool accumulate_matches = true, const bool distributive_with_sum = true ) { // The common dtype check is relevant when op is done in-place. @@ -403,7 +404,8 @@ void _sparse_binary_op_intersection_kernel_impl( probably_coalesced._values().to(binary_op_res_dtype), intersection_first_idx.to(nnz_arange.scalar_type()), intersection_count, - argsort_hash).to(res.scalar_type()); + argsort_hash, + accumulate_matches).to(res.scalar_type()); const auto res_sparse_dim = source.sparse_dim(); const auto res_dense_dim = source.dense_dim(); const auto& res_shape = broadcasted_shape; diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp index 211d7f6527655..9d1f3495b7caa 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp @@ -35,6 +35,13 @@ struct RhsProjOp { } }; +struct LhsProjOp { + template + static scalar_t apply(scalar_t a, scalar_t b) { + return a; + } +}; + template struct CPUValueSelectionIntersectionKernel { static Tensor apply( @@ -43,7 +50,8 @@ struct CPUValueSelectionIntersectionKernel { const Tensor& rhs_values, const Tensor& rhs_select_idx, const Tensor& intersection_counts, - const Tensor& argsort) { + const Tensor& argsort, + const bool accumulate_matches) { auto iter = make_value_selection_intersection_iter( lhs_values, lhs_select_idx, @@ -86,7 +94,8 @@ struct CPUValueSelectionIntersectionKernel { accscalar_t lhs_values = static_cast(*ptr_lhs_begin); accscalar_t rhs_values; index_t rhs_sorted_nnz_idx; - for (int64_t c = 0; c < count; ++c) { + const auto match_count = accumulate_matches ? count : std::min(count, 1); + for (int64_t c = 0; c < match_count; ++c) { rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++; rhs_values = static_cast(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride)); res_values += binary_op_t::apply(lhs_values, rhs_values); @@ -132,6 +141,17 @@ void sparse_mask_intersection_out_cpu_kernel( ); } +void sparse_mask_projection_out_cpu_kernel( + Tensor& result, + const Tensor& x, + const Tensor& y, + const OptTensor& x_hash_opt = c10::nullopt) { + using CPUValueLhsProjKernel = CPUValueSelectionIntersectionKernel; + _sparse_binary_op_intersection_kernel_out( + result, x, y, x_hash_opt, c10::nullopt, /*accumulate_matches=*/false + ); +} + } REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel); @@ -145,4 +165,10 @@ REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_interse REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); -} // namespace at::native + +REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel); +REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +} diff --git a/aten/src/ATen/native/sparse/SparseStubs.h b/aten/src/ATen/native/sparse/SparseStubs.h index 7782043666494..0f71fa287120f 100644 --- a/aten/src/ATen/native/sparse/SparseStubs.h +++ b/aten/src/ATen/native/sparse/SparseStubs.h @@ -16,6 +16,9 @@ DECLARE_DISPATCH(mul_sparse_sparse_out_fn, mul_sparse_sparse_out_stub); using sparse_mask_intersection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const c10::optional& x_hash_opt); DECLARE_DISPATCH(sparse_mask_intersection_out_fn, sparse_mask_intersection_out_stub); +using sparse_mask_projection_out_fn = void (*)(Tensor& res, const Tensor& x, const Tensor& y, const c10::optional& x_hash_opt); +DECLARE_DISPATCH(sparse_mask_projection_out_fn, sparse_mask_projection_out_stub); + using flatten_indices_fn = Tensor (*)(const Tensor& indices, IntArrayRef size); DECLARE_DISPATCH(flatten_indices_fn, flatten_indices_stub); diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index e446e9bc6e8ae..d54f19f9960e2 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -55,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -744,6 +746,72 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) { } DEFINE_DISPATCH(sparse_mask_intersection_out_stub); +DEFINE_DISPATCH(sparse_mask_projection_out_stub); + +using OptTensor = c10::optional; + +std::tuple sparse_mask_like_prepare_sparse_inputs( + const std::string& method_name, + const Tensor& t, + const Tensor& mask) { + // This is a helper function for operations that implement "sparse_mask"-like + // functionality, namely, projection of values of one tensor onto the other. + // These operations mostly rely on COO intersection primitives that heavily + // exploit coalesced inputs to avoid any syncs and calls to sort. The problem + // is that these primitives might project first argument onto second one or + // the other way around depending on which arguments are coalesced and which are + // larger. This function prepares inputs for `sparse_mask` such that `t` is + // projected onto `mask` by sorting `t` if uncoalesced and artifically marking it + // as coalesced all while `mask` is set to uncoalesced. + // The result of this projectionk is going to be uncoalesced, so it is up to the + // user to set the corresponding flag correctly with respect to the operations' + // semantics. + + // We already assume that t.sizes() == mask.sizes() + TORCH_CHECK(t.sparse_dim() == mask.sparse_dim(), + method_name, "(): the number of sparse dimensions in `self` ", + "should match that of the `mask`. ", + "Got `self.sparse_dim() == ", t.sparse_dim(), "` != ", + "`mask.sparse_dim() == ", mask.sparse_dim(), "`."); + + const auto wrapped_tensor = [](const Tensor& t, + const OptTensor& indices = c10::nullopt, + const OptTensor& values = c10::nullopt) -> Tensor { + auto res = at::empty({0}, t.options()); + auto* res_sparse_impl = get_sparse_impl(res); + res_sparse_impl->raw_resize_(t.sparse_dim(), t.dense_dim(), t.sizes()); + const auto res_indices = indices.has_value() ? *indices : t._indices(); + const auto res_values = values.has_value() ? *values : t._values(); + res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values); + res_sparse_impl->set_nnz_and_narrow(t._nnz()); + res._coalesced_(false); + return res; + }; + + Tensor lhs; + OptTensor lhs_hash_opt; + + std::tie(lhs, lhs_hash_opt) = [&]() -> auto { + if (t.is_coalesced()) { + return std::make_tuple(t, static_cast(c10::nullopt)); + } else { + const auto indices_hash = at::sparse::flatten_indices(t._indices(), t.sizes()); + const auto argsort_indices_hash = std::get<1>(indices_hash.sort(0)); + // Probably worth having a dedicated kernel for. + const auto res_indices = t._indices().index_select(1, argsort_indices_hash); + const auto res_values = t._values().index_select(0, argsort_indices_hash); + const auto indices_hash_sorted = indices_hash.index_select(0, argsort_indices_hash); + // NOTE: res is not necessariy coalesced, but it is sorted. + // We mark it as "coalesced" to skip sorting in the intersection kernel. + auto res = wrapped_tensor(t, res_indices, res_values)._coalesced_(true); + return std::make_tuple(res, static_cast(indices_hash_sorted)); + } + }(); + + const auto rhs = mask.is_coalesced() ? wrapped_tensor(mask) : mask; + + return std::make_tuple(lhs, rhs, lhs_hash_opt); +} SparseTensor sparse_mask(const Tensor& t, const SparseTensor& mask) { TORCH_CHECK( @@ -753,57 +821,25 @@ SparseTensor sparse_mask(const Tensor& t, const SparseTensor& mask) { " but mask has size ", mask.sizes()); - if (!mask.numel()) { + if (t.is_same(mask)) { + return t; + } + + if (!mask.numel() || !mask._nnz()) { return mask.clone().to(t.device(), t.scalar_type()); } if (t.layout() == at::kSparse) { - TORCH_CHECK(t.sparse_dim() == mask.sparse_dim(), - "sparse_mask(): the number of sparse dimensions in `self` ", - "should match that of the `mask`. ", - "Got `self.sparse_dim() == ", t.sparse_dim(), "` != ", - "`mask.sparse_dim() == ", mask.sparse_dim(), "`."); - - using OptTensor = c10::optional; - - const auto wrapped_tensor = [](const Tensor& t, - const OptTensor& indices = c10::nullopt, - const OptTensor& values = c10::nullopt) -> Tensor { - auto res = at::empty({0}, t.options()); - auto* res_sparse_impl = get_sparse_impl(res); - res_sparse_impl->raw_resize_(t.sparse_dim(), t.dense_dim(), t.sizes()); - const auto res_indices = indices.has_value() ? *indices : t._indices(); - const auto res_values = values.has_value() ? *values : t._values(); - res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values); - res_sparse_impl->set_nnz_and_narrow(t._nnz()); - res._coalesced_(false); + if (!t._nnz()) { + auto res = mask.clone().to(t.device(), t.scalar_type()); + res._values().zero_(); return res; - }; - - using OptTensor = c10::optional; - Tensor lhs; - OptTensor lhs_hash_opt; - - std::tie(lhs, lhs_hash_opt) = [&]() -> auto { - if (t.is_coalesced()) { - return std::make_tuple(t, static_cast(c10::nullopt)); - } else { - const auto indices_hash = at::sparse::flatten_indices(t._indices(), t.sizes()); - const auto argsort_indices_hash = std::get<1>(indices_hash.sort(0)); - // Probably worth having a dedicated kernel for. - const auto res_indices = t._indices().index_select(1, argsort_indices_hash); - const auto res_values = t._values().index_select(0, argsort_indices_hash); - const auto indices_hash_sorted = indices_hash.index_select(0, argsort_indices_hash); - // NOTE: res is not necessariy coalesced, but it is sorted. - // We mark it as "coalesced" to skip sorting in the intersection kernel. - auto res = wrapped_tensor(t, res_indices, res_values)._coalesced_(true); - return std::make_tuple(res, static_cast(indices_hash_sorted)); - } - }(); - - const auto rhs = mask.is_coalesced() ? wrapped_tensor(mask) : mask; + } auto res = at::empty({0}, t.options()); + Tensor lhs, rhs; + OptTensor lhs_hash_opt; + std::tie(lhs, rhs, lhs_hash_opt) = sparse_mask_like_prepare_sparse_inputs("sparse_mask", t, mask); sparse_mask_intersection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt); return res._coalesced_(mask.is_coalesced()); } @@ -816,6 +852,31 @@ SparseTensor sparse_mask(const Tensor& t, const SparseTensor& mask) { return t.mul(mask_template).to(t.scalar_type()); } +Tensor sparse_mask_projection(const Tensor& t, const Tensor& mask) { + TORCH_INTERNAL_ASSERT(t.is_sparse()); + TORCH_INTERNAL_ASSERT(mask.is_sparse()); + + TORCH_CHECK( + mask.sizes().equals(t.sizes()), + "_sparse_mask_projection(): operands have incompatible sizes; self has size ", + t.sizes(), + " but mask has size ", + mask.sizes()); + + if (!t.numel() || !t._nnz() || !mask._nnz()) { + auto res = t.clone(); + res._values().zero_(); + return res; + } + + auto res = at::empty({0}, t.options()); + Tensor lhs, rhs; + OptTensor lhs_hash_opt; + std::tie(lhs, rhs, lhs_hash_opt) = sparse_mask_like_prepare_sparse_inputs("_sparse_mask_projection", mask, t); + sparse_mask_projection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt); + return res._coalesced_(t.is_coalesced()); +} + Tensor empty_like_sparse_coo( const Tensor& self, c10::optional dtype, diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index f914bab7c206e..2e12685261ccd 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -465,6 +465,8 @@ aten::_sparse_log_softmax aten::_sparse_log_softmax.out aten::_sparse_log_softmax_backward_data aten::_sparse_log_softmax_backward_data.out +aten::_sparse_mask_projection +aten::_sparse_mask_projection.out aten::_sparse_mm_reduce_impl aten::_sparse_mm_reduce_impl_backward aten::_sparse_softmax diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 1cd06e7ee7bf4..342ab86a70061 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1451,6 +1451,30 @@ Tensor mm_mat1_sparse_backward( mat2.layout()); } +Tensor sparse_mask_like_grad(const Tensor& x, const Tensor& gx) { + if (x.is_coalesced() && gx.is_coalesced()) { + if (x._nnz() >= gx._nnz()) { + // search into x is faster + return gx._sparse_mask_projection(x); + } else { + // search into gx is faster + return gx.sparse_mask(x); + } + } else if (x.is_coalesced()) { + return gx.sparse_mask(x); + } else if (gx.is_coalesced()) { + return gx._sparse_mask_projection(x); + } else { + if (x._nnz() >= gx._nnz()) { + // gx.coalesce() is likely faster + return gx.coalesce()._sparse_mask_projection(x); + } else { + // x.coalesce() is likely faster + return gx.sparse_mask(x.coalesce()); + } + } +} + Tensor sparse_sparse_matmul_backward( const Tensor& grad, const Tensor& a, @@ -1475,19 +1499,13 @@ Tensor sparse_sparse_matmul_backward( TORCH_CHECK( grad_order == 0 || grad_order == 1, ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function"); - const auto mask_ones_like = [](const Tensor& t) -> Tensor { - return at::sparse_coo_tensor( - t._indices(), - at::ones({1}, t._values().options()).expand_as(t._values()), - t.sizes()); - }; if (grad_order == 0) { auto a_grad = _sparse_sparse_matmul(grad, b.conj().t()); - return a_grad.mul(mask_ones_like(a.coalesce())); + return sparse_mask_like_grad(a, a_grad); } auto b_grad = _sparse_sparse_matmul(a.conj().t(), grad); - return b_grad.mul(mask_ones_like(b.coalesce())); + return sparse_mask_like_grad(b, b_grad); } Tensor renorm_backward( diff --git a/torch/overrides.py b/torch/overrides.py index 11989e7da0df3..c51b35e20829f 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1331,6 +1331,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1, Tensor.sparse_dim: lambda self: -1, Tensor.sparse_mask: lambda self, mask: -1, + Tensor._sparse_mask_projection: lambda self, mask: -1, Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1, Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1, Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1, diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index b057634d2b47e..eb91a4985f0f1 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -126,6 +126,7 @@ def has_alias( "zero", "_sparse_addmm", "sparse_mask", + "_sparse_mask_projection", "_to_dense", "_coalesce", "_coalesced",