forked from descriptinc/melgan-neurips
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
70 lines (56 loc) · 2.08 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.utils.data
import torch.nn.functional as F
from librosa.core import load
from librosa.util import normalize
from pathlib import Path
import numpy as np
import random
def files_to_list(filename):
"""
Takes a text file of filenames and makes a list of filenames
"""
with open(filename, encoding="utf-8") as f:
files = f.readlines()
files = [f.rstrip() for f in files]
return files
class AudioDataset(torch.utils.data.Dataset):
"""
This is the main class that calculates the spectrogram and returns the
spectrogram, audio pair.
"""
def __init__(self, training_files, segment_length, sampling_rate, augment=True):
self.sampling_rate = sampling_rate
self.segment_length = segment_length
self.audio_files = files_to_list(training_files)
self.audio_files = [Path(training_files).parent / x for x in self.audio_files]
random.seed(1234)
random.shuffle(self.audio_files)
self.augment = augment
def __getitem__(self, index):
# Read audio
filename = self.audio_files[index]
audio, sampling_rate = self.load_wav_to_torch(filename)
# Take segment
if audio.size(0) >= self.segment_length:
max_audio_start = audio.size(0) - self.segment_length
audio_start = random.randint(0, max_audio_start)
audio = audio[audio_start : audio_start + self.segment_length]
else:
audio = F.pad(
audio, (0, self.segment_length - audio.size(0)), "constant"
).data
# audio = audio / 32768.0
return audio.unsqueeze(0)
def __len__(self):
return len(self.audio_files)
def load_wav_to_torch(self, full_path):
"""
Loads wavdata into torch array
"""
data, sampling_rate = load(full_path, sr=self.sampling_rate)
data = 0.95 * normalize(data)
if self.augment:
amplitude = np.random.uniform(low=0.3, high=1.0)
data = data * amplitude
return torch.from_numpy(data).float(), sampling_rate