Skip to content

Commit a183f44

Browse files
authored
fix(steps): fix min_frequency and max_category (ibis-project#156)
1 parent 2f2aaaf commit a183f44

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

ibis_ml/steps/_encode.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import ibis
88
import ibis.expr.types as ir
9+
from ibis import _
910

1011
from ibis_ml.core import Metadata, Step
1112
from ibis_ml.select import SelectionType, selector
@@ -41,10 +42,14 @@ def collect(col: str) -> ir.Table:
4142
query = (
4243
table.select(value=col)
4344
.group_by("value")
44-
.count("count")
45+
.aggregate(count=_.count())
4546
.mutate(column=ibis.literal(col))
4647
)
47-
return query if max_categories is None else query.limit(max_categories)
48+
return (
49+
query
50+
if max_categories is None
51+
else query.order_by(ibis.desc("count")).limit(max_categories)
52+
)
4853

4954
def process(df: pd.DataFrame) -> list[Any]:
5055
if isinstance(min_frequency, int):

tests/test_encode.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,34 @@ def test_ordinal_encode(t_train, t_test):
6262
tm.assert_frame_equal(res.execute(), expected.execute(), check_dtype=False)
6363

6464

65-
def test_one_hot_encode(t_train, t_test):
66-
step = ml.OneHotEncode("ticker")
65+
@pytest.mark.parametrize(
66+
("min_frequency", "max_categories", "expected"),
67+
[
68+
(
69+
None,
70+
None,
71+
{
72+
"ticker_AAPL": [0, 0, 0, 0, 0, 0],
73+
"ticker_GOOG": [0, 0, 1, 1, 0, 0],
74+
"ticker_MSFT": [1, 1, 0, 0, 0, 0],
75+
"ticker_None": [0, 0, 0, 0, 0, 1],
76+
},
77+
),
78+
(
79+
2,
80+
None,
81+
{"ticker_GOOG": [0, 0, 1, 1, 0, 0], "ticker_MSFT": [1, 1, 0, 0, 0, 0]},
82+
),
83+
(None, 1, {"ticker_MSFT": [1, 1, 0, 0, 0, 0]}),
84+
],
85+
)
86+
def test_onehotencode(t_train, t_test, min_frequency, max_categories, expected):
87+
step = ml.OneHotEncode(
88+
"ticker", min_frequency=min_frequency, max_categories=max_categories
89+
)
6790
step.fit_table(t_train, ml.core.Metadata())
6891
result = step.transform_table(t_test)
69-
expected = pd.DataFrame(
92+
expected_df = pd.DataFrame(
7093
{
7194
"time": [
7295
pd.Timestamp("2016-05-25 13:30:00.023"),
@@ -76,13 +99,10 @@ def test_one_hot_encode(t_train, t_test):
7699
pd.Timestamp("2016-05-25 13:30:00.050"),
77100
pd.Timestamp("2016-05-25 13:30:00.051"),
78101
],
79-
"ticker_AAPL": [0, 0, 0, 0, 0, 0],
80-
"ticker_GOOG": [0, 0, 1, 1, 0, 0],
81-
"ticker_MSFT": [1, 1, 0, 0, 0, 0],
82-
"ticker_None": [0, 0, 0, 0, 0, 1],
102+
**expected,
83103
}
84104
)
85-
tm.assert_frame_equal(result.execute(), expected, check_dtype=False)
105+
tm.assert_frame_equal(result.execute(), expected_df, check_dtype=False)
86106

87107

88108
@pytest.mark.parametrize("smooth", [5000.0, 1.0, 0.0])

0 commit comments

Comments
 (0)