Skip to content

Commit d284e46

Browse files
author
The TensorFlow Datasets Authors
committed
Ensure that Croissant versions that only specify major and minor versions are properly converted to Semantic versions.
PiperOrigin-RevId: 760995569
1 parent df781bb commit d284e46

File tree

5 files changed

+63
-5
lines changed

5 files changed

+63
-5
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,17 @@ def __init__(
282282
self.name = croissant_utils.get_tfds_dataset_name(dataset)
283283
self.metadata = dataset.metadata
284284

285-
# In TFDS, version is a mandatory attribute, while in Croissant it is only a
286-
# recommended attribute. If the version is unspecified in Croissant, we set
287-
# it to `1.0.0` in TFDS.
285+
# The dataset version is determined using the following precedence:
286+
# * overwrite_version (if provided).
287+
# * The version from Croissant metadata (self.metadata.version),
288+
# automatically converting major.minor formats to major.minor.0 (e.g., "1.2"
289+
# becomes "1.2.0"). See croissant_utils.get_croissant_version for details.
290+
# * Defaults to '1.0.0' if no version is specified (version is optional in
291+
# Croissant, but mandatory in TFDS).
288292
self.VERSION = version_lib.Version( # pylint: disable=invalid-name
289-
overwrite_version or self.metadata.version or '1.0.0'
293+
overwrite_version
294+
or croissant_utils.get_croissant_version(self.metadata.version)
295+
or '1.0.0'
290296
)
291297
self.RELEASE_NOTES = {} # pylint: disable=invalid-name
292298

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,17 @@ def test_sequence_feature_datatype_converter():
249249
assert isinstance(actual_feature.feature, text_feature.Text)
250250

251251

252+
def test_version_converter(tmp_path):
253+
with testing.dummy_croissant_file(version="1.0") as croissant_file:
254+
builder = croissant_builder.CroissantBuilder(
255+
jsonld=croissant_file,
256+
file_format=FileFormat.ARRAY_RECORD,
257+
disable_shuffling=True,
258+
data_dir=tmp_path,
259+
)
260+
assert builder.version == "1.0.0"
261+
262+
252263
@pytest.fixture(name="crs_builder")
253264
def mock_croissant_dataset_builder(tmp_path, request):
254265
dataset_name = request.param["dataset_name"]

tensorflow_datasets/core/utils/croissant_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import annotations
1919

2020
import dataclasses
21+
import re
2122
import typing
2223

2324
from tensorflow_datasets.core.utils import conversion_utils
@@ -28,6 +29,7 @@
2829
import mlcroissant as mlc
2930

3031
_HUGGINGFACE_URL_PREFIX = "https://huggingface.co/datasets/"
32+
_VERSION_REGEX_WITHOUT_PATCH = re.compile(r"^(?P<major>\d+)\.(?P<minor>\d+)$")
3133

3234

3335
@dataclasses.dataclass(frozen=True)
@@ -40,6 +42,27 @@ class SplitReference:
4042
reference_field: mlc.Field
4143

4244

45+
def get_croissant_version(version: str | None) -> str | None:
46+
"""Returns the possibly corrected Croissant version in TFDS format.
47+
48+
TFDS expects versions to follow the Semantic versioning 2.0.0 syntax, but
49+
Croissant is more lax and accepts also {major.minor}. To avoid raising errors
50+
in these cases, we add a `0` as a patch version to the Croissant-provided
51+
version.
52+
53+
Args:
54+
version: The Croissant version.
55+
56+
Returns:
57+
The Croissant version in TFDS format.
58+
"""
59+
if not version:
60+
return None
61+
if _VERSION_REGEX_WITHOUT_PATCH.match(version):
62+
return f"{version}.0"
63+
return version
64+
65+
4366
def get_dataset_name(dataset: mlc.Dataset) -> str:
4467
"""Returns dataset name of the given MLcroissant dataset."""
4568
if (url := dataset.metadata.url) and url.startswith(_HUGGINGFACE_URL_PREFIX):

tensorflow_datasets/core/utils/croissant_utils_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ def test_get_tfds_dataset_name(croissant_name, croissant_url, tfds_name):
3636
assert croissant_utils.get_tfds_dataset_name(dataset) == tfds_name
3737

3838

39+
@pytest.mark.parametrize(
40+
'croissant_version,tfds_version',
41+
[
42+
('1.0', '1.0.0'),
43+
('1.2', '1.2.0'),
44+
('1.2.3', '1.2.3'),
45+
('1.2.3.4', '1.2.3.4'),
46+
(None, None),
47+
],
48+
)
49+
def test_get_croissant_version(croissant_version, tfds_version):
50+
assert (
51+
croissant_utils.get_croissant_version(croissant_version) == tfds_version
52+
)
53+
54+
3955
def test_get_record_set_ids():
4056
metadata = mlc.Metadata(
4157
name='dummy_dataset',

tensorflow_datasets/testing/test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ def dummy_croissant_file(
724724
raw_data_filename: epath.PathLike = 'raw_data.jsonl',
725725
croissant_filename: epath.PathLike = 'croissant.json',
726726
split_names: Sequence[str] | None = None,
727+
version: str = '1.2.0',
727728
) -> Iterator[epath.Path]:
728729
"""Yields temporary path to a dummy Croissant file.
729730
@@ -746,6 +747,7 @@ def dummy_croissant_file(
746747
If None, the function will create a split record set with the default
747748
split names `train` and `test`. If `split_names` is defined, the `split`
748749
key in the entries must match one of the split names.
750+
version: The version of the dataset. Defaults to `1.2.0`.
749751
"""
750752
if entries is None:
751753
entries = [
@@ -874,7 +876,7 @@ def dummy_croissant_file(
874876
url='https://dummy_url',
875877
distribution=distribution,
876878
record_sets=record_sets,
877-
version='1.2.0',
879+
version=version,
878880
license='Public',
879881
)
880882
# Write Croissant JSON-LD to tempdir.

0 commit comments

Comments
 (0)