Skip to content

Commit 00fb679

Browse files
pierrot0The TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Fix asqa_dataset_builder for numpy2 by using int64 for sample_id feature.
PiperOrigin-RevId: 666805583
1 parent b9b534a commit 00fb679

File tree

5 files changed

+20
-7
lines changed

5 files changed

+20
-7
lines changed

tensorflow_datasets/core/beam_utils_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def test_read_from_tfds(
5151
dummy_dataset, split=split, workers_per_shard=workers_per_shard
5252
)
5353
| beam.Map(dataset_utils.as_numpy)
54+
# Post numpy2, we don't get `{'id': 0}` but
55+
# `{'id': np.int64(0)}`
56+
| beam.Map(lambda x: {'id': int(x['id'])})
5457
| beam.io.WriteToText(os.fspath(tmp_path / 'out.txt'))
5558
)
5659

tensorflow_datasets/core/features/bounding_boxes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ def encode_example(self, bbox: Union[bb_utils.BBox, np.ndarray]):
8888

8989
# Validate the coordinates
9090
for coordinate in bbox:
91-
if not isinstance(coordinate, float):
91+
if not isinstance(
92+
coordinate,
93+
(float, np.float16, np.float32, np.float64),
94+
):
9295
raise ValueError(
9396
'BBox coordinates should be float. Got {}.'.format(bbox)
9497
)

tensorflow_datasets/core/utils/image_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,17 @@ def get_colormap() -> np.ndarray:
127127
"""
128128
colormap_path = resource_utils.tfds_path() / 'core/utils/colormap.csv'
129129
with colormap_path.open() as f:
130-
return np.array(list(csv.reader(f)), dtype=np.uint8)
130+
colormap = np.array(list(csv.reader(f)), dtype=np.uint8)
131+
assert colormap.shape == (256, 3)
132+
return colormap
131133

132134

133135
def apply_colormap(image: np.ndarray) -> np.ndarray:
134136
"""Apply colormap from grayscale (h, w, 1) to colored (h, w, 3) image."""
135137
image = image.squeeze(axis=-1) # (h, w, 1) -> (h, w)
136138
cmap = get_colormap() # Get the (256, 3) colormap
137-
# Normalize uint16 and convert each value to a unique color
138-
return cmap[image % len(cmap)]
139+
# Convert the image to uint64 first to avoid overflow.
140+
return cmap[(image.astype(np.int64) % 256).astype(np.uint8)]
139141

140142

141143
# Visualization single image

tensorflow_datasets/datasets/asqa/asqa_dataset_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
def _features():
2525
return tfds.features.FeaturesDict({
26-
'sample_id': np.int32,
26+
'sample_id': np.int64,
2727
'ambiguous_question': tfds.features.Text(
2828
doc='Disambiguated question from AmbigQA.'
2929
),
@@ -82,6 +82,7 @@ class Builder(tfds.core.GeneratorBasedBuilder):
8282

8383
VERSION = tfds.core.Version('1.0.0')
8484
RELEASE_NOTES = {
85+
'2.0.0': 'Sample ID goes from int32 (overflowing) to int64.',
8586
'1.0.0': 'Initial release.',
8687
}
8788

tensorflow_datasets/datasets/duke_ultrasound/duke_ultrasound_dataset_builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
"""DAS beamformed phantom images and paired clinical post-processed images."""
1717

1818
import csv
19+
from datetime import datetime
1920
import os
20-
2121
from etils import epath
2222
import numpy as np
2323
from tensorflow_datasets.core.utils import bool_utils
@@ -44,6 +44,7 @@ class Builder(tfds.core.GeneratorBasedBuilder):
4444

4545
VERSION = tfds.core.Version('1.0.1')
4646
RELEASE_NOTES = {
47+
'2.0.0': r'Fix timestamp_id from %Y%m%d%H%M%S to posix timestamp.',
4748
'1.0.1': 'Fixes parsing of boolean field `harmonic`.',
4849
'1.0.0': 'Initial release.',
4950
}
@@ -124,6 +125,9 @@ def _generate_examples(self, datapath, csvpath):
124125
iq = iq / iq.max()
125126
iq = 20 * np.log10(iq)
126127

128+
timestamp_id = datetime.strptime(row['timestamp_id'], '%Y%m%d%H%M%S')
129+
timestamp = int(timestamp_id.timestamp())
130+
127131
yield row['filename'], {
128132
'das': {
129133
'dB': iq.astype(np.float32),
@@ -143,6 +147,6 @@ def _generate_examples(self, datapath, csvpath):
143147
'probe': row['probe'],
144148
'scanner': row['scanner'],
145149
'target': row['target'],
146-
'timestamp_id': row['timestamp_id'],
150+
'timestamp_id': timestamp,
147151
'harmonic': bool_utils.parse_bool(row['harm']),
148152
}

0 commit comments

Comments
 (0)