Skip to content

Commit

Permalink
Merge pull request jax-ml#5791 from hawkinsp:jax2tf2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 358459171
  • Loading branch information
jax authors committed Feb 19, 2021
2 parents c5bfdcc + 6e48050 commit 15a2f50
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
4 changes: 4 additions & 0 deletions jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def fun_tf(x, y):

@parameterized_jit
def test_with_var_read(self, with_jit=True):
if jtu.device_under_test() == "gpu":
raise unittest.SkipTest("Test fails on GPU")
outer_var = tf.Variable(3., dtype=np.float32)

def fun_tf(x):
Expand Down Expand Up @@ -211,6 +213,8 @@ def fun_tf(x):

@parameterized_jit
def test_with_multiple_capture(self, with_jit=True):
if jtu.device_under_test() == "gpu":
raise unittest.SkipTest("Test fails on GPU")
v2 = tf.Variable(2., dtype=np.float32)
v3 = tf.Variable(3., dtype=np.float32)
t4 = tf.constant(4., dtype=np.float32)
Expand Down
7 changes: 7 additions & 0 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def conv_general_dilated(cls, harness: primitive_harness.Harness):
Jax2TfLimitation(
"jax2tf BUG: batch_group_count > 1 not yet converted",
enabled=(harness.params["batch_group_count"] > 1)),
missing_tf_kernel(dtypes=[np.complex64, np.complex128], devices="gpu"),
custom_numeric(devices="gpu", tol=1e-4),
custom_numeric(devices="tpu", tol=1e-3),
# TODO(bchetioui): significant discrepancies in some float16 cases.
Expand Down Expand Up @@ -723,6 +724,9 @@ def custom_assert(tst, result_jax, result_tf, *, args, tol):
tst.assertAllClose(result_jax[~special_cases], result_tf[~special_cases])

return [
# TODO(necula): Produces mismatched outputs on GPU.
Jax2TfLimitation("mismatched outputs on GPU",
devices=("gpu",), skip_comparison=True),
missing_tf_kernel(
dtypes=[dtypes.bfloat16, np.float16]),
custom_numeric(
Expand Down Expand Up @@ -758,6 +762,9 @@ def custom_assert(tst, result_jax, result_tf, *, args, tol): # noqa: F811
rtol=tol)

return [
# TODO(necula): Produces mismatched outputs on GPU.
Jax2TfLimitation("mismatched outputs on GPU",
devices=("gpu",), skip_comparison=True),
missing_tf_kernel(
dtypes=[dtypes.bfloat16, np.float16]),
custom_numeric(dtypes=np.float64, tol=1e-9),
Expand Down
8 changes: 5 additions & 3 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,19 @@ def testJaxRoundTrip(self, shape, dtype, take_ownership):
for dtype in dlpack_dtypes))
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testTensorFlowToJax(self, shape, dtype):
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
jnp.float64]:
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]:
raise self.skipTest("x64 types are disabled by jax_enable_x64")
if (jtu.device_under_test() == "gpu" and
not tf.config.list_physical_devices("GPU")):
raise self.skipTest("TensorFlow not configured with GPU support")

if jtu.device_under_test() == "gpu" and dtype == jnp.int32:
raise self.skipTest("TensorFlow does not place int32 tensors on GPU")

rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"):
x = tf.constant(np)
x = tf.identity(tf.constant(np))
dlpack = tf.experimental.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)
self.assertAllClose(np, y)
Expand Down

0 comments on commit 15a2f50

Please sign in to comment.