Skip to content
This repository has been archived by the owner on Jan 21, 2025. It is now read-only.

Commit

Permalink
Make mesh_tensorflow's call of get_replicated_var_handle backward-c…
Browse files Browse the repository at this point in the history
…ompatible with tf <= 2.8.0. Fixes google-research/text-to-text-transfer-transformer#1020.

PiperOrigin-RevId: 448301001
  • Loading branch information
adarob authored and Mesh TensorFlow Team committed May 12, 2022
1 parent b631ad5 commit 6d44045
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
11 changes: 8 additions & 3 deletions mesh_tensorflow/tpu_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import print_function

import contextlib
import inspect

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops
Expand Down Expand Up @@ -82,9 +83,13 @@ def handle(self):
if tpu_context is None:
return self._primary_var.handle

# Using variable name as handle id.
return tpu_context.get_replicated_var_handle(self._name, self._name,
self._vars)
# TODO(adarob): Remove backward-compatibility when TF 2.10 is released.
if 'handle_id' not in inspect.signature(
tpu_context.get_replicated_var_handle).parameters:
return tpu_context.get_replicated_var_handle(
name=self._name, vars_=self._vars)
return tpu_context.get_replicated_var_handle(
name=self._name, handle_id=self._name, vars_=self._vars)

@contextlib.contextmanager
def _assign_dependencies(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='mesh-tensorflow',
version='0.1.20',
version='0.1.21',
description='Mesh TensorFlow',
author='Google Inc.',
author_email='[email protected]',
Expand Down

0 comments on commit 6d44045

Please sign in to comment.