@@ -190,7 +190,14 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
190
190
s3 = S3Client ()
191
191
192
192
while True :
193
- local_filepath : Optional [str ] = upload_queue .get ()
193
+ data : Optional [Union [str , Tuple [str , str ]]] = upload_queue .get ()
194
+
195
+ tmpdir = None
196
+
197
+ if isinstance (data , str ) or data is None :
198
+ local_filepath = data
199
+ else :
200
+ tmpdir , local_filepath = data
194
201
195
202
# Terminate the process if we received a termination signal
196
203
if local_filepath is None :
@@ -202,15 +209,25 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
202
209
203
210
if obj .scheme == "s3" :
204
211
try :
212
+ if tmpdir is None :
213
+ output_filepath = os .path .join (str (obj .path ).lstrip ("/" ), os .path .basename (local_filepath ))
214
+ else :
215
+ output_filepath = os .path .join (str (obj .path ).lstrip ("/" ), local_filepath .replace (tmpdir , "" )[1 :])
216
+
205
217
s3 .client .upload_file (
206
218
local_filepath ,
207
219
obj .netloc ,
208
- os . path . join ( str ( obj . path ). lstrip ( "/" ), os . path . basename ( local_filepath )) ,
220
+ output_filepath ,
209
221
)
210
222
except Exception as e :
211
223
print (e )
212
224
elif output_dir .path and os .path .isdir (output_dir .path ):
213
- shutil .copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
225
+ if tmpdir is None :
226
+ shutil .copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
227
+ else :
228
+ output_filepath = os .path .join (output_dir .path , local_filepath .replace (tmpdir , "" )[1 :])
229
+ os .makedirs (os .path .dirname (output_filepath ), exist_ok = True )
230
+ shutil .copyfile (local_filepath , output_filepath )
214
231
else :
215
232
raise ValueError (f"The provided { output_dir .path } isn't supported." )
216
233
@@ -435,12 +452,15 @@ def _create_cache(self) -> None:
435
452
)
436
453
self .cache ._reader ._rank = _get_node_rank () * self .num_workers + self .worker_index
437
454
438
- def _try_upload (self , filepath : Optional [str ]) -> None :
439
- if not filepath or (self .output_dir .url if self .output_dir .url else self .output_dir .path ) is None :
455
+ def _try_upload (self , data : Optional [Union [ str , Tuple [ str , str ]] ]) -> None :
456
+ if not data or (self .output_dir .url if self .output_dir .url else self .output_dir .path ) is None :
440
457
return
441
458
442
- assert os .path .exists (filepath ), filepath
443
- self .to_upload_queues [self ._counter % self .num_uploaders ].put (filepath )
459
+ if isinstance (data , str ):
460
+ assert os .path .exists (data ), data
461
+ else :
462
+ assert os .path .exists (data [- 1 ]), data
463
+ self .to_upload_queues [self ._counter % self .num_uploaders ].put (data )
444
464
445
465
def _collect_paths (self ) -> None :
446
466
if self .input_dir .path is None :
@@ -582,7 +602,7 @@ def _handle_data_transform_recipe(self, index: int) -> None:
582
602
filepaths .append (os .path .join (directory , filename ))
583
603
584
604
for filepath in filepaths :
585
- self ._try_upload (filepath )
605
+ self ._try_upload (( output_dir , filepath ) )
586
606
587
607
588
608
class DataWorkerProcess (BaseWorker , Process ):
0 commit comments