@@ -2290,15 +2290,17 @@ def auto_train_steps(batch_size,
2290
2290
2291
2291
@gin .configurable
2292
2292
def get_checkpoint_iterator (checkpoint_step , model_dir , skip_until = 0 ,
2293
- stop_after = None ):
2293
+ stop_after = None , find_closest = True ):
2294
2294
"""Get an iterable of checkpoint paths from a provided checkpoint step(s).
2295
2295
2296
2296
Args:
2297
- checkpoint_step: If checkpoint_step is an int, find the checkpoint with the
2298
- closest global step and return a singleton list. If checkpoint_step is a
2299
- list of ints, replace each int with the path to the checkpoint with the
2300
- closest global step. If checkpoint_step == "all", return the path of every
2301
- checkpoint in model_dir, starting from the earliest checkpoint. If
2297
+ checkpoint_step: If checkpoint_step is an int, return a singleton list with
2298
+ that checkpoint path in it. If find_closest, the checkpoint with the
2299
+ closest global step will be reurned. If checkpoint_step is a
2300
+ list of ints, replace each int with its corresponding path (if
2301
+ find_closest, the path with the closest global step). If
2302
+ checkpoint_step == "all", return the path of every checkpoint in
2303
+ model_dir, starting from the earliest checkpoint. If
2302
2304
checkpoint_step == -1, return the latest checkpoint as specified in
2303
2305
model_dir/checkpoint. If checkpoint_step is None, return
2304
2306
`tf.train.checkpoints_iterator` for `model_dir`.
@@ -2308,6 +2310,9 @@ def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0,
2308
2310
stop_after: an optional integer - for "None behavior, if specified
2309
2311
stop after finding a checkpoint number that is >= stop_at. When a
2310
2312
checkpoint number == stop_at is found, it is yielded before exiting.
2313
+ find_closest: If True and a specified checkpoint step does not exist, will
2314
+ choose the nearest checkpoint to that step. If False, then will
2315
+ only look for a checkpoint matching the exact specified step.
2311
2316
2312
2317
Returns:
2313
2318
An iterable which yields checkpoint paths.
@@ -2338,6 +2343,10 @@ def _get_closest_checkpoint(target_checkpoint):
2338
2343
def _get_checkpoint_path (step ):
2339
2344
return os .path .join (model_dir , "model.ckpt-{}" .format (step ))
2340
2345
2346
+ def _get_checkpoint_path_if_exists (step ):
2347
+ path = _get_checkpoint_path (step )
2348
+ return path if tf .train .checkpoint_exists (path ) else None
2349
+
2341
2350
def _filter_fn (p ):
2342
2351
return get_step_from_checkpoint_path (p ) > skip_until
2343
2352
@@ -2363,11 +2372,22 @@ def _generate_checkpoints():
2363
2372
return _generate_checkpoints ()
2364
2373
else :
2365
2374
return checkpoints_iterator
2366
- elif isinstance (checkpoint_step , int ):
2367
- return [_get_checkpoint_path (_get_closest_checkpoint (checkpoint_step ))]
2375
+ elif find_closest :
2376
+ if isinstance (checkpoint_step , int ):
2377
+ return [_get_checkpoint_path (_get_closest_checkpoint (checkpoint_step ))]
2378
+ else :
2379
+ closests = np .unique (
2380
+ [_get_closest_checkpoint (c ) for c in checkpoint_step ])
2381
+ return [_get_checkpoint_path (closest ) for closest in closests ]
2368
2382
else :
2369
- closests = np .unique ([_get_closest_checkpoint (c ) for c in checkpoint_step ])
2370
- return [_get_checkpoint_path (closest ) for closest in closests ]
2383
+ if isinstance (checkpoint_step , int ):
2384
+ checkpoint_step = [checkpoint_step ]
2385
+ checkpoints = [_get_checkpoint_path_if_exists (c ) for c in checkpoint_step ]
2386
+ checkpoints = [c for c in checkpoints if c ]
2387
+ if not checkpoints :
2388
+ raise ValueError ("You asked for checkpoints '%s' but none were found." %
2389
+ str (checkpoint_step ))
2390
+ return checkpoints
2371
2391
2372
2392
2373
2393
# TODO(noam): provide a more informative string for layout_rules:
0 commit comments