Skip to content

Commit baaa27a

Browse files
zhr1201hzhou245
and
hzhou245
authored
[runtime] Whisper inference support in cpp runtime (#2320)
* Working version * Refactor how params are passed in * Passing in through CLI * Fix minnor issues * Remove extra dump for debug * Remove unused arrays for debugging * Remove unused header * Change naming style of Enum * Move init_mel_filters to it's own method * Fix one a bug introduced in the last two commit * Use const instead of macro --------- Co-authored-by: hzhou245 <hzhou245@bloomberg.net>
1 parent 6e68e01 commit baaa27a

File tree

4 files changed

+232
-38
lines changed

4 files changed

+232
-38
lines changed

runtime/core/decoder/params.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ DEFINE_int32(core_number, 1, "Core number of process");
6262
// FeaturePipelineConfig flags
6363
DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
6464
DEFINE_int32(sample_rate, 16000, "sample rate for audio");
65+
DEFINE_string(feat_type, "kaldi", "Type of feature extraction: kaldi, whisper");
6566

6667
// TLG fst
6768
DEFINE_string(fst_path, "", "TLG fst path");
@@ -115,9 +116,20 @@ DEFINE_int32(language_type, 0,
115116
DEFINE_bool(lowercase, true, "lowercase final result if needed");
116117

117118
namespace wenet {
119+
120+
FeatureType StringToFeatureType(const std::string& feat_type_str) {
121+
if (feat_type_str == "kaldi")
122+
return FeatureType::kKaldi;
123+
else if (feat_type_str == "whisper")
124+
return FeatureType::kWhisper;
125+
else
126+
throw std::invalid_argument("Unsupported feat type!");
127+
}
128+
118129
std::shared_ptr<FeaturePipelineConfig> InitFeaturePipelineConfigFromFlags() {
130+
FeatureType feat_type = StringToFeatureType(FLAGS_feat_type);
119131
auto feature_config = std::make_shared<FeaturePipelineConfig>(
120-
FLAGS_num_bins, FLAGS_sample_rate);
132+
FLAGS_num_bins, FLAGS_sample_rate, feat_type);
121133
return feature_config;
122134
}
123135

runtime/core/frontend/fbank.h

Lines changed: 175 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,39 @@ namespace wenet {
2828

2929
// This code is based on kaldi Fbank implementation, please see
3030
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc
31+
32+
static const int kS16AbsMax = 1 << 15;
33+
34+
enum class WindowType {
35+
kPovey = 0,
36+
kHanning,
37+
};
38+
39+
enum class MelType {
40+
kHTK = 0,
41+
kSlaney,
42+
};
43+
44+
enum class NormalizationType {
45+
kKaldi = 0,
46+
kWhisper,
47+
};
48+
49+
enum class LogBase {
50+
kBaseE = 0,
51+
kBase10,
52+
};
53+
3154
class Fbank {
3255
public:
33-
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift)
56+
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift,
57+
float low_freq = 20, bool pre_emphasis = true,
58+
bool scale_input_to_unit = false,
59+
float log_floor = std::numeric_limits<float>::epsilon(),
60+
LogBase log_base = LogBase::kBaseE,
61+
WindowType window_type = WindowType::kPovey,
62+
MelType mel_type = MelType::kHTK,
63+
NormalizationType norm_type = NormalizationType::kKaldi)
3464
: num_bins_(num_bins),
3565
sample_rate_(sample_rate),
3666
frame_length_(frame_length),
@@ -39,40 +69,69 @@ class Fbank {
3969
remove_dc_offset_(true),
4070
generator_(0),
4171
distribution_(0, 1.0),
42-
dither_(0.0) {
72+
dither_(0.0),
73+
low_freq_(low_freq),
74+
high_freq_(sample_rate / 2),
75+
pre_emphasis_(pre_emphasis),
76+
scale_input_to_unit_(scale_input_to_unit),
77+
log_floor_(log_floor),
78+
log_base_(log_base),
79+
norm_type_(norm_type) {
4380
fft_points_ = UpperPowerOfTwo(frame_length_);
4481
// generate bit reversal table and trigonometric function table
4582
const int fft_points_4 = fft_points_ / 4;
4683
bitrev_.resize(fft_points_);
4784
sintbl_.resize(fft_points_ + fft_points_4);
4885
make_sintbl(fft_points_, sintbl_.data());
4986
make_bitrev(fft_points_, bitrev_.data());
87+
InitMelFilters(mel_type);
88+
InitWindow(window_type);
89+
}
5090

91+
void InitMelFilters(MelType mel_type) {
5192
int num_fft_bins = fft_points_ / 2;
5293
float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;
53-
int low_freq = 20, high_freq = sample_rate_ / 2;
54-
float mel_low_freq = MelScale(low_freq);
55-
float mel_high_freq = MelScale(high_freq);
56-
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
94+
float mel_low_freq = MelScale(low_freq_, mel_type);
95+
float mel_high_freq = MelScale(high_freq_, mel_type);
96+
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins_ + 1);
5797
bins_.resize(num_bins_);
5898
center_freqs_.resize(num_bins_);
59-
for (int bin = 0; bin < num_bins; ++bin) {
99+
100+
for (int bin = 0; bin < num_bins_; ++bin) {
60101
float left_mel = mel_low_freq + bin * mel_freq_delta,
61102
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
62103
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
63-
center_freqs_[bin] = InverseMelScale(center_mel);
104+
center_freqs_[bin] = InverseMelScale(center_mel, mel_type);
64105
std::vector<float> this_bin(num_fft_bins);
65106
int first_index = -1, last_index = -1;
66107
for (int i = 0; i < num_fft_bins; ++i) {
67108
float freq = (fft_bin_width * i); // Center frequency of this fft
68109
// bin.
69-
float mel = MelScale(freq);
110+
float mel = MelScale(freq, mel_type);
70111
if (mel > left_mel && mel < right_mel) {
71112
float weight;
72-
if (mel <= center_mel)
73-
weight = (mel - left_mel) / (center_mel - left_mel);
74-
else
75-
weight = (right_mel - mel) / (right_mel - center_mel);
113+
if (mel_type == MelType::kHTK) {
114+
if (mel <= center_mel)
115+
weight = (mel - left_mel) / (center_mel - left_mel);
116+
else if (mel > center_mel)
117+
weight = (right_mel - mel) / (right_mel - center_mel);
118+
} else if (mel_type == MelType::kSlaney) {
119+
if (mel <= center_mel) {
120+
weight = (InverseMelScale(mel, mel_type) -
121+
InverseMelScale(left_mel, mel_type)) /
122+
(InverseMelScale(center_mel, mel_type) -
123+
InverseMelScale(left_mel, mel_type));
124+
weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
125+
InverseMelScale(left_mel, mel_type));
126+
} else if (mel > center_mel) {
127+
weight = (InverseMelScale(right_mel, mel_type) -
128+
InverseMelScale(mel, mel_type)) /
129+
(InverseMelScale(right_mel, mel_type) -
130+
InverseMelScale(center_mel, mel_type));
131+
weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
132+
InverseMelScale(left_mel, mel_type));
133+
}
134+
}
76135
this_bin[i] = weight;
77136
if (first_index == -1) first_index = i;
78137
last_index = i;
@@ -86,12 +145,20 @@ class Fbank {
86145
bins_[bin].second[i] = this_bin[first_index + i];
87146
}
88147
}
148+
}
89149

90-
// povey window
91-
povey_window_.resize(frame_length_);
92-
double a = M_2PI / (frame_length - 1);
93-
for (int i = 0; i < frame_length; ++i) {
94-
povey_window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
150+
void InitWindow(WindowType window_type) {
151+
window_.resize(frame_length_);
152+
if (window_type == WindowType::kPovey) {
153+
// povey window
154+
double a = M_2PI / (frame_length_ - 1);
155+
for (int i = 0; i < frame_length_; ++i)
156+
window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
157+
} else if (window_type == WindowType::kHanning) {
158+
// periodic hanning window
159+
double a = M_2PI / (frame_length_);
160+
for (int i = 0; i < frame_length_; ++i)
161+
window_[i] = 0.5 * (1.0 - cos(i * a));
95162
}
96163
}
97164

@@ -105,12 +172,45 @@ class Fbank {
105172

106173
int num_bins() const { return num_bins_; }
107174

108-
static inline float InverseMelScale(float mel_freq) {
109-
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
175+
static inline float InverseMelScale(float mel_freq,
176+
MelType mel_type = MelType::kHTK) {
177+
if (mel_type == MelType::kHTK) {
178+
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
179+
} else if (mel_type == MelType::kSlaney) {
180+
float f_min = 0.0;
181+
float f_sp = 200.0f / 3.0f;
182+
float min_log_hz = 1000.0;
183+
float freq = f_min + f_sp * mel_freq;
184+
float min_log_mel = (min_log_hz - f_min) / f_sp;
185+
float logstep = logf(6.4) / 27.0f;
186+
if (mel_freq >= min_log_mel) {
187+
return min_log_hz * expf(logstep * (mel_freq - min_log_mel));
188+
} else {
189+
return freq;
190+
}
191+
} else {
192+
throw std::invalid_argument("Unsupported mel type!");
193+
}
110194
}
111195

112-
static inline float MelScale(float freq) {
113-
return 1127.0f * logf(1.0f + freq / 700.0f);
196+
static inline float MelScale(float freq, MelType mel_type = MelType::kHTK) {
197+
if (mel_type == MelType::kHTK) {
198+
return 1127.0f * logf(1.0f + freq / 700.0f);
199+
} else if (mel_type == MelType::kSlaney) {
200+
float f_min = 0.0;
201+
float f_sp = 200.0f / 3.0f;
202+
float min_log_hz = 1000.0;
203+
float mel = (freq - f_min) / f_sp;
204+
float min_log_mel = (min_log_hz - f_min) / f_sp;
205+
float logstep = logf(6.4) / 27.0f;
206+
if (freq >= min_log_hz) {
207+
return min_log_mel + logf(freq / min_log_hz) / logstep;
208+
} else {
209+
return mel;
210+
}
211+
} else {
212+
throw std::invalid_argument("Unsupported mel type!");
213+
}
114214
}
115215

116216
static int UpperPowerOfTwo(int n) {
@@ -125,26 +225,50 @@ class Fbank {
125225
(*data)[0] -= coeff * (*data)[0];
126226
}
127227

128-
// Apply povey window on data in place
129-
void Povey(std::vector<float>* data) const {
130-
CHECK_GE(data->size(), povey_window_.size());
131-
for (size_t i = 0; i < povey_window_.size(); ++i) {
132-
(*data)[i] *= povey_window_[i];
228+
// Apply window on data in place
229+
void ApplyWindow(std::vector<float>* data) const {
230+
CHECK_GE(data->size(), window_.size());
231+
for (size_t i = 0; i < window_.size(); ++i) {
232+
(*data)[i] *= window_[i];
233+
}
234+
}
235+
236+
void WhisperNorm(std::vector<std::vector<float>>* feat,
237+
float max_mel_engery) {
238+
int num_frames = feat->size();
239+
for (int i = 0; i < num_frames; ++i) {
240+
for (int j = 0; j < num_bins_; ++j) {
241+
float energy = (*feat)[i][j];
242+
if (energy < max_mel_engery - 8) energy = max_mel_engery - 8;
243+
energy = (energy + 4.0) / 4.0;
244+
(*feat)[i][j] = energy;
245+
}
133246
}
134247
}
135248

136249
// Compute fbank feat, return num frames
137250
int Compute(const std::vector<float>& wave,
138251
std::vector<std::vector<float>>* feat) {
139252
int num_samples = wave.size();
253+
140254
if (num_samples < frame_length_) return 0;
141255
int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
142256
feat->resize(num_frames);
143257
std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);
144258
std::vector<float> power(fft_points_ / 2);
259+
260+
float max_mel_engery = std::numeric_limits<float>::min();
261+
145262
for (int i = 0; i < num_frames; ++i) {
146263
std::vector<float> data(wave.data() + i * frame_shift_,
147264
wave.data() + i * frame_shift_ + frame_length_);
265+
266+
if (scale_input_to_unit_) {
267+
for (int j = 0; j < frame_length_; ++j) {
268+
data[j] = data[j] / kS16AbsMax;
269+
}
270+
}
271+
148272
// optional add noise
149273
if (dither_ != 0.0) {
150274
for (size_t j = 0; j < data.size(); ++j)
@@ -158,8 +282,10 @@ class Fbank {
158282
for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;
159283
}
160284

161-
PreEmphasis(0.97, &data);
162-
Povey(&data);
285+
if (pre_emphasis_) {
286+
PreEmphasis(0.97, &data);
287+
}
288+
ApplyWindow(&data);
163289
// copy data to fft_real
164290
memset(fft_img.data(), 0, sizeof(float) * fft_points_);
165291
memset(fft_real.data() + frame_length_, 0,
@@ -174,6 +300,7 @@ class Fbank {
174300

175301
(*feat)[i].resize(num_bins_);
176302
// cepstral coefficients, triangle filter array
303+
177304
for (int j = 0; j < num_bins_; ++j) {
178305
float mel_energy = 0.0;
179306
int s = bins_[j].first;
@@ -182,14 +309,20 @@ class Fbank {
182309
}
183310
// optional use log
184311
if (use_log_) {
185-
if (mel_energy < std::numeric_limits<float>::epsilon())
186-
mel_energy = std::numeric_limits<float>::epsilon();
187-
mel_energy = logf(mel_energy);
188-
}
312+
if (mel_energy < log_floor_) mel_energy = log_floor_;
189313

314+
if (log_base_ == LogBase::kBaseE)
315+
mel_energy = logf(mel_energy);
316+
else if (log_base_ == LogBase::kBase10)
317+
mel_energy = log10(mel_energy);
318+
}
319+
if (max_mel_engery < mel_energy) max_mel_engery = mel_energy;
190320
(*feat)[i][j] = mel_energy;
191321
}
192322
}
323+
if (norm_type_ == NormalizationType::kWhisper)
324+
WhisperNorm(feat, max_mel_engery);
325+
193326
return num_frames;
194327
}
195328

@@ -200,9 +333,17 @@ class Fbank {
200333
int fft_points_;
201334
bool use_log_;
202335
bool remove_dc_offset_;
336+
bool pre_emphasis_;
337+
bool scale_input_to_unit_;
338+
float low_freq_;
339+
float log_floor_;
340+
float high_freq_;
341+
LogBase log_base_;
342+
NormalizationType norm_type_;
343+
203344
std::vector<float> center_freqs_;
204345
std::vector<std::pair<int, std::vector<float>>> bins_;
205-
std::vector<float> povey_window_;
346+
std::vector<float> window_;
206347
std::default_random_engine generator_;
207348
std::normal_distribution<float> distribution_;
208349
float dither_;

runtime/core/frontend/feature_pipeline.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config)
2323
: config_(config),
2424
feature_dim_(config.num_bins),
2525
fbank_(config.num_bins, config.sample_rate, config.frame_length,
26-
config.frame_shift),
26+
config.frame_shift, config.low_freq, config.pre_emphasis,
27+
config.scale_input_to_unit, config.log_floor, config.log_base,
28+
config.window_type, config.mel_type, config.norm_type),
2729
num_frames_(0),
2830
input_finished_(false) {}
2931

0 commit comments

Comments
 (0)