|
10 | 10 | import pytest
|
11 | 11 | import torch
|
12 | 12 | from lightning import seed_everything
|
| 13 | +from lightning.data.constants import _TORCH_AUDIO_AVAILABLE, _ZSTD_AVAILABLE |
13 | 14 | from lightning.data.processing import data_processor as data_processor_module
|
14 | 15 | from lightning.data.processing import functions
|
15 | 16 | from lightning.data.processing.data_processor import (
|
|
26 | 27 | _wait_for_file_to_exist,
|
27 | 28 | )
|
28 | 29 | from lightning.data.processing.functions import LambdaDataTransformRecipe, map, optimize
|
29 |
| -from lightning.data.streaming import resolver |
| 30 | +from lightning.data.streaming import StreamingDataset, resolver |
30 | 31 | from lightning.data.streaming.cache import Cache, Dir
|
31 | 32 | from lightning_utilities.core.imports import RequirementCache
|
32 | 33 |
|
@@ -1058,3 +1059,79 @@ def test_empty_optimize(tmpdir):
|
1058 | 1059 | )
|
1059 | 1060 |
|
1060 | 1061 | assert os.listdir(tmpdir) == ["index.json"]
|
| 1062 | + |
| 1063 | + |
| 1064 | +def create_synthetic_audio_bytes(index) -> dict: |
| 1065 | + from io import BytesIO |
| 1066 | + |
| 1067 | + import torchaudio |
| 1068 | + |
| 1069 | + # load dummy audio as bytes |
| 1070 | + data = torch.randn((1, 16000)) |
| 1071 | + |
| 1072 | + # convert tensor to bytes |
| 1073 | + with BytesIO() as f: |
| 1074 | + torchaudio.save(f, data, 16000, format="wav") |
| 1075 | + data = f.getvalue() |
| 1076 | + |
| 1077 | + data = {"content": data} |
| 1078 | + return data |
| 1079 | + |
| 1080 | + |
| 1081 | +@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']") |
| 1082 | +@pytest.mark.parametrize("compression", [None, "zstd"]) |
| 1083 | +def test_load_torch_audio(tmpdir, compression): |
| 1084 | + seed_everything(42) |
| 1085 | + |
| 1086 | + import torchaudio |
| 1087 | + |
| 1088 | + optimize( |
| 1089 | + fn=create_synthetic_audio_bytes, |
| 1090 | + inputs=list(range(100)), |
| 1091 | + output_dir=str(tmpdir), |
| 1092 | + num_workers=1, |
| 1093 | + chunk_bytes="64MB", |
| 1094 | + compression=compression, |
| 1095 | + ) |
| 1096 | + |
| 1097 | + dataset = StreamingDataset(input_dir=str(tmpdir)) |
| 1098 | + sample = dataset[0] |
| 1099 | + tensor = torchaudio.load(sample["content"]) |
| 1100 | + assert tensor[0].shape == torch.Size([1, 16000]) |
| 1101 | + assert tensor[1] == 16000 |
| 1102 | + |
| 1103 | + |
| 1104 | +def create_synthetic_audio_file(filepath) -> dict: |
| 1105 | + import torchaudio |
| 1106 | + |
| 1107 | + # load dummy audio as bytes |
| 1108 | + data = torch.randn((1, 16000)) |
| 1109 | + |
| 1110 | + # convert tensor to bytes |
| 1111 | + with open(filepath, "wb") as f: |
| 1112 | + torchaudio.save(f, data, 16000, format="wav") |
| 1113 | + |
| 1114 | + return filepath |
| 1115 | + |
| 1116 | + |
| 1117 | +@pytest.mark.skipif(condition=not _TORCH_AUDIO_AVAILABLE or not _ZSTD_AVAILABLE, reason="Requires: ['torchaudio']") |
| 1118 | +@pytest.mark.parametrize("compression", [None]) |
| 1119 | +def test_load_torch_audio_from_wav_file(tmpdir, compression): |
| 1120 | + seed_everything(42) |
| 1121 | + |
| 1122 | + import torchaudio |
| 1123 | + |
| 1124 | + optimize( |
| 1125 | + fn=create_synthetic_audio_file, |
| 1126 | + inputs=[os.path.join(tmpdir, f"{i}.wav") for i in range(5)], |
| 1127 | + output_dir=str(tmpdir), |
| 1128 | + num_workers=1, |
| 1129 | + chunk_bytes="64MB", |
| 1130 | + compression=compression, |
| 1131 | + ) |
| 1132 | + |
| 1133 | + dataset = StreamingDataset(input_dir=str(tmpdir)) |
| 1134 | + sample = dataset[0] |
| 1135 | + tensor = torchaudio.load(sample) |
| 1136 | + assert tensor[0].shape == torch.Size([1, 16000]) |
| 1137 | + assert tensor[1] == 16000 |
0 commit comments