|
3 | 3 | import urllib
|
4 | 4 | from contextlib import contextmanager
|
5 | 5 | from subprocess import Popen
|
6 |
| -from typing import Any, Callable, Optional, Tuple |
| 6 | +from typing import Any, Callable, List, Optional, Tuple, Union |
| 7 | + |
| 8 | +from lightning.data.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_LATEST |
| 9 | + |
| 10 | +if _LIGHTNING_CLOUD_LATEST: |
| 11 | + from lightning_cloud.openapi import ( |
| 12 | + ProjectIdDatasetsBody, |
| 13 | + V1DatasetType, |
| 14 | + ) |
| 15 | + from lightning_cloud.openapi.rest import ApiException |
| 16 | + from lightning_cloud.rest_client import LightningClient |
| 17 | + |
| 18 | + |
| 19 | +def _create_dataset( |
| 20 | + input_dir: Optional[str], |
| 21 | + storage_dir: str, |
| 22 | + dataset_type: V1DatasetType, |
| 23 | + empty: Optional[bool] = None, |
| 24 | + size: Optional[int] = None, |
| 25 | + num_bytes: Optional[str] = None, |
| 26 | + data_format: Optional[Union[str, Tuple[str]]] = None, |
| 27 | + compression: Optional[str] = None, |
| 28 | + num_chunks: Optional[int] = None, |
| 29 | + num_bytes_per_chunk: Optional[List[int]] = None, |
| 30 | + name: Optional[str] = None, |
| 31 | + version: Optional[int] = None, |
| 32 | +) -> None: |
| 33 | + """Create a dataset with metadata information about its source and destination.""" |
| 34 | + project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) |
| 35 | + cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) |
| 36 | + user_id = os.getenv("LIGHTNING_USER_ID", None) |
| 37 | + cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) |
| 38 | + lightning_app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None) |
| 39 | + |
| 40 | + if project_id is None: |
| 41 | + return |
| 42 | + |
| 43 | + if not storage_dir: |
| 44 | + raise ValueError("The storage_dir should be defined.") |
| 45 | + |
| 46 | + client = LightningClient(retry=False) |
7 | 47 |
|
8 |
| -from lightning.data.constants import _IS_IN_STUDIO |
| 48 | + try: |
| 49 | + client.dataset_service_create_dataset( |
| 50 | + body=ProjectIdDatasetsBody( |
| 51 | + cloud_space_id=cloud_space_id if lightning_app_id is None else None, |
| 52 | + cluster_id=cluster_id, |
| 53 | + creator_id=user_id, |
| 54 | + empty=empty, |
| 55 | + input_dir=input_dir, |
| 56 | + lightning_app_id=lightning_app_id, |
| 57 | + name=name, |
| 58 | + size=size, |
| 59 | + num_bytes=num_bytes, |
| 60 | + data_format=str(data_format) if data_format else data_format, |
| 61 | + compression=compression, |
| 62 | + num_chunks=num_chunks, |
| 63 | + num_bytes_per_chunk=num_bytes_per_chunk, |
| 64 | + storage_dir=storage_dir, |
| 65 | + type=dataset_type, |
| 66 | + version=version, |
| 67 | + ), |
| 68 | + project_id=project_id, |
| 69 | + ) |
| 70 | + except ApiException as ex: |
| 71 | + if "already exists" in str(ex.body): |
| 72 | + pass |
| 73 | + else: |
| 74 | + raise ex |
9 | 75 |
|
10 | 76 |
|
11 | 77 | def get_worker_rank() -> Optional[str]:
|
|
0 commit comments