diff --git a/benchmarks/sparse/benchmark_semi_structured_sparsity.py b/benchmarks/sparse/benchmark_semi_structured_sparsity.py deleted file mode 100644 index c6753a95e6780..0000000000000 --- a/benchmarks/sparse/benchmark_semi_structured_sparsity.py +++ /dev/null @@ -1,245 +0,0 @@ -import random -import torch -import torch.utils.benchmark as benchmark -from torch import nn -from tqdm import tqdm -import pandas as pd -import argparse -from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - - -torch.set_printoptions( - precision=2, - threshold=None, - edgeitems=16, - linewidth=480, - profile=None, - sci_mode=False, -) - - -# helper model definition for pruner -class Model(nn.Module): - def __init__(self, m, k, dtype=None): - super().__init__() - # transposed so reversed - self.linear = nn.Linear(k, m) - - def forward(self, x): - return self.linear(x) - - -def rand_sparse_semi_structured_mask( - r, c, dtype=torch.float16, device="cuda", choice=None -): - """ - This function returns a 1:2 sparse matrix of size (r, c). - Note that this means this matrix will also be 2:4 and 4:8 sparse as well. - """ - - choices = [[0, 1], [1, 0]] - mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] - - return ( - torch.tensor(mask_entries, dtype=dtype, device=device) - .reshape(r, c) - .contiguous() - ) - - -def test_linear(m, k, n, dtype, contiguous, backend): - SparseSemiStructuredTensor.fuse_transpose = contiguous - mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype) - sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask - input_tensor = torch.zeros(n, k).to(dtype).cuda() - model = Model(m, k).to(dtype).cuda().eval() - - dense_measurement = benchmark.Timer( - stmt="model(input_tensor)", - globals=locals(), - ).blocked_autorange() - - dense_output = model(input_tensor) - - # sparsify weights - model.linear.weight = nn.Parameter(to_sparse_semi_structured(sparse_weight, mask=mask.bool())) - - sparse_output = model(input_tensor) - - sparse_measurement = benchmark.Timer( - stmt="model(input_tensor)", - globals=locals(), - ).blocked_autorange() - - correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) - - return { - "test_function": "linear", - "m": m, - "k": k, - "n": n, - "dtype": str(dtype), - "backend": backend, - "sparse_latency (ms)": sparse_measurement.median * 1000, - "dense_latency (ms)": dense_measurement.median * 1000, - "speedup (d/s)": dense_measurement.median / sparse_measurement.median, - "correct": correct, - "contiguous": sparse_output.is_contiguous(), - } - - -def test_tensor(m, k, n, dtype, contiguous, backend): - A = rand_sparse_semi_structured_mask(m, k, dtype=dtype) - B = torch.zeros(k, n).to(dtype).cuda() - bias = torch.rand(n).to(dtype).cuda() - - sA = to_sparse_semi_structured(A, mask=A.bool()) - - # torch.mm calculation - if dtype is not torch.int8: - dense_output = torch.mm(A, B) - - dense_measurement = benchmark.Timer( - stmt="torch.mm(A, B)", - globals=locals(), - ).blocked_autorange() - - else: - print("int8 baseline not supported") - dense_output = torch.mm(sA, B) - - dense_measurement = benchmark.Timer( - stmt="torch.mm(sA, B)", - globals=locals(), - ).blocked_autorange() - - sparse_output = torch.mm(sA, B) - sparse_measurement = benchmark.Timer( - stmt="torch.mm(sA, B)", - globals=locals(), - ).blocked_autorange() - - correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) - - return { - "test_function": "tensor", - "m": m, - "k": k, - "n": n, - "dtype": str(dtype), - "backend": backend, - "sparse_latency (ms)": sparse_measurement.median * 1000, - "dense_latency (ms)": dense_measurement.median * 1000, - "speedup (d/s)": dense_measurement.median / sparse_measurement.median, - "correct": correct, - "contiguous": sparse_output.is_contiguous(), - } - - -if __name__ == "__main__": - dtype_lookup = { - "int8": torch.int8, - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, - } - - parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks") - parser.add_argument( - "--mode", - type=str, - choices=[ - "nvidia-bert", - "nvidia-fixed-k", - "nvidia-fixed-mn", - ], - ) - parser.add_argument( - "--dtype", - type=str, - choices=dtype_lookup.keys(), - default="fp16", - ) - parser.add_argument( - "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt" - ) - parser.add_argument("-contiguous", action="store_true") - parser.add_argument("-e2e", action="store_true") - parser.add_argument("-save", action="store_true") - args = parser.parse_args() - - if args.e2e: - eval_fn = test_linear - else: - eval_fn = test_tensor - - print(f"Started benchmark: {args.mode} | dtype: {args.dtype}") - dtype = dtype_lookup[args.dtype] - - if args.mode == "nvidia-bert": - bert_shapes = [ - (3072, 1024, 16384), - (4096, 1024, 16384), - (1024, 1024, 16384), - (1024, 4096, 16384), - ] - results = ( - eval_fn(m, k, n, dtype, args.contiguous, args.backend) - for (m, k, n) in tqdm(bert_shapes) - ) - - elif args.mode == "nvidia-fixed-k": - mn_vals = [ - 3072, - 4096, - 5120, - 6144, - 7168, - 8192, - 9216, - 10240, - 11264, - 12288, - 13312, - 14336, - 15360, - 16384, - 17408, - 18432, - 19456, - 20480, - ] - results = ( - eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend) - for mn in tqdm(mn_vals) - ) - - elif args.mode == "nvidia-fixed-mn": - k_vals = [ - 2560, - 3840, - 5120, - 6400, - 7680, - 8960, - 10240, - 11520, - 12800, - 14080, - 15360, - 16640, - 17920, - 19200, - 20480, - ] - results = ( - eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend) - for k in tqdm(k_vals) - ) - - df = pd.DataFrame.from_records(results) - if args.save: - save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" - df.to_csv(save_file) - print(f"Finished benchmark: {args.mode} saved results to {save_file}") - print(df) diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index 364d457b70fe2..c273f74b8c0b6 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -24,7 +24,7 @@ matrices, pruned weights or points clouds by Tensors whose *elements are mostly zero valued*. We recognize these are important applications and aim to provide performance optimizations for these use cases via sparse storage formats. -Various sparse storage formats such as COO, CSR/CSC, semi-structured, LIL, etc. have been +Various sparse storage formats such as COO, CSR/CSC, LIL, etc. have been developed over the years. While they differ in exact layouts, they all compress data through efficient representation of zero valued elements. We call the uncompressed values *specified* in contrast to *unspecified*, @@ -67,8 +67,6 @@ indices of non-zero elements are stored in this case. PyTorch currently supports :ref:`COO`, :ref:`CSR`, :ref:`CSC`, :ref:`BSR`, and :ref:`BSC`. - -We also have a prototype implementation to support :ref: `semi-structured sparsity`. Please see the references for more details. Note that we provide slight generalizations of these formats. @@ -169,147 +167,6 @@ receiving a particular layout. We are working on an API to control the result la and recognize it is an important feature to plan a more optimal path of execution for any given model. -.. _sparse-semi-structured-docs: - -Sparse Semi-Structured Tensors -++++++++++++++++++++++++++++++ - -.. warning:: - - Sparse semi-sturctured tensors are currently a prototype feature and subject to change. Please feel free to open an issue to report a bug or if you have feedback to share. - -Semi-Structured sparsity is a sparse data layout that was first introduced in NVIDIA's Ampere architecture. It is also referred to as **fine-grained structured sparsity** or **2:4 structured sparsity**. - -This sparse layout stores `n` elements out of every `2n` elements, with `n` being determined by the width of the Tensor's data type (dtype). The most frequently used dtype is float16, where `n=2`, thus the term "2:4 structured sparsity." - -Semi-structured sparsity is explained in greater detail in `this NVIDIA blog post `_. - -In PyTorch, semi-structured sparsity is implemented via a Tensor subclass. -By subclassing, we can override ``__torch_dispatch__`` , allowing us to use faster sparse kernels when performing matrix multiplication. -We can also store the tensor in it's compressed form inside the subclass to reduce memory overhead. - -In this compressed form, the sparse tensor is stored by retaining only the *specified* elements and some metadata, which encodes the mask. - -.. note:: - The specified elements and metadata mask of a semi-structured sparse tensor are stored together in a single - flat compressed tensor. They are appended to each other to form a contiguous chunk of memory. - - compressed tensor = [ specified elements of original tensor | metadata_mask ] - - For an original tensor of size `(r, c)` we expect the first `m * k // 2` elements to be the kept elements - and the rest of the tensor is metadata. - - In order to make it easier for the user to view the specified elements - and mask, one can use ``.indices()`` and ``.values()`` to access the mask and specified elements respectively. - - - - ``.values()`` returns the specified elements in a tensor of size `(r, c//2)` and with the same dtype as the dense matrix. - - - ``.indices()`` returns the metadata_mask in a tensor of size `(r, c//2 )` and with element type ``torch.int16`` if dtype is torch.float16 and element type ``torch.int32`` if dtype is torch.int8. - - -For 2:4 sparse tensors, the metadata overhead is minor - just 2 bits per specified element. - -.. note:: - It's important to note that ``torch.float32`` is only supported for 1:2 sparsity. Therefore, it does not follow the same formula as above. - -Here, we break down how to calculate the compression ratio ( size dense / size sparse) of a 2:4 sparse tensor. - -Let `(r, c) = tensor.shape` and `e = bitwidth(tensor.dtype)`, so `e = 16` for ``torch.float16`` and ``torch.bfloat16`` and `e = 8` for ``torch.int8``. - -.. math:: - M_{dense} = r \times c \times e \\ - M_{sparse} = M_{specified} + M_{metadata} = r \times \frac{c}{2} \times e + r \times \frac{c}{2} \times 2 = \frac{rce}{2} + rc =rce(\frac{1}{2} +\frac{1}{e}) - -Using these calculations, we can determine the total memory footprint for both the original dense and the new sparse representation. - -This gives us a simple formula for the compression ratio, which is dependent only on the bitwidth of the tensor datatype. - -.. math:: - C = \frac{M_{sparse}}{M_{dense}} = \frac{1}{2} + \frac{1}{e} - -By using this formula, we find that the compression ratio is 56.25% for ``torch.float16`` and 62.5% for ``torch.int8``. - -Constructing Sparse Semi-Structured Tensors -------------------------------------------- - -You can transform a dense tensor into a sparse semi-structured tensor by using the ``torch.sparse.to_sparse_semi_structured`` function. - -Please also note that we only support CUDA tensors since hardware compatibility for semi-structured sparsity is limited to NVIDIA GPUs. - - -The following datatypes are supported for semi-structured sparsity. Note that each datatype has its own shape constraints and compression factor. - -.. csv-table:: - :header: "PyTorch dtype", "Shape Constraints", "Compression Factor", "Sparsity Pattern" - :widths: 15, 45, 10, 10 - :delim: ; - - ``torch.float16``; Tensor must be 2D and (r, c) must both be a positive multiple of 64;9/16;2:4 - ``torch.int8``; Tensor must be 2D and (r, c) must both be a positive multiple of 128;10/16;2:4 - - -To construct a semi-structured sparse tensor, start by creating a regular dense tensor that adheres to a 2:4 (or semi-structured) sparse format. -To do this we tile a small 1x4 strip to create a 16x16 dense float16 tensor. -Afterwards, we can call ``to_sparse_semi_structured`` on this matrix to compress it for accelerated inference. - - >>> from torch.sparse import to_sparse_semi_structured - >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() - tensor([[0., 0., 1., ..., 0., 1., 1.], - [0., 0., 1., ..., 0., 1., 1.], - [0., 0., 1., ..., 0., 1., 1.], - ..., - [0., 0., 1., ..., 0., 1., 1.], - [0., 0., 1., ..., 0., 1., 1.], - [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) - >>> A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - ..., - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], - [-4370, -4370, -4370, ..., -4370, -4370, -4370], - [-4370, -4370, -4370, ..., -4370, -4370, -4370], - ..., - [-4370, -4370, -4370, ..., -4370, -4370, -4370], - [-4370, -4370, -4370, ..., -4370, -4370, -4370], - [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', - dtype=torch.int16)) - -Sparse Semi-Structured Tensor Operations ----------------------------------------- - -Currently, the following operations are supported for semi-structured sparse tensors: - -- torch.addmm(bias, dense, sparse.t()) -- torch.mm(dense, sparse) -- torch.mm(sparse, dense) -- aten.linear.default(dense, sparse, bias) -- aten.t.default(sparse) -- aten.t.detach(sparse) - -To use these ops, simply pass the output of ``to_sparse_semi_structured(tensor)`` instead of using ``tensor`` once your tensor has 0s in a semi-structured sparse format, like this: - - >>> a = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).half().cuda() - >>> b = torch.rand(64, 64).half().cuda() - >>> c = torch.mm(a, b) - >>> a_sparse = to_sparse_semi_structured(a, mask=a.bool()) - >>> torch.allclose(c, torch.mm(a_sparse, b)) - True - -Under the hood, SparseSemiStructuredTensor will call ``torch._structured_sparse_linear`` for accelerated inference using CUTLASS sparse kernels. - -Accelerating nn.Linear with semi-structured sparsity ----------------------------------------------------- -You can accelerate the linear layers in your model if the weights are already semi-structured sparse with just a few lines of code: - - >>> input = torch.rand(64, 64).half().cuda() - >>> mask = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).cuda().bool() - >>> linear = nn.Linear(64, 64).half().cuda() - >>> linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=mask)) - .. _sparse-coo-docs: @@ -1135,18 +992,12 @@ multiplication, and ``@`` is matrix multiplication. :func:`torch.mv`;no; ``M[sparse_csr] @ V[strided] -> V[strided]`` :func:`torch.matmul`; no; ``M[sparse_coo] @ M[strided] -> M[strided]`` :func:`torch.matmul`; no; ``M[sparse_csr] @ M[strided] -> M[strided]`` - :func:`torch.matmul`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]`` - :func:`torch.matmul`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]`` :func:`torch.mm`; no; ``M[sparse_coo] @ M[strided] -> M[strided]`` - :func:`torch.mm`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]`` - :func:`torch.mm`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]`` :func:`torch.sparse.mm`; yes; ``M[sparse_coo] @ M[strided] -> M[strided]`` :func:`torch.smm`; no; ``M[sparse_coo] @ M[strided] -> M[sparse_coo]`` :func:`torch.hspmm`; no; ``M[sparse_coo] @ M[strided] -> M[hybrid sparse_coo]`` :func:`torch.bmm`; no; ``T[sparse_coo] @ T[strided] -> T[strided]`` :func:`torch.addmm`; no; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]`` - :func:`torch.addmm`; no; ``f * M[strided] + f * (M[SparseSemiStructured] @ M[strided]) -> M[strided]`` - :func:`torch.addmm`; no; ``f * M[strided] + f * (M[strided] @ M[SparseSemiStructured]) -> M[strided]`` :func:`torch.sparse.addmm`; yes; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]`` :func:`torch.sspaddmm`; no; ``f * M[sparse_coo] + f * (M[sparse_coo] @ M[strided]) -> M[sparse_coo]`` :func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]`` diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py deleted file mode 100644 index 7f2c813287d9e..0000000000000 --- a/test/test_sparse_semi_structured.py +++ /dev/null @@ -1,227 +0,0 @@ -# Owner(s): ["module: sparse"] -import random -import unittest - -import torch -from torch import nn - -from torch.sparse.semi_structured import ( - _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG, - SparseSemiStructuredTensor, - to_sparse_semi_structured, -) - -from torch.testing._internal.common_device_type import ( - dtypes, - instantiate_device_type_tests, -) - -from torch.testing._internal.common_dtype import all_types_and_complex - -from torch.testing._internal.common_utils import ( - parametrize, - run_tests, - subtest, - TestCase, -) - -SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys() - -_IS_SM8X = False -if torch.cuda.is_available(): - _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 - -def rand_sparse_semi_structured_mask( - r, c, dtype=torch.float16, device="cuda", choice=None -): - """ - This function returns a 1:2 sparse matrix of size (r, c). - Note that this means this matrix will also be 2:4 and 4:8 sparse as well. - """ - - choices = [[0, 1], [1, 0]] - mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] - - return ( - torch.tensor(mask_entries, dtype=dtype, device=device) - .reshape(r, c) - .contiguous() - ) - - -class TestSparseSemiStructured(TestCase): - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) - def test_to_sparse_semi_structured(self, dtype): - A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) - A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - - assert A.shape == A_sparse.shape - assert A.device == A_sparse.device - assert A.dtype == A_sparse.dtype - - assert isinstance(A, torch.Tensor) - assert isinstance(A_sparse, SparseSemiStructuredTensor) - - with self.assertRaisesRegex( - NotImplementedError, - "You must pass in a mask to to_sparse_semi_structured, currently mask=None.", - ): - A_sparse = to_sparse_semi_structured(A) - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) - def test_mm_sparse_first_NT(self, dtype, device): - """ - Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8 - Ensure torch.mm(A_sparse, B.t()) is correct - """ - A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) - A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - - B = torch.rand((128, 128), device=A_sparse.device).to(dtype) - - # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over - if dtype is torch.int8: - # This should fail - with self.assertRaisesRegex(RuntimeError, "_structured_sparse_linear"): - sparse_result = torch.mm(A_sparse, B) - - # test transpose - # NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior. - # CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor - dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32) - sparse_result = torch.mm(A_sparse, B.t()) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - else: - dense_result = torch.mm(A, B) - sparse_result = torch.mm(A_sparse, B) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - # test transpose - dense_result = torch.mm(A, B.t()) - sparse_result = torch.mm(A_sparse, B.t()) - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) - def test_mm_sparse_first_T(self, dtype, device): - """ - Ensure torch.mm(A_sparse.t(), B) throws error - """ - A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) - A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - - B = torch.rand((128, 128), device=A_sparse.device).to(dtype) - - with self.assertRaisesRegex( - NotImplementedError, - r"arg0: SparseSemiStructuredTensor\(.*transposed=True", - ): - torch.mm(A_sparse.t(), B) - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) - def test_mm_sparse_second_T(self, dtype, device): - """ - Ensure torch.mm(A, B_sparse.t()) is correct - """ - B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) - B_sparse = to_sparse_semi_structured(B, mask=B.bool()) - - A = torch.rand((128, 128), device=B_sparse.device).to(dtype) - - # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over - if dtype is torch.int8: - dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32) - sparse_result = torch.mm(A, B_sparse.t()) - else: - dense_result = torch.mm(A, B.t()) - sparse_result = torch.mm(A, B_sparse.t()) - - assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) - def test_mm_sparse_second_NT(self, dtype, device): - """ - Ensure torch.mm(A, B_sparse) throws error - """ - B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) - B_sparse = to_sparse_semi_structured(B, mask=B.bool()) - - A = torch.rand((128, 128), device=B_sparse.device).to(dtype) - - with self.assertRaisesRegex( - NotImplementedError, - r"arg1: SparseSemiStructuredTensor\(.*transposed=False", - ): - sparse_result = torch.mm(A, B_sparse) - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - @parametrize("inference_mode", [subtest(False), subtest(True)]) - def test_linear(self, inference_mode, device): - """ - Test nn.Linear has the same numerics - """ - input = torch.rand(128, 128, device=device).half() - model = nn.Linear(128, 128).to(device).half() - m, n = model.weight.shape - mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool) - # set masked weight - model.weight = nn.Parameter(model.weight * mask) - - dense_result = model(input) - model.weight = nn.Parameter(to_sparse_semi_structured(model.weight, mask=mask)) - - if inference_mode: - with torch.inference_mode(): - sparse_result = model(input) - else: - sparse_result = model(input) - - assert torch.allclose(dense_result, sparse_result, rtol=1e-5, atol=1e-5) - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - def test_values(self): - A = rand_sparse_semi_structured_mask(128, 128) - A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - assert A_sparse.values().shape == (128, 64) - assert (A_sparse.values() == 1).all() - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - def test_indices(self): - A = rand_sparse_semi_structured_mask(128, 128) - A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - assert A_sparse.indices().shape == (128, 8) - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) - def test_unsupported_shape(self, dtype, device): - A = rand_sparse_semi_structured_mask(4, 4, dtype=dtype, device=device) - with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"): - A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - @dtypes(*all_types_and_complex()) - def test_unsupported_dtype(self, dtype, device): - A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device) - - if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES: - with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"): - A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - else: - A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - - @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") - def test_unsupported_dim(self, device): - A = torch.rand(128, 128, 128, device=device, dtype=torch.float16) - - with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"): - A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - - -instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda") - -if __name__ == "__main__": - run_tests() diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index b1a91fb82c865..6f05dfbb22097 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -5,9 +5,6 @@ from torch._C import _add_docstr, _sparse # type: ignore[attr-defined] from torch import Tensor -# Semi structured sparsity support -from .semi_structured import SparseSemiStructuredTensor, to_sparse_semi_structured - # A workaround to support both TorchScript and MyPy: from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -26,10 +23,9 @@ 'sum', 'softmax', 'log_softmax', - 'SparseSemiStructuredTensor', - 'to_sparse_semi_structured', ] + addmm = _add_docstr(_sparse._sparse_addmm, r""" sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py deleted file mode 100644 index 1bf13c0a7089c..0000000000000 --- a/torch/sparse/semi_structured.py +++ /dev/null @@ -1,389 +0,0 @@ -import warnings -from collections import namedtuple -from typing import Any, Optional - -import torch - - -__all__ = [ - "to_sparse_semi_structured", - "SparseSemiStructuredTensor", -] - -_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple( - "_SEMI_STRUCTURED_SPARSE_CONFIG", "compression_factor min_size" -) -_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = { - torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(9, 64), - torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(10, 128), -} - -_WARNING_SHOWN = False - -class SparseSemiStructuredTensor(torch.Tensor): - """This class implementes semi-structured sparsity as a Tensor subclass. - - Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, - depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained - structured sparsity. - - Currently, this class supports 2:4 sparsity for int8 and float16 dtypes. - - This subclass stores the dense tensor in a compressed form by only storing the specified elemenets and a metadata mask. - These two are stored next to each other in one contiguous tensor. - - We choose to store the specified elements and the metadata in a single tensor for future compatibilty with cuSPARSELt. - - compressed tensor = [ specified elements of original tensor | mask_metadata ] - - For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements - The rest of the tensor is metadata. - - This subclass also overrides __torch_dispatch__ to use _structured_sparse_linear for faster matrix multiplications - via sparse CUTLASS kernels. In the future we will also call into cuSPARSELt kernels for more performance gains. - """ - - @staticmethod - def __new__( - cls, - original_tensor: Optional[torch.Tensor], - original_shape: Optional[torch.Size] = None, - mask: Optional[torch.Tensor] = None, - compressed_tensor: Optional[torch.Tensor] = None, - transposed: bool = False, - ): - """ - Create a new instance of the class. - - When original_tensor is passed in, we compress it and store the compresed representation. - We can also create new instance of the class from the compressed representation without the original tensor. - - Args: - original_tensor: The original dense tensor, or None, if we have already compressed the tensor. - original_shape: The shape of the original dense tensor - mask: Mask to be applied to the original tensor. - compressed_tensor: A flattened tensor to store the specified elements and mask metadata. - transposed: Whether the tensor is transposed or not. - - Returns: - torch.Tensor: A torch.Tensor wrapper subclass. - - Raises: - ValueError: If both original_tensor and compressed_tensor are None. - - """ - if original_tensor is not None: - previous_tensor = original_tensor - original_shape = original_tensor.shape - elif compressed_tensor is not None: - previous_tensor = compressed_tensor - else: - raise ValueError("Both compressed_tensor and original_tensor are None!") - - kwargs = {} - kwargs["device"] = previous_tensor.device # type: ignore[assignment] - kwargs["dtype"] = previous_tensor.dtype # type: ignore[assignment] - kwargs["layout"] = previous_tensor.layout # type: ignore[assignment] - kwargs["requires_grad"] = False # type: ignore[assignment] - - return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - original_tensor: Optional[torch.Tensor], - original_shape: Optional[torch.Size] = None, - mask: Optional[torch.Tensor] = None, - compressed_tensor: Optional[torch.Tensor] = None, - transposed: bool = False, - ) -> None: - """SparseSemiStructuredTensor constructor. - - Args: - original_tensor: The original dense tensor, or None, if we have already compressed the tensor. - original_shape: The shape of the original dense tensor - mask: Mask to be applied to the original tensor. - compressed_tensor: A flattened tensor to store the specified elements and mask metadata. - transposed: Whether the tensor is transposed or not. - - Returns: - None - - Raises: - NotImplementedError: If ``mask=None``, as we currently do not support inferring a mask from the dense tensor. - RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device. - """ - global _WARNING_SHOWN - if not _WARNING_SHOWN: - warnings.warn( - ( - "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " - "and will change in the near future. Please open a Github issue " - "for features requests and see our documentation on the torch.sparse " - "module for further information about the project." - ), - UserWarning, - ) - _WARNING_SHOWN = True - - # if original tensor is passed in, we need to compress it and store the compressed representation. - if original_tensor is not None: - # check if mask passed in - if mask is None: - raise NotImplementedError("You must pass in a mask to to_sparse_semi_structured, currently mask=None.") - - # check device - if not original_tensor.is_cuda: - raise RuntimeError( - ( - f"Error original_tensor.device= {original_tensor.device} is not supported! " - "Only CUDA tensors are currently supported." - ) - ) - - # check dim - if original_tensor.dim() != 2: - raise RuntimeError( - ( - f"Error original_tensor.dim = {original_tensor.dim()} is not supported! " - "Only 2d tensors are currently supported." - ) - ) - - # check dtype - if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG: - raise RuntimeError( - ( - f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! " - "dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}" - ) - ) - - # check shape - m, n = original_tensor.shape - min_size = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_size - if m < min_size or m % min_size or n < min_size or n % min_size: - # TODO in the future we can add in padding to support dimensions that aren't perfect multiples - raise RuntimeError( - ( - f"Error original_tensor.shape {original_tensor.shape} is not supported! " - "Both dimensions must be larger than and a multiple of {min_size}" - ) - ) - - # This code calculates the size of the compressed tensor. - # compression factor is different based on dtype it's given by the formula below for 2:4 sparsity: - # compression_factor = 1/2 + 1/bitwidth(dtype) - original_size = original_tensor.nelement() - compression_factor = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ - original_tensor.dtype - ].compression_factor - compressed_size = original_size * compression_factor // 16 - - compressed_tensor = torch.empty( - (compressed_size,), - dtype=original_tensor.dtype, - device=original_tensor.device, - ) - - # TODO This is a temporoary hack to get the mask in compressed form so we can store the compressed tensor. - # In the future, we will add in a conversion function from the mask to the meta that we can use instead. - placeholder = torch.ones( - (128, n), dtype=original_tensor.dtype, device=original_tensor.device - ) - specified = original_tensor.masked_select(mask).view(m, n // 2) - _, meta = torch._structured_sparse_linear(placeholder, specified, mask) - # set the specified elements - compressed_tensor[: m * n // 2] = specified.view(-1) - # set the metadata - compressed_tensor[m * n // 2 :] = meta.view(original_tensor.dtype).view(-1) - - # set values - self.original_tensor = None - self.compressed_tensor = compressed_tensor - self.transposed = transposed - - def __repr__(self) -> str: - """Return string representation of SparseSemiStructuredTensor - - Returns: - str: String representation - - Raises: - None - """ - return ( - f"SparseSemiStructuredTensor(shape={self.shape}, " - f"transposed={self.transposed}" - f"values={self.values()}" - f"metadata={self.indices()})" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: - """Overload __torch_dispatch__ to use torch._structured_sparse_linear. - - `torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels. - In the future we plan to also add in support for cuSPARSELt kernels. - - Args: - func: The function being dispatched. - types: The types of the arguments. - args: The arguments passed to the function. - kwargs: The keyword arguments passed to the function. - - Returns: - Any: The result of the dispatched operation. - - Raises: - NotImplementedError: If the dispatched operation is not implemented. - """ - # Since this code runs below autograd, a detach corresponds to only returning a new object - if func is torch.ops.aten.detach.default: - return SparseSemiStructuredTensor( - args[0].original_tensor, - original_shape=args[0].shape, - mask=None, - compressed_tensor=args[0].compressed_tensor, - transposed=args[0].transposed, - ) - - # Because we cannot go from the compressed representation back to the dense representation currently, - # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix - # is the first or second argument, we expect an even / odd number of calls to transpose respectively. - if func is torch.ops.aten.t.default: - return SparseSemiStructuredTensor( - args[0].original_tensor, - original_shape=args[0].shape, - mask=None, - compressed_tensor=args[0].compressed_tensor, - transposed=not args[0].transposed, - ) - - # handle addmm - if func is torch.ops.aten.addmm.default: - bias, input_A, input_B = args - - # Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS. - # CUTLASS only supports the first input to be sparse for a given matmul. - # cuSPARSELt does not have this limitation, although our implementation is only for sparse first. - - # We support second matrix sparse matmul by taking advantage of some transpose properties: - # This is also why we want an odd number of transposed for second matrix sparse vs an even number - # of transpose calss for first matrix sparse. - # F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')'' - # = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T - if isinstance(input_B, cls) and input_B.transposed: - result, _ = torch._structured_sparse_linear( - input_A, input_B.values(), input_B.indices(), bias=bias - ) - return result - - # handle mm - if func is torch.ops.aten.mm.default: - input_A, input_B = args - - if isinstance(input_A, cls) and not input_A.transposed: - transposed_result, _ = torch._structured_sparse_linear( - input_B.t(), input_A.values(), input_A.indices() - ) - return transposed_result.t() - - elif isinstance(input_B, cls) and input_B.transposed: - result, _ = torch._structured_sparse_linear( - input_A, input_B.values(), input_B.indices() - ) - return result - - # When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(), - # so we must match the aten.linear op. - # TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here. - if func is torch.ops.aten.linear.default: - input_tensor, weight, bias = args - if isinstance(weight, cls): - result, _ = torch._structured_sparse_linear( - input_tensor, weight.values(), weight.indices(), bias=bias - ) - return result - - # handle values - if func is torch.ops.aten.values.default: - m, k = args[0].shape - num_kept_elements = m * k // 2 - return args[0].compressed_tensor[:num_kept_elements].view(m, k // 2) - - # handle indices - if func is torch.ops.aten.indices.default: - m, k = args[0].shape - num_kept_elements = m * k // 2 - metadata = args[0].compressed_tensor[num_kept_elements:].view(m, -1) - - # the metadata is expected to be in different datatypes for fp16/int8 respectively for CUTLASS. - if args[0].dtype is torch.int8: - return metadata.view(torch.int32) - elif args[0].dtype is torch.float16: - return metadata.view(torch.int16) - - error_string = "\n".join( - [f"func {func} with args: "] - + [f"arg{i}: {arg}" for i, arg in enumerate(args)] - ) - raise NotImplementedError(error_string) - - -def to_sparse_semi_structured( - original_tensor: torch.Tensor, - mask: Optional[torch.Tensor] = None, - transposed: bool = False, -) -> SparseSemiStructuredTensor: - """ - This function converts a dense tensor into a sparse semi-structured tensor. - It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor. - - This function will check to ensure the dense tensor has the right dtype, size, dims, and device. - We currently only support semi-structured sparse tensors for 2d CUDA tensors. - Additionally, your tensor must be a positive multiple of a block size given the dtype - - - torch.float16 (r, c) must be >= and a multiple of 64 - - torch.int8 (r, c) must be >= and a multiple of 128 - - Args: - original_tensor (Tensor): the dense tensor to convert - mask (Optional BoolTensor): boolean mask to apply to the original tensor - transposed (bool, optional): whether the dense tensor is transposed - - Returns: - SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor and mask - - Raises: - None - - Example: - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) - >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() - tensor([[0., 0., 1., ..., 0., 1., 1.], - [0., 0., 1., ..., 0., 1., 1.], - [0., 0., 1., ..., 0., 1., 1.], - ..., - [0., 0., 1., ..., 0., 1., 1.], - [0., 0., 1., ..., 0., 1., 1.], - [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) - >>> A_sparse = to_sparse_semi_structured(A, mask=A.bool()) - SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - ..., - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), - metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], - [-4370, -4370, -4370, ..., -4370, -4370, -4370], - [-4370, -4370, -4370, ..., -4370, -4370, -4370], - ..., - [-4370, -4370, -4370, ..., -4370, -4370, -4370], - [-4370, -4370, -4370, ..., -4370, -4370, -4370], - [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', - dtype=torch.int16)) - """ - return SparseSemiStructuredTensor(original_tensor, mask=mask, transposed=transposed)