32
32
import itertools
33
33
import multiprocessing
34
34
import os
35
+ import time
35
36
from typing import Any , Dict , Optional , Union
36
37
37
38
from absl import logging
@@ -108,9 +109,24 @@ class _ShardInfo:
108
109
num_exceptions : int
109
110
110
111
112
+ def _load_dataset (
113
+ hf_builder : hf_datasets .DatasetBuilder ,
114
+ split : str ,
115
+ ) -> hf_datasets .Dataset :
116
+ """Efficiently loads a HuggingFace iterable dataset from its builder."""
117
+ if hf_builder .repo_id is None :
118
+ return hf_builder .as_dataset (split = split )
119
+ return hf_datasets .load_dataset (
120
+ hf_builder .repo_id or hf_builder .cache_dir ,
121
+ hf_builder .config_id ,
122
+ split = split ,
123
+ streaming = True ,
124
+ )
125
+
126
+
111
127
def _write_shard (
112
128
shard_spec : _ShardSpec ,
113
- hf_builder ,
129
+ hf_builder : hf_datasets . DatasetBuilder ,
114
130
example_writer ,
115
131
features : feature_lib .FeaturesDict ,
116
132
ignore_hf_errors : bool ,
@@ -136,12 +152,19 @@ def _write_shard(
136
152
def get_serialized_examples_iter ():
137
153
nonlocal num_bytes
138
154
nonlocal num_exceptions
139
- dataset = hf_builder .as_dataset (
140
- split = shard_spec .shard_split , run_post_process = False
155
+ dataset = _load_dataset (
156
+ hf_builder ,
157
+ shard_spec .hf_split ,
141
158
)
142
- for i in range (shard_spec .num_examples ):
159
+ dataset = iter (dataset )
160
+ # Skipping the first `start_index` examples. `streaming=True` returns an
161
+ # iterable dataset, so we cannot jump to a specific index. This is not too
162
+ # costly because it takes <0.5 ms/element in the wikipedia dataset.
163
+ for _ in range (shard_spec .start_index ):
164
+ next (dataset )
165
+ for _ in range (shard_spec .num_examples ):
143
166
try :
144
- hf_value = dataset [ i ]
167
+ hf_value = next ( dataset )
145
168
except Exception : # pylint: disable=broad-exception-caught
146
169
num_exceptions += 1
147
170
if ignore_hf_errors :
@@ -155,6 +178,7 @@ def get_serialized_examples_iter():
155
178
num_bytes += len (serialized_example )
156
179
yield serialized_example
157
180
181
+ start = time .time ()
158
182
example_writer .write (
159
183
os .fspath (shard_spec .path ),
160
184
tqdm_utils .tqdm (
@@ -166,6 +190,11 @@ def get_serialized_examples_iter():
166
190
mininterval = 1.0 ,
167
191
),
168
192
)
193
+ logging .info (
194
+ 'Generated %s examples in %s seconds' ,
195
+ shard_spec .num_examples ,
196
+ time .time () - start ,
197
+ )
169
198
170
199
return _ShardInfo (
171
200
num_bytes = num_bytes ,
@@ -247,6 +276,7 @@ def __init__(
247
276
self ._builder_config = self ._converted_builder_config
248
277
self .generation_errors = []
249
278
self ._ignore_hf_errors = ignore_hf_errors
279
+ login_to_hf (self ._hf_hub_token )
250
280
251
281
@property
252
282
def builder_config (self ) -> Optional [Any ]:
@@ -257,14 +287,6 @@ def _create_builder_config(
257
287
) -> Optional [dataset_builder .BuilderConfig ]:
258
288
return self ._converted_builder_config
259
289
260
- @functools .lru_cache (maxsize = 1 )
261
- def _hf_download_and_prepare (self ):
262
- login_to_hf (self ._hf_hub_token )
263
- self ._hf_builder .download_and_prepare (
264
- num_proc = self ._hf_num_proc ,
265
- verification_mode = self ._verification_mode ,
266
- )
267
-
268
290
@property
269
291
def _hf_info (self ) -> hf_datasets .DatasetInfo :
270
292
"""Retrieves the dataset info from the HuggingFace Datasets."""
@@ -278,11 +300,18 @@ def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo:
278
300
)
279
301
280
302
def _hf_features (self ) -> hf_datasets .Features :
281
- if not self ._hf_info .features :
282
- # We need to download and prepare the data to know its features.
283
- self ._hf_download_and_prepare ()
284
-
285
- return self ._hf_info .features
303
+ # Return the features from the builder info.
304
+ if self ._hf_info .features :
305
+ return self ._hf_info .features
306
+ # Return the features from the first split.
307
+ for split in self ._hf_info .splits :
308
+ ds = _load_dataset (
309
+ self ._hf_builder ,
310
+ split ,
311
+ )
312
+ if hasattr (ds , 'info' ) and ds .info .features :
313
+ return ds .info .features
314
+ raise ValueError ('No features found in the dataset.' )
286
315
287
316
def _info (self ) -> dataset_info_lib .DatasetInfo :
288
317
return dataset_info_lib .DatasetInfo (
@@ -309,7 +338,6 @@ def _generate_splits(
309
338
) -> Sequence [splits_lib .SplitInfo ]:
310
339
"""Prepares the dataset by writing to shards directly."""
311
340
del dl_manager , download_config # Unused.
312
- self ._hf_download_and_prepare ()
313
341
314
342
shard_specs_by_split : dict [str , Sequence [_ShardSpec ]] = {}
315
343
for hf_split , hf_split_info in self ._hf_info .splits .items ():
0 commit comments