34
34
# pylint: disable=g-import-not-at-top
35
35
from tensorflow_datasets .core import example_serializer
36
36
from tensorflow_datasets .core import features as features_lib
37
+ from tensorflow_datasets .core import file_adapters
37
38
from tensorflow_datasets .core import naming
38
39
from tensorflow_datasets .core import splits as splits_lib
39
40
from tensorflow_datasets .core import utils
@@ -410,14 +411,18 @@ def submit_split_generation(
410
411
generator : SplitGenerator ,
411
412
filename_template : naming .ShardedFileTemplate ,
412
413
disable_shuffling : bool ,
414
+ nondeterministic_order : bool ,
413
415
) -> _SplitInfoFuture :
414
416
"""Start the split generation.
415
417
416
418
Args:
417
- split_name: Name of the split to generate
418
- generator: Generator, beam.PTransform,... yielding the examples
419
+ split_name: Name of the split to generate.
420
+ generator: Generator, beam.PTransform,... yielding the examples.
419
421
filename_template: Template to format the filename for a shard.
420
- disable_shuffling: Specifies whether to shuffle the examples
422
+ disable_shuffling: Specifies whether to shuffle the examples.
423
+ nondeterministic_order: If True, it will not assure deterministic ordering
424
+ when writing' examples to disk. This might result in quicker dataset
425
+ preparation
421
426
422
427
Returns:
423
428
split_info_future: Future containing the `split_info`, once generation
@@ -433,13 +438,19 @@ def submit_split_generation(
433
438
# Depending on the type of generator, we use the corresponding
434
439
# `_build_from_xyz` method.
435
440
if isinstance (generator , Iterable ):
441
+ if nondeterministic_order :
442
+ logging .warning (
443
+ 'Enabling `nondeterministic_order` for a dataset that does not use'
444
+ ' beam has no effect.'
445
+ )
436
446
return self ._build_from_generator (** build_kwargs )
437
447
else : # Otherwise, beam required
438
448
unknown_generator_type = TypeError (
439
449
f'Invalid split generator value for split `{ split_name } `. '
440
450
'Expected generator or apache_beam object. Got: '
441
451
f'{ type (generator )} '
442
452
)
453
+ build_kwargs ['nondeterministic_order' ] = nondeterministic_order
443
454
if isinstance (generator , beam .PTransform ):
444
455
# Generate the beam.PCollection
445
456
pcollection = self .beam_pipeline | split_name >> generator
@@ -527,20 +538,35 @@ def _build_from_pcollection(
527
538
generator : 'beam.PCollection[KeyExample]' ,
528
539
filename_template : naming .ShardedFileTemplate ,
529
540
disable_shuffling : bool ,
541
+ nondeterministic_order : bool ,
530
542
) -> _SplitInfoFuture :
531
543
"""Split generator for `beam.PCollection`."""
532
544
# TODO(tfds): Should try to add support to `max_examples_per_split`
533
- beam_writer = writer_lib .BeamWriter (
534
- serializer = example_serializer .ExampleSerializer (
535
- self ._features .get_serialized_info ()
536
- ),
537
- filename_template = filename_template ,
538
- hash_salt = split_name ,
539
- disable_shuffling = disable_shuffling ,
540
- shard_config = self ._shard_config ,
541
- example_writer = self ._example_writer ,
542
- ignore_duplicates = self ._ignore_duplicates ,
545
+ serializer = example_serializer .ExampleSerializer (
546
+ self ._features .get_serialized_info ()
543
547
)
548
+ if nondeterministic_order :
549
+ logging .info (
550
+ 'Order of examples does not matter, using NoShuffleBeamWriter'
551
+ )
552
+ beam_writer = writer_lib .NoShuffleBeamWriter (
553
+ serializer = serializer ,
554
+ file_format = file_adapters .FileFormat .from_value (
555
+ filename_template .filetype_suffix
556
+ ),
557
+ filename_template = filename_template ,
558
+ )
559
+ else :
560
+ logging .info ('Deterministic ordering is enabled, using BeamWriter' )
561
+ beam_writer = writer_lib .BeamWriter (
562
+ serializer = serializer ,
563
+ filename_template = filename_template ,
564
+ hash_salt = split_name ,
565
+ disable_shuffling = disable_shuffling ,
566
+ shard_config = self ._shard_config ,
567
+ example_writer = self ._example_writer ,
568
+ ignore_duplicates = self ._ignore_duplicates ,
569
+ )
544
570
545
571
def _encode_example (key_ex , encode_fn = self ._features .encode_example ):
546
572
# We do not access self._features in this function to avoid pickling the
0 commit comments