Skip to content

Commit e0aadf6

Browse files
author
The TensorFlow Datasets Authors
committed
Ensure that Croissant versions that only specify major and minor versions are properly converted to Semantic versions.
PiperOrigin-RevId: 760995569
1 parent 2034928 commit e0aadf6

File tree

5 files changed

+70
-5
lines changed

5 files changed

+70
-5
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,11 +279,17 @@ def __init__(
279279
self.name = croissant_utils.get_tfds_dataset_name(dataset)
280280
self.metadata = dataset.metadata
281281

282-
# In TFDS, version is a mandatory attribute, while in Croissant it is only a
283-
# recommended attribute. If the version is unspecified in Croissant, we set
284-
# it to `1.0.0` in TFDS.
282+
# The dataset version is determined using the following precedence:
283+
# * overwrite_version (if provided).
284+
# * The version from Croissant metadata (self.metadata.version),
285+
# automatically converting major.minor formats to major.minor.0 (e.g., "1.2"
286+
# becomes "1.2.0"). See croissant_utils.get_croissant_version for details.
287+
# * Defaults to '1.0.0' if no version is specified (version is optional in
288+
# Croissant, but mandatory in TFDS).
285289
self.VERSION = version_lib.Version( # pylint: disable=invalid-name
286-
overwrite_version or self.metadata.version or '1.0.0'
290+
overwrite_version
291+
or croissant_utils.get_croissant_version(self.metadata.version)
292+
or '1.0.0'
287293
)
288294
self.RELEASE_NOTES = {} # pylint: disable=invalid-name
289295

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,17 @@ def test_sequence_feature_datatype_converter():
249249
assert isinstance(actual_feature.feature, text_feature.Text)
250250

251251

252+
def test_version_converter(tmp_path):
253+
with testing.dummy_croissant_file(version="1.0") as croissant_file:
254+
builder = croissant_builder.CroissantBuilder(
255+
jsonld=croissant_file,
256+
file_format=FileFormat.ARRAY_RECORD,
257+
disable_shuffling=True,
258+
data_dir=tmp_path,
259+
)
260+
assert builder.version == "1.0.0"
261+
262+
252263
@pytest.fixture(name="crs_builder")
253264
def mock_croissant_dataset_builder(tmp_path, request):
254265
dataset_name = request.param["dataset_name"]

tensorflow_datasets/core/utils/croissant_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import annotations
1919

2020
import dataclasses
21+
import re
2122
import typing
2223

2324
from tensorflow_datasets.core.utils import conversion_utils
@@ -29,6 +30,12 @@
2930

3031
_HUGGINGFACE_URL_PREFIX = "https://huggingface.co/datasets/"
3132

33+
_VERSION_TMPL_WITHOUT_PATCH = r"^(?P<major>{v})" r"\.(?P<minor>{v})$"
34+
_NO_LEADING_ZEROS = r"\d|[1-9]\d*"
35+
_VERSION_REGEX_WITHOUT_PATCH = re.compile(
36+
_VERSION_TMPL_WITHOUT_PATCH.format(v=_NO_LEADING_ZEROS)
37+
)
38+
3239

3340
@dataclasses.dataclass(frozen=True)
3441
class SplitReference:
@@ -40,6 +47,28 @@ class SplitReference:
4047
reference_field: mlc.Field
4148

4249

50+
def get_croissant_version(version: str | None) -> str | None:
51+
"""Returns the possibly corrected Croissant version in TFDS format.
52+
53+
TFDS expects versions to follow the Semantic versioning 2.0.0 syntax, but
54+
Croissant is more lax and accepts also {major.minor} (without leading zeros).
55+
To avoid raising errors in these cases, we add a 0 as a patch version to the
56+
Croissant-provided version.
57+
58+
Args:
59+
version: The Croissant version.
60+
61+
Returns:
62+
The Croissant version in TFDS format.
63+
"""
64+
if not version:
65+
return None
66+
res = _VERSION_REGEX_WITHOUT_PATCH.match(version)
67+
if res:
68+
return f"{version}.0"
69+
return version
70+
71+
4372
def get_dataset_name(dataset: mlc.Dataset) -> str:
4473
"""Returns dataset name of the given MLcroissant dataset."""
4574
if (url := dataset.metadata.url) and url.startswith(_HUGGINGFACE_URL_PREFIX):

tensorflow_datasets/core/utils/croissant_utils_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,23 @@ def test_get_tfds_dataset_name(croissant_name, croissant_url, tfds_name):
3636
assert croissant_utils.get_tfds_dataset_name(dataset) == tfds_name
3737

3838

39+
@pytest.mark.parametrize(
40+
'croissant_version,tfds_version',
41+
[
42+
('1.0', '1.0.0'),
43+
('1.2', '1.2.0'),
44+
('1.2.3', '1.2.3'),
45+
('1.2.3.4', '1.2.3.4'),
46+
('0.1', '0.1'),
47+
(None, None),
48+
],
49+
)
50+
def test_get_croissant_version(croissant_version, tfds_version):
51+
assert (
52+
croissant_utils.get_croissant_version(croissant_version) == tfds_version
53+
)
54+
55+
3956
def test_get_record_set_ids():
4057
metadata = mlc.Metadata(
4158
name='dummy_dataset',

tensorflow_datasets/testing/test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ def dummy_croissant_file(
724724
raw_data_filename: epath.PathLike = 'raw_data.jsonl',
725725
croissant_filename: epath.PathLike = 'croissant.json',
726726
split_names: Sequence[str] | None = None,
727+
version: str = '1.2.0',
727728
) -> Iterator[epath.Path]:
728729
"""Yields temporary path to a dummy Croissant file.
729730
@@ -746,6 +747,7 @@ def dummy_croissant_file(
746747
If None, the function will create a split record set with the default
747748
split names `train` and `test`. If `split_names` is defined, the `split`
748749
key in the entries must match one of the split names.
750+
version: The version of the dataset. Defaults to `1.2.0`.
749751
"""
750752
if entries is None:
751753
entries = [
@@ -874,7 +876,7 @@ def dummy_croissant_file(
874876
url='https://dummy_url',
875877
distribution=distribution,
876878
record_sets=record_sets,
877-
version='1.2.0',
879+
version=version,
878880
license='Public',
879881
)
880882
# Write Croissant JSON-LD to tempdir.

0 commit comments

Comments
 (0)