Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 5d16bca

Browse files
author
Mesh TensorFlow Team
committed
Change default checkpoint saving dtype to float32 instead of bfloat16. Saving
PiperOrigin-RevId: 363758862
1 parent 441ff47 commit 5d16bca

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def parse_gin_defaults_and_flags(skip_unknown=False, finalize_config=True):
8787
# this stupid VariableDtype class and stop passing it all over creation.
8888
@gin.configurable
8989
def get_variable_dtype(
90-
master_dtype=tf.bfloat16,
90+
master_dtype=tf.float32,
9191
slice_dtype=tf.float32,
9292
activation_dtype=tf.float32):
9393
"""Datatypes to use for the run.

0 commit comments

Comments
 (0)