Skip to content

Commit b8b39f2

Browse files
committed
handle near constant column
1 parent 72eb5b3 commit b8b39f2

File tree

2 files changed

+13
-35
lines changed

2 files changed

+13
-35
lines changed

ibis_ml/steps/_standardize.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from collections.abc import Iterable
1212

1313
_DOCS_PAGE_NAME = "standardization"
14+
# a small epsilon value to handle near-constant columns during normalization
15+
_APPROX_EPS = 10e-7
1416

1517

1618
class ScaleMinMax(Step):
@@ -61,21 +63,18 @@ def fit_table(self, table: ir.Table, metadata: Metadata) -> None:
6163
self._fit_expr = [expr]
6264
results = expr.execute().to_dict("records")[0]
6365
for name in columns:
64-
col_max = results[f"{name}_max"]
65-
col_min = results[f"{name}_min"]
66-
if col_max == col_min:
67-
raise ValueError(
68-
f"Cannot standardize {name!r} - "
69-
"the maximum and minimum values are equal"
70-
)
71-
stats[name] = (col_max, col_min)
66+
stats[name] = (results[f"{name}_max"], results[f"{name}_min"])
7267

7368
self.stats_ = stats
7469

7570
def transform_table(self, table: ir.Table) -> ir.Table:
7671
return table.mutate(
7772
[
78-
((table[c] - min) / (max - min)).name(c) # type: ignore
73+
# for near-constant column, set the scale to 1.0
74+
(
75+
(table[c] - min)
76+
/ (1.0 if abs(max - min) < _APPROX_EPS else max - min)
77+
).name(c)
7978
for c, (max, min) in self.stats_.items()
8079
]
8180
)
@@ -128,19 +127,17 @@ def fit_table(self, table: ir.Table, metadata: Metadata) -> None:
128127
self._fit_expr = [table.aggregate(aggs)]
129128
results = self._fit_expr[-1].execute().to_dict("records")[0]
130129
for name in columns:
131-
col_std = results[f"{name}_std"]
132-
if col_std == 0:
133-
raise ValueError(
134-
f"Cannot standardize {name!r} - the standard deviation is zero"
135-
)
136-
stats[name] = (results[f"{name}_mean"], col_std)
130+
stats[name] = (results[f"{name}_mean"], results[f"{name}_std"])
137131

138132
self.stats_ = stats
139133

140134
def transform_table(self, table: ir.Table) -> ir.Table:
141135
return table.mutate(
142136
[
143-
((table[c] - center) / scale).name(c) # type: ignore
137+
# for near-constant column, set the scale to 1.0
138+
(
139+
(table[c] - center) / (1.0 if abs(scale) < _APPROX_EPS else scale)
140+
).name(c)
144141
for c, (center, scale) in self.stats_.items()
145142
]
146143
)

tests/test_standardize.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
import pandas as pd
44
import pandas.testing as tm
5-
import pytest
65

76
import ibis_ml as ml
87

@@ -29,21 +28,3 @@ def test_scaleminmax():
2928
result = step.transform_table(table)
3029
expected = pd.DataFrame({"col": (cols - min_val) / (max_val - min_val)})
3130
tm.assert_frame_equal(result.execute(), expected, check_exact=False)
32-
33-
34-
@pytest.mark.parametrize(
35-
("model", "msg"),
36-
[
37-
("ScaleStandard", "Cannot standardize 'col' - the standard deviation is zero"),
38-
(
39-
"ScaleMinMax",
40-
"Cannot standardize 'col' - the maximum and minimum values are equal",
41-
),
42-
],
43-
)
44-
def test_scale_unique_col(model, msg):
45-
table = ibis.memtable({"col": [1]})
46-
scale_class = getattr(ml, model)
47-
step = scale_class("col")
48-
with pytest.raises(ValueError, match=msg):
49-
step.fit_table(table, ml.core.Metadata())

0 commit comments

Comments
 (0)