Skip to content

Commit

Permalink
feat: support multithread spectrogram and small perf tweaks (#1674)
Browse files Browse the repository at this point in the history
* feat: support multithread spectrogram and small perf tweaks

* feat: clippy improvement for loop variable

* fix: add back speed up scale down logic

* fix: readd mirroring logic

* feat: prefer scoped thread and simplify/improve logic/traits
  • Loading branch information
drbh authored Feb 8, 2024
1 parent 020a979 commit 9cadd4e
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 28 deletions.
162 changes: 142 additions & 20 deletions candle-transformers/src/models/whisper/audio.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
// Audio processing code, adapted from whisper.cpp
// https://github.com/ggerganov/whisper.cpp

pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
use candle::utils::get_num_threads;
use std::sync::Arc;
use std::thread;

pub trait Float:
num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync
{
}

impl Float for f32 {}
impl Float for f64 {}
Expand Down Expand Up @@ -102,22 +109,26 @@ fn log_mel_spectrogram_w<T: Float>(
let half = T::from(0.5).unwrap();
let mut fft_in = vec![zero; fft_size];
let mut mel = vec![zero; n_len * n_mel];
let n_samples = samples.len();
let end = std::cmp::min(n_samples / fft_step + 1, n_len);

for i in (ith..n_len).step_by(n_threads) {
for i in (ith..end).step_by(n_threads) {
let offset = i * fft_step;

// apply Hanning window
for j in 0..fft_size {
fft_in[j] = if offset + j < samples.len() {
hann[j] * samples[offset + j]
} else {
zero
}
for j in 0..std::cmp::min(fft_size, n_samples - offset) {
fft_in[j] = hann[j] * samples[offset + j];
}

// FFT -> mag^2
// fill the rest with zeros
if n_samples - offset < fft_size {
fft_in[n_samples - offset..].fill(zero);
}

// FFT
let mut fft_out: Vec<T> = fft(&fft_in);

// Calculate modulus^2 of complex numbers
for j in 0..fft_size {
fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
}
Expand All @@ -136,16 +147,27 @@ fn log_mel_spectrogram_w<T: Float>(
// mel spectrogram
for j in 0..n_mel {
let mut sum = zero;
for k in 0..n_fft {
let mut k = 0;
// Unroll loop
while k < n_fft.saturating_sub(3) {
sum += fft_out[k] * filters[j * n_fft + k]
+ fft_out[k + 1] * filters[j * n_fft + k + 1]
+ fft_out[k + 2] * filters[j * n_fft + k + 2]
+ fft_out[k + 3] * filters[j * n_fft + k + 3];
k += 4;
}
// Handle remainder
while k < n_fft {
sum += fft_out[k] * filters[j * n_fft + k];
k += 1;
}
mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
}
}
mel
}

fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
fn log_mel_spectrogram_<T: Float>(
samples: &[T],
filters: &[T],
fft_size: usize,
Expand Down Expand Up @@ -180,10 +202,55 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
samples_padded
};

// Use a single thread for now.
let mut mel = log_mel_spectrogram_w(
0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
);
// ensure that the number of threads is even and less than 12
let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);

let hann = Arc::new(hann);
let samples = Arc::new(samples);
let filters = Arc::new(filters);

// use scope to allow for non static references to be passed to the threads
// and directly collect the results into a single vector
let all_outputs = thread::scope(|s| {
(0..n_threads)
// create threads and return their handles
.map(|thread_id| {
let hann = Arc::clone(&hann);
let samples = Arc::clone(&samples);
let filters = Arc::clone(&filters);
// spawn new thread and start work
s.spawn(move || {
log_mel_spectrogram_w(
thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len,
n_mel, n_threads,
)
})
})
.collect::<Vec<_>>()
.into_iter()
// wait for each thread to finish and collect their results
.map(|handle| handle.join().expect("Thread failed"))
.collect::<Vec<_>>()
});

let l = all_outputs[0].len();
let mut mel = vec![zero; l];

// iterate over mel spectrogram segments, dividing work by threads.
for segment_start in (0..l).step_by(n_threads) {
// go through each thread's output.
for thread_output in all_outputs.iter() {
// add each thread's piece to our mel spectrogram.
for offset in 0..n_threads {
let mel_index = segment_start + offset; // find location in mel.
if mel_index < mel.len() {
// Make sure we don't go out of bounds.
mel[mel_index] += thread_output[mel_index];
}
}
}
}

let mmax = mel
.iter()
.max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
Expand All @@ -197,11 +264,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
mel
}

pub fn pcm_to_mel<T: Float + std::fmt::Display>(
cfg: &super::Config,
samples: &[T],
filters: &[T],
) -> Vec<T> {
pub fn pcm_to_mel<T: Float>(cfg: &super::Config, samples: &[T], filters: &[T]) -> Vec<T> {
log_mel_spectrogram_(
samples,
filters,
Expand All @@ -211,3 +274,62 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
false,
)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_fft() {
let input = vec![0.0, 1.0, 0.0, 0.0];
let output = fft(&input);
assert_eq!(
output,
vec![
1.0,
0.0,
6.123233995736766e-17,
-1.0,
-1.0,
0.0,
-6.123233995736766e-17,
1.0
]
);
}

#[test]
fn test_dft() {
let input = vec![0.0, 1.0, 0.0, 0.0];
let output = dft(&input);
assert_eq!(
output,
vec![
1.0,
0.0,
6.123233995736766e-17,
-1.0,
-1.0,
-1.2246467991473532e-16,
-1.8369701987210297e-16,
1.0
]
);
}

#[test]
fn test_log_mel_spectrogram() {
let samples = vec![0.0; 1000];
let filters = vec![0.0; 1000];
let output = log_mel_spectrogram_(&samples, &filters, 100, 10, 10, false);
assert_eq!(output.len(), 30_000);
}

#[test]
fn test_tiny_log_mel_spectrogram() {
let samples = vec![0.0; 100];
let filters = vec![0.0; 100];
let output = log_mel_spectrogram_(&samples, &filters, 20, 2, 2, false);
assert_eq!(output.len(), 6_000);
}
}
8 changes: 4 additions & 4 deletions candle-transformers/src/models/whisper/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,14 @@ impl ResidualAttentionBlock {
}
}

fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
let max_timescale = 10000f32;
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
let inv_timescales: Vec<_> = (0..channels / 2)
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
.collect();
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
let arange = Tensor::arange(0, length as u32, device)?
.to_dtype(candle::DType::F32)?
.unsqueeze(1)?;
let sh = (length, channels / 2);
Expand Down Expand Up @@ -246,7 +246,7 @@ impl AudioEncoder {
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
Expand Down
8 changes: 4 additions & 4 deletions candle-transformers/src/models/whisper/quantized_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,14 @@ impl ResidualAttentionBlock {
}
}

fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
let max_timescale = 10000f32;
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
let inv_timescales: Vec<_> = (0..channels / 2)
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
.collect();
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
let arange = Tensor::arange(0, length as u32, device)?
.to_dtype(candle::DType::F32)?
.unsqueeze(1)?;
let sh = (length, channels / 2);
Expand Down Expand Up @@ -242,7 +242,7 @@ impl AudioEncoder {
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
Expand Down

0 comments on commit 9cadd4e

Please sign in to comment.