Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from kornia:main #5

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kornia/augmentation/_2d/mix/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MixAugmentationBaseV2(_BasicAugmentationBase):
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
to the batch form ``False``.
data_keys: the input type sequential for applying augmentations.
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label".
"""

def __init__(
Expand Down
3 changes: 2 additions & 1 deletion kornia/augmentation/_2d/mix/jigsaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class RandomJigsaw(MixAugmentationBaseV2):
ensure_perm: to ensure the nonidentical patch permutation generation against
the original one.
data_keys: the input type sequential for applying augmentations.
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints",
"class", "label".
p: probability of applying the transformation for the whole batch.
same_on_batch: apply the same transformation across the batch.
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
Expand Down
3 changes: 2 additions & 1 deletion kornia/augmentation/_2d/mix/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class RandomMosaic(MixAugmentationBaseV2):
each output will mix 4 images in a 2x2 grid.
min_bbox_size: minimum area of bounding boxes. Default to 0.
data_keys: the input type sequential for applying augmentations.
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints",
"class", "label".
p: probability of applying the transformation for the whole batch.
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
to the batch form ``False``.
Expand Down
10 changes: 5 additions & 5 deletions kornia/augmentation/container/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
*args: a list of kornia augmentation modules.

data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask",
"bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
"bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label".

same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise
settings.
Expand Down Expand Up @@ -234,7 +234,7 @@ def __init__(
self.data_keys = data_keys

if self.data_keys:
if not all(in_type in DataKey for in_type in self.data_keys):
if any(in_type not in DataKey for in_type in self.data_keys):
raise AssertionError(f"`data_keys` must be in {DataKey}. Got {self.data_keys}.")

if self.data_keys[0] != DataKey.INPUT:
Expand Down Expand Up @@ -446,9 +446,9 @@ def _preproc_dict_data(
if self.data_keys is not None:
raise ValueError("If you are using a dictionary as input, the data_keys should be None.")

data_keys = self._read_datakeys_from_dict(tuple(data.keys()))
keys = tuple(data.keys())
data_unpacked = tuple(v for v in data.values())
data_keys = self._read_datakeys_from_dict(keys)
data_unpacked = tuple(data.values())

return keys, data_keys, data_unpacked

Expand All @@ -467,7 +467,7 @@ def retrieve_key(key: str) -> DataKey:

allowed_dk = " | ".join(f"`{d.name}`" for d in DataKey)
raise ValueError(
f"You input data dictionary keys should starts with some of datakey values: {allowed_dk}. Got `{key}`"
f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`"
)

return [DataKey.get(retrieve_key(k)) for k in keys]
Expand Down
1 change: 1 addition & 0 deletions kornia/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class DataKey(Enum, metaclass=_KORNIA_EnumMeta):
BBOX_XYXY = 3
BBOX_XYWH = 4
KEYPOINTS = 5
LABEL = 6
CLASS = 6

@classmethod
Expand Down
5 changes: 3 additions & 2 deletions tests/augmentation/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,10 +514,11 @@ def test_bbox(self, bbox, augmentation, device, dtype):
for i in range(len(bbox)):
assert len(bboxes_transformed[i]) == len(bbox[i])

def test_class(self, device, dtype):
@pytest.mark.parametrize("class_data_key", ["class", "label"])
def test_class(self, class_data_key, device, dtype):
img = torch.zeros((5, 1, 5, 5))
labels = torch.randint(0, 10, size=(5, 1))
aug = K.AugmentationSequential(K.RandomCrop((3, 3), pad_if_needed=True), data_keys=["input", "class"])
aug = K.AugmentationSequential(K.RandomCrop((3, 3), pad_if_needed=True), data_keys=["input", class_data_key])

_, out_labels = aug(img, labels)
assert labels is out_labels
Expand Down