From 6e480507309bd95fe3e3f09dc4413d7faf7bd4fc Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 19 Feb 2021 15:01:55 -0500 Subject: [PATCH] Disable some jax2tf tests that fail on GPU. Fix TF/JAX array interoperability test on GPU. --- jax/experimental/jax2tf/tests/call_tf_test.py | 4 ++++ jax/experimental/jax2tf/tests/jax2tf_limitations.py | 7 +++++++ tests/array_interoperability_test.py | 8 +++++--- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 56e17e50927d..c537519d0272 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -175,6 +175,8 @@ def fun_tf(x, y): @parameterized_jit def test_with_var_read(self, with_jit=True): + if jtu.device_under_test() == "gpu": + raise unittest.SkipTest("Test fails on GPU") outer_var = tf.Variable(3., dtype=np.float32) def fun_tf(x): @@ -211,6 +213,8 @@ def fun_tf(x): @parameterized_jit def test_with_multiple_capture(self, with_jit=True): + if jtu.device_under_test() == "gpu": + raise unittest.SkipTest("Test fails on GPU") v2 = tf.Variable(2., dtype=np.float32) v3 = tf.Variable(3., dtype=np.float32) t4 = tf.constant(4., dtype=np.float32) diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 924da813b3ab..b2ac0bdc6517 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -313,6 +313,7 @@ def conv_general_dilated(cls, harness: primitive_harness.Harness): Jax2TfLimitation( "jax2tf BUG: batch_group_count > 1 not yet converted", enabled=(harness.params["batch_group_count"] > 1)), + missing_tf_kernel(dtypes=[np.complex64, np.complex128], devices="gpu"), custom_numeric(devices="gpu", tol=1e-4), custom_numeric(devices="tpu", tol=1e-3), # TODO(bchetioui): significant discrepancies in some float16 cases. @@ -723,6 +724,9 @@ def custom_assert(tst, result_jax, result_tf, *, args, tol): tst.assertAllClose(result_jax[~special_cases], result_tf[~special_cases]) return [ + # TODO(necula): Produces mismatched outputs on GPU. + Jax2TfLimitation("mismatched outputs on GPU", + devices=("gpu",), skip_comparison=True), missing_tf_kernel( dtypes=[dtypes.bfloat16, np.float16]), custom_numeric( @@ -758,6 +762,9 @@ def custom_assert(tst, result_jax, result_tf, *, args, tol): # noqa: F811 rtol=tol) return [ + # TODO(necula): Produces mismatched outputs on GPU. + Jax2TfLimitation("mismatched outputs on GPU", + devices=("gpu",), skip_comparison=True), missing_tf_kernel( dtypes=[dtypes.bfloat16, np.float16]), custom_numeric(dtypes=np.float64, tol=1e-9), diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 7dfb6c861cd9..3e4b43b595a9 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -93,17 +93,19 @@ def testJaxRoundTrip(self, shape, dtype, take_ownership): for dtype in dlpack_dtypes)) @unittest.skipIf(not tf, "Test requires TensorFlow") def testTensorFlowToJax(self, shape, dtype): - if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, - jnp.float64]: + if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]: raise self.skipTest("x64 types are disabled by jax_enable_x64") if (jtu.device_under_test() == "gpu" and not tf.config.list_physical_devices("GPU")): raise self.skipTest("TensorFlow not configured with GPU support") + if jtu.device_under_test() == "gpu" and dtype == jnp.int32: + raise self.skipTest("TensorFlow does not place int32 tensors on GPU") + rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"): - x = tf.constant(np) + x = tf.identity(tf.constant(np)) dlpack = tf.experimental.dlpack.to_dlpack(x) y = jax.dlpack.from_dlpack(dlpack) self.assertAllClose(np, y)