Skip to content

Commit 07a31c7

Browse files
committed
Fixes #414 - categorical maps are integers now by default
1 parent 7802024 commit 07a31c7

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

src/tech/v3/dataset/categorical.clj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ Non integers found: " (vec bad-mappings)))))
104104
m
105105
(set/unique (ds-base/column dataset colname)))
106106
colname
107-
(or res-dtype :float64))))
107+
(or res-dtype :int64))))
108108

109109

110110

111111
(defn transform-categorical-map
112112
"Apply a categorical mapping transformation fit with fit-categorical-map."
113113
[dataset fit-data]
114114
(let [colname (:src-column fit-data)
115-
result-datatype (or (:result-datatype fit-data) :float64)
115+
result-datatype (or (:result-datatype fit-data) :int64)
116116
lookup-table (:lookup-table fit-data)
117117
column (ds-base/column dataset colname)
118118
missing (ds-proto/missing column)
@@ -231,7 +231,7 @@ user> (ds-cat/dataset->categorical-maps catds)
231231
dataset (dissoc dataset src-column)
232232
n-elems (dtype/ecount column)
233233
op-space (casting/simple-operation-space (dtype-proto/operational-elemwise-datatype column))]
234-
(merge dataset
234+
(merge dataset
235235
(->> one-hot-table
236236
(lznc/map
237237
(fn [[k v]]

test/data/local_date.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[
2+
{"test": 1, "time-period": "2024-06-20"},
3+
{"test": 2, "time-period": "2024-06-21"},
4+
{"test": 3, "time-period": "2024-06-22"}]

test/tech/v3/dataset/categorical_test.clj

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,13 @@
7777
(dtype/emap val-map :keyword col))))
7878
(ds/categorical->number cf/categorical)
7979
(ds/column "Survived")))))
80+
(deftest categorical-assignments-are-integers
81+
(is (= #{0 1 2 3}
82+
(->
83+
(ds/->dataset {:x1 [1 2 4 5 6 5 6 7]
84+
:x2 [5 6 6 7 8 2 4 6]
85+
:y [:a :b :b :a :c :a :b :d]})
86+
(ds/categorical->number [:y])
87+
(get :y)
88+
distinct
89+
set))))

0 commit comments

Comments
 (0)