Skip to content

Commit bb3b942

Browse files
authored
Add utility to map COCO IDs to class names (#2219)
Similar to the ImageNet utility.
1 parent 9d319ff commit bb3b942

File tree

6 files changed

+212
-0
lines changed

6 files changed

+212
-0
lines changed

keras_hub/api/utils/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras_hub.src.utils.coco.coco_utils import (
8+
coco_id_to_name as coco_id_to_name,
9+
)
10+
from keras_hub.src.utils.coco.coco_utils import (
11+
coco_name_to_id as coco_name_to_id,
12+
)
713
from keras_hub.src.utils.imagenet.imagenet_utils import (
814
decode_imagenet_predictions as decode_imagenet_predictions,
915
)
16+
from keras_hub.src.utils.imagenet.imagenet_utils import (
17+
imagenet_id_to_name as imagenet_id_to_name,
18+
)
19+
from keras_hub.src.utils.imagenet.imagenet_utils import (
20+
imagenet_name_to_id as imagenet_name_to_id,
21+
)

keras_hub/src/utils/coco/__init__.py

Whitespace-only changes.
+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
3+
4+
@keras_hub_export("keras_hub.utils.coco_id_to_name")
5+
def coco_id_to_name(id):
6+
"""Convert a single COCO class name to a class ID.
7+
8+
Args:
9+
id: An integer class id from 0 to 91.
10+
11+
Returns:
12+
The human readable image class name, e.g. "bicycle".
13+
14+
Example:
15+
>>> keras_hub.utils.coco_id_to_name(2)
16+
'bicycle'
17+
"""
18+
return COCO_NAMES[id]
19+
20+
21+
@keras_hub_export("keras_hub.utils.coco_name_to_id")
22+
def coco_name_to_id(name):
23+
"""Convert a single COCO class name to a class ID.
24+
25+
Args:
26+
name: A human readable image class name, e.g. "bicycle".
27+
28+
Returns:
29+
The integer class id from 0 to 999.
30+
31+
Example:
32+
>>> keras_hub.utils.coco_name_to_id("bicycle")
33+
2
34+
"""
35+
return COCO_IDS[name]
36+
37+
38+
COCO_NAMES = {
39+
0: "unlabeled",
40+
1: "person",
41+
2: "bicycle",
42+
3: "car",
43+
4: "motorcycle",
44+
5: "airplane",
45+
6: "bus",
46+
7: "train",
47+
8: "truck",
48+
9: "boat",
49+
10: "traffic_light",
50+
11: "fire_hydrant",
51+
12: "street_sign",
52+
13: "stop_sign",
53+
14: "parking_meter",
54+
15: "bench",
55+
16: "bird",
56+
17: "cat",
57+
18: "dog",
58+
19: "horse",
59+
20: "sheep",
60+
21: "cow",
61+
22: "elephant",
62+
23: "bear",
63+
24: "zebra",
64+
25: "giraffe",
65+
26: "hat",
66+
27: "backpack",
67+
28: "umbrella",
68+
29: "shoe",
69+
30: "eye_glasses",
70+
31: "handbag",
71+
32: "tie",
72+
33: "suitcase",
73+
34: "frisbee",
74+
35: "skis",
75+
36: "snowboard",
76+
37: "sports_ball",
77+
38: "kite",
78+
39: "baseball_bat",
79+
40: "baseball_glove",
80+
41: "skateboard",
81+
42: "surfboard",
82+
43: "tennis_racket",
83+
44: "bottle",
84+
45: "plate",
85+
46: "wine_glass",
86+
47: "cup",
87+
48: "fork",
88+
49: "knife",
89+
50: "spoon",
90+
51: "bowl",
91+
52: "banana",
92+
53: "apple",
93+
54: "sandwich",
94+
55: "orange",
95+
56: "broccoli",
96+
57: "carrot",
97+
58: "hot_dog",
98+
59: "pizza",
99+
60: "donut",
100+
61: "cake",
101+
62: "chair",
102+
63: "couch",
103+
64: "potted_plant",
104+
65: "bed",
105+
66: "mirror",
106+
67: "dining_table",
107+
68: "window",
108+
69: "desk",
109+
70: "toilet",
110+
71: "door",
111+
72: "tv",
112+
73: "laptop",
113+
74: "mouse",
114+
75: "remote",
115+
76: "keyboard",
116+
77: "cell_phone",
117+
78: "microwave",
118+
79: "oven",
119+
80: "toaster",
120+
81: "sink",
121+
82: "refrigerator",
122+
83: "blender",
123+
84: "book",
124+
85: "clock",
125+
86: "vase",
126+
87: "scissors",
127+
88: "teddy_bear",
128+
89: "hair_drier",
129+
90: "toothbrush",
130+
91: "hair_brush",
131+
}
132+
133+
COCO_IDS = {v: k for k, v in COCO_NAMES.items()}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from keras_hub.src.tests.test_case import TestCase
2+
from keras_hub.src.utils.coco.coco_utils import coco_id_to_name
3+
from keras_hub.src.utils.coco.coco_utils import coco_name_to_id
4+
5+
6+
class CocoUtilsTest(TestCase):
7+
def test_coco_id_to_name(self):
8+
self.assertEqual(coco_id_to_name(0), "unlabeled")
9+
self.assertEqual(coco_id_to_name(24), "zebra")
10+
with self.assertRaises(KeyError):
11+
coco_id_to_name(2001)
12+
13+
def test_coco_name_to_id(self):
14+
self.assertEqual(coco_name_to_id("unlabeled"), 0)
15+
self.assertEqual(coco_name_to_id("zebra"), 24)
16+
with self.assertRaises(KeyError):
17+
coco_name_to_id("whirligig")

keras_hub/src/utils/imagenet/imagenet_utils.py

+36
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,40 @@
33
from keras_hub.src.api_export import keras_hub_export
44

55

6+
@keras_hub_export("keras_hub.utils.imagenet_id_to_name")
7+
def imagenet_id_to_name(id):
8+
"""Convert a single ImageNet class ID to a class name.
9+
10+
Args:
11+
id: An integer class id from 0 to 999.
12+
13+
Returns:
14+
The human readable image class name, e.g. "goldfish".
15+
16+
Example:
17+
>>> keras_hub.utils.imagenet_id_to_name(1)
18+
'goldfish'
19+
"""
20+
return IMAGENET_NAMES[id][1]
21+
22+
23+
@keras_hub_export("keras_hub.utils.imagenet_name_to_id")
24+
def imagenet_name_to_id(name):
25+
"""Convert a single ImageNet class name to a class ID.
26+
27+
Args:
28+
name: A human readable image class name, e.g. "goldfish".
29+
30+
Returns:
31+
The integer class id from 0 to 999.
32+
33+
Example:
34+
>>> keras_hub.utils.imagenet_name_to_id("goldfish")
35+
1
36+
"""
37+
return IMAGENET_IDS[name]
38+
39+
640
@keras_hub_export("keras_hub.utils.decode_imagenet_predictions")
741
def decode_imagenet_predictions(preds, top=5, include_synset_ids=False):
842
"""Decodes the predictions for an ImageNet-1k prediction.
@@ -1052,3 +1086,5 @@ def decode_imagenet_predictions(preds, top=5, include_synset_ids=False):
10521086
998: ("n13133613", "ear"),
10531087
999: ("n15075141", "toilet_tissue"),
10541088
}
1089+
1090+
IMAGENET_IDS = {v[1]: k for k, v in IMAGENET_NAMES.items()}

keras_hub/src/utils/imagenet/imagenet_utils_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,23 @@
44
from keras_hub.src.utils.imagenet.imagenet_utils import (
55
decode_imagenet_predictions,
66
)
7+
from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_id_to_name
8+
from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_name_to_id
79

810

911
class ImageNetUtilsTest(TestCase):
12+
def test_imagenet_id_to_name(self):
13+
self.assertEqual(imagenet_id_to_name(0), "tench")
14+
self.assertEqual(imagenet_id_to_name(21), "kite")
15+
with self.assertRaises(KeyError):
16+
imagenet_id_to_name(2001)
17+
18+
def test_imagenet_name_to_id(self):
19+
self.assertEqual(imagenet_name_to_id("tench"), 0)
20+
self.assertEqual(imagenet_name_to_id("kite"), 21)
21+
with self.assertRaises(KeyError):
22+
imagenet_name_to_id(2001)
23+
1024
def test_decode_imagenet_predictions(self):
1125
preds = np.array(
1226
[

0 commit comments

Comments
 (0)