Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit e22cc2b

Browse files
daphneiMesh TensorFlow Team
authored andcommitted
Add option to only load from the specific checkpoints requested, instead of trying to find the nearest ones.
PiperOrigin-RevId: 360511110
1 parent 27613a9 commit e22cc2b

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,15 +2290,17 @@ def auto_train_steps(batch_size,
22902290

22912291
@gin.configurable
22922292
def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0,
2293-
stop_after=None):
2293+
stop_after=None, find_closest=True):
22942294
"""Get an iterable of checkpoint paths from a provided checkpoint step(s).
22952295
22962296
Args:
2297-
checkpoint_step: If checkpoint_step is an int, find the checkpoint with the
2298-
closest global step and return a singleton list. If checkpoint_step is a
2299-
list of ints, replace each int with the path to the checkpoint with the
2300-
closest global step. If checkpoint_step == "all", return the path of every
2301-
checkpoint in model_dir, starting from the earliest checkpoint. If
2297+
checkpoint_step: If checkpoint_step is an int, return a singleton list with
2298+
that checkpoint path in it. If find_closest, the checkpoint with the
2299+
closest global step will be reurned. If checkpoint_step is a
2300+
list of ints, replace each int with its corresponding path (if
2301+
find_closest, the path with the closest global step). If
2302+
checkpoint_step == "all", return the path of every checkpoint in
2303+
model_dir, starting from the earliest checkpoint. If
23022304
checkpoint_step == -1, return the latest checkpoint as specified in
23032305
model_dir/checkpoint. If checkpoint_step is None, return
23042306
`tf.train.checkpoints_iterator` for `model_dir`.
@@ -2308,6 +2310,9 @@ def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0,
23082310
stop_after: an optional integer - for "None behavior, if specified
23092311
stop after finding a checkpoint number that is >= stop_at. When a
23102312
checkpoint number == stop_at is found, it is yielded before exiting.
2313+
find_closest: If True and a specified checkpoint step does not exist, will
2314+
choose the nearest checkpoint to that step. If False, then will
2315+
only look for a checkpoint matching the exact specified step.
23112316
23122317
Returns:
23132318
An iterable which yields checkpoint paths.
@@ -2338,6 +2343,10 @@ def _get_closest_checkpoint(target_checkpoint):
23382343
def _get_checkpoint_path(step):
23392344
return os.path.join(model_dir, "model.ckpt-{}".format(step))
23402345

2346+
def _get_checkpoint_path_if_exists(step):
2347+
path = _get_checkpoint_path(step)
2348+
return path if tf.train.checkpoint_exists(path) else None
2349+
23412350
def _filter_fn(p):
23422351
return get_step_from_checkpoint_path(p) > skip_until
23432352

@@ -2363,11 +2372,22 @@ def _generate_checkpoints():
23632372
return _generate_checkpoints()
23642373
else:
23652374
return checkpoints_iterator
2366-
elif isinstance(checkpoint_step, int):
2367-
return [_get_checkpoint_path(_get_closest_checkpoint(checkpoint_step))]
2375+
elif find_closest:
2376+
if isinstance(checkpoint_step, int):
2377+
return [_get_checkpoint_path(_get_closest_checkpoint(checkpoint_step))]
2378+
else:
2379+
closests = np.unique(
2380+
[_get_closest_checkpoint(c) for c in checkpoint_step])
2381+
return [_get_checkpoint_path(closest) for closest in closests]
23682382
else:
2369-
closests = np.unique([_get_closest_checkpoint(c) for c in checkpoint_step])
2370-
return [_get_checkpoint_path(closest) for closest in closests]
2383+
if isinstance(checkpoint_step, int):
2384+
checkpoint_step = [checkpoint_step]
2385+
checkpoints = [_get_checkpoint_path_if_exists(c) for c in checkpoint_step]
2386+
checkpoints = [c for c in checkpoints if c]
2387+
if not checkpoints:
2388+
raise ValueError("You asked for checkpoints '%s' but none were found." %
2389+
str(checkpoint_step))
2390+
return checkpoints
23712391

23722392

23732393
# TODO(noam): provide a more informative string for layout_rules:

mesh_tensorflow/transformer/utils_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
2021
from absl.testing import absltest
2122
from absl.testing import parameterized
2223
import mesh_tensorflow as mtf

0 commit comments

Comments
 (0)