Skip to content

Commit ed9de76

Browse files
fineguyThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Fix feature names in Croissant builder.
PiperOrigin-RevId: 693628650
1 parent 3e5515f commit ed9de76

File tree

3 files changed

+43
-23
lines changed

3 files changed

+43
-23
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636

3737
from __future__ import annotations
3838

39-
from collections.abc import Mapping
40-
from typing import Any, Dict, Optional, Sequence
39+
from collections.abc import Mapping, Sequence
40+
from typing import Any
4141

4242
from etils import epath
4343
import numpy as np
@@ -61,10 +61,23 @@
6161
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd
6262

6363

64+
_RecordOrFeature = Mapping[str, Any]
65+
66+
67+
def _strip_record_set_prefix(
68+
record_or_feature: _RecordOrFeature, record_set_id: str
69+
) -> _RecordOrFeature:
70+
"""Removes the record set prefix from the field ids of a record or feature."""
71+
return {
72+
field_id.removeprefix(f'{record_set_id}/'): value
73+
for field_id, value in record_or_feature.items()
74+
}
75+
76+
6477
def datatype_converter(
6578
field: mlc.Field,
66-
int_dtype: Optional[type_utils.TfdsDType] = np.int64,
67-
float_dtype: Optional[type_utils.TfdsDType] = np.float32,
79+
int_dtype: type_utils.TfdsDType = np.int64,
80+
float_dtype: type_utils.TfdsDType = np.float32,
6881
):
6982
"""Converts a Croissant field to a TFDS-compatible feature.
7083
@@ -162,8 +175,8 @@ def __init__(
162175
jsonld: epath.PathLike | Mapping[str, Any],
163176
record_set_ids: Sequence[str] | None = None,
164177
disable_shuffling: bool | None = False,
165-
int_dtype: type_utils.TfdsDType | None = np.int64,
166-
float_dtype: type_utils.TfdsDType | None = np.float32,
178+
int_dtype: type_utils.TfdsDType = np.int64,
179+
float_dtype: type_utils.TfdsDType = np.float32,
167180
mapping: Mapping[str, epath.PathLike] | None = None,
168181
overwrite_version: version_lib.VersionOrStr | None = None,
169182
filters: Mapping[str, Any] | None = None,
@@ -214,7 +227,7 @@ def __init__(
214227
conversion_utils.to_tfds_name(record_set_id)
215228
for record_set_id in record_set_ids
216229
]
217-
self.BUILDER_CONFIGS: Sequence[dataset_builder.BuilderConfig] = [ # pylint: disable=invalid-name
230+
self.BUILDER_CONFIGS: list[dataset_builder.BuilderConfig] = [ # pylint: disable=invalid-name
218231
dataset_builder.BuilderConfig(name=config_name)
219232
for config_name in config_names
220233
]
@@ -261,13 +274,14 @@ def get_features(self) -> features_dict.FeaturesDict:
261274
if field.repeated:
262275
feature = sequence_feature.Sequence(feature)
263276
features[field.id] = feature
277+
features = _strip_record_set_prefix(features, record_set.id)
264278
return features_dict.FeaturesDict(features)
265279

266280
def _split_generators(
267281
self,
268282
dl_manager: download.DownloadManager,
269283
pipeline: beam.Pipeline,
270-
) -> Dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
284+
) -> dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
271285
# If a split recordset is joined for the required record set, we generate
272286
# splits accordingly. Otherwise, it generates a single `default` split with
273287
# all the records.
@@ -317,11 +331,15 @@ def _generate_examples(
317331

318332
def convert_to_tfds_format(
319333
global_index: int,
320-
record: Any,
334+
record: _RecordOrFeature,
321335
features: feature_lib.FeatureConnector | None = None,
322-
) -> tuple[int, Any]:
336+
record_set_id: str | None = None,
337+
) -> tuple[int, _RecordOrFeature]:
323338
if not features:
324339
raise ValueError('features should not be None.')
340+
if not record_set_id:
341+
raise ValueError('record_set_id should not be None.')
342+
record = _strip_record_set_prefix(record, record_set_id)
325343
return (
326344
global_index,
327345
conversion_utils.to_tfds_value(record, features),
@@ -330,5 +348,7 @@ def convert_to_tfds_format(
330348
return records.beam_reader(
331349
pipeline=pipeline
332350
) | 'Convert to TFDS format' >> beam.MapTuple(
333-
convert_to_tfds_format, features=self.info.features
351+
convert_to_tfds_format,
352+
features=self.info.features,
353+
record_set_id=record_set.id,
334354
)

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@
9191
def test_simple_datatype_converter(field, feature_type, int_dtype, float_dtype):
9292
actual_feature = croissant_builder.datatype_converter(
9393
field,
94-
int_dtype=int_dtype if int_dtype else np.int64,
95-
float_dtype=float_dtype if float_dtype else np.float32,
94+
int_dtype=int_dtype or np.int64,
95+
float_dtype=float_dtype or np.float32,
9696
)
9797
assert actual_feature == feature_type
9898

@@ -221,6 +221,6 @@ def test_download_and_prepare(crs_builder, expected_entries, split_name):
221221
crs_builder.download_and_prepare()
222222
data_source = crs_builder.as_data_source(split=split_name)
223223
assert len(data_source) == 2
224-
for i in range(2):
225-
assert data_source[i]["jsonl/index"] == expected_entries[i]["index"]
226-
assert data_source[i]["jsonl/text"].decode() == expected_entries[i]["text"]
224+
for entry, expected_entry in zip(data_source, expected_entries):
225+
assert entry["index"] == expected_entry["index"]
226+
assert entry["text"].decode() == expected_entry["text"]

tensorflow_datasets/core/features/features_dict.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
from __future__ import annotations
1919

20+
from collections.abc import Mapping
2021
import concurrent.futures
21-
from typing import Dict, List, Union
2222

2323
from tensorflow_datasets.core import utils
2424
from tensorflow_datasets.core.features import feature as feature_lib
@@ -33,7 +33,7 @@
3333
WORKER_COUNT = 16
3434

3535

36-
class _DictGetCounter(object):
36+
class _DictGetCounter:
3737
"""Wraps dict.get and counts successful key accesses."""
3838

3939
def __init__(self, d):
@@ -114,15 +114,15 @@ class FeaturesDict(top_level_feature.TopLevelFeature):
114114

115115
def __init__(
116116
self,
117-
feature_dict: Dict[str, feature_lib.FeatureConnectorArg],
117+
feature_dict: Mapping[str, feature_lib.FeatureConnectorArg],
118118
*,
119119
doc: feature_lib.DocArg = None,
120120
):
121121
"""Initialize the features.
122122
123123
Args:
124-
feature_dict (dict): Dictionary containing the feature connectors of a
125-
example. The keys should correspond to the data dict as returned by
124+
feature_dict: Mapping containing the feature connectors of a example. The
125+
keys should correspond to the data dict as returned by
126126
tf.data.Dataset(). Types (np.int32,...) and dicts will automatically be
127127
converted into FeatureConnector.
128128
doc: Documentation of this feature (e.g. description).
@@ -173,7 +173,7 @@ def __repr__(self):
173173

174174
def catalog_documentation(
175175
self,
176-
) -> List[feature_lib.CatalogFeatureDocumentation]:
176+
) -> list[feature_lib.CatalogFeatureDocumentation]:
177177
feature_docs = [
178178
feature_lib.CatalogFeatureDocumentation(
179179
name='',
@@ -210,7 +210,7 @@ def get_serialized_info(self):
210210

211211
@classmethod
212212
def from_json_content(
213-
cls, value: Union[Json, feature_pb2.FeaturesDict]
213+
cls, value: Json | feature_pb2.FeaturesDict
214214
) -> 'FeaturesDict':
215215
if isinstance(value, dict):
216216
features = {

0 commit comments

Comments
 (0)