Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from kornia:main #23

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: pyproject-fmt

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.9
rev: v0.4.10
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
74 changes: 67 additions & 7 deletions kornia/augmentation/container/augment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

import torch

from kornia.augmentation._2d.base import RigidAffineAugmentationBase2D
from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D
from kornia.augmentation.base import _AugmentationBase
Expand All @@ -21,7 +23,11 @@

_BOXES_OPTIONS = {DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH}
_KEYPOINTS_OPTIONS = {DataKey.KEYPOINTS}
_IMG_MSK_OPTIONS = {DataKey.INPUT, DataKey.MASK}
_IMG_OPTIONS = {DataKey.INPUT, DataKey.IMAGE}
_MSK_OPTIONS = {DataKey.MASK}
_CLS_OPTIONS = {DataKey.CLASS, DataKey.LABEL}

MaskDataType = Union[Tensor, List[Tensor]]


class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
Expand Down Expand Up @@ -195,6 +201,9 @@ class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
dict_keys(['image', 'mask', 'mask-b', 'bbox', 'bbox-other'])
"""

input_dtype = None
mask_dtype = None

def __init__(
self,
*args: Union[_AugmentationBase, ImageSequential],
Expand Down Expand Up @@ -332,13 +341,23 @@ def _validate_args_datakeys(self, *args: DataType, data_keys: List[DataKey]) ->
def _arguments_preproc(self, *args: DataType, data_keys: List[DataKey]) -> List[DataType]:
inp: List[DataType] = []
for arg, dcate in zip(args, data_keys):
if DataKey.get(dcate) in _IMG_MSK_OPTIONS:
if DataKey.get(dcate) in _IMG_OPTIONS:
arg = cast(Tensor, arg)
self.input_dtype = arg.dtype
inp.append(arg)
elif DataKey.get(dcate) in _MSK_OPTIONS:
if isinstance(inp, list):
arg = cast(List[Tensor], arg)
self.mask_dtype = arg[0].dtype
else:
arg = cast(Tensor, arg)
self.mask_dtype = arg.dtype
inp.append(self._preproc_mask(arg))
elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
inp.append(self._preproc_keypoints(arg, dcate))
elif DataKey.get(dcate) in _BOXES_OPTIONS:
inp.append(self._preproc_boxes(arg, dcate))
elif DataKey.get(dcate) is DataKey.CLASS:
elif DataKey.get(dcate) in _CLS_OPTIONS:
inp.append(arg)
else:
raise NotImplementedError(f"input type of {dcate} is not implemented.")
Expand All @@ -349,10 +368,13 @@ def _arguments_postproc(
) -> List[DataType]:
out: List[DataType] = []
for in_arg, out_arg, dcate in zip(in_args, out_args, data_keys):
if DataKey.get(dcate) in _IMG_MSK_OPTIONS:
if DataKey.get(dcate) in _IMG_OPTIONS:
# It is tensor type already.
out.append(out_arg)
# TODO: may add the float to integer (for masks), etc.
elif DataKey.get(dcate) in _MSK_OPTIONS:
_out_m = self._postproc_mask(cast(MaskDataType, out_arg))
out.append(_out_m)

elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
_out_k = self._postproc_keypoint(in_arg, cast(Keypoints, out_arg), dcate)
Expand All @@ -372,7 +394,7 @@ def _arguments_postproc(
_out_b = _out_b.type(in_arg.dtype)
out.append(_out_b)

elif DataKey.get(dcate) is DataKey.CLASS:
elif DataKey.get(dcate) in _CLS_OPTIONS:
out.append(out_arg)

else:
Expand Down Expand Up @@ -472,6 +494,30 @@ def retrieve_key(key: str) -> DataKey:

return [DataKey.get(retrieve_key(k)) for k in keys]

def _preproc_mask(self, arg: MaskDataType) -> MaskDataType:
if isinstance(arg, list):
new_arg = []
for a in arg:
a_new = a.to(self.input_dtype) if self.input_dtype else a.to(torch.float)
new_arg.append(a_new)
return new_arg

else:
arg = arg.to(self.input_dtype) if self.input_dtype else arg.to(torch.float)
return arg

def _postproc_mask(self, arg: MaskDataType) -> MaskDataType:
if isinstance(arg, list):
new_arg = []
for a in arg:
a_new = a.to(self.mask_dtype) if self.mask_dtype else a.to(torch.float)
new_arg.append(a_new)
return new_arg

else:
arg = arg.to(self.mask_dtype) if self.mask_dtype else arg.to(torch.float)
return arg

def _preproc_boxes(self, arg: DataType, dcate: DataKey) -> Boxes:
if DataKey.get(dcate) in [DataKey.BBOX]:
mode = "vertices_plus"
Expand Down Expand Up @@ -509,17 +555,31 @@ def _postproc_boxes(self, in_arg: DataType, out_arg: Boxes, dcate: DataKey) -> U
return out_arg.to_tensor(mode=mode)

def _preproc_keypoints(self, arg: DataType, dcate: DataKey) -> Keypoints:
dtype = None

if self.contains_video_sequential:
arg = cast(Union[Tensor, List[Tensor]], arg)
return VideoKeypoints.from_tensor(arg)
if isinstance(arg, list):
if not torch.is_floating_point(arg[0]):
dtype = arg[0].dtype
arg = [a.float() for a in arg]
elif not torch.is_floating_point(arg):
dtype = arg.dtype
arg = arg.float()
video_result = VideoKeypoints.from_tensor(arg)
return video_result.type(dtype) if dtype else video_result
elif self.contains_3d_augmentation:
raise NotImplementedError("3D keypoint handlers are not yet supported.")
elif isinstance(arg, (Keypoints,)):
return arg
else:
arg = cast(Tensor, arg)
if not torch.is_floating_point(arg):
dtype = arg.dtype
arg = arg.float()
# TODO: Add List[Tensor] in the future.
return Keypoints.from_tensor(arg)
result = Keypoints.from_tensor(arg)
return result.type(dtype) if dtype else result

def _postproc_keypoint(
self, in_arg: DataType, out_arg: Keypoints, dcate: DataKey
Expand Down
9 changes: 6 additions & 3 deletions kornia/augmentation/container/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def transform(
outputs = []
for inp, dcate in zip(arg, _data_keys):
op = self._get_op(dcate)
extra_arg = extra_args[dcate] if dcate in extra_args else {}
extra_arg = extra_args.get(dcate, {})
if dcate.name == "MASK" and isinstance(inp, list):
outputs.append(MaskSequentialOps.transform_list(inp, module, param=param, extra_args=extra_arg))
else:
Expand Down Expand Up @@ -240,6 +240,7 @@ def transform(cls, input: Tensor, module: Module, param: ParamItem, extra_args:
to apply transformations.
param: the corresponding parameters to the module.
"""

if isinstance(module, (K.GeometricAugmentationBase2D,)):
input = module.transform_masks(
input,
Expand Down Expand Up @@ -269,7 +270,8 @@ def transform(cls, input: Tensor, module: Module, param: ParamItem, extra_args:
input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

elif isinstance(module, (K.auto.operations.OperationBase,)):
return MaskSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)
input = MaskSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)

return input

@classmethod
Expand Down Expand Up @@ -344,6 +346,7 @@ def inverse(cls, input: Tensor, module: Module, param: ParamItem, extra_args: Di
to apply transformations.
param: the corresponding parameters to the module.
"""

if isinstance(module, (K.GeometricAugmentationBase2D,)):
if module.transform_matrix is None:
raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
Expand All @@ -365,7 +368,7 @@ def inverse(cls, input: Tensor, module: Module, param: ParamItem, extra_args: Di
input = module.inverse_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

elif isinstance(module, (K.auto.operations.OperationBase,)):
return MaskSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
input = MaskSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)

return input

Expand Down
14 changes: 8 additions & 6 deletions tests/augmentation/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,17 +555,18 @@ def test_masks_without_channel_dim(self, device, dtype, B, C_i, C_m, keepdim):
else:
assert (*(1,) * (4 - len(img_shape)), *img_shape) == out[0].shape

out_mask_shape = tuple(x if x else 1 for x in (B, C_m, *img_shape[-2:]))
out_mask_shape = tuple(x or 1 for x in (B, C_m, *img_shape[-2:]))
assert out[1].shape == out_mask_shape

@pytest.mark.slow
@pytest.mark.parametrize("random_apply", [1, (2, 2), (1, 2), (2,), 10, True, False])
def test_forward_and_inverse(self, random_apply, device, dtype):
@pytest.mark.parametrize("mask_dtype", [torch.int32, torch.int64, torch.float32])
def test_forward_and_inverse(self, random_apply, device, dtype, mask_dtype):
inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype)
bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype)
keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype)
mask = bbox_to_mask(
torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500
torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=mask_dtype), 1000, 500
)[:, None]
aug = K.AugmentationSequential(
K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0)),
Expand Down Expand Up @@ -715,11 +716,12 @@ def test_inverse_and_forward_return_transform(self, random_apply, device, dtype)
if random_apply is False:
reproducibility_test((inp, mask, bbox, keypoints, bbox_2, bbox_wh, bbox_wh_2), aug)

def test_transform_list_of_masks_and_boxes(self, device, dtype):
@pytest.mark.parametrize("mask_dtype", [torch.int32, torch.int64, torch.float32])
def test_transform_list_of_masks_and_boxes(self, device, dtype, mask_dtype):
input = torch.randn(2, 3, 256, 256, device=device, dtype=dtype)
mask = [
torch.ones(1, 3, 256, 256, device=device, dtype=dtype),
torch.ones(1, 2, 256, 256, device=device, dtype=dtype),
torch.ones(1, 3, 256, 256, device=device, dtype=mask_dtype),
torch.ones(1, 2, 256, 256, device=device, dtype=mask_dtype),
]

bbox = [
Expand Down