diff --git a/numba_dpex/core/targets/dpjit_target.py b/numba_dpex/core/targets/dpjit_target.py index 883a00bc5e..c79f3c3561 100644 --- a/numba_dpex/core/targets/dpjit_target.py +++ b/numba_dpex/core/targets/dpjit_target.py @@ -14,6 +14,8 @@ from numba.core.imputils import Registry from numba.core.target_extension import CPU, target_registry +from numba_dpex.dpnp_iface import dpnp_ufunc_db + class Dpex(CPU): pass @@ -57,3 +59,6 @@ def load_additional_registries(self): # loading CPU specific registries super().load_additional_registries() + + def get_ufunc_info(self, ufunc_key): + return dpnp_ufunc_db.get_ufunc_info(ufunc_key) diff --git a/numba_dpex/dpnp_iface/dpnp_ufunc_db.py b/numba_dpex/dpnp_iface/dpnp_ufunc_db.py index 348f213e8d..e5c33010b7 100644 --- a/numba_dpex/dpnp_iface/dpnp_ufunc_db.py +++ b/numba_dpex/dpnp_iface/dpnp_ufunc_db.py @@ -33,6 +33,20 @@ def get_ufuncs(): return _dpnp_ufunc_db.keys() +def get_ufunc_info(ufunc_key): + """get the lowering information for the ufunc with key ufunc_key. + + The lowering information is a dictionary that maps from a numpy + loop string (as given by the ufunc types attribute) to a function + that handles code generation for a scalar version of the ufunc + (that is, generates the "per element" operation"). + + raises a KeyError if the ufunc is not in the ufunc_db + """ + _lazy_init_dpnp_db() + return _dpnp_ufunc_db[ufunc_key] + + def _fill_ufunc_db_with_dpnp_ufuncs(ufunc_db): """Populates the _dpnp_ufunc_db from Numba's NumPy ufunc_db"""