Skip to content

Commit 8df6c44

Browse files
authored
DataKey: add 'label' as alias of 'class' (kornia#2873)
* DataKey: add 'label' as alias of 'class' * Move up LABEL * Simplify expressions
1 parent f6bf869 commit 8df6c44

File tree

6 files changed

+14
-10
lines changed

6 files changed

+14
-10
lines changed

kornia/augmentation/_2d/mix/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class MixAugmentationBaseV2(_BasicAugmentationBase):
2929
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
3030
to the batch form ``False``.
3131
data_keys: the input type sequential for applying augmentations.
32-
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
32+
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label".
3333
"""
3434

3535
def __init__(

kornia/augmentation/_2d/mix/jigsaw.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ class RandomJigsaw(MixAugmentationBaseV2):
2424
ensure_perm: to ensure the nonidentical patch permutation generation against
2525
the original one.
2626
data_keys: the input type sequential for applying augmentations.
27-
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
27+
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints",
28+
"class", "label".
2829
p: probability of applying the transformation for the whole batch.
2930
same_on_batch: apply the same transformation across the batch.
3031
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it

kornia/augmentation/_2d/mix/mosaic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class RandomMosaic(MixAugmentationBaseV2):
3535
each output will mix 4 images in a 2x2 grid.
3636
min_bbox_size: minimum area of bounding boxes. Default to 0.
3737
data_keys: the input type sequential for applying augmentations.
38-
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
38+
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints",
39+
"class", "label".
3940
p: probability of applying the transformation for the whole batch.
4041
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
4142
to the batch form ``False``.

kornia/augmentation/container/augment.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
3333
*args: a list of kornia augmentation modules.
3434
3535
data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask",
36-
"bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
36+
"bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label".
3737
3838
same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise
3939
settings.
@@ -234,7 +234,7 @@ def __init__(
234234
self.data_keys = data_keys
235235

236236
if self.data_keys:
237-
if not all(in_type in DataKey for in_type in self.data_keys):
237+
if any(in_type not in DataKey for in_type in self.data_keys):
238238
raise AssertionError(f"`data_keys` must be in {DataKey}. Got {self.data_keys}.")
239239

240240
if self.data_keys[0] != DataKey.INPUT:
@@ -446,9 +446,9 @@ def _preproc_dict_data(
446446
if self.data_keys is not None:
447447
raise ValueError("If you are using a dictionary as input, the data_keys should be None.")
448448

449-
data_keys = self._read_datakeys_from_dict(tuple(data.keys()))
450449
keys = tuple(data.keys())
451-
data_unpacked = tuple(v for v in data.values())
450+
data_keys = self._read_datakeys_from_dict(keys)
451+
data_unpacked = tuple(data.values())
452452

453453
return keys, data_keys, data_unpacked
454454

@@ -467,7 +467,7 @@ def retrieve_key(key: str) -> DataKey:
467467

468468
allowed_dk = " | ".join(f"`{d.name}`" for d in DataKey)
469469
raise ValueError(
470-
f"You input data dictionary keys should starts with some of datakey values: {allowed_dk}. Got `{key}`"
470+
f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`"
471471
)
472472

473473
return [DataKey.get(retrieve_key(k)) for k in keys]

kornia/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class DataKey(Enum, metaclass=_KORNIA_EnumMeta):
130130
BBOX_XYXY = 3
131131
BBOX_XYWH = 4
132132
KEYPOINTS = 5
133+
LABEL = 6
133134
CLASS = 6
134135

135136
@classmethod

tests/augmentation/test_container.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -514,10 +514,11 @@ def test_bbox(self, bbox, augmentation, device, dtype):
514514
for i in range(len(bbox)):
515515
assert len(bboxes_transformed[i]) == len(bbox[i])
516516

517-
def test_class(self, device, dtype):
517+
@pytest.mark.parametrize("class_data_key", ["class", "label"])
518+
def test_class(self, class_data_key, device, dtype):
518519
img = torch.zeros((5, 1, 5, 5))
519520
labels = torch.randint(0, 10, size=(5, 1))
520-
aug = K.AugmentationSequential(K.RandomCrop((3, 3), pad_if_needed=True), data_keys=["input", "class"])
521+
aug = K.AugmentationSequential(K.RandomCrop((3, 3), pad_if_needed=True), data_keys=["input", class_data_key])
521522

522523
_, out_labels = aug(img, labels)
523524
assert labels is out_labels

0 commit comments

Comments
 (0)