Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op Diff #428

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .performance_utils import (
Benchmark,
Config,
GenericBenchmark,
GenericBenchmark2DOnly,
SkipVersion,
generate_tensor_input,
Expand Down Expand Up @@ -202,3 +203,19 @@ def count_nonzero_input_fn(shape, dtype, device):
dtypes=FLOAT_DTYPES,
)
bench.run()


@pytest.mark.diff
def test_perf_diff():
def diff_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
n = 1
yield inp, n, 0

bench = GenericBenchmark(
input_fn=diff_input_fn,
op_name="diff",
torch_op=torch.diff,
dtypes=FLOAT_DTYPES + INT_DTYPES,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("constant_pad_nd", constant_pad_nd, Autograd.disable),
("cumsum", cumsum, Autograd.disable),
("cummin", cummin, Autograd.disable),
("diff", diff, Autograd.disable),
("div.Tensor", true_divide, Autograd.disable),
("div.Scalar", true_divide, Autograd.disable),
("div.Tensor_mode", div_mode, Autograd.disable),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .diag import diag
from .diag_embed import diag_embed
from .diagonal import diagonal_backward
from .diff import diff
from .div import div_mode, floor_divide, remainder, true_divide
from .dropout import native_dropout
from .embedding import embedding
Expand Down Expand Up @@ -296,6 +297,7 @@
"logical_xor",
"logical_not",
"sort",
"diff",
"nll_loss_forward",
"nll_loss_backward",
"nll_loss2d_forward",
Expand Down
96 changes: 96 additions & 0 deletions src/flag_gems/ops/diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import triton
import triton.language as tl
from torch import Tensor, tensor

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import dim_compress, libentry, libtuner
from ..utils import triton_lang_extension as tle


@libentry()
@libtuner(
configs=runtime.get_tuned_config("diff_1d"),
key=["N"],
)
@triton.jit
def diff_kernel_1d(in_ptr, out_ptr, N, N_bound, BLOCK_DIFF: tl.constexpr):
pid = tle.program_id(0)

in_offsets = pid * BLOCK_DIFF + tl.arange(0, BLOCK_DIFF)
mask_in = in_offsets < N_bound - 1
in_block = tl.load(in_ptr + in_offsets, mask_in)
next_block = tl.load(in_ptr + in_offsets + 1, mask_in)
tl.store(out_ptr + in_offsets, next_block - in_block, mask_in)


@libentry()
@libtuner(
configs=runtime.get_tuned_config("diff"),
key=["M", "N"],
)
@triton.jit
def diff_kernel_2d(
in_ptr, out_ptr, M, N, N_bound, BLOCK_M: tl.constexpr, BLOCK_DIFF: tl.constexpr
):
pid_M = tle.program_id(0)
pid_diff = tle.program_id(1)

M_offsets = pid_M * BLOCK_M + tl.arange(0, BLOCK_M)
mask_M = M_offsets < M

in_offsets_diff = pid_diff * BLOCK_DIFF + tl.arange(0, BLOCK_DIFF)
mask_in_diff = in_offsets_diff < N_bound - 1
in_offsets = M_offsets[:, None] * N + in_offsets_diff
mask_in = mask_M[:, None] & mask_in_diff
out_offsets = M_offsets[:, None] * N + in_offsets_diff

in_block = tl.load(in_ptr + in_offsets, mask_in)
next_block = tl.load(in_ptr + in_offsets + 1, mask_in)
tl.store(out_ptr + out_offsets, next_block - in_block, mask_in)


def diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor:
if prepend is not None:
input = torch.cat([prepend, input], dim=dim)
if append is not None:
input = torch.cat([input, append], dim=dim)

if n <= 0:
return input

shape = list(input.shape)
dim = dim % input.ndim
reduce_len = shape[dim]

if n >= reduce_len:
empty_tensor = tensor([], dtype=input.dtype, device=input.device)
return torch.reshape(empty_tensor, shape[:dim] + [0] + shape[(dim + 1) :])

input = dim_compress(input, dim)
N = reduce_len
M = input.numel() // N

output = torch.zeros(input.shape, device=input.device, dtype=input.dtype)

n_steps = n
while n_steps > 0:
cur_in_diff_len = N - (n - n_steps)
if len(shape) == 1:
grid = lambda meta: (triton.cdiv(cur_in_diff_len, meta["BLOCK_DIFF"]),)
with torch_device_fn.device(input.device):
diff_kernel_1d[grid](input, output, N, cur_in_diff_len)
else:
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(cur_in_diff_len, meta["BLOCK_DIFF"]),
)
with torch_device_fn.device(input.device):
diff_kernel_2d[grid](input, output, M, N, cur_in_diff_len)
n_steps -= 1
input.copy_(output)

output = output[..., : (N - n)].contiguous()
output = torch.moveaxis(output, -1, dim)
return output
37 changes: 37 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2762,6 +2762,43 @@ index_select:
- 1024
- 2048
- 4096
diff_1d:
- gen: true
param_map:
META:
BLOCK_DIFF: block_diff
num_warps: warps
warps:
- 4
- 8
- 16
block_diff:
- 1
- 16
- 256
- 1024
diff:
- gen: true
param_map:
META:
BLOCK_M: block_m
BLOCK_DIFF: block_diff
num_warps: warps
warps:
- 4
- 8
- 16
block_m:
- 1
- 2
- 4
- 8
- 32
block_diff:
- 1
- 16
- 256
- 1024
layer_norm_persistent:
- gen: true
param_map:
Expand Down
37 changes: 37 additions & 0 deletions src/flag_gems/runtime/backend/_metax/tune_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,43 @@ index_select:
- 1024
- 2048
- 4096
diff_1d:
- gen: true
param_map:
META:
BLOCK_DIFF: block_diff
num_warps: warps
warps:
- 4
- 8
- 16
block_diff:
- 1
- 16
- 256
- 1024
diff:
- gen: true
param_map:
META:
BLOCK_M: block_m
BLOCK_DIFF: block_diff
num_warps: warps
warps:
- 4
- 8
- 16
block_m:
- 1
- 2
- 4
- 8
- 32
block_diff:
- 1
- 16
- 256
- 1024
layer_norm_persistent:
- gen: true
param_map:
Expand Down
37 changes: 37 additions & 0 deletions src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,43 @@ index_select:
- 1024
- 2048
- 4096
diff_1d:
- gen: true
param_map:
META:
BLOCK_DIFF: block_diff
num_warps: warps
warps:
- 4
- 8
- 16
block_diff:
- 1
- 16
- 256
- 1024
diff:
- gen: true
param_map:
META:
BLOCK_M: block_m
BLOCK_DIFF: block_diff
num_warps: warps
warps:
- 4
- 8
- 16
block_m:
- 1
- 2
- 4
- 8
- 32
block_diff:
- 1
- 16
- 256
- 1024
layer_norm_persistent:
- gen: true
param_map:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,3 +927,28 @@ def test_accuracy_depthwise2d(
inp, weight, kernel, bias=None, stride=stride, padding=padding, dilation=1
)
gems_assert_close(res_out, ref_out, dtype)


DIFF_N_VALUES = list(range(0, 10))


@pytest.mark.diff
@pytest.mark.parametrize("shape", [(1024**3,)] + REDUCTION_SHAPES)
@pytest.mark.parametrize("dim", DIM_LIST)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
@pytest.mark.parametrize("n", DIFF_N_VALUES)
def test_accuracy_diff(shape, dim, dtype, n):
if dtype in INT_DTYPES:
inp = torch.randint(
low=-10, high=11, size=shape, dtype=dtype, device=flag_gems.device
)
else:
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
ref_inp = to_reference(inp)

ref_out = torch.diff(ref_inp, n, dim % inp.ndim)
with flag_gems.use_gems():
res_out = torch.diff(inp, n, dim)

reduce_dim = shape[dim % inp.ndim]
gems_assert_close(res_out, ref_out, dtype, reduce_dim=reduce_dim, equal_nan=True)
Loading