Skip to content

Commit e3d499f

Browse files
marcenacpThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Make tfds.data_source pickable.
PiperOrigin-RevId: 636470622
1 parent 33fe626 commit e3d499f

File tree

7 files changed

+109
-69
lines changed

7 files changed

+109
-69
lines changed

tensorflow_datasets/core/data_sources/array_record.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,8 @@
2020
"""
2121

2222
import dataclasses
23-
from typing import Any, Optional
2423

25-
from tensorflow_datasets.core import dataset_info as dataset_info_lib
26-
from tensorflow_datasets.core import decode
27-
from tensorflow_datasets.core import splits as splits_lib
2824
from tensorflow_datasets.core.data_sources import base
29-
from tensorflow_datasets.core.utils import type_utils
3025
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source
3126

3227

@@ -42,18 +37,9 @@ class ArrayRecordDataSource(base.BaseDataSource):
4237
source.
4338
"""
4439

45-
dataset_info: dataset_info_lib.DatasetInfo
46-
split: splits_lib.Split = None
47-
decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = (
48-
None
49-
)
50-
# In order to lazy load array_record, we don't load
51-
# `array_record_data_source.ArrayRecordDataSource` here.
52-
data_source: Any = dataclasses.field(init=False)
53-
length: int = dataclasses.field(init=False)
54-
5540
def __post_init__(self):
56-
file_instructions = base.file_instructions(self.dataset_info, self.split)
41+
dataset_info = self.dataset_builder.info
42+
file_instructions = base.file_instructions(dataset_info, self.split)
5743
self.data_source = array_record_data_source.ArrayRecordDataSource(
5844
file_instructions
5945
)

tensorflow_datasets/core/data_sources/base.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
from collections.abc import MappingView, Sequence
1919
import dataclasses
20+
import functools
2021
import typing
2122
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar
2223

2324
from tensorflow_datasets.core import dataset_info as dataset_info_lib
2425
from tensorflow_datasets.core import decode
2526
from tensorflow_datasets.core import splits as splits_lib
27+
from tensorflow_datasets.core.features import top_level_feature
2628
from tensorflow_datasets.core.utils import shard_utils
2729
from tensorflow_datasets.core.utils import type_utils
2830
from tensorflow_datasets.core.utils.lazy_imports_utils import tree
@@ -54,6 +56,14 @@ def file_instructions(
5456
return split_dict[split].file_instructions
5557

5658

59+
class _DatasetBuilder(Protocol):
60+
"""Protocol for the DatasetBuilder to avoid cyclic imports."""
61+
62+
@property
63+
def info(self) -> dataset_info_lib.DatasetInfo:
64+
...
65+
66+
5767
@dataclasses.dataclass
5868
class BaseDataSource(MappingView, Sequence):
5969
"""Base DataSource to override all dunder methods with the deserialization.
@@ -64,22 +74,28 @@ class BaseDataSource(MappingView, Sequence):
6474
deserialization/decoding.
6575
6676
Attributes:
67-
dataset_info: The DatasetInfo of the
77+
dataset_builder: The dataset builder.
6878
split: The split to load in the data source.
6979
decoders: Optional decoders for decoding.
7080
data_source: The underlying data source to initialize in the __post_init__.
7181
"""
7282

73-
dataset_info: dataset_info_lib.DatasetInfo
83+
dataset_builder: _DatasetBuilder
7484
split: splits_lib.Split | None = None
7585
decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None
7686
data_source: DataSource[Any] = dataclasses.field(init=False)
7787

88+
@functools.cached_property
89+
def _features(self) -> top_level_feature.TopLevelFeature:
90+
"""Caches features because we log the use of dataset_builder.info."""
91+
features = self.dataset_builder.info.features
92+
if not features:
93+
raise ValueError('No feature defined in the dataset buidler.')
94+
return features
95+
7896
def __getitem__(self, key: SupportsIndex) -> Any:
7997
record = self.data_source[key.__index__()]
80-
return self.dataset_info.features.deserialize_example_np(
81-
record, decoders=self.decoders
82-
)
98+
return self._features.deserialize_example_np(record, decoders=self.decoders)
8399

84100
def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
85101
"""Retrieves items by batch.
@@ -98,24 +114,24 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
98114
if not keys:
99115
return []
100116
records = self.data_source.__getitems__(keys)
101-
features = self.dataset_info.features
102117
if len(keys) != len(records):
103118
raise IndexError(
104119
f'Requested {len(keys)} records but got'
105120
f' {len(records)} records.'
106121
f'{keys=}, {records=}'
107122
)
108123
return [
109-
features.deserialize_example_np(record, decoders=self.decoders)
124+
self._features.deserialize_example_np(record, decoders=self.decoders)
110125
for record in records
111126
]
112127

113128
def __repr__(self) -> str:
114129
decoders_repr = (
115130
tree.map_structure(type, self.decoders) if self.decoders else None
116131
)
132+
name = self.dataset_builder.info.name
117133
return (
118-
f'{self.__class__.__name__}(name={self.dataset_info.name}, '
134+
f'{self.__class__.__name__}(name={name}, '
119135
f'split={self.split!r}, '
120136
f'decoders={decoders_repr})'
121137
)

tensorflow_datasets/core/data_sources/base_test.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
"""Tests for all data sources."""
1717

18+
import pickle
1819
from unittest import mock
1920

21+
import cloudpickle
2022
from etils import epath
2123
import pytest
2224
import tensorflow_datasets as tfds
2325
from tensorflow_datasets import testing
24-
from tensorflow_datasets.core import dataset_builder
26+
from tensorflow_datasets.core import dataset_builder as dataset_builder_lib
2527
from tensorflow_datasets.core import dataset_info as dataset_info_lib
2628
from tensorflow_datasets.core import decode
2729
from tensorflow_datasets.core import file_adapters
@@ -77,7 +79,7 @@ def mocked_parquet_dataset():
7779
)
7880
def test_read_write(
7981
tmp_path: epath.Path,
80-
builder_cls: dataset_builder.DatasetBuilder,
82+
builder_cls: dataset_builder_lib.DatasetBuilder,
8183
file_format: file_adapters.FileFormat,
8284
):
8385
builder = builder_cls(data_dir=tmp_path, file_format=file_format)
@@ -106,28 +108,34 @@ def test_read_write(
106108
]
107109

108110

109-
def create_dataset_info(file_format: file_adapters.FileFormat):
111+
def create_dataset_builder(file_format: file_adapters.FileFormat):
110112
with mock.patch.object(splits_lib, 'SplitInfo') as split_mock:
111113
split_mock.return_value.name = 'train'
112114
split_mock.return_value.file_instructions = _FILE_INSTRUCTIONS
113115
dataset_info = mock.create_autospec(dataset_info_lib.DatasetInfo)
114116
dataset_info.file_format = file_format
115117
dataset_info.splits = {'train': split_mock()}
116118
dataset_info.name = 'dataset_name'
117-
return dataset_info
119+
120+
dataset_builder = mock.create_autospec(dataset_builder_lib.DatasetBuilder)
121+
dataset_builder.info = dataset_info
122+
123+
return dataset_builder
118124

119125

120126
@pytest.mark.parametrize(
121127
'data_source_cls',
122128
_DATA_SOURCE_CLS,
123129
)
124130
def test_missing_split_raises_error(data_source_cls):
125-
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
131+
dataset_builder = create_dataset_builder(
132+
file_adapters.FileFormat.ARRAY_RECORD
133+
)
126134
with pytest.raises(
127135
ValueError,
128136
match="Unknown split 'doesnotexist'.",
129137
):
130-
data_source_cls(dataset_info, split='doesnotexist')
138+
data_source_cls(dataset_builder, split='doesnotexist')
131139

132140

133141
@pytest.mark.usefixtures(*_FIXTURES)
@@ -136,8 +144,10 @@ def test_missing_split_raises_error(data_source_cls):
136144
_DATA_SOURCE_CLS,
137145
)
138146
def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
139-
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
140-
source = data_source_cls(dataset_info, split='train')
147+
dataset_builder = create_dataset_builder(
148+
file_adapters.FileFormat.ARRAY_RECORD
149+
)
150+
source = data_source_cls(dataset_builder, split='train')
141151
name = data_source_cls.__name__
142152
assert (
143153
repr(source) == f"{name}(name=dataset_name, split='train', decoders=None)"
@@ -150,9 +160,11 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
150160
_DATA_SOURCE_CLS,
151161
)
152162
def test_repr_returns_meaningful_string_with_decoders(data_source_cls):
153-
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
163+
dataset_builder = create_dataset_builder(
164+
file_adapters.FileFormat.ARRAY_RECORD
165+
)
154166
source = data_source_cls(
155-
dataset_info,
167+
dataset_builder,
156168
split='train',
157169
decoders={'my_feature': decode.SkipDecoding()},
158170
)
@@ -181,3 +193,18 @@ def test_data_source_is_sliceable():
181193
file_instructions = mock_array_record_data_source.call_args_list[1].args[0]
182194
assert file_instructions[0].skip == 0
183195
assert file_instructions[0].take == 30000
196+
197+
198+
# PyGrain requires that data sources are picklable.
199+
@pytest.mark.parametrize(
200+
'file_format',
201+
file_adapters.FileFormat.with_random_access(),
202+
)
203+
@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle])
204+
def test_data_source_is_picklable_after_use(file_format, pickle_module):
205+
with tfds.testing.tmp_dir() as data_dir:
206+
builder = tfds.testing.DummyDataset(data_dir=data_dir)
207+
builder.download_and_prepare(file_format=file_format)
208+
data_source = builder.as_data_source(split='train')
209+
assert data_source[0] == {'id': 0}
210+
assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0}

tensorflow_datasets/core/data_sources/parquet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class ParquetDataSource(base.BaseDataSource):
5757
"""ParquetDataSource to read from a ParquetDataset."""
5858

5959
def __post_init__(self):
60-
file_instructions = base.file_instructions(self.dataset_info, self.split)
60+
dataset_info = self.dataset_builder.info
61+
file_instructions = base.file_instructions(dataset_info, self.split)
6162
filenames = [
6263
file_instruction.filename for file_instruction in file_instructions
6364
]

tensorflow_datasets/core/dataset_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,13 +774,13 @@ def build_single_data_source(
774774
file_format = self.info.file_format
775775
if file_format == file_adapters.FileFormat.ARRAY_RECORD:
776776
return array_record.ArrayRecordDataSource(
777-
self.info,
777+
self,
778778
split=split,
779779
decoders=decoders,
780780
)
781781
elif file_format == file_adapters.FileFormat.PARQUET:
782782
return parquet.ParquetDataSource(
783-
self.info,
783+
self,
784784
split=split,
785785
decoders=decoders,
786786
)

tensorflow_datasets/testing/mocking.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,25 @@ class PickableDataSourceMock(mock.MagicMock):
8383
"""Makes MagicMock pickable in order to work with multiprocessing in Grain."""
8484

8585
def __getstate__(self):
86-
return {'num_examples': len(self), 'generator': self._generator}
86+
return {
87+
'num_examples': len(self),
88+
'generator': self._generator,
89+
'serialize_example': self._serialize_example,
90+
}
8791

8892
def __setstate__(self, state):
89-
num_examples, generator = state['num_examples'], state['generator']
93+
num_examples, generator, serialize_example = (
94+
state['num_examples'],
95+
state['generator'],
96+
state['serialize_example'],
97+
)
9098
self.__len__.return_value = num_examples
91-
self.__getitem__ = functools.partial(_getitem, generator=generator)
92-
self.__getitems__ = functools.partial(_getitems, generator=generator)
99+
self.__getitem__ = functools.partial(
100+
_getitem, generator=generator, serialize_example=serialize_example
101+
)
102+
self.__getitems__ = functools.partial(
103+
_getitems, generator=generator, serialize_example=serialize_example
104+
)
93105

94106
def __reduce__(self):
95107
return (PickableDataSourceMock, (), self.__getstate__())
@@ -99,50 +111,33 @@ def _getitem(
99111
self,
100112
record_key: int,
101113
generator: RandomFakeGenerator,
102-
serialized: bool = False,
114+
serialize_example=None,
103115
) -> Any:
104116
"""Function to overwrite __getitem__ in data sources."""
117+
del self
105118
example = generator[record_key]
106-
if serialized:
119+
if serialize_example:
107120
# Return serialized raw bytes
108-
return self.dataset_info.features.serialize_example(example)
121+
return serialize_example(example)
109122
return example
110123

111124

112125
def _getitems(
113126
self,
114127
record_keys: Sequence[int],
115128
generator: RandomFakeGenerator,
116-
serialized: bool = False,
129+
serialize_example=None,
117130
) -> Sequence[Any]:
118131
"""Function to overwrite __getitems__ in data sources."""
119132
items = [
120-
_getitem(self, record_key, generator, serialized=serialized)
133+
_getitem(self, record_key, generator, serialize_example=serialize_example)
121134
for record_key in record_keys
122135
]
123-
if serialized:
136+
if serialize_example:
124137
return np.array(items)
125138
return items
126139

127140

128-
def _deserialize_example_np(serialized_example, *, decoders=None):
129-
"""Function to overwrite dataset_info.features.deserialize_example_np.
130-
131-
Warning: this has to be defined in the outer scope in order for the function
132-
to be pickable.
133-
134-
Args:
135-
serialized_example: the example to deserialize.
136-
decoders: optional decoders.
137-
138-
Returns:
139-
The serialized example, because deserialization is taken care by
140-
RandomFakeGenerator.
141-
"""
142-
del decoders
143-
return serialized_example
144-
145-
146141
class MockPolicy(enum.Enum):
147142
"""Strategy to use with `tfds.testing.mock_data` to mock the dataset.
148143
@@ -385,21 +380,27 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
385380
# Force ARRAY_RECORD as the default file_format.
386381
return_value=file_adapters.FileFormat.ARRAY_RECORD,
387382
):
388-
self.info.features.deserialize_example_np = _deserialize_example_np
383+
# Make mock_data_source pickable with a given len:
389384
mock_data_source.return_value.__len__.return_value = num_examples
385+
# Make mock_data_source pickable with a given generator:
390386
mock_data_source.return_value._generator = ( # pylint:disable=protected-access
391387
generator
392388
)
389+
# Make mock_data_source pickable with a given serialize_example:
390+
mock_data_source.return_value._serialize_example = ( # pylint:disable=protected-access
391+
self.info.features.serialize_example
392+
)
393+
serialize_example = self.info.features.serialize_example
393394
mock_data_source.return_value.__getitem__ = functools.partial(
394-
_getitem, generator=generator
395+
_getitem, generator=generator, serialize_example=serialize_example
395396
)
396397
mock_data_source.return_value.__getitems__ = functools.partial(
397-
_getitems, generator=generator
398+
_getitems, generator=generator, serialize_example=serialize_example
398399
)
399400

400401
def build_single_data_source(split):
401402
single_data_source = array_record.ArrayRecordDataSource(
402-
dataset_info=self.info, split=split, decoders=decoders
403+
dataset_builder=self, split=split, decoders=decoders
403404
)
404405
return single_data_source
405406

tensorflow_datasets/testing/mocking_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,12 @@ def test_as_data_source_fn():
392392
assert imagenet[0] == 'foo'
393393
assert imagenet[1] == 'bar'
394394
assert imagenet[2] == 'baz'
395+
396+
397+
# PyGrain requires that data sources are picklable.
398+
def test_mocked_data_source_is_pickable():
399+
with tfds.testing.mock_data(num_examples=2):
400+
data_source = tfds.data_source('imagenet2012', split='train')
401+
pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source))
402+
assert len(pickled_and_unpickled_data_source) == 2
403+
assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)

0 commit comments

Comments
 (0)