Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit fb55a77

Browse files
crccwMesh TensorFlow Team
authored andcommitted
Change get_replicated_var_handle to accept resource tensors instead of variables
This avoid the needs to wrap packed tensor handle in a variable. This enables a ongoing refactoring of DistributedVariable. PiperOrigin-RevId: 371530563
1 parent f4755c4 commit fb55a77

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

mesh_tensorflow/tpu_variables.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121

2222
import contextlib
2323

24+
# pylint: disable=g-direct-tensorflow-import
25+
from tensorflow.python.compat import compat
2426
from tensorflow.python.framework import ops
2527
from tensorflow.python.ops import control_flow_ops
2628
from tensorflow.python.ops import gen_resource_variable_ops
2729

2830
try:
29-
from tensorflow.python.types import core # pylint:disable=g-import-not-at-top,g-direct-tensorflow-import
31+
from tensorflow.python.types import core # pylint:disable=g-import-not-at-top
3032
TF_23 = True
3133
except ImportError:
3234
TF_23 = False
@@ -80,8 +82,11 @@ def handle(self):
8082
tpu_context = _enclosing_tpu_context()
8183
if tpu_context is None:
8284
return self._primary_var.handle
83-
84-
return tpu_context.get_replicated_var_handle(self._name, self._vars)
85+
if compat.forward_compatible(2021, 4, 29):
86+
handles = [v.handle for v in self._vars]
87+
return tpu_context.get_replicated_var_handle(self._name, handles)
88+
else:
89+
return tpu_context.get_replicated_var_handle(self._name, self._vars)
8590

8691
@contextlib.contextmanager
8792
def _assign_dependencies(self):

0 commit comments

Comments
 (0)