@@ -99,7 +99,7 @@ def _from_dict(d):
99
99
typ = d .pop ("model_type" )
100
100
return Model (typ , d )
101
101
102
- def _zip (self , local_dir , tarball , save_to_db = False ):
102
+ def _zip (self , local_dir , tarball ):
103
103
"""
104
104
Zip the model information and all files in local_dir into a tarball.
105
105
@@ -110,20 +110,23 @@ def _zip(self, local_dir, tarball, save_to_db=False):
110
110
Returns:
111
111
None.
112
112
"""
113
- if not save_to_db :
114
- model_obj_file = os .path .join (local_dir , MODEL_OBJ_FILE_NAME )
115
- with open (model_obj_file , "w" ) as f :
116
- d = self ._to_dict ()
117
- f .write (json .dumps (d , cls = JSONEncoderWithFeatureColumn ))
118
- else :
119
- model_obj_file = None
113
+ # NOTE: the unzip files of the job tarball should be skipped
114
+ from runtime .pai .prepare_archive import ALL_TAR_FILES
115
+
116
+ def filter (tarinfo ):
117
+ name = tarinfo .name
118
+ if name .startswith ("./" ):
119
+ name = name [2 :]
120
+
121
+ if name in ALL_TAR_FILES :
122
+ return None
123
+
124
+ return tarinfo
120
125
121
- zip_dir (local_dir , tarball , arcname = "./" )
122
- if model_obj_file :
123
- os .remove (model_obj_file )
126
+ zip_dir (local_dir , tarball , arcname = "./" , filter = filter )
124
127
125
128
@staticmethod
126
- def _unzip (local_dir , tarball , load_from_db = False ):
129
+ def _unzip (local_dir , tarball ):
127
130
"""
128
131
Unzip the tarball into local_dir and deserialize the model
129
132
information.
@@ -137,13 +140,6 @@ def _unzip(local_dir, tarball, load_from_db=False):
137
140
information.
138
141
"""
139
142
unzip_dir (tarball , local_dir )
140
- if not load_from_db :
141
- model_obj_file = os .path .join (local_dir , MODEL_OBJ_FILE_NAME )
142
- with open (model_obj_file , "r" ) as f :
143
- d = json .loads (f .read (), cls = JSONDecoderWithFeatureColumn )
144
- model = Model ._from_dict (d )
145
- os .remove (model_obj_file )
146
- return model
147
143
148
144
def save_to_db (self , datasource , table , local_dir = None ):
149
145
"""
@@ -164,7 +160,7 @@ def save_to_db(self, datasource, table, local_dir=None):
164
160
165
161
with temp_file .TemporaryDirectory () as tmp_dir :
166
162
tarball = os .path .join (tmp_dir , TARBALL_NAME )
167
- self ._zip (local_dir , tarball , save_to_db = True )
163
+ self ._zip (local_dir , tarball )
168
164
169
165
def _bytes_reader (filename , buf_size = 8 * 32 ):
170
166
def _gen ():
@@ -212,7 +208,7 @@ def load_from_db(datasource, table, local_dir=None):
212
208
for data in gen ():
213
209
f .write (bytes (data ))
214
210
215
- Model ._unzip (local_dir , tarball , load_from_db = True )
211
+ Model ._unzip (local_dir , tarball )
216
212
217
213
return Model ._from_dict (metadata )
218
214
@@ -237,6 +233,14 @@ def save_to_oss(self, oss_model_dir, local_dir=None):
237
233
self ._zip (local_dir , tarball )
238
234
oss .save_file (oss_model_dir , tarball , TARBALL_NAME )
239
235
236
+ with temp_file .TemporaryDirectory () as tmp_dir :
237
+ model_obj_file = os .path .join (tmp_dir , MODEL_OBJ_FILE_NAME )
238
+ with open (model_obj_file , "w" ) as f :
239
+ f .write (
240
+ json .dumps (self ._to_dict (),
241
+ cls = JSONEncoderWithFeatureColumn ))
242
+ oss .save_file (oss_model_dir , model_obj_file , MODEL_OBJ_FILE_NAME )
243
+
240
244
@staticmethod
241
245
def load_from_oss (oss_model_dir , local_dir = None ):
242
246
"""
@@ -257,7 +261,14 @@ def load_from_oss(oss_model_dir, local_dir=None):
257
261
with temp_file .TemporaryDirectory () as tmp_dir :
258
262
tarball = os .path .join (tmp_dir , TARBALL_NAME )
259
263
oss .load_file (oss_model_dir , tarball , TARBALL_NAME )
260
- return Model ._unzip (local_dir , tarball )
264
+ Model ._unzip (local_dir , tarball )
265
+
266
+ model_obj_file = os .path .join (tmp_dir , MODEL_OBJ_FILE_NAME )
267
+ oss .load_file (oss_model_dir , model_obj_file , MODEL_OBJ_FILE_NAME )
268
+ with open (model_obj_file , "r" ) as f :
269
+ d = json .loads (f .read (), cls = JSONDecoderWithFeatureColumn )
270
+ model = Model ._from_dict (d )
271
+ return model
261
272
262
273
263
274
def _decompose_model_name (name ):
0 commit comments