Skip to content

Commit 41bb11e

Browse files
authored
Merge branch 'main' into per_stream_batching
2 parents 7cf1b41 + 012baa2 commit 41bb11e

22 files changed

+933
-136
lines changed

README.md

+93
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,99 @@ dataset = StreamingDataset('s3://my-bucket/my-data', cache_dir="/path/to/cache")
236236

237237
</details>
238238

239+
<details>
240+
<summary> ✅ Stream Hugging Face 🤗 datasets</summary>
241+
242+
&nbsp;
243+
244+
To use your favorite Hugging Face dataset with LitData, simply pass its URL to `StreamingDataset`.
245+
246+
<details>
247+
<summary>How to get HF dataset URI?</summary>
248+
249+
https://github.com/user-attachments/assets/3ba9e2ef-bf6b-41fc-a578-e4b4113a0e72
250+
251+
</details>
252+
253+
```python
254+
import litdata as ld
255+
256+
hf_uri = "hf://datasets/leonardPKU/clevr_cogen_a_train/data"
257+
258+
ds = ld.StreamingDataset(hf_uri)
259+
260+
for _ds in ds:
261+
print(f"{_ds[1]}; {_ds[2]}")
262+
```
263+
264+
You don’t need to worry about indexing the dataset or any other setup. **LitData** will **handle all the necessary steps automatically** and `cache` the `index.json` file, so you won't have to index it again.
265+
266+
This ensures that the next time you stream the dataset, the indexing step is skipped..
267+
268+
&nbsp;
269+
270+
### Indexing the HF dataset (Optional)
271+
272+
If the Hugging Face dataset hasn't been indexed yet, you can index it first using the `index_hf_dataset` method, and then stream it using the code above.
273+
274+
```python
275+
import litdata as ld
276+
277+
hf_uri = "hf://datasets/leonardPKU/clevr_cogen_a_train/data"
278+
279+
ld.index_hf_dataset(hf_uri)
280+
```
281+
282+
- Indexing the Hugging Face dataset ahead of time will make streaming faster, as it avoids the need for real-time indexing during streaming.
283+
284+
- To use `HF gated dataset`, ensure the `HF_TOKEN` environment variable is set.
285+
286+
**Note**: For HuggingFace datasets, `indexing` & `streaming` is supported only for datasets in **`Parquet format`**.
287+
288+
&nbsp;
289+
290+
### Full Workflow for Hugging Face Datasets
291+
292+
For full control over the cache path(`where index.json file will be stored`) and other configurations, follow these steps:
293+
294+
1. Index the Hugging Face dataset first:
295+
296+
```python
297+
import litdata as ld
298+
299+
hf_uri = "hf://datasets/open-thoughts/OpenThoughts-114k/data"
300+
301+
ld.index_parquet_dataset(hf_uri, "hf-index-dir")
302+
```
303+
304+
2. To stream HF datasets now, pass the `HF dataset URI`, the path where the `index.json` file is stored, and `ParquetLoader` as the `item_loader` to the **`StreamingDataset`**:
305+
306+
```python
307+
import litdata as ld
308+
from litdata.streaming.item_loader import ParquetLoader
309+
310+
hf_uri = "hf://datasets/open-thoughts/OpenThoughts-114k/data"
311+
312+
ds = ld.StreamingDataset(hf_uri, item_loader=ParquetLoader(), index_path="hf-index-dir")
313+
314+
for _ds in ds:
315+
print(f"{_ds[0]}; {_ds[1]}\n")
316+
```
317+
318+
&nbsp;
319+
320+
### LitData `Optimize` v/s `Parquet`
321+
322+
Below is the benchmark for the `Imagenet dataset (155 GB)`, demonstrating that **`optimizing the dataset using LitData is faster and results in smaller output size compared to raw Parquet files`**.
323+
324+
| **Operation** | **Size (GB)** | **Time (seconds)** | **Throughput (images/sec)** |
325+
|-----------------------------------|---------------|---------------------|-----------------------------|
326+
| LitData Optimize Dataset | 45 | 283.17 | 4000-4700 |
327+
| Parquet Optimize Dataset | 51 | 465.96 | 3600-3900 |
328+
| Index Parquet Dataset (overhead) | N/A | 6 | N/A |
329+
330+
</details>
331+
239332
<details>
240333
<summary> ✅ Streams on multi-GPU, multi-node</summary>
241334

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ lint.ignore = [
7272
"E731", # Do not assign a lambda expression, use a def
7373
"S101", # todo: Use of `assert` detected
7474
]
75-
lint.ignore-init-module-imports = true
7675
# Unlike Flake8, default to a complexity level of 10.
7776
lint.mccabe.max-complexity = 10
7877
# Use Google-style docstrings.

requirements/extras.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ tqdm
66
lightning-sdk==0.1.46 # Must be pinned to ensure compatibility
77
google-cloud-storage
88
polars
9+
fsspec

src/litdata/__about__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import time
1616

17-
__version__ = "0.2.38"
17+
__version__ = "0.2.39"
1818
__author__ = "Lightning AI et al."
1919
__author_email__ = "pytorch@lightning.ai"
2020
__license__ = "Apache-2.0"

src/litdata/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from litdata.streaming.item_loader import TokensLoader
2121
from litdata.streaming.writer import index_parquet_dataset
2222
from litdata.utilities.breakpoint import breakpoint
23+
from litdata.utilities.hf_dataset import index_hf_dataset
2324
from litdata.utilities.train_test_split import train_test_split
2425

2526
__all__ = [
@@ -33,6 +34,7 @@
3334
"train_test_split",
3435
"merge_datasets",
3536
"index_parquet_dataset",
37+
"index_hf_dataset",
3638
"breakpoint",
3739
]
3840
if RequirementCache("lightning_sdk"):

src/litdata/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
3030
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
3131
_BOTO3_AVAILABLE = RequirementCache("boto3")
32+
_FSSPEC_AVAILABLE = RequirementCache("fsspec")
3233
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
3334
_ZSTD_AVAILABLE = RequirementCache("zstd")
3435
_CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography")
3536
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
3637
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
3738
_TQDM_AVAILABLE = RequirementCache("tqdm")
3839
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
40+
_HF_HUB_AVAILABLE = RequirementCache("huggingface_hub")
3941
_POLARS_AVAILABLE = RequirementCache("polars>1.0.0")
4042
_DEBUG = bool(int(os.getenv("DEBUG", "1")))
4143

src/litdata/processing/functions.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from functools import partial
2323
from pathlib import Path
2424
from types import FunctionType
25-
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
25+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
2626
from urllib import parse
2727

2828
import torch
@@ -52,6 +52,9 @@
5252
from litdata.utilities.encryption import Encryption
5353
from litdata.utilities.format import _get_tqdm_iterator_if_available
5454

55+
if TYPE_CHECKING:
56+
from lightning_sdk import Machine
57+
5558

5659
def _is_remote_file(path: str) -> bool:
5760
obj = parse.urlparse(path)
@@ -194,7 +197,7 @@ def map(
194197
num_workers: Optional[int] = None,
195198
fast_dev_run: Union[bool, int] = False,
196199
num_nodes: Optional[int] = None,
197-
machine: Optional[str] = None,
200+
machine: Optional[Union["Machine", str]] = None,
198201
num_downloaders: Optional[int] = None,
199202
num_uploaders: Optional[int] = None,
200203
reorder_files: bool = True,
@@ -312,7 +315,7 @@ def optimize(
312315
num_workers: Optional[int] = None,
313316
fast_dev_run: bool = False,
314317
num_nodes: Optional[int] = None,
315-
machine: Optional[str] = None,
318+
machine: Optional[Union["Machine", str]] = None,
316319
num_downloaders: Optional[int] = None,
317320
num_uploaders: Optional[int] = None,
318321
reorder_files: bool = True,

src/litdata/streaming/config.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,18 @@ def download_chunk_from_index(self, chunk_index: int) -> None:
120120

121121
if os.path.exists(local_chunkpath):
122122
self.try_decompress(local_chunkpath)
123+
if self._downloader is not None:
124+
# We don't want to redownload the base, but we should mark
125+
# it as having been requested by something
126+
self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", ""))
127+
pass
123128
return
124129

125130
if self._downloader is None:
126131
return
127132

133+
self._downloader._increment_local_lock(local_chunkpath.replace(f".{self._compressor_name}", ""))
134+
128135
self._downloader.download_chunk_from_index(chunk_index)
129136

130137
self.try_decompress(local_chunkpath)
@@ -257,8 +264,16 @@ def load(
257264
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
258265

259266
if isinstance(remote_dir, str):
260-
downloader = get_downloader_cls(remote_dir, cache_dir, [], storage_options)
261-
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)
267+
# for remote_dir, we try downloading `index.json` file.
268+
# If the files are stored on HF, they don't have an index file, so we can skip downloading it.
269+
if remote_dir.startswith("hf://"):
270+
if not os.path.exists(cache_index_filepath):
271+
raise RuntimeError(
272+
f"This should not have happened. No index.json file found in cache: {cache_index_filepath}"
273+
)
274+
else:
275+
downloader = get_downloader_cls(remote_dir, cache_dir, [], storage_options)
276+
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)
262277

263278
if not os.path.exists(cache_index_filepath):
264279
return None

src/litdata/streaming/dataset.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626
from litdata.helpers import _check_version_and_prompt_upgrade
2727
from litdata.streaming import Cache
2828
from litdata.streaming.downloader import get_downloader_cls # noqa: F401
29-
from litdata.streaming.item_loader import BaseItemLoader
29+
from litdata.streaming.item_loader import BaseItemLoader, ParquetLoader
3030
from litdata.streaming.resolver import Dir, _resolve_dir
3131
from litdata.streaming.sampler import ChunkedIndex
3232
from litdata.streaming.serializers import Serializer
3333
from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
3434
from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset
3535
from litdata.utilities.encryption import Encryption
3636
from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv
37+
from litdata.utilities.hf_dataset import index_hf_dataset
3738
from litdata.utilities.shuffle import (
3839
_find_chunks_per_workers_on_which_to_skip_deletion,
3940
_map_node_worker_rank_to_chunk_indexes_to_not_delete,
@@ -59,6 +60,7 @@ def __init__(
5960
encryption: Optional[Encryption] = None,
6061
storage_options: Optional[Dict] = {},
6162
max_pre_download: int = 2,
63+
index_path: Optional[str] = None,
6264
) -> None:
6365
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
6466
@@ -79,6 +81,9 @@ def __init__(
7981
encryption: The encryption object to use for decrypting the data.
8082
storage_options: Additional connection options for accessing storage services.
8183
max_pre_download: Maximum number of chunks that can be pre-downloaded by the StreamingDataset.
84+
index_path: Path to `index.json` for the Parquet dataset.
85+
If `index_path` is a directory, the function will look for `index.json` within it.
86+
If `index_path` is a full file path, it will use that directly.
8287
8388
"""
8489
_check_version_and_prompt_upgrade(__version__)
@@ -93,12 +98,20 @@ def __init__(
9398
input_dir = _resolve_dir(input_dir)
9499
cache_dir = _resolve_dir(cache_dir)
95100

101+
if input_dir.url is not None and input_dir.url.startswith("hf://"):
102+
if index_path is None:
103+
# no index path provide, load from cache, or try indexing on the go.
104+
index_path = index_hf_dataset(input_dir.url)
105+
cache_dir.path = index_path
106+
input_dir.path = index_path
107+
item_loader = ParquetLoader()
108+
96109
self.input_dir = input_dir
97110
self.cache_dir = cache_dir
98111
self.subsampled_files: List[str] = []
99112
self.region_of_interest: List[Tuple[int, int]] = []
100113
self.subsampled_files, self.region_of_interest = subsample_streaming_dataset(
101-
self.input_dir, self.cache_dir, item_loader, subsample, shuffle, seed, storage_options
114+
self.input_dir, self.cache_dir, item_loader, subsample, shuffle, seed, storage_options, index_path
102115
)
103116

104117
self.item_loader = item_loader

0 commit comments

Comments
 (0)