From 6d440456e4d0d474b07546c8cf06eb83409b2bf1 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Thu, 12 May 2022 11:44:16 -0700 Subject: [PATCH] Make mesh_tensorflow's call of `get_replicated_var_handle` backward-compatible with tf <= 2.8.0. Fixes https://github.com/google-research/text-to-text-transfer-transformer/issues/1020. PiperOrigin-RevId: 448301001 --- mesh_tensorflow/tpu_variables.py | 11 ++++++++--- setup.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mesh_tensorflow/tpu_variables.py b/mesh_tensorflow/tpu_variables.py index 06d5d191..570fb6c2 100644 --- a/mesh_tensorflow/tpu_variables.py +++ b/mesh_tensorflow/tpu_variables.py @@ -20,6 +20,7 @@ from __future__ import print_function import contextlib +import inspect # pylint: disable=g-direct-tensorflow-import from tensorflow.python.framework import ops @@ -82,9 +83,13 @@ def handle(self): if tpu_context is None: return self._primary_var.handle - # Using variable name as handle id. - return tpu_context.get_replicated_var_handle(self._name, self._name, - self._vars) + # TODO(adarob): Remove backward-compatibility when TF 2.10 is released. + if 'handle_id' not in inspect.signature( + tpu_context.get_replicated_var_handle).parameters: + return tpu_context.get_replicated_var_handle( + name=self._name, vars_=self._vars) + return tpu_context.get_replicated_var_handle( + name=self._name, handle_id=self._name, vars_=self._vars) @contextlib.contextmanager def _assign_dependencies(self): diff --git a/setup.py b/setup.py index baf8eefb..c07a8741 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='mesh-tensorflow', - version='0.1.20', + version='0.1.21', description='Mesh TensorFlow', author='Google Inc.', author_email='no-reply@google.com',