From 0df7e226099023270683007c2e5a9c07eb9a3918 Mon Sep 17 00:00:00 2001 From: CelysPr Date: Mon, 13 Jan 2025 18:59:39 +0800 Subject: [PATCH 1/2] op diff: draft --- src/flag_gems/__init__.py | 1 + src/flag_gems/ops/__init__.py | 2 + src/flag_gems/ops/diff.py | 119 ++++++++++++++++++++++++++++++++++ tests/test_reduction_ops.py | 20 ++++++ 4 files changed, 142 insertions(+) create mode 100644 src/flag_gems/ops/diff.py diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 0ca58a4a..af831e85 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -40,6 +40,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): ("constant_pad_nd", constant_pad_nd, Autograd.disable), ("cumsum", cumsum, Autograd.disable), ("cummin", cummin, Autograd.disable), + ("diff", diff, Autograd.disable), ("div.Tensor", true_divide, Autograd.disable), ("div.Scalar", true_divide, Autograd.disable), ("div.Tensor_mode", div_mode, Autograd.disable), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 354a14cc..9c242392 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -30,6 +30,7 @@ from .diag import diag from .diag_embed import diag_embed from .diagonal import diagonal_backward +from .diff import diff from .div import div_mode, floor_divide, remainder, true_divide from .dropout import native_dropout from .embedding import embedding @@ -289,4 +290,5 @@ "logical_xor", "logical_not", "sort", + "diff", ] diff --git a/src/flag_gems/ops/diff.py b/src/flag_gems/ops/diff.py new file mode 100644 index 00000000..fe205367 --- /dev/null +++ b/src/flag_gems/ops/diff.py @@ -0,0 +1,119 @@ +import logging + +import torch +import triton +import triton.language as tl +from torch import Tensor, tensor + +from ..runtime import torch_device_fn +from ..utils import dim_compress, libentry +from ..utils import triton_lang_extension as tle + + +@libentry() +@triton.jit +def diff_kernel_1d(in_ptr, out_ptr, coeff_ptr, n: tl.constexpr): + pid = tle.program_id(0) + + coeff_offsets = tl.arange(0, triton.next_power_of_2(n + 1)) + in_offsets = pid + coeff_offsets + out_offset = pid + + mask_co_in = coeff_offsets < n + 1 + + in_block = tl.load(in_ptr + in_offsets, mask_co_in) + coeff = tl.load(coeff_ptr + coeff_offsets, mask_co_in) + result = tl.sum(in_block * coeff) + tl.store(out_ptr + out_offset, result) + + +@libentry() +@triton.jit +def diff_kernel_2d( + in_ptr, out_ptr, coeff_ptr, M, N, n: tl.constexpr, BLOCK_N: tl.constexpr +): + pid_diff = tle.program_id(1) + pid_n = tle.program_id(0) + + n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = n_offsets < M + + coeff_offsets = tl.arange(0, triton.next_power_of_2(n + 1)) + + in_offsets_diff = pid_diff + coeff_offsets + in_offsets = n_offsets[:, None] * N + in_offsets_diff + + out_offset_diff = pid_diff + out_offsets = n_offsets * tle.num_programs(1) + out_offset_diff + + mask_co_in = coeff_offsets < n + 1 + mask_in = mask_n[:, None] & mask_co_in + mask_out = mask_n + + in_block = tl.load(in_ptr + in_offsets, mask_in) + coeff = tl.load(coeff_ptr + coeff_offsets, mask_co_in) + result = tl.sum(in_block * coeff, axis=1) + tl.store(out_ptr + out_offsets, result, mask_out) + + +def bin_coeff(n, device): + # fg introduce errors + # coeff = torch.ones(n + 1, dtype=torch.int64, device=device) + coeff = [1] * (n + 1) + # coeff[-1] = 1 + for i in range(1, n + 1): + coeff[n - i] = coeff[n + 1 - i] * (n - i + 1) // i * (-1) + return torch.tensor(coeff, dtype=torch.int64, device=device) + + +def diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor: + if prepend is not None: + input = torch.cat([prepend, input], dim=dim) + if append is not None: + input = torch.cat([input, append], dim=dim) + + if n == 0: + return input + + max_len = input.shape[dim % input.ndim] + if n >= max_len: + logging.warning( + "Cannot conduct diff(input, n) with diff length = {} and n = {}.".format( + max_len, n + ) + ) + return tensor([], dtype=input.dtype, device=input.device) + + if input.ndim == 1: + shape = list(input.shape) + dim = dim % input.ndim + input = dim_compress(input, dim) + + output_diff_len = shape[dim] - n + output = torch.zeros(output_diff_len, device=input.device, dtype=input.dtype) + + coeff = bin_coeff(n, input.device) + + grid = [output_diff_len] + with torch_device_fn.device(input.device): + diff_kernel_1d[grid](input, output, coeff, n) + return output + + shape = list(input.shape) + dim = dim % input.ndim + input = dim_compress(input, dim) + N = shape[dim] + M = input.numel() // N + + output_diff_len = shape[dim] - n + output_shape = shape[:dim] + shape[(dim + 1) :] + [output_diff_len] + output = torch.zeros(output_shape, device=input.device, dtype=input.dtype) + + coeff = bin_coeff(n, input.device) + + block_size = 16 + grid = [triton.cdiv(M, block_size), output_diff_len] + with torch_device_fn.device(input.device): + diff_kernel_2d[grid](input, output, coeff, M, N, n, block_size) + output = torch.moveaxis(output, -1, dim) + return output diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 37f4ab2c..cbd186f4 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -874,3 +874,23 @@ def test_accuracy_depthwise2d( inp, weight, kernel, bias=None, stride=stride, padding=padding, dilation=1 ) gems_assert_close(res_out, ref_out, dtype) + + +DIFF_N_VALUES = list(range(0, 10)) + + +@pytest.mark.diff +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST + []) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("n", DIFF_N_VALUES) +def test_accuracy_diff(shape, dim, dtype, n): + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + ref_inp = to_reference(inp) + + ref_out = torch.diff(ref_inp, n, dim) + with flag_gems.use_gems(): + res_out = torch.diff(inp, n, dim) + + reduce_dim = shape[dim % inp.ndim] + gems_assert_close(res_out, ref_out, dtype, reduce_dim=reduce_dim) From dd95e7a928bc2ef7a06d2738d0f32ed7f6066878 Mon Sep 17 00:00:00 2001 From: CelysPr <128201073+CelysPr@users.noreply.github.com> Date: Fri, 17 Jan 2025 08:01:04 +0000 Subject: [PATCH 2/2] Op Diff Update [Recursive] --- benchmark/test_reduction_perf.py | 17 +++ src/flag_gems/ops/diff.py | 131 ++++++++---------- .../backend/_iluvatar/tune_configs.yaml | 37 +++++ .../runtime/backend/_metax/tune_configs.yaml | 37 +++++ .../runtime/backend/_nvidia/tune_configs.yaml | 37 +++++ tests/test_reduction_ops.py | 17 ++- 6 files changed, 193 insertions(+), 83 deletions(-) diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 6cc50495..5a52e259 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -10,6 +10,7 @@ from .performance_utils import ( Benchmark, Config, + GenericBenchmark, GenericBenchmark2DOnly, SkipVersion, generate_tensor_input, @@ -186,3 +187,19 @@ def count_nonzero_input_fn(shape, dtype, device): dtypes=FLOAT_DTYPES, ) bench.run() + + +@pytest.mark.diff +def test_perf_diff(): + def diff_input_fn(shape, cur_dtype, device): + inp = generate_tensor_input(shape, cur_dtype, device) + n = 1 + yield inp, n, 0 + + bench = GenericBenchmark( + input_fn=diff_input_fn, + op_name="diff", + torch_op=torch.diff, + dtypes=FLOAT_DTYPES + INT_DTYPES, + ) + bench.run() diff --git a/src/flag_gems/ops/diff.py b/src/flag_gems/ops/diff.py index fe205367..5752245f 100644 --- a/src/flag_gems/ops/diff.py +++ b/src/flag_gems/ops/diff.py @@ -1,69 +1,54 @@ -import logging - import torch import triton import triton.language as tl from torch import Tensor, tensor +from .. import runtime from ..runtime import torch_device_fn -from ..utils import dim_compress, libentry +from ..utils import dim_compress, libentry, libtuner from ..utils import triton_lang_extension as tle @libentry() +@libtuner( + configs=runtime.get_tuned_config("diff_1d"), + key=["N"], +) @triton.jit -def diff_kernel_1d(in_ptr, out_ptr, coeff_ptr, n: tl.constexpr): +def diff_kernel_1d(in_ptr, out_ptr, N, N_bound, BLOCK_DIFF: tl.constexpr): pid = tle.program_id(0) - coeff_offsets = tl.arange(0, triton.next_power_of_2(n + 1)) - in_offsets = pid + coeff_offsets - out_offset = pid - - mask_co_in = coeff_offsets < n + 1 - - in_block = tl.load(in_ptr + in_offsets, mask_co_in) - coeff = tl.load(coeff_ptr + coeff_offsets, mask_co_in) - result = tl.sum(in_block * coeff) - tl.store(out_ptr + out_offset, result) + in_offsets = pid * BLOCK_DIFF + tl.arange(0, BLOCK_DIFF) + mask_in = in_offsets < N_bound - 1 + in_block = tl.load(in_ptr + in_offsets, mask_in) + next_block = tl.load(in_ptr + in_offsets + 1, mask_in) + tl.store(out_ptr + in_offsets, next_block - in_block, mask_in) @libentry() +@libtuner( + configs=runtime.get_tuned_config("diff"), + key=["M", "N"], +) @triton.jit def diff_kernel_2d( - in_ptr, out_ptr, coeff_ptr, M, N, n: tl.constexpr, BLOCK_N: tl.constexpr + in_ptr, out_ptr, M, N, N_bound, BLOCK_M: tl.constexpr, BLOCK_DIFF: tl.constexpr ): + pid_M = tle.program_id(0) pid_diff = tle.program_id(1) - pid_n = tle.program_id(0) - - n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - mask_n = n_offsets < M - coeff_offsets = tl.arange(0, triton.next_power_of_2(n + 1)) + M_offsets = pid_M * BLOCK_M + tl.arange(0, BLOCK_M) + mask_M = M_offsets < M - in_offsets_diff = pid_diff + coeff_offsets - in_offsets = n_offsets[:, None] * N + in_offsets_diff - - out_offset_diff = pid_diff - out_offsets = n_offsets * tle.num_programs(1) + out_offset_diff - - mask_co_in = coeff_offsets < n + 1 - mask_in = mask_n[:, None] & mask_co_in - mask_out = mask_n + in_offsets_diff = pid_diff * BLOCK_DIFF + tl.arange(0, BLOCK_DIFF) + mask_in_diff = in_offsets_diff < N_bound - 1 + in_offsets = M_offsets[:, None] * N + in_offsets_diff + mask_in = mask_M[:, None] & mask_in_diff + out_offsets = M_offsets[:, None] * N + in_offsets_diff in_block = tl.load(in_ptr + in_offsets, mask_in) - coeff = tl.load(coeff_ptr + coeff_offsets, mask_co_in) - result = tl.sum(in_block * coeff, axis=1) - tl.store(out_ptr + out_offsets, result, mask_out) - - -def bin_coeff(n, device): - # fg introduce errors - # coeff = torch.ones(n + 1, dtype=torch.int64, device=device) - coeff = [1] * (n + 1) - # coeff[-1] = 1 - for i in range(1, n + 1): - coeff[n - i] = coeff[n + 1 - i] * (n - i + 1) // i * (-1) - return torch.tensor(coeff, dtype=torch.int64, device=device) + next_block = tl.load(in_ptr + in_offsets + 1, mask_in) + tl.store(out_ptr + out_offsets, next_block - in_block, mask_in) def diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor: @@ -72,48 +57,40 @@ def diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor: if append is not None: input = torch.cat([input, append], dim=dim) - if n == 0: + if n <= 0: return input - max_len = input.shape[dim % input.ndim] - if n >= max_len: - logging.warning( - "Cannot conduct diff(input, n) with diff length = {} and n = {}.".format( - max_len, n - ) - ) - return tensor([], dtype=input.dtype, device=input.device) - - if input.ndim == 1: - shape = list(input.shape) - dim = dim % input.ndim - input = dim_compress(input, dim) - - output_diff_len = shape[dim] - n - output = torch.zeros(output_diff_len, device=input.device, dtype=input.dtype) - - coeff = bin_coeff(n, input.device) - - grid = [output_diff_len] - with torch_device_fn.device(input.device): - diff_kernel_1d[grid](input, output, coeff, n) - return output - shape = list(input.shape) dim = dim % input.ndim + reduce_len = shape[dim] + + if n >= reduce_len: + empty_tensor = tensor([], dtype=input.dtype, device=input.device) + return torch.reshape(empty_tensor, shape[:dim] + [0] + shape[(dim + 1) :]) + input = dim_compress(input, dim) - N = shape[dim] + N = reduce_len M = input.numel() // N - output_diff_len = shape[dim] - n - output_shape = shape[:dim] + shape[(dim + 1) :] + [output_diff_len] - output = torch.zeros(output_shape, device=input.device, dtype=input.dtype) - - coeff = bin_coeff(n, input.device) + output = torch.zeros(input.shape, device=input.device, dtype=input.dtype) + + n_steps = n + while n_steps > 0: + cur_in_diff_len = N - (n - n_steps) + if len(shape) == 1: + grid = lambda meta: (triton.cdiv(cur_in_diff_len, meta["BLOCK_DIFF"]),) + with torch_device_fn.device(input.device): + diff_kernel_1d[grid](input, output, N, cur_in_diff_len) + else: + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(cur_in_diff_len, meta["BLOCK_DIFF"]), + ) + with torch_device_fn.device(input.device): + diff_kernel_2d[grid](input, output, M, N, cur_in_diff_len) + n_steps -= 1 + input.copy_(output) - block_size = 16 - grid = [triton.cdiv(M, block_size), output_diff_len] - with torch_device_fn.device(input.device): - diff_kernel_2d[grid](input, output, coeff, M, N, n, block_size) + output = output[..., : (N - n)].contiguous() output = torch.moveaxis(output, -1, dim) return output diff --git a/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml b/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml index c6cdcf47..ba4e4263 100644 --- a/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml @@ -2762,6 +2762,43 @@ index_select: - 1024 - 2048 - 4096 +diff_1d: +- gen: true + param_map: + META: + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_diff: + - 1 + - 16 + - 256 + - 1024 +diff: +- gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + - 32 + block_diff: + - 1 + - 16 + - 256 + - 1024 layer_norm_persistent: - gen: true param_map: diff --git a/src/flag_gems/runtime/backend/_metax/tune_configs.yaml b/src/flag_gems/runtime/backend/_metax/tune_configs.yaml index 170de35f..e1d29f50 100644 --- a/src/flag_gems/runtime/backend/_metax/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_metax/tune_configs.yaml @@ -500,6 +500,43 @@ index_select: - 1024 - 2048 - 4096 +diff_1d: +- gen: true + param_map: + META: + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_diff: + - 1 + - 16 + - 256 + - 1024 +diff: +- gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + - 32 + block_diff: + - 1 + - 16 + - 256 + - 1024 layer_norm_persistent: - gen: true param_map: diff --git a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml index 879d9474..77ad685f 100644 --- a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml @@ -493,6 +493,43 @@ index_select: - 1024 - 2048 - 4096 +diff_1d: +- gen: true + param_map: + META: + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_diff: + - 1 + - 16 + - 256 + - 1024 +diff: +- gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_DIFF: block_diff + num_warps: warps + warps: + - 4 + - 8 + - 16 + block_m: + - 1 + - 2 + - 4 + - 8 + - 32 + block_diff: + - 1 + - 16 + - 256 + - 1024 layer_norm_persistent: - gen: true param_map: diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index cbd186f4..acbc459d 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -880,17 +880,22 @@ def test_accuracy_depthwise2d( @pytest.mark.diff -@pytest.mark.parametrize("shape", REDUCTION_SHAPES) -@pytest.mark.parametrize("dim", DIM_LIST + []) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("shape", [(1024**3,)] + REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES) @pytest.mark.parametrize("n", DIFF_N_VALUES) def test_accuracy_diff(shape, dim, dtype, n): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in INT_DTYPES: + inp = torch.randint( + low=-10, high=11, size=shape, dtype=dtype, device=flag_gems.device + ) + else: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) - ref_out = torch.diff(ref_inp, n, dim) + ref_out = torch.diff(ref_inp, n, dim % inp.ndim) with flag_gems.use_gems(): res_out = torch.diff(inp, n, dim) reduce_dim = shape[dim % inp.ndim] - gems_assert_close(res_out, ref_out, dtype, reduce_dim=reduce_dim) + gems_assert_close(res_out, ref_out, dtype, reduce_dim=reduce_dim, equal_nan=True)