-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathPitchShifterModule.py
169 lines (144 loc) · 7.59 KB
/
PitchShifterModule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import csv
import scipy.io as sio
from scipy.io import wavfile
from scipy.io.wavfile import write
import scipy.signal as sis
import scipy.fftpack as fftpack
import librosa
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
from Utils import *
from VPM import *
from ANN import *
# Hardcoded data
n_ffts = 2048; overlap = 0.75; n_mels = 256; global_max_mels = 18.21612;
global_max_mfcc = 9.774869; global_max_timbre = 3.7099004; sr = 44100;
maxMelsGeneral = [ 32.550102, 32.25072, 32.205074, 31.656128, 31.714958, 18.21612,
19.796812, 21.5951, 21.045912, 19.812326, 22.308287 ]
# Information used based on architecture.
paths = [ "Archi-0_ModelInfo.txt", "Archi-1_ModelInfo.txt", "Archi-2_ModelInfo.txt", "Archi-3_ModelInfo.txt" ]
pathsGeneral = [ "Archi-0_ModelInfoGeneral.txt", "Archi-1_ModelInfoGeneral.txt", "Archi-2_ModelInfoGeneral.txt", "Archi-3_ModelInfoGeneral.txt" ]
n_inputs = [ 256, 296, 266, 512 ]
nhids = [ 256, 296, 266, 512 ]
# The Timbre Encoder model in case it's being used
TE_path = os.path.join('model_data', 'TimbreEncoder', 'TimbreVAE-IdealFixedNorm-40-36-10.pt')
TE = TimbreVAE(n_mfcc=40, n_hid=36, n_timb=10)
TE.load_state_dict(torch.load(TE_path, map_location=torch.device('cpu')))
TE.eval()
def loadModels(archi = 3, modelsDirectory="model_data\\VPMModel", general=False):
"""Load the models from the given model info directory.
Returns a tuple - the models, as well as the max mels used to normalize.
Args:
archi (int): The mmodel architecture to make use of.
modelsDirectory (str): The directory the models and model information are stored.
This should follow the structure:
- model_data
- VPMModel
- ModelInfo.txt
- <.pt models>
Returns:
models (list): A list of pytorch models.
"""
models = []
specifiedPath = os.path.join(modelsDirectory, pathsGeneral[archi] if general else paths[archi])
with open(specifiedPath, 'r', newline='') as f:
reader = csv.reader(f, delimiter=',')
for idx, row in enumerate(reader):
if idx == 0: continue
_, shiftAmt, maxMel, path = row
# Hardcoded settings here
model = BaseFFN(n_input=n_inputs[archi], n_hid=nhids[archi], n_output=256)
model.load_state_dict(torch.load(os.path.join(modelsDirectory, path), map_location=torch.device('cpu')))
model.eval()
models.append(model)
return models
def pitchShift(mode, models, wavform, pitchShiftAmt, getPureShift=False):
"""Perform the pitch shift, using the models provided.
Args:
mode (int): Which architecture to use
models (list): A list of pytorch models.
np.array (len(models)): the normalizing factor used for each model.
wavform (np.array): An array of floats representing the waveform to shift
pitchShiftAmt (int): The pitch shift amount. Should be [-5, 5]
Returns:
gl_abs_waveform (np.array): The shifted wav file put through the decoder
shifted_wav (np.array): The purely shifted wav file
"""
pitchShiftAmt = int(pitchShiftAmt)
assert (-5 <= pitchShiftAmt and pitchShiftAmt <= 5)
shiftIdx = pitchShiftAmt + 5
spectrogram = stft(wavform, win_length=n_ffts, overlap=overlap, plot=False)
# Only mode 1 and 2 use MFCC
if mode == 1 or mode == 2:
wav_mels_prenorm, wav_mfcc_prenorm = ffts_to_mel(spectrogram, n_mels=n_mels, n_mfcc=40, skip_mfcc=False)
else:
wav_mels_prenorm = ffts_to_mel(spectrogram, n_mels=n_mels, skip_mfcc=True)
wav_mels_prenorm = wav_mels_prenorm.T;
wav_mels_logged = np.log(wav_mels_prenorm)
wav_mels = wav_mels_logged / global_max_mels
# Mode 1 and 2 use MFCC
if mode == 1 or mode == 2:
wav_mfcc_prenorm = wav_mfcc_prenorm.T
wav_mfcc_logged = np.log(np.abs(wav_mfcc_prenorm))
wav_mfcc = wav_mfcc_logged / global_max_mfcc
# Mode 2 uses timbre encoder
if mode == 2:
wav_mfcc_tensor = torch.Tensor(wav_mfcc)
wav_timbre_prenorm = TE.get_z(wav_mfcc_tensor).detach().numpy()
wav_timbre = wav_timbre_prenorm / global_max_timbre
shifted_wav, shifted_spectrogram = resample_pitch_shift(np.array([wavform]), pitchShiftAmt, overlap, n_ffts)
shifted_wav = shifted_wav[0]; shifted_spectrogram = shifted_spectrogram[0]
shifted_wav_mels_prenorm = ffts_to_mel(shifted_spectrogram, n_mels=n_mels, skip_mfcc=True)
shifted_wav_mels_prenorm = shifted_wav_mels_prenorm.T
shifted_wav_mels_logged = np.log(shifted_wav_mels_prenorm)
shifted_wav_mels = shifted_wav_mels_logged / maxMelsGeneral[shiftIdx]
# Truncate excess windows if off by a few
# print("Shapes: Orig: {}, Shifted: {}".format(wav_mels.shape, shifted_wav_mels.shape))
if (wav_mels.shape[0] < shifted_wav_mels.shape[0]):
shifted_wav_mels = shifted_wav_mels[0:wav_mels.shape[0]]
if (wav_mels.shape[0] > shifted_wav_mels.shape[0]):
wav_mels = wav_mels[0:shifted_wav_mels.shape[0]]
if mode == 1 or mode == 2:
wav_mfcc = wav_mfcc[0:shifted_wav_mels.shape[0]]
if mode == 2:
wav_timbre = wav_timbre[0:shifted_wav_mels.shape[0]]
model = models[0] if len(models) == 1 else models[shiftIdx]
if (mode == 0):
wav_input = torch.tensor(wav_mels).float()
if (mode == 1):
wav_input = torch.tensor(np.concatenate((wav_mfcc, shifted_wav_mels), axis=1)).float()
if (mode == 2):
wav_input = torch.tensor(np.concatenate((wav_timbre, shifted_wav_mels), axis=1)).float()
if (mode == 3):
wav_input = torch.tensor(np.concatenate((wav_mels, shifted_wav_mels), axis=1)).float()
# print("Input to model shape: {}".format(wav_input.shape))
wav_predicted = model(wav_input).detach().numpy();
wav_denorm_mels = np.e ** (wav_predicted * global_max_mels)
# wav_denorm_mels = np.e ** (wav_predicted * maxMels[modelIdx]) # This denormalization doesn't work well
gl_abs_waveform = librosa.feature.inverse.mel_to_audio(
wav_denorm_mels.T, sr=sr, n_fft=n_ffts,
hop_length=compute_hop_length(n_ffts, overlap),
win_length=n_ffts)
if getPureShift:
return np.array(gl_abs_waveform, dtype=np.float32), np.array(shifted_wav, dtype=np.float32)
else:
return np.array(gl_abs_waveform, dtype=np.float32)
# The wrapper functions
modelss = [ loadModels(archi, os.path.join("model_data", "VPMModel")) for archi in range(4) ]
# These use the NN's trained on a single pitch shift
PitchShift0 = lambda wavform, pitchShiftAmt: pitchShift(0, modelss[0], wavform, pitchShiftAmt)
PitchShift1 = lambda wavform, pitchShiftAmt: pitchShift(1, modelss[1], wavform, pitchShiftAmt)
PitchShift2 = lambda wavform, pitchShiftAmt: pitchShift(2, modelss[2], wavform, pitchShiftAmt)
PitchShift3 = lambda wavform, pitchShiftAmt: pitchShift(3, modelss[3], wavform, pitchShiftAmt)
# General models
modelsGeneral = [ loadModels(archi, os.path.join("model_data", "VPMModel"), True) for archi in range(4) ]
# These use the NN's trained on a 10 different pitch shifts (and identity)
PitchShiftGen0 = lambda wavform, pitchShiftAmt: pitchShift(0, modelsGeneral[0], wavform, pitchShiftAmt)
PitchShiftGen1 = lambda wavform, pitchShiftAmt: pitchShift(1, modelsGeneral[1], wavform, pitchShiftAmt)
PitchShiftGen2 = lambda wavform, pitchShiftAmt: pitchShift(2, modelsGeneral[2], wavform, pitchShiftAmt)
PitchShiftGen3 = lambda wavform, pitchShiftAmt: pitchShift(3, modelsGeneral[3], wavform, pitchShiftAmt)
# Recommended is Architecture 3
PitchShift = lambda wavform, pitchShiftAmt: pitchShift(3, modelsGeneral[3], wavform, pitchShiftAmt)