diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 18fcb354aedb..4c11d15e9d12 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -360,20 +360,14 @@ def discover_pjrt_plugins() -> None: if sys.version_info < (3, 10): # Use the backport library because it provides a forward-compatible # implementation. - try: - from importlib_metadata import entry_points - except ModuleNotFoundError: - logger.debug( - "No importlib_metadata found (for Python < 3.10): " - "Plugins advertised from entrypoints will not be found.") - entry_points = None + from importlib_metadata import entry_points else: from importlib.metadata import entry_points - if entry_points: - for entry_point in entry_points(group="jax_plugins"): - logger.debug("Discovered entry-point based JAX plugin: %s", - entry_point.value) - plugin_modules.add(entry_point.value) + + for entry_point in entry_points(group="jax_plugins"): + logger.debug("Discovered entry-point based JAX plugin: %s", + entry_point.value) + plugin_modules.add(entry_point.value) # Now load and initialize them all. for plugin_module_name in plugin_modules: diff --git a/setup.py b/setup.py index 75cf89dba95e..0bcb23b78bf7 100644 --- a/setup.py +++ b/setup.py @@ -67,6 +67,10 @@ def generate_proto(source): 'numpy>=1.21', 'opt_einsum', 'scipy>=1.7', + # Required by xla_bridge.discover_pjrt_plugins for forwards compat with + # Python versions < 3.10. Can be dropped when 3.10 is the minimum + # required Python version. + 'importlib_metadata>=4.6;python_version<"3.10"', ], extras_require={ # Minimum jaxlib version; used in testing. @@ -82,9 +86,7 @@ def generate_proto(source): # Cloud TPU VM jaxlib can be installed via: # $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 'tpu': [f'jaxlib=={_current_jaxlib_version}', - f'libtpu-nightly=={_libtpu_version}', - # Required by cloud_tpu_init.py - 'requests'], + f'libtpu-nightly=={_libtpu_version}'], # $ pip install jax[australis] 'australis': ['protobuf>=3.13,<4'],