Skip to content

Commit 463a08d

Browse files
authored
Make pai_local executor can run using experimental codegen (#3034)
* make pai_local run successfully * fix circular reference
1 parent 0ac7970 commit 463a08d

File tree

9 files changed

+138
-46
lines changed

9 files changed

+138
-46
lines changed

python/runtime/model/model.py

+33-22
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _from_dict(d):
9999
typ = d.pop("model_type")
100100
return Model(typ, d)
101101

102-
def _zip(self, local_dir, tarball, save_to_db=False):
102+
def _zip(self, local_dir, tarball):
103103
"""
104104
Zip the model information and all files in local_dir into a tarball.
105105
@@ -110,20 +110,23 @@ def _zip(self, local_dir, tarball, save_to_db=False):
110110
Returns:
111111
None.
112112
"""
113-
if not save_to_db:
114-
model_obj_file = os.path.join(local_dir, MODEL_OBJ_FILE_NAME)
115-
with open(model_obj_file, "w") as f:
116-
d = self._to_dict()
117-
f.write(json.dumps(d, cls=JSONEncoderWithFeatureColumn))
118-
else:
119-
model_obj_file = None
113+
# NOTE: the unzip files of the job tarball should be skipped
114+
from runtime.pai.prepare_archive import ALL_TAR_FILES
115+
116+
def filter(tarinfo):
117+
name = tarinfo.name
118+
if name.startswith("./"):
119+
name = name[2:]
120+
121+
if name in ALL_TAR_FILES:
122+
return None
123+
124+
return tarinfo
120125

121-
zip_dir(local_dir, tarball, arcname="./")
122-
if model_obj_file:
123-
os.remove(model_obj_file)
126+
zip_dir(local_dir, tarball, arcname="./", filter=filter)
124127

125128
@staticmethod
126-
def _unzip(local_dir, tarball, load_from_db=False):
129+
def _unzip(local_dir, tarball):
127130
"""
128131
Unzip the tarball into local_dir and deserialize the model
129132
information.
@@ -137,13 +140,6 @@ def _unzip(local_dir, tarball, load_from_db=False):
137140
information.
138141
"""
139142
unzip_dir(tarball, local_dir)
140-
if not load_from_db:
141-
model_obj_file = os.path.join(local_dir, MODEL_OBJ_FILE_NAME)
142-
with open(model_obj_file, "r") as f:
143-
d = json.loads(f.read(), cls=JSONDecoderWithFeatureColumn)
144-
model = Model._from_dict(d)
145-
os.remove(model_obj_file)
146-
return model
147143

148144
def save_to_db(self, datasource, table, local_dir=None):
149145
"""
@@ -164,7 +160,7 @@ def save_to_db(self, datasource, table, local_dir=None):
164160

165161
with temp_file.TemporaryDirectory() as tmp_dir:
166162
tarball = os.path.join(tmp_dir, TARBALL_NAME)
167-
self._zip(local_dir, tarball, save_to_db=True)
163+
self._zip(local_dir, tarball)
168164

169165
def _bytes_reader(filename, buf_size=8 * 32):
170166
def _gen():
@@ -212,7 +208,7 @@ def load_from_db(datasource, table, local_dir=None):
212208
for data in gen():
213209
f.write(bytes(data))
214210

215-
Model._unzip(local_dir, tarball, load_from_db=True)
211+
Model._unzip(local_dir, tarball)
216212

217213
return Model._from_dict(metadata)
218214

@@ -237,6 +233,14 @@ def save_to_oss(self, oss_model_dir, local_dir=None):
237233
self._zip(local_dir, tarball)
238234
oss.save_file(oss_model_dir, tarball, TARBALL_NAME)
239235

236+
with temp_file.TemporaryDirectory() as tmp_dir:
237+
model_obj_file = os.path.join(tmp_dir, MODEL_OBJ_FILE_NAME)
238+
with open(model_obj_file, "w") as f:
239+
f.write(
240+
json.dumps(self._to_dict(),
241+
cls=JSONEncoderWithFeatureColumn))
242+
oss.save_file(oss_model_dir, model_obj_file, MODEL_OBJ_FILE_NAME)
243+
240244
@staticmethod
241245
def load_from_oss(oss_model_dir, local_dir=None):
242246
"""
@@ -257,7 +261,14 @@ def load_from_oss(oss_model_dir, local_dir=None):
257261
with temp_file.TemporaryDirectory() as tmp_dir:
258262
tarball = os.path.join(tmp_dir, TARBALL_NAME)
259263
oss.load_file(oss_model_dir, tarball, TARBALL_NAME)
260-
return Model._unzip(local_dir, tarball)
264+
Model._unzip(local_dir, tarball)
265+
266+
model_obj_file = os.path.join(tmp_dir, MODEL_OBJ_FILE_NAME)
267+
oss.load_file(oss_model_dir, model_obj_file, MODEL_OBJ_FILE_NAME)
268+
with open(model_obj_file, "r") as f:
269+
d = json.loads(f.read(), cls=JSONDecoderWithFeatureColumn)
270+
model = Model._from_dict(d)
271+
return model
261272

262273

263274
def _decompose_model_name(name):

python/runtime/model/tar.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import tarfile
1818

1919

20-
def zip_dir(src_dir, tarball, arcname=None):
20+
def zip_dir(src_dir, tarball, arcname=None, filter=None):
2121
"""To compress a directory into tarball.
2222
2323
Args:
@@ -31,7 +31,7 @@ def zip_dir(src_dir, tarball, arcname=None):
3131
The output name of src_dir in the tarball.
3232
"""
3333
with tarfile.open(tarball, "w:gz") as tar:
34-
tar.add(src_dir, arcname=arcname, recursive=True)
34+
tar.add(src_dir, arcname=arcname, recursive=True, filter=filter)
3535

3636

3737
def unzip_dir(tarball, dest_dir=None):

python/runtime/pai/entry.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ def call_fun(func, params):
7272
return func(**dict_args)
7373

7474

75-
def entrypoint():
76-
with open("train_params.pkl", "rb") as file:
77-
params = pickle.load(file)
75+
def entrypoint(params):
7876
if params["entry_type"] == "train_tf":
7977
call_fun(train_tf, params)
8078
elif params["entry_type"] == "train_xgb":
@@ -96,4 +94,6 @@ def entrypoint():
9694
if __name__ == "__main__":
9795
FLAGS = define_tf_flags()
9896
set_oss_environs(FLAGS)
99-
entrypoint()
97+
with open("train_params.pkl", "rb") as file:
98+
params = pickle.load(file)
99+
entrypoint(params)

python/runtime/pai/prepare_archive.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@
4141
sklearn_pandas==1.6.0
4242
"""
4343

44+
ALL_TAR_FILES = [
45+
JOB_ARCHIVE_FILE,
46+
ENTRY_FILE,
47+
"runtime",
48+
"sqlflow_models",
49+
"requirements.txt",
50+
TRAIN_PARAMS_FILE,
51+
]
52+
4453

4554
def prepare_archive(cwd, estimator, model_save_path, train_params):
4655
"""package needed resource into a tarball"""
@@ -60,10 +69,7 @@ def prepare_archive(cwd, estimator, model_save_path, train_params):
6069
_copy_python_package("sqlflow_models", cwd)
6170
_copy_custom_package(estimator, cwd)
6271

63-
args = [
64-
"tar", "czf", JOB_ARCHIVE_FILE, ENTRY_FILE, "runtime",
65-
"sqlflow_models", "requirements.txt", TRAIN_PARAMS_FILE
66-
]
72+
args = ["tar", "czf"] + ALL_TAR_FILES
6773
if subprocess.call(args, cwd=cwd) != 0:
6874
raise SQLFlowDiagnostic("Can't zip resource")
6975

python/runtime/pai/submitter_train.py

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from runtime.pai.pai_ml.random_forest import get_train_random_forest_pai_cmd
2222
from runtime.pai.prepare_archive import prepare_archive
2323
from runtime.pai.submit_pai_task import submit_pai_task
24+
from runtime.pai_local.try_run import try_pai_local_run
2425

2526

2627
def get_pai_train_cmd(datasource, estimator_string, model_name, train_table,
@@ -137,6 +138,9 @@ def submit_pai_train(datasource,
137138
pai_model.clean_oss_model_path(oss_path_to_save + "/")
138139
train_params["oss_path_to_load"] = oss_path_to_load
139140

141+
if try_pai_local_run(params, oss_path_to_save):
142+
return
143+
140144
with temp_file.TemporaryDirectory(prefix="sqlflow", dir="/tmp") as cwd:
141145
# zip all required resource to a tarball
142146
prepare_archive(cwd, estimator_string, oss_path_to_save, params)

python/runtime/pai_local/__init__.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import os
15+
16+
17+
def _gen_pai_local_method(name):
18+
def impl(*args, **kwargs):
19+
import runtime.pai as pai
20+
method = getattr(pai, name)
21+
os.environ["SQLFLOW_submitter"] = "pai_local"
22+
return method(*args, **kwargs)
23+
24+
return impl
25+
26+
27+
train = _gen_pai_local_method('train')
28+
predict = _gen_pai_local_method('predict')
29+
evaluate = _gen_pai_local_method('evaluate')
30+
explain = _gen_pai_local_method('explain')

python/runtime/pai_local/try_run.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import os
15+
16+
from runtime.pai import pai_model
17+
from runtime.pai.pai_distributed import define_tf_flags, set_oss_environs
18+
19+
20+
def init_pai_local_tf_flags_and_envs(oss_model_dir):
21+
FLAGS = define_tf_flags()
22+
FLAGS.sqlflow_oss_ak = os.getenv("SQLFLOW_OSS_AK")
23+
FLAGS.sqlflow_oss_sk = os.getenv("SQLFLOW_OSS_SK")
24+
FLAGS.sqlflow_oss_ep = os.getenv("SQLFLOW_OSS_MODEL_ENDPOINT")
25+
if not oss_model_dir.startswith("oss://"):
26+
oss_model_dir = pai_model.get_oss_model_url(oss_model_dir)
27+
FLAGS.sqlflow_oss_modeldir = oss_model_dir
28+
FLAGS.checkpointDir = os.getcwd()
29+
set_oss_environs(FLAGS)
30+
31+
32+
def try_pai_local_run(params, oss_model_dir):
33+
if os.getenv("SQLFLOW_submitter") == "pai_local":
34+
from runtime.pai.entry import entrypoint
35+
init_pai_local_tf_flags_and_envs(oss_model_dir)
36+
print('start to run using pai_local submitter ...')
37+
entrypoint(params)
38+
return True
39+
else:
40+
return False

python/runtime/step/tensorflow/train.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from runtime.feature.compile import compile_ir_feature_columns
2020
from runtime.feature.derivation import (get_ordered_field_descs,
2121
infer_feature_columns)
22-
from runtime.model import EstimatorType, Model, collect_metadata, oss
22+
from runtime.model import EstimatorType, Model, collect_metadata
2323
from runtime.pai.pai_distributed import define_tf_flags, set_oss_environs
2424
from runtime.step.tensorflow.train_estimator import estimator_train_and_save
2525
from runtime.step.tensorflow.train_keras import keras_train_and_save
@@ -69,9 +69,15 @@ def train_step(original_sql,
6969
if validation_params is None:
7070
validation_params = {}
7171

72+
is_pai = True if pai_table else False
73+
if is_pai:
74+
actual_select = "SELECT * FROM %s" % pai_table
75+
else:
76+
actual_select = select
77+
7278
conn = db.connect_with_data_source(datasource)
7379
fc_map_ir, fc_label_ir = infer_feature_columns(conn,
74-
select,
80+
actual_select,
7581
feature_column_map,
7682
label_column,
7783
n=1000)
@@ -124,7 +130,6 @@ def train_step(original_sql,
124130
estimator = import_model(estimator_string)
125131
is_estimator = is_tf_estimator(estimator)
126132

127-
is_pai = True if pai_table else False
128133
# always use verbose == 1 when using PAI to get more logs
129134
if verbose < 1:
130135
verbose = 1
@@ -150,7 +155,7 @@ def train_step(original_sql,
150155
num_workers=num_workers,
151156
worker_id=worker_id)
152157
val_dataset_fn = None
153-
if validation_select:
158+
if validation_select or pai_val_table:
154159
val_dataset_fn = get_dataset_fn(validation_select, datasource,
155160
feature_column_names, feature_metas,
156161
label_meta, is_pai, pai_val_table,
@@ -185,19 +190,16 @@ def train_step(original_sql,
185190
save_checkpoints_steps, validation_metrics,
186191
load, model_meta)
187192

188-
# save model to DB
193+
# save model to DB/OSS
194+
model = Model(EstimatorType.TENSORFLOW, model_meta)
189195
if num_workers == 1 or worker_id == 0:
190196
if is_pai:
191197
oss_model_dir = FLAGS.sqlflow_oss_modeldir
192-
oss.save_oss_model(oss_model_dir, estimator_string, is_estimator,
193-
feature_column_names, feature_column_names_map,
194-
feature_metas, label_meta, model_params,
195-
fc_map_ir, num_workers)
198+
model.save_to_oss(oss_model_dir)
196199
print("Model saved to OSS: %s" % oss_model_dir)
197200
else:
198-
model = Model(EstimatorType.TENSORFLOW, model_meta)
199201
model.save_to_db(datasource, save)
200-
print("Model saved to db: %s" % save)
202+
print("Model saved to DB: %s" % save)
201203

202204
print("Done training")
203205
conn.close()

python/runtime/xgboost/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313

1414
import sys
1515

16-
import runtime.pai.pai_distributed as pai_dist
1716
import six
1817
import xgboost as xgb
1918
from runtime.local.xgboost_submitter.save import save_model_to_local_file
2019
from runtime.model import collect_metadata
2120
from runtime.model import oss as pai_model_store
2221
from runtime.model import save_metadata
22+
from runtime.pai.pai_distributed import make_distributed_info_without_evaluator
2323
from runtime.xgboost.dataset import xgb_dataset
2424
from runtime.xgboost.pai_rabit import PaiXGBoostTracker, PaiXGBoostWorker
2525

@@ -50,8 +50,7 @@ def dist_train(flags,
5050
"XGBoost distributed training is only supported on PAI")
5151

5252
num_workers = len(flags.worker_hosts.split(","))
53-
cluster, node, task_id = pai_dist.make_distributed_info_without_evaluator(
54-
flags)
53+
cluster, node, task_id = make_distributed_info_without_evaluator(flags)
5554
master_addr = cluster["ps"][0].split(":")
5655
master_host = master_addr[0]
5756
master_port = int(master_addr[1]) + 1

0 commit comments

Comments
 (0)