36
36
37
37
from __future__ import annotations
38
38
39
- from collections .abc import Mapping
40
- from typing import Any , Dict , Optional , Sequence
39
+ from collections .abc import Mapping , Sequence
40
+ from typing import Any
41
41
42
42
from etils import epath
43
43
import numpy as np
61
61
from tensorflow_datasets .core .utils .lazy_imports_utils import pandas as pd
62
62
63
63
64
+ _RecordOrFeature = Mapping [str , Any ]
65
+
66
+
67
+ def _strip_record_set_prefix (
68
+ record_or_feature : _RecordOrFeature , record_set_id : str
69
+ ) -> _RecordOrFeature :
70
+ """Removes the record set prefix from the field ids of a record or feature."""
71
+ return {
72
+ field_id .removeprefix (f'{ record_set_id } /' ): value
73
+ for field_id , value in record_or_feature .items ()
74
+ }
75
+
76
+
64
77
def datatype_converter (
65
78
field : mlc .Field ,
66
- int_dtype : Optional [ type_utils .TfdsDType ] = np .int64 ,
67
- float_dtype : Optional [ type_utils .TfdsDType ] = np .float32 ,
79
+ int_dtype : type_utils .TfdsDType = np .int64 ,
80
+ float_dtype : type_utils .TfdsDType = np .float32 ,
68
81
):
69
82
"""Converts a Croissant field to a TFDS-compatible feature.
70
83
@@ -162,8 +175,8 @@ def __init__(
162
175
jsonld : epath .PathLike | Mapping [str , Any ],
163
176
record_set_ids : Sequence [str ] | None = None ,
164
177
disable_shuffling : bool | None = False ,
165
- int_dtype : type_utils .TfdsDType | None = np .int64 ,
166
- float_dtype : type_utils .TfdsDType | None = np .float32 ,
178
+ int_dtype : type_utils .TfdsDType = np .int64 ,
179
+ float_dtype : type_utils .TfdsDType = np .float32 ,
167
180
mapping : Mapping [str , epath .PathLike ] | None = None ,
168
181
overwrite_version : version_lib .VersionOrStr | None = None ,
169
182
filters : Mapping [str , Any ] | None = None ,
@@ -214,7 +227,7 @@ def __init__(
214
227
conversion_utils .to_tfds_name (record_set_id )
215
228
for record_set_id in record_set_ids
216
229
]
217
- self .BUILDER_CONFIGS : Sequence [dataset_builder .BuilderConfig ] = [ # pylint: disable=invalid-name
230
+ self .BUILDER_CONFIGS : list [dataset_builder .BuilderConfig ] = [ # pylint: disable=invalid-name
218
231
dataset_builder .BuilderConfig (name = config_name )
219
232
for config_name in config_names
220
233
]
@@ -261,13 +274,14 @@ def get_features(self) -> features_dict.FeaturesDict:
261
274
if field .repeated :
262
275
feature = sequence_feature .Sequence (feature )
263
276
features [field .id ] = feature
277
+ features = _strip_record_set_prefix (features , record_set .id )
264
278
return features_dict .FeaturesDict (features )
265
279
266
280
def _split_generators (
267
281
self ,
268
282
dl_manager : download .DownloadManager ,
269
283
pipeline : beam .Pipeline ,
270
- ) -> Dict [splits_lib .Split , split_builder_lib .SplitGenerator ]:
284
+ ) -> dict [splits_lib .Split , split_builder_lib .SplitGenerator ]:
271
285
# If a split recordset is joined for the required record set, we generate
272
286
# splits accordingly. Otherwise, it generates a single `default` split with
273
287
# all the records.
@@ -317,11 +331,15 @@ def _generate_examples(
317
331
318
332
def convert_to_tfds_format (
319
333
global_index : int ,
320
- record : Any ,
334
+ record : _RecordOrFeature ,
321
335
features : feature_lib .FeatureConnector | None = None ,
322
- ) -> tuple [int , Any ]:
336
+ record_set_id : str | None = None ,
337
+ ) -> tuple [int , _RecordOrFeature ]:
323
338
if not features :
324
339
raise ValueError ('features should not be None.' )
340
+ if not record_set_id :
341
+ raise ValueError ('record_set_id should not be None.' )
342
+ record = _strip_record_set_prefix (record , record_set_id )
325
343
return (
326
344
global_index ,
327
345
conversion_utils .to_tfds_value (record , features ),
@@ -330,5 +348,7 @@ def convert_to_tfds_format(
330
348
return records .beam_reader (
331
349
pipeline = pipeline
332
350
) | 'Convert to TFDS format' >> beam .MapTuple (
333
- convert_to_tfds_format , features = self .info .features
351
+ convert_to_tfds_format ,
352
+ features = self .info .features ,
353
+ record_set_id = record_set .id ,
334
354
)
0 commit comments