Skip to content

Commit 53ea76a

Browse files
authored
Prevent dataset to break if it already exists (#19491)
1 parent ddf2ac4 commit 53ea76a

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

src/lightning/data/processing/data_processor.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_TORCH_GREATER_EQUAL_2_1_0,
3030
)
3131
from lightning.data.processing.readers import BaseReader
32+
from lightning.data.processing.utilities import _create_dataset
3233
from lightning.data.streaming import Cache
3334
from lightning.data.streaming.cache import Dir
3435
from lightning.data.streaming.client import S3Client
@@ -41,7 +42,6 @@
4142

4243
if _LIGHTNING_CLOUD_LATEST:
4344
from lightning_cloud.openapi import V1DatasetType
44-
from lightning_cloud.utils.dataset import _create_dataset
4545

4646

4747
if _BOTO3_AVAILABLE:
@@ -973,7 +973,8 @@ def run(self, data_recipe: DataRecipe) -> None:
973973
print("Workers are finished.")
974974
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)
975975

976-
if num_nodes == node_rank + 1 and self.output_dir.url:
976+
if num_nodes == node_rank + 1 and self.output_dir.url and _IS_IN_STUDIO:
977+
assert self.output_dir.path
977978
_create_dataset(
978979
input_dir=self.input_dir.path,
979980
storage_dir=self.output_dir.path,

src/lightning/data/processing/utilities.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,75 @@
33
import urllib
44
from contextlib import contextmanager
55
from subprocess import Popen
6-
from typing import Any, Callable, Optional, Tuple
6+
from typing import Any, Callable, List, Optional, Tuple, Union
7+
8+
from lightning.data.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_LATEST
9+
10+
if _LIGHTNING_CLOUD_LATEST:
11+
from lightning_cloud.openapi import (
12+
ProjectIdDatasetsBody,
13+
V1DatasetType,
14+
)
15+
from lightning_cloud.openapi.rest import ApiException
16+
from lightning_cloud.rest_client import LightningClient
17+
18+
19+
def _create_dataset(
20+
input_dir: Optional[str],
21+
storage_dir: str,
22+
dataset_type: V1DatasetType,
23+
empty: Optional[bool] = None,
24+
size: Optional[int] = None,
25+
num_bytes: Optional[str] = None,
26+
data_format: Optional[Union[str, Tuple[str]]] = None,
27+
compression: Optional[str] = None,
28+
num_chunks: Optional[int] = None,
29+
num_bytes_per_chunk: Optional[List[int]] = None,
30+
name: Optional[str] = None,
31+
version: Optional[int] = None,
32+
) -> None:
33+
"""Create a dataset with metadata information about its source and destination."""
34+
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
35+
cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
36+
user_id = os.getenv("LIGHTNING_USER_ID", None)
37+
cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None)
38+
lightning_app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None)
39+
40+
if project_id is None:
41+
return
42+
43+
if not storage_dir:
44+
raise ValueError("The storage_dir should be defined.")
45+
46+
client = LightningClient(retry=False)
747

8-
from lightning.data.constants import _IS_IN_STUDIO
48+
try:
49+
client.dataset_service_create_dataset(
50+
body=ProjectIdDatasetsBody(
51+
cloud_space_id=cloud_space_id if lightning_app_id is None else None,
52+
cluster_id=cluster_id,
53+
creator_id=user_id,
54+
empty=empty,
55+
input_dir=input_dir,
56+
lightning_app_id=lightning_app_id,
57+
name=name,
58+
size=size,
59+
num_bytes=num_bytes,
60+
data_format=str(data_format) if data_format else data_format,
61+
compression=compression,
62+
num_chunks=num_chunks,
63+
num_bytes_per_chunk=num_bytes_per_chunk,
64+
storage_dir=storage_dir,
65+
type=dataset_type,
66+
version=version,
67+
),
68+
project_id=project_id,
69+
)
70+
except ApiException as ex:
71+
if "already exists" in str(ex.body):
72+
pass
73+
else:
74+
raise ex
975

1076

1177
def get_worker_rank() -> Optional[str]:

0 commit comments

Comments
 (0)