Skip to content

Commit dae6de6

Browse files
authored
color jitter compile and set right device / dtype (kornia#2863)
* don't force random generator device * colorjitter compile * less intrusive compile * compile method * fix tests * skip compile for old versions * more fixes * more ci fixes * more dtype fixes * more test fixes
1 parent 2c761ee commit dae6de6

File tree

5 files changed

+1611
-385
lines changed

5 files changed

+1611
-385
lines changed

kornia/augmentation/_2d/intensity/color_jitter.py

+76-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any, Dict, List, Optional, Tuple, Union
22

3+
import torch
4+
35
from kornia.augmentation import random_generator as rg
46
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
57
from kornia.constants import pi
@@ -84,25 +86,87 @@ def __init__(
8486
self.hue = hue
8587
self._param_generator = rg.ColorJitterGenerator(brightness, contrast, saturation, hue)
8688

89+
# native functions
90+
self.brightness_fn = adjust_brightness_accumulative
91+
self.contrast_fn = adjust_contrast_with_mean_subtraction
92+
self.saturation_fn = adjust_saturation_with_gray_subtraction
93+
self.hue_fn = adjust_hue
94+
8795
def apply_transform(
88-
self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
96+
self,
97+
input: Tensor,
98+
params: Dict[str, Tensor],
99+
flags: Dict[str, Any],
100+
transform: Optional[Tensor] = None,
89101
) -> Tensor:
90102
transforms = [
91-
lambda img: adjust_brightness_accumulative(img, params["brightness_factor"])
92-
if (params["brightness_factor"] != 0).any()
93-
else img,
94-
lambda img: adjust_contrast_with_mean_subtraction(img, params["contrast_factor"])
95-
if (params["contrast_factor"] != 1).any()
96-
else img,
97-
lambda img: adjust_saturation_with_gray_subtraction(img, params["saturation_factor"])
98-
if (params["saturation_factor"] != 1).any()
99-
else img,
100-
lambda img: adjust_hue(img, params["hue_factor"] * 2 * pi) if (params["hue_factor"] != 0).any() else img,
103+
lambda img: (
104+
self.brightness_fn(img, params["brightness_factor"])
105+
if (params["brightness_factor"] != 0).any()
106+
else img
107+
),
108+
lambda img: (
109+
self.contrast_fn(img, params["contrast_factor"]) if (params["contrast_factor"] != 1).any() else img
110+
),
111+
lambda img: (
112+
self.saturation_fn(img, params["saturation_factor"])
113+
if (params["saturation_factor"] != 1).any()
114+
else img
115+
),
116+
lambda img: (self.hue_fn(img, params["hue_factor"] * 2 * pi) if (params["hue_factor"] != 0).any() else img),
101117
]
102118

103119
jittered = input
104-
for idx in params["order"].tolist():
120+
for idx in params["order"]:
105121
t = transforms[idx]
106122
jittered = t(jittered)
107123

108124
return jittered
125+
126+
def compile(
127+
self,
128+
*,
129+
fullgraph: bool = False,
130+
dynamic: bool = False,
131+
backend: str = "inductor",
132+
mode: Optional[str] = None,
133+
options: Optional[Dict[Any, Any]] = None,
134+
disable: bool = False,
135+
) -> "ColorJitter":
136+
self.brightness_fn = torch.compile(
137+
self.brightness_fn,
138+
fullgraph=fullgraph,
139+
dynamic=dynamic,
140+
backend=backend,
141+
mode=mode,
142+
options=options,
143+
disable=disable,
144+
)
145+
self.contrast_fn = torch.compile(
146+
self.contrast_fn,
147+
fullgraph=fullgraph,
148+
dynamic=dynamic,
149+
backend=backend,
150+
mode=mode,
151+
options=options,
152+
disable=disable,
153+
)
154+
self.saturation_fn = torch.compile(
155+
self.saturation_fn,
156+
fullgraph=fullgraph,
157+
dynamic=dynamic,
158+
backend=backend,
159+
mode=mode,
160+
options=options,
161+
disable=disable,
162+
)
163+
self.hue_fn = torch.compile(
164+
self.hue_fn,
165+
fullgraph=fullgraph,
166+
dynamic=dynamic,
167+
backend=backend,
168+
mode=mode,
169+
options=options,
170+
disable=disable,
171+
)
172+
return self

kornia/augmentation/random_generator/_2d/color_jitter.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
from functools import partial
21
from typing import Dict, List, Tuple, Union
32

43
import torch
54

6-
from kornia.augmentation.random_generator.base import RandomGeneratorBase, UniformDistribution
7-
from kornia.augmentation.utils import _adapted_rsampling, _common_param_check, _joint_range_check, _range_bound
5+
from kornia.augmentation.random_generator.base import (
6+
RandomGeneratorBase,
7+
UniformDistribution,
8+
)
9+
from kornia.augmentation.utils import (
10+
_adapted_rsampling,
11+
_joint_range_check,
12+
_range_bound,
13+
)
814
from kornia.core import Tensor
9-
from kornia.utils.helpers import _extract_device_dtype
1015

1116
__all__ = ["ColorJitterGenerator"]
1217

@@ -71,20 +76,18 @@ def make_samplers(self, device: torch.device, dtype: torch.dtype) -> None:
7176
self.contrast_sampler = UniformDistribution(contrast[0], contrast[1], validate_args=False)
7277
self.hue_sampler = UniformDistribution(hue[0], hue[1], validate_args=False)
7378
self.saturation_sampler = UniformDistribution(saturation[0], saturation[1], validate_args=False)
74-
self.randperm = partial(torch.randperm, device=device, dtype=dtype)
7579

7680
def forward(self, batch_shape: Tuple[int, ...], same_on_batch: bool = False) -> Dict[str, Tensor]:
7781
batch_size = batch_shape[0]
78-
_common_param_check(batch_size, same_on_batch)
79-
_device, _dtype = _extract_device_dtype([self.brightness, self.contrast, self.hue, self.saturation])
8082
brightness_factor = _adapted_rsampling((batch_size,), self.brightness_sampler, same_on_batch)
8183
contrast_factor = _adapted_rsampling((batch_size,), self.contrast_sampler, same_on_batch)
8284
hue_factor = _adapted_rsampling((batch_size,), self.hue_sampler, same_on_batch)
8385
saturation_factor = _adapted_rsampling((batch_size,), self.saturation_sampler, same_on_batch)
86+
8487
return {
85-
"brightness_factor": brightness_factor.to(device=_device, dtype=_dtype),
86-
"contrast_factor": contrast_factor.to(device=_device, dtype=_dtype),
87-
"hue_factor": hue_factor.to(device=_device, dtype=_dtype),
88-
"saturation_factor": saturation_factor.to(device=_device, dtype=_dtype),
89-
"order": self.randperm(4).to(device=_device, dtype=_dtype).long(),
88+
"brightness_factor": brightness_factor,
89+
"contrast_factor": contrast_factor,
90+
"hue_factor": hue_factor,
91+
"saturation_factor": saturation_factor,
92+
"order": torch.randperm(4, dtype=torch.long),
9093
}

kornia/enhance/adjust.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
from torch.nn import Module, Parameter
77

88
from kornia.color import hsv_to_rgb, rgb_to_grayscale, rgb_to_hsv
9-
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_COLOR_OR_GRAY, KORNIA_CHECK_IS_TENSOR
9+
from kornia.core.check import (
10+
KORNIA_CHECK,
11+
KORNIA_CHECK_IS_COLOR_OR_GRAY,
12+
KORNIA_CHECK_IS_TENSOR,
13+
)
1014
from kornia.utils.helpers import _torch_histc_cast
1115
from kornia.utils.image import perform_keep_shape_image, perform_keep_shape_video
1216

@@ -393,7 +397,7 @@ def adjust_contrast_with_mean_subtraction(image: Tensor, factor: Union[float, Te
393397
while len(factor.shape) != len(image.shape):
394398
factor = factor[..., None]
395399

396-
KORNIA_CHECK(any(factor >= 0), "Contrast factor must be positive.")
400+
# KORNIA_CHECK(any(factor >= 0), "Contrast factor must be positive.")
397401

398402
if image.shape[-3] == 3:
399403
img_mean = rgb_to_grayscale(image).mean((-2, -1), True)
@@ -625,7 +629,9 @@ def _solarize(input: Tensor, thresholds: Union[float, Tensor] = 0.5) -> Tensor:
625629

626630

627631
def solarize(
628-
input: Tensor, thresholds: Union[float, Tensor] = 0.5, additions: Optional[Union[float, Tensor]] = None
632+
input: Tensor,
633+
thresholds: Union[float, Tensor] = 0.5,
634+
additions: Optional[Union[float, Tensor]] = None,
629635
) -> Tensor:
630636
r"""For each pixel in the image less than threshold.
631637

0 commit comments

Comments
 (0)