9
9
import kornia .color ._colormap_data as cm_data
10
10
from kornia .color ._colormap_data import RGBColor
11
11
from kornia .core import Module , Tensor , tensor
12
- from kornia .core .check import KORNIA_CHECK_IS_GRAY
12
+ from kornia .core .check import KORNIA_CHECK
13
13
from kornia .utils .helpers import deprecated
14
14
15
15
@@ -84,8 +84,8 @@ class ColorMap:
84
84
the `ColorMapType` enum class to view all available colormaps.
85
85
86
86
Args:
87
- base: A list of RGB colors to define a new custom colormap or
88
- the name of a built-in colormap as str or using ColorMapType class.
87
+ base: A list of RGB colors to define a new custom colormap or the name of a built-in colormap as str or
88
+ using ` ColorMapType` class.
89
89
num_colors: Number of colors in the colormap.
90
90
device: The device to put the generated colormap on.
91
91
dtype: The data type of the generated colormap.
@@ -164,7 +164,7 @@ def apply_colormap(input_tensor: Tensor, colormap: ColorMap) -> Tensor:
164
164
.. image:: _static/img/apply_colormap.png
165
165
166
166
Args:
167
- input_tensor: the input tensor of a gray image.
167
+ input_tensor: the input tensor of image.
168
168
colormap: the colormap desired to be applied to the input tensor.
169
169
170
170
Returns:
@@ -174,41 +174,49 @@ def apply_colormap(input_tensor: Tensor, colormap: ColorMap) -> Tensor:
174
174
ValueError: If `colormap` is not a ColorMap object.
175
175
176
176
.. note::
177
- The image data is assumed to be integer values in range of [0-255].
177
+ The input tensor must be integer values in the range of [0-255] or float values in the range of [0-1 ].
178
178
179
179
Example:
180
- >>> input_tensor = torch.tensor([[[0, 1, 2], [25, 50, 63 ]]])
181
- >>> colormap = ColorMap(base=' autumn' )
180
+ >>> input_tensor = torch.tensor([[[0, 1, 2], [15, 25, 33], [128, 158, 188 ]]])
181
+ >>> colormap = ColorMap(base=ColorMapType. autumn)
182
182
>>> apply_colormap(input_tensor, colormap)
183
- tensor([[[1.0000, 1.0000, 1.0000],
184
- [1.0000, 1.0000, 1.0000]],
183
+ tensor([[[[1.0000, 1.0000, 1.0000],
184
+ [1.0000, 1.0000, 1.0000],
185
+ [1.0000, 1.0000, 1.0000]],
185
186
<BLANKLINE>
186
- [[0.0000, 0.0159, 0.0317],
187
- [0.3968, 0.7937, 1.0000]],
187
+ [[0.0000, 0.0159, 0.0159],
188
+ [0.0635, 0.1111, 0.1429],
189
+ [0.5079, 0.6190, 0.7302]],
188
190
<BLANKLINE>
189
- [[0.0000, 0.0000, 0.0000],
190
- [0.0000, 0.0000, 0.0000]]])
191
+ [[0.0000, 0.0000, 0.0000],
192
+ [0.0000, 0.0000, 0.0000],
193
+ [0.0000, 0.0000, 0.0000]]]])
191
194
"""
192
- # FIXME: implement to work with RGB images
193
- # should work with KORNIA_CHECK_SHAPE(x, ["B","C", "H", "W"])
194
195
195
- KORNIA_CHECK_IS_GRAY (input_tensor )
196
+ KORNIA_CHECK (isinstance (input_tensor , Tensor ), f"`input_tensor` must be a Tensor. Got: { type (input_tensor )} " )
197
+ valid_types = [torch .half , torch .float , torch .double , torch .uint8 , torch .int , torch .long , torch .short ]
198
+ KORNIA_CHECK (
199
+ input_tensor .dtype in valid_types , f"`input_tensor` must be a { valid_types } . Got: { input_tensor .dtype } "
200
+ )
201
+ KORNIA_CHECK (len (input_tensor .shape ) in (3 , 4 ), "Wrong input tensor dimension." )
202
+ if len (input_tensor .shape ) == 3 :
203
+ input_tensor = input_tensor .unsqueeze_ (0 )
196
204
197
- if len ( input_tensor . shape ) == 4 and input_tensor .shape [ 1 ] == 1 : # if (B x 1 X H x W)
198
- input_tensor = input_tensor [:, 0 , ...] # (B x H x W )
199
- elif len ( input_tensor . shape ) == 3 and input_tensor .shape [ 0 ] == 1 : # if (1 X H x W)
200
- input_tensor = input_tensor [ 0 , ...] # (H x W )
205
+ B , C , H , W = input_tensor .shape
206
+ input_tensor = input_tensor . reshape ( B , C , - 1 )
207
+ max_value = 1.0 if input_tensor .max () <= 1.0 else 255.0
208
+ input_tensor = input_tensor . float (). div_ ( max_value )
201
209
202
- keys = torch .arange (0 , len (colormap ) - 1 , dtype = input_tensor .dtype , device = input_tensor .device ) # (num_colors)
210
+ colors = colormap .colors .permute (1 , 0 )
211
+ num_colors , channels_cmap = colors .shape
212
+ keys = torch .linspace (0.0 , 1.0 , num_colors - 1 , device = input_tensor .device , dtype = input_tensor .dtype )
213
+ indices = torch .bucketize (input_tensor , keys ).unsqueeze (- 1 ).expand (- 1 , - 1 , - 1 , 3 )
203
214
204
- index = torch .bucketize (input_tensor , keys ) # shape equals <input_tensor>: (B x H x W) or (H x W)
215
+ output = torch .gather (colors .expand (B , C , - 1 , - 1 ), 2 , indices )
216
+ # (B, C, H*W, channels_cmap) -> (B, C*channels_cmap, H, W)
217
+ output = output .permute (0 , 1 , 3 , 2 ).reshape (B , C * channels_cmap , H , W )
205
218
206
- output = colormap .colors [:, index ] # (3 x B x H x W) or (3 x H x W)
207
-
208
- if len (output .shape ) == 4 :
209
- output = output .permute (1 , 0 , - 2 , - 1 ) # (B x 3 x H x W)
210
-
211
- return output # (B x 3 x H x W) or (3 x H x W)
219
+ return output
212
220
213
221
214
222
class ApplyColorMap (Module ):
@@ -229,20 +237,23 @@ class ApplyColorMap(Module):
229
237
ValueError: If `colormap` is not a ColorMap object.
230
238
231
239
.. note::
232
- The image data is assumed to be integer values in range of [0-255].
240
+ The input tensor must be integer values in the range of [0-255] or float values in the range of [0-1 ].
233
241
234
242
Example:
235
- >>> input_tensor = torch.tensor([[[0, 1, 2], [25, 50, 63 ]]])
236
- >>> colormap = ColorMap(base=' autumn' )
243
+ >>> input_tensor = torch.tensor([[[0, 1, 2], [15, 25, 33], [128, 158, 188 ]]])
244
+ >>> colormap = ColorMap(base=ColorMapType. autumn)
237
245
>>> ApplyColorMap(colormap=colormap)(input_tensor)
238
- tensor([[[1.0000, 1.0000, 1.0000],
239
- [1.0000, 1.0000, 1.0000]],
246
+ tensor([[[[1.0000, 1.0000, 1.0000],
247
+ [1.0000, 1.0000, 1.0000],
248
+ [1.0000, 1.0000, 1.0000]],
240
249
<BLANKLINE>
241
- [[0.0000, 0.0159, 0.0317],
242
- [0.3968, 0.7937, 1.0000]],
250
+ [[0.0000, 0.0159, 0.0159],
251
+ [0.0635, 0.1111, 0.1429],
252
+ [0.5079, 0.6190, 0.7302]],
243
253
<BLANKLINE>
244
- [[0.0000, 0.0000, 0.0000],
245
- [0.0000, 0.0000, 0.0000]]])
254
+ [[0.0000, 0.0000, 0.0000],
255
+ [0.0000, 0.0000, 0.0000],
256
+ [0.0000, 0.0000, 0.0000]]]])
246
257
"""
247
258
248
259
def __init__ (
@@ -259,7 +270,7 @@ def forward(self, input_tensor: Tensor) -> Tensor:
259
270
input_tensor: The input tensor representing the grayscale image.
260
271
261
272
.. note::
262
- The image data is assumed to be integer values in range of [0-255].
273
+ The input tensor must be integer values in the range of [0-255] or float values in the range of [0-1 ].
263
274
264
275
Returns:
265
276
The output tensor representing the image with the applied colormap.
0 commit comments