Skip to content

Commit ed6a581

Browse files
author
The TensorFlow Datasets Authors
committed
Use the new NoShuffleBeamWriter when download_config.nondeterministic_order is True and save it to dataset_info proto for documentation.
PiperOrigin-RevId: 693628830
1 parent ed9de76 commit ed6a581

File tree

3 files changed

+85
-35
lines changed

3 files changed

+85
-35
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,10 @@ def download_and_prepare(
622622
data_path = self.data_path
623623
data_exists = data_path.exists()
624624

625+
# Saving nondeterministic_order in the DatasetInfo for documentation.
626+
if download_config.nondeterministic_order:
627+
self.info.set_nondeterministic_order(True)
628+
625629
if download_config.download_mode == UPDATE_DATASET_INFO:
626630
self._update_dataset_info()
627631
return
@@ -1426,6 +1430,8 @@ def _get_filename_template(
14261430
self, split_name: str
14271431
) -> naming.ShardedFileTemplate:
14281432
"""Returns a filename template for the given split."""
1433+
if self.info.file_format is None:
1434+
raise ValueError("File format is not set!")
14291435
return naming.ShardedFileTemplate(
14301436
split=split_name,
14311437
dataset_name=self.name,
@@ -1728,6 +1734,7 @@ def _generate_splits(
17281734
generator=generator,
17291735
filename_template=filename_template,
17301736
disable_shuffling=self.info.disable_shuffling,
1737+
nondeterministic_order=download_config.nondeterministic_order,
17311738
)
17321739
split_info_futures.append(future)
17331740

tensorflow_datasets/core/dataset_builder_beam_test.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ class DummyBeamDataset(dataset_builder.GeneratorBasedBuilder):
3939
'valid_725': 725,
4040
}
4141

42+
FEATURE_DICT = features.FeaturesDict({
43+
'image': features.Image(shape=(16, 16, 1)),
44+
'label': features.ClassLabel(names=['dog', 'cat']),
45+
'id': tf.int32,
46+
})
47+
4248
def _info(self):
4349
return dataset_info.DatasetInfo(
4450
builder=self,
45-
features=features.FeaturesDict({
46-
'image': features.Image(shape=(16, 16, 1)),
47-
'label': features.ClassLabel(names=['dog', 'cat']),
48-
'id': tf.int32,
49-
}),
51+
features=self.FEATURE_DICT,
5052
supervised_keys=('x', 'x'),
5153
metadata=dataset_info.BeamMetadataDict(),
5254
)
@@ -71,6 +73,18 @@ def _generate_examples(self, num_examples):
7173
return examples
7274

7375

76+
class UnshuffledDummyBeamDataset(DummyBeamDataset):
77+
78+
def _info(self) -> dataset_info.DatasetInfo:
79+
return dataset_info.DatasetInfo(
80+
builder=self,
81+
features=self.FEATURE_DICT,
82+
supervised_keys=('x', 'x'),
83+
metadata=dataset_info.BeamMetadataDict(),
84+
disable_shuffling=True,
85+
)
86+
87+
7488
class CommonPipelineDummyBeamDataset(DummyBeamDataset):
7589
EXPECTED_METADATA = {
7690
'label_sum_1000': 500,
@@ -151,12 +165,21 @@ def _compute_mean(examples):
151165
)
152166

153167

168+
def get_id(ex):
169+
return ex['id']
170+
171+
154172
def make_default_config():
155173
return download.DownloadConfig()
156174

157175

158176
@pytest.mark.parametrize(
159-
'dataset_cls', [DummyBeamDataset, CommonPipelineDummyBeamDataset]
177+
'dataset_cls',
178+
[
179+
DummyBeamDataset,
180+
CommonPipelineDummyBeamDataset,
181+
UnshuffledDummyBeamDataset,
182+
],
160183
)
161184
@pytest.mark.parametrize(
162185
'make_dl_config',
@@ -178,29 +201,23 @@ def test_beam_datasets(
178201
assert data_path.exists() # Dataset has been generated
179202

180203
# Check number of shards/generated files
181-
_test_shards(
182-
data_path,
183-
pattern='%s-test.tfrecord-{:05}-of-{:05}' % dataset_name,
184-
# Liquid sharding is not guaranteed to always use the same number.
185-
num_shards=builder.info.splits['test'].num_shards,
186-
)
187-
_test_shards(
188-
data_path,
189-
pattern='%s-train.tfrecord-{:05}-of-{:05}' % dataset_name,
190-
num_shards=1,
191-
)
204+
for split in ['test', 'train']:
205+
_test_shards(
206+
data_path,
207+
pattern='%s-%s.tfrecord-{:05}-of-{:05}' % (dataset_name, split),
208+
num_shards=builder.info.splits[split].num_shards,
209+
)
192210

193211
ds = dataset_utils.as_numpy(builder.as_dataset())
194212

195-
def get_id(ex):
196-
return ex['id']
197-
213+
test_examples = list(ds['test'])
214+
train_examples = list(ds['train'])
198215
_assert_values_equal(
199-
sorted(list(ds['test']), key=get_id),
216+
sorted(test_examples, key=get_id),
200217
sorted([_gen_example(i)[1] for i in range(725)], key=get_id),
201218
)
202219
_assert_values_equal(
203-
sorted(list(ds['train']), key=get_id),
220+
sorted(train_examples, key=get_id),
204221
sorted([_gen_example(i)[1] for i in range(1000)], key=get_id),
205222
)
206223

tensorflow_datasets/core/split_builder.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
# pylint: disable=g-import-not-at-top
3535
from tensorflow_datasets.core import example_serializer
3636
from tensorflow_datasets.core import features as features_lib
37+
from tensorflow_datasets.core import file_adapters
3738
from tensorflow_datasets.core import naming
3839
from tensorflow_datasets.core import splits as splits_lib
3940
from tensorflow_datasets.core import utils
@@ -410,14 +411,18 @@ def submit_split_generation(
410411
generator: SplitGenerator,
411412
filename_template: naming.ShardedFileTemplate,
412413
disable_shuffling: bool,
414+
nondeterministic_order: bool,
413415
) -> _SplitInfoFuture:
414416
"""Start the split generation.
415417
416418
Args:
417-
split_name: Name of the split to generate
418-
generator: Generator, beam.PTransform,... yielding the examples
419+
split_name: Name of the split to generate.
420+
generator: Generator, beam.PTransform,... yielding the examples.
419421
filename_template: Template to format the filename for a shard.
420-
disable_shuffling: Specifies whether to shuffle the examples
422+
disable_shuffling: Specifies whether to shuffle the examples.
423+
nondeterministic_order: If True, it will not assure deterministic ordering
424+
when writing' examples to disk. This might result in quicker dataset
425+
preparation
421426
422427
Returns:
423428
split_info_future: Future containing the `split_info`, once generation
@@ -433,13 +438,19 @@ def submit_split_generation(
433438
# Depending on the type of generator, we use the corresponding
434439
# `_build_from_xyz` method.
435440
if isinstance(generator, Iterable):
441+
if nondeterministic_order:
442+
logging.warning(
443+
'Enabling `nondeterministic_order` for a dataset that does not use'
444+
' beam has no effect.'
445+
)
436446
return self._build_from_generator(**build_kwargs)
437447
else: # Otherwise, beam required
438448
unknown_generator_type = TypeError(
439449
f'Invalid split generator value for split `{split_name}`. '
440450
'Expected generator or apache_beam object. Got: '
441451
f'{type(generator)}'
442452
)
453+
build_kwargs['nondeterministic_order'] = nondeterministic_order
443454
if isinstance(generator, beam.PTransform):
444455
# Generate the beam.PCollection
445456
pcollection = self.beam_pipeline | split_name >> generator
@@ -527,20 +538,35 @@ def _build_from_pcollection(
527538
generator: 'beam.PCollection[KeyExample]',
528539
filename_template: naming.ShardedFileTemplate,
529540
disable_shuffling: bool,
541+
nondeterministic_order: bool,
530542
) -> _SplitInfoFuture:
531543
"""Split generator for `beam.PCollection`."""
532544
# TODO(tfds): Should try to add support to `max_examples_per_split`
533-
beam_writer = writer_lib.BeamWriter(
534-
serializer=example_serializer.ExampleSerializer(
535-
self._features.get_serialized_info()
536-
),
537-
filename_template=filename_template,
538-
hash_salt=split_name,
539-
disable_shuffling=disable_shuffling,
540-
shard_config=self._shard_config,
541-
example_writer=self._example_writer,
542-
ignore_duplicates=self._ignore_duplicates,
545+
serializer = example_serializer.ExampleSerializer(
546+
self._features.get_serialized_info()
543547
)
548+
if nondeterministic_order:
549+
logging.info(
550+
'Order of examples does not matter, using NoShuffleBeamWriter'
551+
)
552+
beam_writer = writer_lib.NoShuffleBeamWriter(
553+
serializer=serializer,
554+
file_format=file_adapters.FileFormat.from_value(
555+
filename_template.filetype_suffix
556+
),
557+
filename_template=filename_template,
558+
)
559+
else:
560+
logging.info('Deterministic ordering is enabled, using BeamWriter')
561+
beam_writer = writer_lib.BeamWriter(
562+
serializer=serializer,
563+
filename_template=filename_template,
564+
hash_salt=split_name,
565+
disable_shuffling=disable_shuffling,
566+
shard_config=self._shard_config,
567+
example_writer=self._example_writer,
568+
ignore_duplicates=self._ignore_duplicates,
569+
)
544570

545571
def _encode_example(key_ex, encode_fn=self._features.encode_example):
546572
# We do not access self._features in this function to avoid pickling the

0 commit comments

Comments
 (0)