@@ -28,9 +28,39 @@ namespace wenet {
28
28
29
29
// This code is based on kaldi Fbank implementation, please see
30
30
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc
31
+
32
+ static const int kS16AbsMax = 1 << 15 ;
33
+
34
+ enum class WindowType {
35
+ kPovey = 0 ,
36
+ kHanning ,
37
+ };
38
+
39
+ enum class MelType {
40
+ kHTK = 0 ,
41
+ kSlaney ,
42
+ };
43
+
44
+ enum class NormalizationType {
45
+ kKaldi = 0 ,
46
+ kWhisper ,
47
+ };
48
+
49
+ enum class LogBase {
50
+ kBaseE = 0 ,
51
+ kBase10 ,
52
+ };
53
+
31
54
class Fbank {
32
55
public:
33
- Fbank (int num_bins, int sample_rate, int frame_length, int frame_shift)
56
+ Fbank (int num_bins, int sample_rate, int frame_length, int frame_shift,
57
+ float low_freq = 20 , bool pre_emphasis = true ,
58
+ bool scale_input_to_unit = false ,
59
+ float log_floor = std::numeric_limits<float >::epsilon(),
60
+ LogBase log_base = LogBase::kBaseE ,
61
+ WindowType window_type = WindowType::kPovey ,
62
+ MelType mel_type = MelType::kHTK ,
63
+ NormalizationType norm_type = NormalizationType::kKaldi )
34
64
: num_bins_(num_bins),
35
65
sample_rate_ (sample_rate),
36
66
frame_length_(frame_length),
@@ -39,40 +69,69 @@ class Fbank {
39
69
remove_dc_offset_(true ),
40
70
generator_(0 ),
41
71
distribution_(0 , 1.0 ),
42
- dither_(0.0 ) {
72
+ dither_(0.0 ),
73
+ low_freq_(low_freq),
74
+ high_freq_(sample_rate / 2 ),
75
+ pre_emphasis_(pre_emphasis),
76
+ scale_input_to_unit_(scale_input_to_unit),
77
+ log_floor_(log_floor),
78
+ log_base_(log_base),
79
+ norm_type_(norm_type) {
43
80
fft_points_ = UpperPowerOfTwo (frame_length_);
44
81
// generate bit reversal table and trigonometric function table
45
82
const int fft_points_4 = fft_points_ / 4 ;
46
83
bitrev_.resize (fft_points_);
47
84
sintbl_.resize (fft_points_ + fft_points_4);
48
85
make_sintbl (fft_points_, sintbl_.data ());
49
86
make_bitrev (fft_points_, bitrev_.data ());
87
+ InitMelFilters (mel_type);
88
+ InitWindow (window_type);
89
+ }
50
90
91
+ void InitMelFilters (MelType mel_type) {
51
92
int num_fft_bins = fft_points_ / 2 ;
52
93
float fft_bin_width = static_cast <float >(sample_rate_) / fft_points_;
53
- int low_freq = 20 , high_freq = sample_rate_ / 2 ;
54
- float mel_low_freq = MelScale (low_freq);
55
- float mel_high_freq = MelScale (high_freq);
56
- float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1 );
94
+ float mel_low_freq = MelScale (low_freq_, mel_type);
95
+ float mel_high_freq = MelScale (high_freq_, mel_type);
96
+ float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins_ + 1 );
57
97
bins_.resize (num_bins_);
58
98
center_freqs_.resize (num_bins_);
59
- for (int bin = 0 ; bin < num_bins; ++bin) {
99
+
100
+ for (int bin = 0 ; bin < num_bins_; ++bin) {
60
101
float left_mel = mel_low_freq + bin * mel_freq_delta,
61
102
center_mel = mel_low_freq + (bin + 1 ) * mel_freq_delta,
62
103
right_mel = mel_low_freq + (bin + 2 ) * mel_freq_delta;
63
- center_freqs_[bin] = InverseMelScale (center_mel);
104
+ center_freqs_[bin] = InverseMelScale (center_mel, mel_type );
64
105
std::vector<float > this_bin (num_fft_bins);
65
106
int first_index = -1 , last_index = -1 ;
66
107
for (int i = 0 ; i < num_fft_bins; ++i) {
67
108
float freq = (fft_bin_width * i); // Center frequency of this fft
68
109
// bin.
69
- float mel = MelScale (freq);
110
+ float mel = MelScale (freq, mel_type );
70
111
if (mel > left_mel && mel < right_mel) {
71
112
float weight;
72
- if (mel <= center_mel)
73
- weight = (mel - left_mel) / (center_mel - left_mel);
74
- else
75
- weight = (right_mel - mel) / (right_mel - center_mel);
113
+ if (mel_type == MelType::kHTK ) {
114
+ if (mel <= center_mel)
115
+ weight = (mel - left_mel) / (center_mel - left_mel);
116
+ else if (mel > center_mel)
117
+ weight = (right_mel - mel) / (right_mel - center_mel);
118
+ } else if (mel_type == MelType::kSlaney ) {
119
+ if (mel <= center_mel) {
120
+ weight = (InverseMelScale (mel, mel_type) -
121
+ InverseMelScale (left_mel, mel_type)) /
122
+ (InverseMelScale (center_mel, mel_type) -
123
+ InverseMelScale (left_mel, mel_type));
124
+ weight *= 2.0 / (InverseMelScale (right_mel, mel_type) -
125
+ InverseMelScale (left_mel, mel_type));
126
+ } else if (mel > center_mel) {
127
+ weight = (InverseMelScale (right_mel, mel_type) -
128
+ InverseMelScale (mel, mel_type)) /
129
+ (InverseMelScale (right_mel, mel_type) -
130
+ InverseMelScale (center_mel, mel_type));
131
+ weight *= 2.0 / (InverseMelScale (right_mel, mel_type) -
132
+ InverseMelScale (left_mel, mel_type));
133
+ }
134
+ }
76
135
this_bin[i] = weight;
77
136
if (first_index == -1 ) first_index = i;
78
137
last_index = i;
@@ -86,12 +145,20 @@ class Fbank {
86
145
bins_[bin].second [i] = this_bin[first_index + i];
87
146
}
88
147
}
148
+ }
89
149
90
- // povey window
91
- povey_window_.resize (frame_length_);
92
- double a = M_2PI / (frame_length - 1 );
93
- for (int i = 0 ; i < frame_length; ++i) {
94
- povey_window_[i] = pow (0.5 - 0.5 * cos (a * i), 0.85 );
150
+ void InitWindow (WindowType window_type) {
151
+ window_.resize (frame_length_);
152
+ if (window_type == WindowType::kPovey ) {
153
+ // povey window
154
+ double a = M_2PI / (frame_length_ - 1 );
155
+ for (int i = 0 ; i < frame_length_; ++i)
156
+ window_[i] = pow (0.5 - 0.5 * cos (a * i), 0.85 );
157
+ } else if (window_type == WindowType::kHanning ) {
158
+ // periodic hanning window
159
+ double a = M_2PI / (frame_length_);
160
+ for (int i = 0 ; i < frame_length_; ++i)
161
+ window_[i] = 0.5 * (1.0 - cos (i * a));
95
162
}
96
163
}
97
164
@@ -105,12 +172,45 @@ class Fbank {
105
172
106
173
int num_bins () const { return num_bins_; }
107
174
108
- static inline float InverseMelScale (float mel_freq) {
109
- return 700 .0f * (expf (mel_freq / 1127 .0f ) - 1 .0f );
175
+ static inline float InverseMelScale (float mel_freq,
176
+ MelType mel_type = MelType::kHTK ) {
177
+ if (mel_type == MelType::kHTK ) {
178
+ return 700 .0f * (expf (mel_freq / 1127 .0f ) - 1 .0f );
179
+ } else if (mel_type == MelType::kSlaney ) {
180
+ float f_min = 0.0 ;
181
+ float f_sp = 200 .0f / 3 .0f ;
182
+ float min_log_hz = 1000.0 ;
183
+ float freq = f_min + f_sp * mel_freq;
184
+ float min_log_mel = (min_log_hz - f_min) / f_sp;
185
+ float logstep = logf (6.4 ) / 27 .0f ;
186
+ if (mel_freq >= min_log_mel) {
187
+ return min_log_hz * expf (logstep * (mel_freq - min_log_mel));
188
+ } else {
189
+ return freq;
190
+ }
191
+ } else {
192
+ throw std::invalid_argument (" Unsupported mel type!" );
193
+ }
110
194
}
111
195
112
- static inline float MelScale (float freq) {
113
- return 1127 .0f * logf (1 .0f + freq / 700 .0f );
196
+ static inline float MelScale (float freq, MelType mel_type = MelType::kHTK ) {
197
+ if (mel_type == MelType::kHTK ) {
198
+ return 1127 .0f * logf (1 .0f + freq / 700 .0f );
199
+ } else if (mel_type == MelType::kSlaney ) {
200
+ float f_min = 0.0 ;
201
+ float f_sp = 200 .0f / 3 .0f ;
202
+ float min_log_hz = 1000.0 ;
203
+ float mel = (freq - f_min) / f_sp;
204
+ float min_log_mel = (min_log_hz - f_min) / f_sp;
205
+ float logstep = logf (6.4 ) / 27 .0f ;
206
+ if (freq >= min_log_hz) {
207
+ return min_log_mel + logf (freq / min_log_hz) / logstep;
208
+ } else {
209
+ return mel;
210
+ }
211
+ } else {
212
+ throw std::invalid_argument (" Unsupported mel type!" );
213
+ }
114
214
}
115
215
116
216
static int UpperPowerOfTwo (int n) {
@@ -125,26 +225,50 @@ class Fbank {
125
225
(*data)[0 ] -= coeff * (*data)[0 ];
126
226
}
127
227
128
- // Apply povey window on data in place
129
- void Povey (std::vector<float >* data) const {
130
- CHECK_GE (data->size (), povey_window_.size ());
131
- for (size_t i = 0 ; i < povey_window_.size (); ++i) {
132
- (*data)[i] *= povey_window_[i];
228
+ // Apply window on data in place
229
+ void ApplyWindow (std::vector<float >* data) const {
230
+ CHECK_GE (data->size (), window_.size ());
231
+ for (size_t i = 0 ; i < window_.size (); ++i) {
232
+ (*data)[i] *= window_[i];
233
+ }
234
+ }
235
+
236
+ void WhisperNorm (std::vector<std::vector<float >>* feat,
237
+ float max_mel_engery) {
238
+ int num_frames = feat->size ();
239
+ for (int i = 0 ; i < num_frames; ++i) {
240
+ for (int j = 0 ; j < num_bins_; ++j) {
241
+ float energy = (*feat)[i][j];
242
+ if (energy < max_mel_engery - 8 ) energy = max_mel_engery - 8 ;
243
+ energy = (energy + 4.0 ) / 4.0 ;
244
+ (*feat)[i][j] = energy;
245
+ }
133
246
}
134
247
}
135
248
136
249
// Compute fbank feat, return num frames
137
250
int Compute (const std::vector<float >& wave,
138
251
std::vector<std::vector<float >>* feat) {
139
252
int num_samples = wave.size ();
253
+
140
254
if (num_samples < frame_length_) return 0 ;
141
255
int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
142
256
feat->resize (num_frames);
143
257
std::vector<float > fft_real (fft_points_, 0 ), fft_img (fft_points_, 0 );
144
258
std::vector<float > power (fft_points_ / 2 );
259
+
260
+ float max_mel_engery = std::numeric_limits<float >::min ();
261
+
145
262
for (int i = 0 ; i < num_frames; ++i) {
146
263
std::vector<float > data (wave.data () + i * frame_shift_,
147
264
wave.data () + i * frame_shift_ + frame_length_);
265
+
266
+ if (scale_input_to_unit_) {
267
+ for (int j = 0 ; j < frame_length_; ++j) {
268
+ data[j] = data[j] / kS16AbsMax ;
269
+ }
270
+ }
271
+
148
272
// optional add noise
149
273
if (dither_ != 0.0 ) {
150
274
for (size_t j = 0 ; j < data.size (); ++j)
@@ -158,8 +282,10 @@ class Fbank {
158
282
for (size_t j = 0 ; j < data.size (); ++j) data[j] -= mean;
159
283
}
160
284
161
- PreEmphasis (0.97 , &data);
162
- Povey (&data);
285
+ if (pre_emphasis_) {
286
+ PreEmphasis (0.97 , &data);
287
+ }
288
+ ApplyWindow (&data);
163
289
// copy data to fft_real
164
290
memset (fft_img.data (), 0 , sizeof (float ) * fft_points_);
165
291
memset (fft_real.data () + frame_length_, 0 ,
@@ -174,6 +300,7 @@ class Fbank {
174
300
175
301
(*feat)[i].resize (num_bins_);
176
302
// cepstral coefficients, triangle filter array
303
+
177
304
for (int j = 0 ; j < num_bins_; ++j) {
178
305
float mel_energy = 0.0 ;
179
306
int s = bins_[j].first ;
@@ -182,14 +309,20 @@ class Fbank {
182
309
}
183
310
// optional use log
184
311
if (use_log_) {
185
- if (mel_energy < std::numeric_limits<float >::epsilon ())
186
- mel_energy = std::numeric_limits<float >::epsilon ();
187
- mel_energy = logf (mel_energy);
188
- }
312
+ if (mel_energy < log_floor_) mel_energy = log_floor_;
189
313
314
+ if (log_base_ == LogBase::kBaseE )
315
+ mel_energy = logf (mel_energy);
316
+ else if (log_base_ == LogBase::kBase10 )
317
+ mel_energy = log10 (mel_energy);
318
+ }
319
+ if (max_mel_engery < mel_energy) max_mel_engery = mel_energy;
190
320
(*feat)[i][j] = mel_energy;
191
321
}
192
322
}
323
+ if (norm_type_ == NormalizationType::kWhisper )
324
+ WhisperNorm (feat, max_mel_engery);
325
+
193
326
return num_frames;
194
327
}
195
328
@@ -200,9 +333,17 @@ class Fbank {
200
333
int fft_points_;
201
334
bool use_log_;
202
335
bool remove_dc_offset_;
336
+ bool pre_emphasis_;
337
+ bool scale_input_to_unit_;
338
+ float low_freq_;
339
+ float log_floor_;
340
+ float high_freq_;
341
+ LogBase log_base_;
342
+ NormalizationType norm_type_;
343
+
203
344
std::vector<float > center_freqs_;
204
345
std::vector<std::pair<int , std::vector<float >>> bins_;
205
- std::vector<float > povey_window_ ;
346
+ std::vector<float > window_ ;
206
347
std::default_random_engine generator_;
207
348
std::normal_distribution<float > distribution_;
208
349
float dither_;
0 commit comments