Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds all fetch_* SPIR-V overload to experimental #1261

Merged
merged 1 commit into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def gen(context, builder, sig, args):
"--spirv-ext=+SPV_EXT_shader_atomic_float_add"
]

context.extra_compile_options[LLVM_SPIRV_ARGS] = [
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max"
]

ptr_type = retty.as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space

Expand Down Expand Up @@ -118,6 +122,59 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add")


def _atomic_sub_float_wrapper(gen_fn):
def gen(context, builder, sig, args):
# args is a tuple, which is immutable
# covert tuple to list obj first before replacing arg[1]
# with fneg and convert back to tuple again.
args_lst = list(args)
args_lst[1] = builder.fneg(args[1])
args = tuple(args_lst)

gen_fn(context, builder, sig, args)

return gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val):
if ty_atomic_ref.dtype in (types.float32, types.float64):
# dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub
# for floats is implemented by negating the value and calling fetch_add.
# For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val).
sig, gen = _intrinsic_helper(
ty_context, ty_atomic_ref, ty_val, "fetch_add"
)
return sig, _atomic_sub_float_wrapper(gen)

return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_min(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_min")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_max(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_max")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_and(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_and")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_or(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_or")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_xor(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_xor")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_atomic_ref_ctor(
ty_context, ref, ty_index, ty_retty_ref # pylint: disable=unused-argument
Expand Down Expand Up @@ -294,3 +351,168 @@ def ol_fetch_add_impl(atomic_ref, val):
return _intrinsic_fetch_add(atomic_ref, val)

return ol_fetch_add_impl


@overload_method(AtomicRefType, "fetch_sub", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_sub(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_sub`.

Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_sub` function.

Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to sub: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_sub_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_sub(atomic_ref, val)

return ol_fetch_sub_impl


@overload_method(AtomicRefType, "fetch_min", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_min(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_min`.

Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_min` function.

Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to find min: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_min_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_min(atomic_ref, val)

return ol_fetch_min_impl


@overload_method(AtomicRefType, "fetch_max", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_max(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_max`.

Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_max` function.

Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to find max: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_max_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_max(atomic_ref, val)

return ol_fetch_max_impl


@overload_method(AtomicRefType, "fetch_and", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_and(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_and`.

Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_and` function.

Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to and: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_and operation only supported on int32 and int64 dtypes."
)

def ol_fetch_and_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_and(atomic_ref, val)

return ol_fetch_and_impl


@overload_method(AtomicRefType, "fetch_or", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_or(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_or`.

Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_or` function.

Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to or: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_or operation only supported on int32 and int64 dtypes."
)

def ol_fetch_or_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_or(atomic_ref, val)

return ol_fetch_or_impl


@overload_method(AtomicRefType, "fetch_xor", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_xor(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_xor`.

Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_xor` function.

Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to xor: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_xor operation only supported on int32 and int64 dtypes."
)

def ol_fetch_xor_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_xor(atomic_ref, val)

return ol_fetch_xor_impl
5 changes: 4 additions & 1 deletion numba_dpex/spirv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def finalize(self):
# TODO: find better approach to set SPIRV compiler arguments. Workaround
# against caching intrinsic that sets this argument.
# https://github.com/IntelPython/numba-dpex/issues/1262
llvm_spirv_args = ["--spirv-ext=+SPV_EXT_shader_atomic_float_add"]
llvm_spirv_args = [
"--spirv-ext=+SPV_EXT_shader_atomic_float_add",
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max",
]
for key in list(self.context.extra_compile_options.keys()):
if key == LLVM_SPIRV_ARGS:
llvm_spirv_args = self.context.extra_compile_options[key]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import dpnp
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
Expand All @@ -14,30 +15,80 @@
no_bool=True, no_float16=True, no_none=True, no_complex=True
)

list_of_fetch_phi_funcs = [
"fetch_add",
"fetch_sub",
"fetch_min",
"fetch_max",
"fetch_and",
"fetch_or",
"fetch_xor",
]


@pytest.fixture(params=list_of_fetch_phi_funcs)
def fetch_phi_fn(request):
return request.param


@pytest.fixture(params=list_of_supported_dtypes)
def input_arrays(request):
# The size of input and out arrays to be used
N = 10
a = dpnp.ones(N, dtype=request.param)
b = dpnp.zeros(N, dtype=request.param)
a = dpnp.arange(N, dtype=request.param)
b = dpnp.ones(N, dtype=request.param)
return a, b


@pytest.mark.parametrize("ref_index", [0, 5])
def test_fetch_add(input_arrays, ref_index):
def test_fetch_phi_fn(input_arrays, ref_index, fetch_phi_fn):
"""A test for all fetch_phi atomic functions."""

@dpex_exp.kernel
def atomic_ref_kernel(a, b, ref_index):
def _kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_add(a[i])
getattr(v, fetch_phi_fn)(a[i])

a, b = input_arrays

dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b, ref_index)
if (
fetch_phi_fn in ["fetch_and", "fetch_or", "fetch_xor"]
and issubclass(a.dtype.type, dpnp.floating)
and issubclass(b.dtype.type, dpnp.floating)
):
# fetch_and, fetch_or, fetch_xor accept only int arguments.
# test for TypingError when float arguments are passed.
with pytest.raises(TypingError):
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
else:
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
# Verify that `a` accumulated at b[ref_index] by kernel
# matches the `a` accumulated at b[ref_index+1] using Python
for i in range(a.size):
v = AtomicRef(b, index=ref_index + 1)
getattr(v, fetch_phi_fn)(a[i])

assert b[ref_index] == b[ref_index + 1]


def test_fetch_phi_diff_types(fetch_phi_fn):
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and value to be added are of different types.
"""

@dpex_exp.kernel
def _kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0)
getattr(v, fetch_phi_fn)(a[i])

N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.int32)

# Verify that `a` was accumulated at b[ref_index]
assert b[ref_index] == 10
with pytest.raises(TypingError):
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b)


@dpex_exp.kernel
Expand All @@ -54,7 +105,7 @@ def atomic_ref_1(a):
v.fetch_add(a[i + 2])


def test_spirv_compiler_flags():
def test_spirv_compiler_flags_add():
"""Check if float atomic flag is being populated from intrinsic for the
second call.

Expand All @@ -68,3 +119,36 @@ def test_spirv_compiler_flags():

assert a[0] == N - 1
assert a[1] == N - 1


@dpex_exp.kernel
def atomic_max_0(a):
i = dpex.get_global_id(0)
v = AtomicRef(a, index=0)
if i != 0:
v.fetch_max(a[i])


@dpex_exp.kernel
def atomic_max_1(a):
i = dpex.get_global_id(0)
v = AtomicRef(a, index=0)
if i != 0:
v.fetch_max(a[i])


def test_spirv_compiler_flags_max():
"""Check if float atomic flag is being populated from intrinsic for the
second call.

https://github.com/IntelPython/numba-dpex/issues/1262
"""
N = 10
a = dpnp.arange(N, dtype=dpnp.float32)
b = dpnp.arange(N, dtype=dpnp.float32)

dpex_exp.call_kernel(atomic_max_0, dpex.Range(N), a)
dpex_exp.call_kernel(atomic_max_1, dpex.Range(N), b)

assert a[0] == N - 1
assert b[0] == N - 1
Loading