Skip to content

Commit 39a86f8

Browse files
authored
Resolve compression, add support for torchaudio (#19503)
1 parent 2394e2f commit 39a86f8

File tree

8 files changed

+143
-16
lines changed

8 files changed

+143
-16
lines changed

src/lightning/data/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
2929
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.64")
3030
_BOTO3_AVAILABLE = RequirementCache("boto3")
31+
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
32+
_ZSTD_AVAILABLE = RequirementCache("zstd")
3133

3234
# DON'T CHANGE ORDER
3335
_TORCH_DTYPES_MAPPING = {

src/lightning/data/streaming/compression.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from abc import ABC, abstractmethod
1515
from typing import Dict, TypeVar
1616

17-
from lightning_utilities.core.imports import RequirementCache, requires
17+
from lightning_utilities.core.imports import requires
1818

19-
_ZSTD_AVAILABLE = RequirementCache("zstd")
19+
from lightning.data.constants import _ZSTD_AVAILABLE
2020

2121
if _ZSTD_AVAILABLE:
2222
import zstd

src/lightning/data/streaming/config.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Dict, List, Optional, Tuple
1717

1818
from lightning.data.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
19+
from lightning.data.streaming.compression import _COMPRESSORS, Compressor
1920
from lightning.data.streaming.downloader import get_downloader_cls
2021
from lightning.data.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
2122
from lightning.data.streaming.sampler import ChunkedIndex
@@ -66,19 +67,47 @@ def __init__(
6667
if remote_dir:
6768
self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks)
6869

70+
self._compressor_name = self._config["compression"]
71+
self._compressor: Optional[Compressor] = None
72+
73+
if self._compressor_name in _COMPRESSORS:
74+
self._compressor = _COMPRESSORS[self._compressor_name]
75+
6976
def download_chunk_from_index(self, chunk_index: int) -> None:
7077
chunk_filename = self._chunks[chunk_index]["filename"]
7178

7279
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
7380

7481
if os.path.exists(local_chunkpath):
82+
self.try_decompress(local_chunkpath)
7583
return
7684

7785
if self._downloader is None:
7886
raise RuntimeError("The downloader should be defined.")
7987

8088
self._downloader.download_chunk_from_index(chunk_index)
8189

90+
self.try_decompress(local_chunkpath)
91+
92+
def try_decompress(self, local_chunkpath: str) -> None:
93+
if self._compressor is None:
94+
return
95+
96+
target_local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "")
97+
98+
if os.path.exists(target_local_chunkpath):
99+
return
100+
101+
with open(local_chunkpath, "rb") as f:
102+
data = f.read()
103+
104+
os.remove(local_chunkpath)
105+
106+
data = self._compressor.decompress(data)
107+
108+
with open(target_local_chunkpath, "wb") as f:
109+
f.write(data)
110+
82111
@property
83112
def intervals(self) -> List[Tuple[int, int]]:
84113
if self._intervals is None:
@@ -132,7 +161,13 @@ def _get_chunk_index_from_index(self, index: int) -> int:
132161
def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
133162
"""Find the associated chunk metadata."""
134163
chunk = self._chunks[index.chunk_index]
135-
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]
164+
165+
local_chunkpath = os.path.join(self._cache_dir, chunk["filename"])
166+
167+
if self._compressor is not None:
168+
local_chunkpath = local_chunkpath.replace(f".{self._compressor_name}", "")
169+
170+
return local_chunkpath, *self._intervals[index.chunk_index]
136171

137172
def _get_chunk_index_from_filename(self, chunk_filename: str) -> int:
138173
"""Retrieves the associated chunk_index for a given chunk filename."""

src/lightning/data/streaming/item_loader.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14+
import functools
1415
import os
1516
from abc import ABC, abstractmethod
1617
from time import sleep
@@ -101,15 +102,25 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
101102
begin, end = np.frombuffer(pair, np.uint32)
102103
fp.seek(begin)
103104
data = fp.read(end - begin)
105+
104106
return self.deserialize(data)
105107

108+
@functools.lru_cache(maxsize=128)
109+
def _data_format_to_key(self, data_format: str) -> str:
110+
if ":" in data_format:
111+
serialier, serializer_sub_type = data_format.split(":")
112+
if serializer_sub_type in self._serializers:
113+
return serializer_sub_type
114+
return serialier
115+
return data_format
116+
106117
def deserialize(self, raw_item_data: bytes) -> "PyTree":
107118
"""Deserialize the raw bytes into their python equivalent."""
108119
idx = len(self._config["data_format"]) * 4
109120
sizes = np.frombuffer(raw_item_data[:idx], np.uint32)
110121
data = []
111122
for size, data_format in zip(sizes, self._config["data_format"]):
112-
serializer = self._serializers[data_format]
123+
serializer = self._serializers[self._data_format_to_key(data_format)]
113124
data_bytes = raw_item_data[idx : idx + size]
114125
data.append(serializer.deserialize(data_bytes))
115126
idx += size

src/lightning/data/streaming/reader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def read(self, index: ChunkedIndex) -> Any:
229229
if self._config is None and self._try_load_config() is None:
230230
raise Exception("The reader index isn't defined.")
231231

232-
if self._config and self._config._remote_dir:
232+
if self._config and (self._config._remote_dir or self._config._compressor):
233233
# Create and start the prepare chunks thread
234234
if self._prepare_thread is None and self._config:
235235
self._prepare_thread = PrepareChunksThread(

src/lightning/data/streaming/serializers.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,8 @@ class FileSerializer(Serializer):
282282
def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]:
283283
_, file_extension = os.path.splitext(filepath)
284284
with open(filepath, "rb") as f:
285-
return f.read(), file_extension.replace(".", "").lower()
285+
file_extension = file_extension.replace(".", "").lower()
286+
return f.read(), f"file:{file_extension}"
286287

287288
def deserialize(self, data: bytes) -> Any:
288289
return data
@@ -292,12 +293,13 @@ def can_serialize(self, data: Any) -> bool:
292293

293294

294295
class VideoSerializer(Serializer):
295-
_EXTENSIONS = ("mp4", "ogv", "mjpeg", "avi", "mov", "h264", "mpg", "webm", "wmv", "wav")
296+
_EXTENSIONS = ("mp4", "ogv", "mjpeg", "avi", "mov", "h264", "mpg", "webm", "wmv")
296297

297298
def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]:
298299
_, file_extension = os.path.splitext(filepath)
299300
with open(filepath, "rb") as f:
300-
return f.read(), file_extension.replace(".", "").lower()
301+
file_extension = file_extension.replace(".", "").lower()
302+
return f.read(), f"video:{file_extension}"
301303

302304
def deserialize(self, data: bytes) -> Any:
303305
if not _TORCH_VISION_AVAILABLE:

tests/tests_data/processing/test_data_processor.py

+78-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
import torch
1212
from lightning import seed_everything
13+
from lightning.data.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE
1314
from lightning.data.processing import data_processor as data_processor_module
1415
from lightning.data.processing import functions
1516
from lightning.data.processing.data_processor import (
@@ -26,7 +27,7 @@
2627
_wait_for_file_to_exist,
2728
)
2829
from lightning.data.processing.functions import LambdaDataTransformRecipe, map, optimize
29-
from lightning.data.streaming import resolver
30+
from lightning.data.streaming import StreamingDataset, resolver
3031
from lightning.data.streaming.cache import Cache, Dir
3132
from lightning_utilities.core.imports import RequirementCache
3233

@@ -1058,3 +1059,79 @@ def test_empty_optimize(tmpdir):
10581059
)
10591060

10601061
assert os.listdir(tmpdir) == ["index.json"]
1062+
1063+
1064+
def create_synthetic_audio_bytes(index) -> dict:
1065+
from io import BytesIO
1066+
1067+
import torchaudio
1068+
1069+
# load dummy audio as bytes
1070+
data = torch.randn((1, 16000))
1071+
1072+
# convert tensor to bytes
1073+
with BytesIO() as f:
1074+
torchaudio.save(f, data, 16000, format="wav")
1075+
data = f.getvalue()
1076+
1077+
data = {"content": data}
1078+
return data
1079+
1080+
1081+
@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
1082+
@pytest.mark.parametrize("compression", [None, "zstd"])
1083+
def test_load_torch_audio(tmpdir, compression):
1084+
seed_everything(42)
1085+
1086+
import torchaudio
1087+
1088+
optimize(
1089+
fn=create_synthetic_audio_bytes,
1090+
inputs=list(range(100)),
1091+
output_dir=str(tmpdir),
1092+
num_workers=1,
1093+
chunk_bytes="64MB",
1094+
compression=compression,
1095+
)
1096+
1097+
dataset = StreamingDataset(input_dir=str(tmpdir))
1098+
sample = dataset[0]
1099+
tensor = torchaudio.load(sample["content"])
1100+
assert tensor[0].shape == torch.Size([1, 16000])
1101+
assert tensor[1] == 16000
1102+
1103+
1104+
def create_synthetic_audio_file(filepath) -> dict:
1105+
import torchaudio
1106+
1107+
# load dummy audio as bytes
1108+
data = torch.randn((1, 16000))
1109+
1110+
# convert tensor to bytes
1111+
with open(filepath, "wb") as f:
1112+
torchaudio.save(f, data, 16000, format="wav")
1113+
1114+
return filepath
1115+
1116+
1117+
@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']")
1118+
@pytest.mark.parametrize("compression", [None])
1119+
def test_load_torch_audio_from_wav_file(tmpdir, compression):
1120+
seed_everything(42)
1121+
1122+
import torchaudio
1123+
1124+
optimize(
1125+
fn=create_synthetic_audio_file,
1126+
inputs=[os.path.join(tmpdir, f"{i}.wav") for i in range(5)],
1127+
output_dir=str(tmpdir),
1128+
num_workers=1,
1129+
chunk_bytes="64MB",
1130+
compression=compression,
1131+
)
1132+
1133+
dataset = StreamingDataset(input_dir=str(tmpdir))
1134+
sample = dataset[0]
1135+
tensor = torchaudio.load(sample)
1136+
assert tensor[0].shape == torch.Size([1, 16000])
1137+
assert tensor[1] == 16000

tests/tests_data/streaming/test_serializer.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,16 @@ def test_assert_no_header_numpy_serializer():
213213
def test_wav_deserialization(tmpdir):
214214
from torch.hub import download_url_to_file
215215

216-
video_file = os.path.join(tmpdir, "video.wav")
217-
key = "tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa E501
216+
video_file = os.path.join(tmpdir, "video.mp4")
217+
key = "tutorial-assets/mptestsrc.mp4" # E501
218218
download_url_to_file(f"https://download.pytorch.org/torchaudio/{key}", video_file)
219219

220220
serializer = VideoSerializer()
221221
assert serializer.can_serialize(video_file)
222222
data, name = serializer.serialize(video_file)
223-
assert len(data) / 1024 / 1024 == 0.10380172729492188
224-
assert name == "wav"
223+
assert len(data) / 1024 / 1024 == 0.2262248992919922
224+
assert name == "video:mp4"
225225
vframes, aframes, info = serializer.deserialize(data)
226-
assert vframes.shape == torch.Size([0, 1, 1, 3])
227-
assert aframes.shape == torch.Size([1, 54400])
228-
assert info == {"audio_fps": 16000}
226+
assert vframes.shape == torch.Size([301, 512, 512, 3])
227+
assert aframes.shape == torch.Size([1, 0])
228+
assert info == {"video_fps": 25.0}

0 commit comments

Comments
 (0)