Skip to content

Commit 0ac7970

Browse files
authored
Fix label access on clustering model (#3030)
* Fix label access on clustering model * format code * update predict * add default shape for label
1 parent e190446 commit 0ac7970

File tree

6 files changed

+21
-11
lines changed

6 files changed

+21
-11
lines changed

go/executor/executor.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ import (
2525
"path"
2626
"path/filepath"
2727
"regexp"
28-
"sqlflow.org/sqlflow/go/codegen/experimental"
2928
"strings"
3029
"sync"
3130

31+
"sqlflow.org/sqlflow/go/codegen/experimental"
32+
3233
"sqlflow.org/sqlflow/go/verifier"
3334

3435
"sqlflow.org/sqlflow/go/codegen/optimize"

go/ir/ir_generator.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ func GenerateTrainStmt(slct *parser.SQLFlowSelectStmt) (*TrainStmt, error) {
114114
}
115115
label := &NumericColumn{
116116
FieldDesc: &FieldDesc{
117-
Name: tc.Label,
117+
Name: tc.Label,
118+
Shape: []int{1},
118119
}}
119120

120121
vslct, _ := parseValidationSelect(attrList)

go/sql/executor_ir_test.go

-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ func TestExecutorTrainAndPredictDNN(t *testing.T) {
188188
}
189189

190190
func TestExecutorTrainAndPredictClusteringLocalFS(t *testing.T) {
191-
t.Skip("fix random nan loss error then re-enable this test")
192191
a := assert.New(t)
193192
modelDir, e := ioutil.TempDir("/tmp", "sqlflow_models")
194193
a.Nil(e)

python/runtime/local/create_result_table.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ def create_predict_table(conn, select, result_table, train_label_desc,
3232
"""
3333
name_and_types = db.selected_columns_and_types(conn, select)
3434
train_label_index = -1
35-
for i, (name, _) in enumerate(name_and_types):
36-
if name == train_label_desc.name:
37-
train_label_index = i
38-
break
35+
if train_label_desc:
36+
for i, (name, _) in enumerate(name_and_types):
37+
if name == train_label_desc.name:
38+
train_label_index = i
39+
break
3940

4041
if train_label_index >= 0:
4142
del name_and_types[train_label_index]
@@ -45,10 +46,12 @@ def create_predict_table(conn, select, result_table, train_label_desc,
4546
column_strs.append("%s %s" %
4647
(name, db.to_db_field_type(conn.driver, typ)))
4748

48-
if train_label_desc.format == DataFormat.PLAIN:
49+
if train_label_desc and train_label_desc.format == DataFormat.PLAIN:
4950
train_label_field_type = DataType.to_db_field_type(
5051
conn.driver, train_label_desc.dtype)
5152
else:
53+
# if no train lable description is provided (clustering),
54+
# we treat the column type as string
5255
train_label_field_type = DataType.to_db_field_type(
5356
conn.driver, DataType.STRING)
5457

python/runtime/local/tensorflow_submitter/predict.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ def pred(datasource, select, result_table, pred_label_name, model):
4646

4747
model_params = model.get_meta("attributes")
4848
train_fc_map = model.get_meta("features")
49-
train_label_desc = model.get_meta("label").get_field_desc()[0]
50-
train_label_name = train_label_desc.name
49+
label_meta = model.get_meta("label")
50+
train_label_desc = label_meta.get_field_desc()[0] if label_meta else None
51+
train_label_name = train_label_desc.name if train_label_desc else None
5152
estimator_string = model.get_meta("class_name")
5253
save = "model_save"
5354

python/runtime/step/tensorflow/train.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,12 @@ def train_step(original_sql,
8080
feature_column_names = [fd.name for fd in field_descs]
8181
feature_metas = dict([(fd.name, fd.to_dict(dtype_to_string=True))
8282
for fd in field_descs])
83-
label_meta = fc_label_ir.get_field_desc()[0].to_dict(dtype_to_string=True)
83+
84+
# no label for clustering model
85+
label_meta = None
86+
if fc_label_ir:
87+
label_meta = fc_label_ir.get_field_desc()[0].to_dict(
88+
dtype_to_string=True)
8489

8590
feature_column_names_map = dict()
8691
for target in fc_map_ir:

0 commit comments

Comments
 (0)