@@ -82,13 +82,28 @@ def normalize_image(
82
82
) -> tf .Tensor :
83
83
"""Normalizes the image to zero mean and unit variance.
84
84
85
- If the input image dtype is float, it is expected to either have values in
86
- [0, 1) and offset is MEAN_NORM, or have values in [0, 255] and offset is
87
- MEAN_RGB.
85
+ This function normalizes the input image by subtracting the `offset`
86
+ and dividing by the `scale`.
87
+
88
+ **Important Note about Input Types and Normalization:**
89
+
90
+ * **Integer Images:** If the input `image` is an integer type (e.g., `uint8`),
91
+ the provided `offset` and `scale` values should be already **normalized**
92
+ to the range [0, 1]. This is because the function converts integer images to
93
+ float32 with values in the range [0, 1] before the normalization happens.
94
+
95
+ * **Float Images:** If the input `image` is a float type (e.g., `float32`),
96
+ the `offset` and `scale` values should be in the **same range** as the
97
+ image data.
98
+ - If the image has values in [0, 1], the `offset` and `scale` should
99
+ also be in [0, 1].
100
+ - If the image has values in [0, 255], the `offset` and `scale` should
101
+ also be in [0, 255].
88
102
89
103
Args:
90
- image: A tf.Tensor in either (1) float dtype with values in range [0, 1) or
91
- [0, 255], or (2) int type with values in range [0, 255].
104
+ image: A `tf.Tensor` in either:
105
+ (1) float dtype with values in range [0, 1) or [0, 255], or
106
+ (2) int type with values in range [0, 255].
92
107
offset: A tuple of mean values to be subtracted from the image.
93
108
scale: A tuple of normalization factors.
94
109
0 commit comments