Skip to content

Commit 6239db1

Browse files
committed
Fixes #438
1 parent b39a095 commit 6239db1

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

neanderthal/tech/v3/dataset/tribuo_test.clj

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
[org.tribuo.regression.xgboost XGBoostRegressionTrainer]))
1212

1313

14-
1514
(defn classification-example-ds
1615
[x]
1716
(let [x (if (integer? x)
@@ -80,3 +79,11 @@
8079
(is (= "class org.tribuo.classification.dtree.CARTClassificationTrainer"
8180
(str (class trainer))))))
8281

82+
83+
(deftest test-keyword-name
84+
(testing "string name (OK)"
85+
(is (-> (ds/->dataset [{"a" 1}] {:dataset-name "string name"})
86+
(tribuo/make-regression-datasource "a"))))
87+
(testing "keyword name (Error)"
88+
(is (-> (ds/->dataset [{"a" 1}] {:dataset-name :keyword/name})
89+
(tribuo/make-regression-datasource "a")))))

src/tech/v3/libs/tribuo.clj

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ _unnamed [5 1]:
5555
[org.tribuo.regression.evaluation RegressionEvaluator RegressionEvaluation]
5656
[com.oracle.labs.mlrg.olcut.config ConfigurationManager]
5757
[com.oracle.labs.mlrg.olcut.config.json JsonConfigFactory]))
58-
58+
5959

6060
(set! *warn-on-reflection* true)
6161

@@ -157,13 +157,22 @@ _unnamed [5 1]:
157157
cnames (->double-array (feat-data idx))))
158158
(meta outputs))))
159159

160+
(defn- safe-str
161+
[n]
162+
(cond (string? n)
163+
n
164+
(or (keyword? n) (symbol? n))
165+
(if-let [nn (namespace n)]
166+
(str nn "/" (name n))
167+
(str (name n)))))
168+
160169

161170
(defn- ds->datasource
162171
^DataSource [ds ds->outputs]
163172
(let [examples (ds->examples ds ds->outputs)
164173
{:keys [output-factory provenance]} (meta examples)
165174
provenance (or provenance
166-
(SimpleDataSourceProvenance. (:name (meta ds)) output-factory))]
175+
(SimpleDataSourceProvenance. (safe-str (:name (meta ds))) output-factory))]
167176
(when-not output-factory
168177
(throw (RuntimeException. "Output factory not present in example metadata")))
169178
(reify DataSource

0 commit comments

Comments
 (0)