Skip to content

Commit

Permalink
Cleanup the @generated_jit leftovers (pygae#430)
Browse files Browse the repository at this point in the history
The test_function_cache is rewritten to use @njit, which may not make
much sense for particular test.
  • Loading branch information
trundev committed Nov 19, 2024
1 parent 9b301f1 commit c79c30f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 48 deletions.
11 changes: 0 additions & 11 deletions clifford/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,7 @@

# Major library imports.
import numpy as np
import numba as _numba # to avoid clashing with clifford.numba
import sparse
try:
from numba.np import numpy_support as _numpy_support
except ImportError:
import numba.numpy_support as _numpy_support


from clifford.io import write_ga_file, read_ga_file # noqa: F401
Expand Down Expand Up @@ -152,12 +147,6 @@ def get_mult_function(mt: sparse.COO, gradeList,
return _get_mult_function_runtime_sparse(mt)


def _get_mult_function_result_type(a: _numba.types.Type, b: _numba.types.Type, mt: np.dtype):
a_dt = _numpy_support.as_dtype(getattr(a, 'dtype', a))
b_dt = _numpy_support.as_dtype(getattr(b, 'dtype', b))
return np.result_type(a_dt, mt, b_dt)


def _get_mult_function(mt: sparse.COO):
"""
Get a function similar to `` lambda a, b: np.einsum('i,ijk,k->j', a, mt, b)``
Expand Down
28 changes: 0 additions & 28 deletions clifford/_numba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,39 +59,11 @@ def __repr__(self):
return "_pickleable_function({!r})".format(self.__func)


class _fake_generated_jit:
def __init__(self, f):
self.__cache = {}
self.__func = pickleable_function(f)
functools.update_wrapper(self, self.__func)

def __getnewargs_ex__(self):
return (self.__func,), {}

def __getstate__(self):
return {}

def __call__(self, *args):
arg_type = tuple(numba.typeof(arg) for arg in args)
try:
func = self.__cache[arg_type]
except KeyError:
func = self.__cache[arg_type] = self.__func(*arg_type)
return func(*args)


if not DISABLE_JIT:
njit = numba.njit
generated_jit = numba.generated_jit
else:
def njit(f=None, **kwargs):
if f is None:
return pickleable_function
else:
return pickleable_function(f)

def generated_jit(f=None, **kwargs):
if f is None:
return _fake_generated_jit
else:
return _fake_generated_jit(f)
14 changes: 5 additions & 9 deletions clifford/test/test_function_cache.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import numpy as np
from clifford._numba_utils import generated_jit
from clifford._numba_utils import njit
import pytest


@generated_jit(cache=True)
def foo(x):
from clifford.g3 import e3

def impl(x):
return (x * e3).value
return impl
@njit(cache=True)
def foo(x, y):
return (x * y).value


# Make the test fail on a failed cache warning
@pytest.mark.filterwarnings("error")
def test_function_cache():
from clifford.g3 import e3
np.testing.assert_array_equal((1.0*e3).value, foo(1.0))
np.testing.assert_array_equal((1.0*e3).value, foo(1.0, e3))

0 comments on commit c79c30f

Please sign in to comment.