Skip to content

Commit

Permalink
Add numba integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Jan 4, 2024
1 parent b9f3c8c commit b854d87
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions numba_dpex/tests/test_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numba as nb
import numpy as np

import numba_dpex


@nb.njit
def add_1(a):
return a + 1


def add_py(a, b):
return np.add(a, b)


add_jit = nb.njit(add_py)


def test_add1():
a = np.asarray([1j])
assert np.array_equal(a, np.asarray([1 + 1j]))


def test_add_py():
a = np.ones((10,), dtype=np.complex128)
assert np.array_equal(add_py(a, 1.5), np.full((10,), 2.5, dtype=a.dtype))


def test_add_jit():
a = np.ones((10,), dtype=np.complex128)
assert np.array_equal(add_jit(a, 1.5), np.full((10,), 2.5, dtype=a.dtype))

0 comments on commit b854d87

Please sign in to comment.