Skip to content

Commit 290ee7e

Browse files
marcenacpThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Stream from Hugging Face instead of downloading and preparing everything.
PiperOrigin-RevId: 657212303
1 parent 2123db7 commit 290ee7e

File tree

2 files changed

+48
-25
lines changed

2 files changed

+48
-25
lines changed

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import itertools
3333
import multiprocessing
3434
import os
35+
import time
3536
from typing import Any, Dict, Optional, Union
3637

3738
from absl import logging
@@ -108,9 +109,24 @@ class _ShardInfo:
108109
num_exceptions: int
109110

110111

112+
def _load_dataset(
113+
hf_builder: hf_datasets.DatasetBuilder,
114+
split: str,
115+
) -> hf_datasets.Dataset:
116+
"""Efficiently loads a HuggingFace iterable dataset from its builder."""
117+
if hf_builder.repo_id is None:
118+
return hf_builder.as_dataset(split=split)
119+
return hf_datasets.load_dataset(
120+
hf_builder.repo_id or hf_builder.cache_dir,
121+
hf_builder.config_id,
122+
split=split,
123+
streaming=True,
124+
)
125+
126+
111127
def _write_shard(
112128
shard_spec: _ShardSpec,
113-
hf_builder,
129+
hf_builder: hf_datasets.DatasetBuilder,
114130
example_writer,
115131
features: feature_lib.FeaturesDict,
116132
ignore_hf_errors: bool,
@@ -136,12 +152,19 @@ def _write_shard(
136152
def get_serialized_examples_iter():
137153
nonlocal num_bytes
138154
nonlocal num_exceptions
139-
dataset = hf_builder.as_dataset(
140-
split=shard_spec.shard_split, run_post_process=False
155+
dataset = _load_dataset(
156+
hf_builder,
157+
shard_spec.hf_split,
141158
)
142-
for i in range(shard_spec.num_examples):
159+
dataset = iter(dataset)
160+
# Skipping the first `start_index` examples. `streaming=True` returns an
161+
# iterable dataset, so we cannot jump to a specific index. This is not too
162+
# costly because it takes <0.5 ms/element in the wikipedia dataset.
163+
for _ in range(shard_spec.start_index):
164+
next(dataset)
165+
for _ in range(shard_spec.num_examples):
143166
try:
144-
hf_value = dataset[i]
167+
hf_value = next(dataset)
145168
except Exception: # pylint: disable=broad-exception-caught
146169
num_exceptions += 1
147170
if ignore_hf_errors:
@@ -155,6 +178,7 @@ def get_serialized_examples_iter():
155178
num_bytes += len(serialized_example)
156179
yield serialized_example
157180

181+
start = time.time()
158182
example_writer.write(
159183
os.fspath(shard_spec.path),
160184
tqdm_utils.tqdm(
@@ -166,6 +190,11 @@ def get_serialized_examples_iter():
166190
mininterval=1.0,
167191
),
168192
)
193+
logging.info(
194+
'Generated %s examples in %s seconds',
195+
shard_spec.num_examples,
196+
time.time() - start,
197+
)
169198

170199
return _ShardInfo(
171200
num_bytes=num_bytes,
@@ -247,6 +276,7 @@ def __init__(
247276
self._builder_config = self._converted_builder_config
248277
self.generation_errors = []
249278
self._ignore_hf_errors = ignore_hf_errors
279+
login_to_hf(self._hf_hub_token)
250280

251281
@property
252282
def builder_config(self) -> Optional[Any]:
@@ -257,14 +287,6 @@ def _create_builder_config(
257287
) -> Optional[dataset_builder.BuilderConfig]:
258288
return self._converted_builder_config
259289

260-
@functools.lru_cache(maxsize=1)
261-
def _hf_download_and_prepare(self):
262-
login_to_hf(self._hf_hub_token)
263-
self._hf_builder.download_and_prepare(
264-
num_proc=self._hf_num_proc,
265-
verification_mode=self._verification_mode,
266-
)
267-
268290
@property
269291
def _hf_info(self) -> hf_datasets.DatasetInfo:
270292
"""Retrieves the dataset info from the HuggingFace Datasets."""
@@ -278,11 +300,18 @@ def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo:
278300
)
279301

280302
def _hf_features(self) -> hf_datasets.Features:
281-
if not self._hf_info.features:
282-
# We need to download and prepare the data to know its features.
283-
self._hf_download_and_prepare()
284-
285-
return self._hf_info.features
303+
# Return the features from the builder info.
304+
if self._hf_info.features:
305+
return self._hf_info.features
306+
# Return the features from the first split.
307+
for split in self._hf_info.splits:
308+
ds = _load_dataset(
309+
self._hf_builder,
310+
split,
311+
)
312+
if hasattr(ds, 'info') and ds.info.features:
313+
return ds.info.features
314+
raise ValueError('No features found in the dataset.')
286315

287316
def _info(self) -> dataset_info_lib.DatasetInfo:
288317
return dataset_info_lib.DatasetInfo(
@@ -309,7 +338,6 @@ def _generate_splits(
309338
) -> Sequence[splits_lib.SplitInfo]:
310339
"""Prepares the dataset by writing to shards directly."""
311340
del dl_manager, download_config # Unused.
312-
self._hf_download_and_prepare()
313341

314342
shard_specs_by_split: dict[str, Sequence[_ShardSpec]] = {}
315343
for hf_split, hf_split_info in self._hf_info.splits.items():

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def mock_load_dataset_builder(tmp_path):
6262
with mock.patch.object(
6363
hf_datasets, 'load_dataset_builder', return_value=hf_builder
6464
) as load_dataset_builder:
65+
hf_builder.download_and_prepare()
6566
yield load_dataset_builder
6667

6768

@@ -133,12 +134,6 @@ def test_download_and_prepare(builder):
133134
assert len(ds['train_clean']) == 2
134135

135136

136-
def test_all_parameters_are_passed_down_to_hf(builder):
137-
builder._hf_builder.download_and_prepare.assert_called_once_with(
138-
verification_mode='no_checks', num_proc=100
139-
)
140-
141-
142137
def test_hf_features(builder):
143138
assert builder._hf_features() == {
144139
'number': hf_datasets.Value('int64'),

0 commit comments

Comments
 (0)