|
11 | 11 | from collections.abc import Iterable
|
12 | 12 |
|
13 | 13 | _DOCS_PAGE_NAME = "standardization"
|
| 14 | +# a small epsilon value to handle near-constant columns during normalization |
| 15 | +_APPROX_EPS = 10e-7 |
14 | 16 |
|
15 | 17 |
|
16 | 18 | class ScaleMinMax(Step):
|
@@ -61,21 +63,18 @@ def fit_table(self, table: ir.Table, metadata: Metadata) -> None:
|
61 | 63 | self._fit_expr = [expr]
|
62 | 64 | results = expr.execute().to_dict("records")[0]
|
63 | 65 | for name in columns:
|
64 |
| - col_max = results[f"{name}_max"] |
65 |
| - col_min = results[f"{name}_min"] |
66 |
| - if col_max == col_min: |
67 |
| - raise ValueError( |
68 |
| - f"Cannot standardize {name!r} - " |
69 |
| - "the maximum and minimum values are equal" |
70 |
| - ) |
71 |
| - stats[name] = (col_max, col_min) |
| 66 | + stats[name] = (results[f"{name}_max"], results[f"{name}_min"]) |
72 | 67 |
|
73 | 68 | self.stats_ = stats
|
74 | 69 |
|
75 | 70 | def transform_table(self, table: ir.Table) -> ir.Table:
|
76 | 71 | return table.mutate(
|
77 | 72 | [
|
78 |
| - ((table[c] - min) / (max - min)).name(c) # type: ignore |
| 73 | + # for near-constant column, set the scale to 1.0 |
| 74 | + ( |
| 75 | + (table[c] - min) |
| 76 | + / (1.0 if abs(max - min) < _APPROX_EPS else max - min) |
| 77 | + ).name(c) |
79 | 78 | for c, (max, min) in self.stats_.items()
|
80 | 79 | ]
|
81 | 80 | )
|
@@ -128,19 +127,17 @@ def fit_table(self, table: ir.Table, metadata: Metadata) -> None:
|
128 | 127 | self._fit_expr = [table.aggregate(aggs)]
|
129 | 128 | results = self._fit_expr[-1].execute().to_dict("records")[0]
|
130 | 129 | for name in columns:
|
131 |
| - col_std = results[f"{name}_std"] |
132 |
| - if col_std == 0: |
133 |
| - raise ValueError( |
134 |
| - f"Cannot standardize {name!r} - the standard deviation is zero" |
135 |
| - ) |
136 |
| - stats[name] = (results[f"{name}_mean"], col_std) |
| 130 | + stats[name] = (results[f"{name}_mean"], results[f"{name}_std"]) |
137 | 131 |
|
138 | 132 | self.stats_ = stats
|
139 | 133 |
|
140 | 134 | def transform_table(self, table: ir.Table) -> ir.Table:
|
141 | 135 | return table.mutate(
|
142 | 136 | [
|
143 |
| - ((table[c] - center) / scale).name(c) # type: ignore |
| 137 | + # for near-constant column, set the scale to 1.0 |
| 138 | + ( |
| 139 | + (table[c] - center) / (1.0 if abs(scale) < _APPROX_EPS else scale) |
| 140 | + ).name(c) |
144 | 141 | for c, (center, scale) in self.stats_.items()
|
145 | 142 | ]
|
146 | 143 | )
|
0 commit comments