Skip to content

Commit 55c4c54

Browse files
author
The TensorFlow Datasets Authors
committed
The dataset_utils.as_numpy doesn't convert RaggedTensors. Fix that but providing custom conversion.
PiperOrigin-RevId: 736137874
1 parent 91a57fe commit 55c4c54

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

tensorflow_datasets/robotics/asimov/asimov.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
from tensorflow_datasets.robotics import dataset_importer_builder
1919

2020
ASIMOV_CITATION = """
21-
@article{Sermanet2025RobotConstitutions,
21+
@article{sermanet2025asimov,
2222
author = {Pierre Sermanet and Anirudha Majumdar and Alex Irpan and Dmitry Kalashnikov and Vikas Sindhwani},
2323
title = {Generating Robot Constitutions & Benchmarks for Semantic Safety},
24-
journal = {arXiv preprint arXiv:FIXME},
25-
url = {https://arxiv.org/abs/FIXME},
24+
journal = {arXiv preprint arXiv:2503.08663},
25+
url = {https://arxiv.org/abs/2503.08663},
2626
year = {2025},
2727
}
2828
"""

tensorflow_datasets/robotics/dataset_importer_builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tensorflow_datasets.core.utils import read_config as read_config_lib
2929
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
3030
import tensorflow_datasets.public_api as tfds
31+
import tree
3132

3233

3334

@@ -218,7 +219,7 @@ def _generate_examples(
218219
read_config = read_config_lib.ReadConfig(add_tfds_id=True)
219220

220221
def converter_fn(example):
221-
example_out = dataset_utils.as_numpy(example)
222+
example_out = tree.map_structure(to_np, example)
222223
example_id = example_out['tfds_id'].decode('utf-8')
223224
del example_out['tfds_id']
224225

@@ -235,3 +236,10 @@ def get_ds_builder(self):
235236
ds_location = self.get_dataset_location()
236237
ds_builder = tfds.builder_from_directory(ds_location)
237238
return ds_builder
239+
240+
241+
def to_np(tensor):
242+
"""Convert tensor to numpy."""
243+
if isinstance(tensor, tf.Tensor) or isinstance(tensor, tf.RaggedTensor):
244+
return tensor.numpy()
245+
return tensor

0 commit comments

Comments
 (0)