Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Jan 5, 2024
1 parent 8389dfe commit 8033cb6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
1 change: 1 addition & 0 deletions numba_dpex/core/targets/dpjit_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,6 @@ def load_additional_registries(self):
# loading CPU specific registries
super().load_additional_registries()

# TODO: do we need it?
def get_ufunc_info(self, ufunc_key):
return dpnp_ufunc_db.get_ufunc_info(ufunc_key)
9 changes: 5 additions & 4 deletions numba_dpex/core/targets/kernel_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,14 @@ def init(self):
self.data_model_manager = _init_data_model_manager()
self.extra_compile_options = dict()

from numba_dpex.dpnp_iface.dpnp_ufunc_db import (
_dpnp_ufunc_db,
_lazy_init_dpnp_db,
)
from numba_dpex.dpnp_iface.dpnp_ufunc_db import _lazy_init_dpnp_db

_lazy_init_dpnp_db()

# we need to import it after, because before init it is None and
# variable is passed by value
from numba_dpex.dpnp_iface.dpnp_ufunc_db import _dpnp_ufunc_db

self.ufunc_db = _dpnp_ufunc_db

def create_module(self, name):
Expand Down
12 changes: 10 additions & 2 deletions numba_dpex/dpnp_iface/dpnp_ufunc_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ def get_ufunc_info(ufunc_key):
def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
"""Populates the _dpnp_ufunc_db from Numba's NumPy ufunc_db"""

from numba.np.ufunc_db import _lazy_init_db, _ufunc_db
from numba.np.ufunc_db import _lazy_init_db

_lazy_init_db()

# we need to import it after, because before init it is None and
# variable is passed by value
from numba.np.ufunc_db import _ufunc_db

for ufuncop in dpnpdecl.supported_ufuncs:
if ufuncop == "erf":
op = getattr(dpnp, "erf")
Expand All @@ -77,7 +81,11 @@ def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db):
op.nargs = npop.nargs
op.types = npop.types
op.is_dpnp_ufunc = True
ufunc_db.update({op: copy.copy(_ufunc_db[npop])})
cp = copy.copy(_ufunc_db[npop])
if "'divide'" in str(npop):
# TODO: why do we need to do it only for divide???
ufunc_db.update({npop: cp})
ufunc_db.update({op: cp})
for key in list(ufunc_db[op].keys()):
if (
"FF->" in key
Expand Down

0 comments on commit 8033cb6

Please sign in to comment.