From 8df6c44ce6ba233c07185ac46107c6d9a076d3c0 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 8 Apr 2024 16:25:06 +0400 Subject: [PATCH] DataKey: add 'label' as alias of 'class' (#2873) * DataKey: add 'label' as alias of 'class' * Move up LABEL * Simplify expressions --- kornia/augmentation/_2d/mix/base.py | 2 +- kornia/augmentation/_2d/mix/jigsaw.py | 3 ++- kornia/augmentation/_2d/mix/mosaic.py | 3 ++- kornia/augmentation/container/augment.py | 10 +++++----- kornia/constants.py | 1 + tests/augmentation/test_container.py | 5 +++-- 6 files changed, 14 insertions(+), 10 deletions(-) diff --git a/kornia/augmentation/_2d/mix/base.py b/kornia/augmentation/_2d/mix/base.py index 44efd496a0..0d2dbd6fbd 100644 --- a/kornia/augmentation/_2d/mix/base.py +++ b/kornia/augmentation/_2d/mix/base.py @@ -29,7 +29,7 @@ class MixAugmentationBaseV2(_BasicAugmentationBase): keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. data_keys: the input type sequential for applying augmentations. - Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label". """ def __init__( diff --git a/kornia/augmentation/_2d/mix/jigsaw.py b/kornia/augmentation/_2d/mix/jigsaw.py index 37c815e8a1..bd994a3e63 100644 --- a/kornia/augmentation/_2d/mix/jigsaw.py +++ b/kornia/augmentation/_2d/mix/jigsaw.py @@ -24,7 +24,8 @@ class RandomJigsaw(MixAugmentationBaseV2): ensure_perm: to ensure the nonidentical patch permutation generation against the original one. data_keys: the input type sequential for applying augmentations. - Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", + "class", "label". p: probability of applying the transformation for the whole batch. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it diff --git a/kornia/augmentation/_2d/mix/mosaic.py b/kornia/augmentation/_2d/mix/mosaic.py index acbc1ca592..2a5537dec5 100644 --- a/kornia/augmentation/_2d/mix/mosaic.py +++ b/kornia/augmentation/_2d/mix/mosaic.py @@ -35,7 +35,8 @@ class RandomMosaic(MixAugmentationBaseV2): each output will mix 4 images in a 2x2 grid. min_bbox_size: minimum area of bounding boxes. Default to 0. data_keys: the input type sequential for applying augmentations. - Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", + "class", "label". p: probability of applying the transformation for the whole batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index 21716f277e..7d5dcc8206 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -33,7 +33,7 @@ class AugmentationSequential(TransformMatrixMinIn, ImageSequential): *args: a list of kornia augmentation modules. data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask", - "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label". same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise settings. @@ -234,7 +234,7 @@ def __init__( self.data_keys = data_keys if self.data_keys: - if not all(in_type in DataKey for in_type in self.data_keys): + if any(in_type not in DataKey for in_type in self.data_keys): raise AssertionError(f"`data_keys` must be in {DataKey}. Got {self.data_keys}.") if self.data_keys[0] != DataKey.INPUT: @@ -446,9 +446,9 @@ def _preproc_dict_data( if self.data_keys is not None: raise ValueError("If you are using a dictionary as input, the data_keys should be None.") - data_keys = self._read_datakeys_from_dict(tuple(data.keys())) keys = tuple(data.keys()) - data_unpacked = tuple(v for v in data.values()) + data_keys = self._read_datakeys_from_dict(keys) + data_unpacked = tuple(data.values()) return keys, data_keys, data_unpacked @@ -467,7 +467,7 @@ def retrieve_key(key: str) -> DataKey: allowed_dk = " | ".join(f"`{d.name}`" for d in DataKey) raise ValueError( - f"You input data dictionary keys should starts with some of datakey values: {allowed_dk}. Got `{key}`" + f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`" ) return [DataKey.get(retrieve_key(k)) for k in keys] diff --git a/kornia/constants.py b/kornia/constants.py index 24d0f644b4..72ac09863f 100644 --- a/kornia/constants.py +++ b/kornia/constants.py @@ -130,6 +130,7 @@ class DataKey(Enum, metaclass=_KORNIA_EnumMeta): BBOX_XYXY = 3 BBOX_XYWH = 4 KEYPOINTS = 5 + LABEL = 6 CLASS = 6 @classmethod diff --git a/tests/augmentation/test_container.py b/tests/augmentation/test_container.py index 25e5fad4b0..8963bb6ddd 100644 --- a/tests/augmentation/test_container.py +++ b/tests/augmentation/test_container.py @@ -514,10 +514,11 @@ def test_bbox(self, bbox, augmentation, device, dtype): for i in range(len(bbox)): assert len(bboxes_transformed[i]) == len(bbox[i]) - def test_class(self, device, dtype): + @pytest.mark.parametrize("class_data_key", ["class", "label"]) + def test_class(self, class_data_key, device, dtype): img = torch.zeros((5, 1, 5, 5)) labels = torch.randint(0, 10, size=(5, 1)) - aug = K.AugmentationSequential(K.RandomCrop((3, 3), pad_if_needed=True), data_keys=["input", "class"]) + aug = K.AugmentationSequential(K.RandomCrop((3, 3), pad_if_needed=True), data_keys=["input", class_data_key]) _, out_labels = aug(img, labels) assert labels is out_labels