diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index d115e6efb11f..eb0f1b02b67c 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -21,6 +21,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax.interpreters import xla from jax._src.config import config @@ -99,7 +100,7 @@ def test_register_plugin(self): with mock.patch.object( xc, "load_pjrt_plugin_dynamically", autospec=True ) as mock_load_plugin: - if xc._version >= 152: + if xla_extension_version >= 152: with mock.patch.object( xc, "pjrt_plugin_loaded", autospec=True ) as mock_plugin_loaded: @@ -115,18 +116,13 @@ def test_register_plugin(self): self.assertIn("name1", xb._backend_factories) self.assertIn("name2", xb._backend_factories) self.assertEqual(priotiy, 400) - if xc._version >= 152: + if xla_extension_version >= 152: mock_plugin_loaded.assert_called_once_with("name1") else: mock_load_plugin.assert_called_once_with("name1", "path1") - if xc._version >= 134: - mock_make.assert_called_once_with("name1", None) - else: - mock_make.assert_called_once_with("name1") + mock_make.assert_called_once_with("name1", None) def test_register_plugin_with_config(self): - if xc._version < 134: - return test_json_file_path = os.path.join( os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json" ) @@ -137,7 +133,7 @@ def test_register_plugin_with_config(self): with mock.patch.object( xc, "load_pjrt_plugin_dynamically", autospec=True ) as mock_load_plugin: - if xc._version >= 152: + if xla_extension_version >= 152: with mock.patch.object( xc, "pjrt_plugin_loaded", autospec=True ) as mock_plugin_loaded: @@ -147,7 +143,7 @@ def test_register_plugin_with_config(self): self.assertIn("name1", xb._backend_factories) self.assertEqual(priority, 400) - if xc._version >= 152: + if xla_extension_version >= 152: mock_plugin_loaded.assert_called_once_with("name1") else: mock_load_plugin.assert_called_once_with(