Skip to content

Commit e461f92

Browse files
vgilabert94johnnv1
andauthored
feat: support batched and float data in apply colormap (kornia#2886)
* initial commit new apply_colormap * update generate doc * Update kornia/color/colormap.py Co-authored-by: João Gustavo A. Amorim <joaogustavoamorim@gmail.com> * Update kornia/color/colormap.py Co-authored-by: João Gustavo A. Amorim <joaogustavoamorim@gmail.com> --------- Co-authored-by: João Gustavo A. Amorim <joaogustavoamorim@gmail.com>
1 parent 2606afb commit e461f92

File tree

3 files changed

+82
-69
lines changed

3 files changed

+82
-69
lines changed

docs/generate_examples.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def main():
370370
# ITERATE OVER THE COLORMAPS
371371
for colormap_name, args in colormaps_list.items():
372372
cm = K.color.ColorMap(base=colormap_name, num_colors=args[0])
373-
out = K.color.rgb_to_bgr(K.color.apply_colormap(bar_img_gray, cm))
373+
out = K.color.rgb_to_bgr(K.color.apply_colormap(bar_img_gray, cm))[0]
374374

375375
out = torch.cat([bar_img, out], dim=-1)
376376

@@ -399,7 +399,7 @@ def main():
399399
for i, ax in enumerate(axes.flat):
400400
if i < num_colormaps:
401401
cmap = K.color.ColorMap(base=colormap_list[i], num_colors=num_colors)
402-
res = K.color.ApplyColorMap(colormap=cmap)(input_tensor)
402+
res = K.color.ApplyColorMap(colormap=cmap)(input_tensor)[0]
403403
ax.imshow(res.permute(1, 2, 0).numpy())
404404
ax.set_title(colormap_list[i], fontsize=12)
405405
ax.axis("off")

kornia/color/colormap.py

+49-38
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import kornia.color._colormap_data as cm_data
1010
from kornia.color._colormap_data import RGBColor
1111
from kornia.core import Module, Tensor, tensor
12-
from kornia.core.check import KORNIA_CHECK_IS_GRAY
12+
from kornia.core.check import KORNIA_CHECK
1313
from kornia.utils.helpers import deprecated
1414

1515

@@ -84,8 +84,8 @@ class ColorMap:
8484
the `ColorMapType` enum class to view all available colormaps.
8585
8686
Args:
87-
base: A list of RGB colors to define a new custom colormap or
88-
the name of a built-in colormap as str or using ColorMapType class.
87+
base: A list of RGB colors to define a new custom colormap or the name of a built-in colormap as str or
88+
using `ColorMapType` class.
8989
num_colors: Number of colors in the colormap.
9090
device: The device to put the generated colormap on.
9191
dtype: The data type of the generated colormap.
@@ -164,7 +164,7 @@ def apply_colormap(input_tensor: Tensor, colormap: ColorMap) -> Tensor:
164164
.. image:: _static/img/apply_colormap.png
165165
166166
Args:
167-
input_tensor: the input tensor of a gray image.
167+
input_tensor: the input tensor of image.
168168
colormap: the colormap desired to be applied to the input tensor.
169169
170170
Returns:
@@ -174,41 +174,49 @@ def apply_colormap(input_tensor: Tensor, colormap: ColorMap) -> Tensor:
174174
ValueError: If `colormap` is not a ColorMap object.
175175
176176
.. note::
177-
The image data is assumed to be integer values in range of [0-255].
177+
The input tensor must be integer values in the range of [0-255] or float values in the range of [0-1].
178178
179179
Example:
180-
>>> input_tensor = torch.tensor([[[0, 1, 2], [25, 50, 63]]])
181-
>>> colormap = ColorMap(base='autumn')
180+
>>> input_tensor = torch.tensor([[[0, 1, 2], [15, 25, 33], [128, 158, 188]]])
181+
>>> colormap = ColorMap(base=ColorMapType.autumn)
182182
>>> apply_colormap(input_tensor, colormap)
183-
tensor([[[1.0000, 1.0000, 1.0000],
184-
[1.0000, 1.0000, 1.0000]],
183+
tensor([[[[1.0000, 1.0000, 1.0000],
184+
[1.0000, 1.0000, 1.0000],
185+
[1.0000, 1.0000, 1.0000]],
185186
<BLANKLINE>
186-
[[0.0000, 0.0159, 0.0317],
187-
[0.3968, 0.7937, 1.0000]],
187+
[[0.0000, 0.0159, 0.0159],
188+
[0.0635, 0.1111, 0.1429],
189+
[0.5079, 0.6190, 0.7302]],
188190
<BLANKLINE>
189-
[[0.0000, 0.0000, 0.0000],
190-
[0.0000, 0.0000, 0.0000]]])
191+
[[0.0000, 0.0000, 0.0000],
192+
[0.0000, 0.0000, 0.0000],
193+
[0.0000, 0.0000, 0.0000]]]])
191194
"""
192-
# FIXME: implement to work with RGB images
193-
# should work with KORNIA_CHECK_SHAPE(x, ["B","C", "H", "W"])
194195

195-
KORNIA_CHECK_IS_GRAY(input_tensor)
196+
KORNIA_CHECK(isinstance(input_tensor, Tensor), f"`input_tensor` must be a Tensor. Got: {type(input_tensor)}")
197+
valid_types = [torch.half, torch.float, torch.double, torch.uint8, torch.int, torch.long, torch.short]
198+
KORNIA_CHECK(
199+
input_tensor.dtype in valid_types, f"`input_tensor` must be a {valid_types}. Got: {input_tensor.dtype}"
200+
)
201+
KORNIA_CHECK(len(input_tensor.shape) in (3, 4), "Wrong input tensor dimension.")
202+
if len(input_tensor.shape) == 3:
203+
input_tensor = input_tensor.unsqueeze_(0)
196204

197-
if len(input_tensor.shape) == 4 and input_tensor.shape[1] == 1: # if (B x 1 X H x W)
198-
input_tensor = input_tensor[:, 0, ...] # (B x H x W)
199-
elif len(input_tensor.shape) == 3 and input_tensor.shape[0] == 1: # if (1 X H x W)
200-
input_tensor = input_tensor[0, ...] # (H x W)
205+
B, C, H, W = input_tensor.shape
206+
input_tensor = input_tensor.reshape(B, C, -1)
207+
max_value = 1.0 if input_tensor.max() <= 1.0 else 255.0
208+
input_tensor = input_tensor.float().div_(max_value)
201209

202-
keys = torch.arange(0, len(colormap) - 1, dtype=input_tensor.dtype, device=input_tensor.device) # (num_colors)
210+
colors = colormap.colors.permute(1, 0)
211+
num_colors, channels_cmap = colors.shape
212+
keys = torch.linspace(0.0, 1.0, num_colors - 1, device=input_tensor.device, dtype=input_tensor.dtype)
213+
indices = torch.bucketize(input_tensor, keys).unsqueeze(-1).expand(-1, -1, -1, 3)
203214

204-
index = torch.bucketize(input_tensor, keys) # shape equals <input_tensor>: (B x H x W) or (H x W)
215+
output = torch.gather(colors.expand(B, C, -1, -1), 2, indices)
216+
# (B, C, H*W, channels_cmap) -> (B, C*channels_cmap, H, W)
217+
output = output.permute(0, 1, 3, 2).reshape(B, C * channels_cmap, H, W)
205218

206-
output = colormap.colors[:, index] # (3 x B x H x W) or (3 x H x W)
207-
208-
if len(output.shape) == 4:
209-
output = output.permute(1, 0, -2, -1) # (B x 3 x H x W)
210-
211-
return output # (B x 3 x H x W) or (3 x H x W)
219+
return output
212220

213221

214222
class ApplyColorMap(Module):
@@ -229,20 +237,23 @@ class ApplyColorMap(Module):
229237
ValueError: If `colormap` is not a ColorMap object.
230238
231239
.. note::
232-
The image data is assumed to be integer values in range of [0-255].
240+
The input tensor must be integer values in the range of [0-255] or float values in the range of [0-1].
233241
234242
Example:
235-
>>> input_tensor = torch.tensor([[[0, 1, 2], [25, 50, 63]]])
236-
>>> colormap = ColorMap(base='autumn')
243+
>>> input_tensor = torch.tensor([[[0, 1, 2], [15, 25, 33], [128, 158, 188]]])
244+
>>> colormap = ColorMap(base=ColorMapType.autumn)
237245
>>> ApplyColorMap(colormap=colormap)(input_tensor)
238-
tensor([[[1.0000, 1.0000, 1.0000],
239-
[1.0000, 1.0000, 1.0000]],
246+
tensor([[[[1.0000, 1.0000, 1.0000],
247+
[1.0000, 1.0000, 1.0000],
248+
[1.0000, 1.0000, 1.0000]],
240249
<BLANKLINE>
241-
[[0.0000, 0.0159, 0.0317],
242-
[0.3968, 0.7937, 1.0000]],
250+
[[0.0000, 0.0159, 0.0159],
251+
[0.0635, 0.1111, 0.1429],
252+
[0.5079, 0.6190, 0.7302]],
243253
<BLANKLINE>
244-
[[0.0000, 0.0000, 0.0000],
245-
[0.0000, 0.0000, 0.0000]]])
254+
[[0.0000, 0.0000, 0.0000],
255+
[0.0000, 0.0000, 0.0000],
256+
[0.0000, 0.0000, 0.0000]]]])
246257
"""
247258

248259
def __init__(
@@ -259,7 +270,7 @@ def forward(self, input_tensor: Tensor) -> Tensor:
259270
input_tensor: The input tensor representing the grayscale image.
260271
261272
.. note::
262-
The image data is assumed to be integer values in range of [0-255].
273+
The input tensor must be integer values in the range of [0-255] or float values in the range of [0-1].
263274
264275
Returns:
265276
The output tensor representing the image with the applied colormap.

tests/color/test_colormap.py

+31-29
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,26 @@ def test_autumn(device, dtype):
2626

2727
class TestApplyColorMap(BaseTester):
2828
def test_smoke(self, device, dtype):
29-
input_tensor = tensor([[[0, 1, 3], [25, 50, 63]]], device=device, dtype=dtype)
30-
29+
input_tensor = tensor([[[0, 1, 2], [15, 25, 33], [128, 158, 188]]], device=device, dtype=dtype)
3130
expected_tensor = tensor(
3231
[
33-
[[1, 1, 1], [1, 1, 1]],
34-
[[0, 0.01587301587301587, 0.04761904761904762], [0.3968253968253968, 0.7936507936507936, 1]],
35-
[[0, 0, 0], [0, 0, 0]],
32+
[
33+
[
34+
[1.0000000000, 1.0000000000, 1.0000000000],
35+
[1.0000000000, 1.0000000000, 1.0000000000],
36+
[1.0000000000, 1.0000000000, 1.0000000000],
37+
],
38+
[
39+
[0.0000000000, 0.0158730168, 0.0158730168],
40+
[0.0634920672, 0.1111111119, 0.1428571492],
41+
[0.5079365373, 0.6190476418, 0.7301587462],
42+
],
43+
[
44+
[0.0000000000, 0.0000000000, 0.0000000000],
45+
[0.0000000000, 0.0000000000, 0.0000000000],
46+
[0.0000000000, 0.0000000000, 0.0000000000],
47+
],
48+
]
3649
],
3750
device=device,
3851
dtype=dtype,
@@ -42,39 +55,28 @@ def test_smoke(self, device, dtype):
4255

4356
self.assert_close(actual, expected_tensor)
4457

45-
def test_eye(self, device, dtype):
46-
input_tensor = torch.stack(
47-
[torch.eye(2, dtype=dtype, device=device) * 255, torch.eye(2, dtype=dtype, device=device) * 150]
48-
).view(2, -1, 2, 2)
49-
50-
expected_tensor = tensor(
51-
[
52-
[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]],
53-
[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]],
54-
],
55-
device=device,
56-
dtype=dtype,
57-
)
58-
59-
actual = apply_colormap(input_tensor, ColorMap(base="autumn", device=device, dtype=dtype))
60-
self.assert_close(actual, expected_tensor)
61-
6258
def test_exception(self, device, dtype):
6359
cm = ColorMap(base="autumn", device=device, dtype=dtype)
64-
with pytest.raises(TypeError):
65-
apply_colormap(torch.rand(size=(5, 1, 1), dtype=dtype, device=device), cm)
60+
with pytest.raises(Exception):
61+
apply_colormap(torch.rand(size=(3, 3), dtype=dtype, device=device), cm)
62+
63+
with pytest.raises(Exception):
64+
apply_colormap(torch.rand(size=(3), dtype=dtype, device=device), cm)
65+
66+
with pytest.raises(Exception):
67+
apply_colormap(torch.rand(size=(3), dtype=dtype, device=device).item(), cm)
6668

67-
@pytest.mark.parametrize("shape", [(2, 1, 4, 4), (1, 4, 4), (4, 4)])
69+
@pytest.mark.parametrize("shape", [(2, 1, 3, 3), (1, 3, 3, 3), (1, 3, 3)])
6870
@pytest.mark.parametrize("cmap_base", ColorMapType)
6971
def test_cardinality(self, shape, device, dtype, cmap_base):
70-
cm = ColorMap(base=cmap_base, device=device, dtype=dtype)
71-
input_tensor = torch.randint(0, 63, shape, device=device, dtype=dtype)
72+
cm = ColorMap(base=cmap_base, num_colors=256, device=device, dtype=dtype)
73+
input_tensor = torch.randint(0, 256, shape, device=device, dtype=dtype)
7274
actual = apply_colormap(input_tensor, cm)
7375

7476
if len(shape) == 4:
75-
expected_shape = (shape[0], 3, shape[-2], shape[-1])
77+
expected_shape = (shape[-4], shape[-3] * 3, shape[-2], shape[-1])
7678
else:
77-
expected_shape = (3, shape[-2], shape[-1])
79+
expected_shape = (1, shape[-3] * 3, shape[-2], shape[-1])
7880

7981
assert actual.shape == expected_shape
8082

0 commit comments

Comments
 (0)