Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Test Spartan2 optimizations #370

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 88 additions & 15 deletions src/r1cs/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,22 +215,95 @@ impl<F: PrimeField> SparseMatrix<F> {
/// This does not check that the shape of the matrix/vector are compatible.
pub fn multiply_witness_into_unchecked(&self, W: &[F], u: &F, X: &[F], sink: &mut Vec<F>) {
let num_vars = W.len();
self
.indptr
.par_windows(2)
.map(|ptrs| {
self
.get_row_unchecked(ptrs.try_into().unwrap())
.fold(F::ZERO, |acc, (val, col_idx)| {
let val = match col_idx.cmp(&num_vars) {
Ordering::Less => *val * W[*col_idx],
Ordering::Equal => *val * *u,
Ordering::Greater => *val * X[*col_idx - num_vars - 1],
};
acc + val
})
sink.clear();
// Parallelism strategy below splits the (row, column, value) tuples into num_threads different chunks.
// It is assumed that the tuples are (row, column) ordered. We exploit this fact to create a mutex over
// each of the chunks and assume that only one of the threads will be writing to each chunk at a time
// due to ordering.

let num_threads = rayon::current_num_threads() * 4; // Enable work stealing incase of thread work imbalance
let row_chunk_size = (self.num_rows() as f64 / num_threads as f64).ceil() as usize;

let mut chunks: Vec<std::sync::Mutex<Vec<F>>> = Vec::with_capacity(num_threads);
let mut remaining_rows = self.num_rows();
(0..num_threads).for_each(|i| {
if i == num_threads - 1 {
// the final chunk may be smaller
let inner = std::sync::Mutex::new(vec![F::ZERO; remaining_rows]);
chunks.push(inner);
} else {
let inner = std::sync::Mutex::new(vec![F::ZERO; row_chunk_size]);
chunks.push(inner);
remaining_rows -= row_chunk_size;
}
});

let get_chunk = |row_index: usize| -> usize { row_index / row_chunk_size };
let get_index = |row_index: usize| -> usize { row_index % row_chunk_size };
let get_value = |col_idx: usize| -> F {
match col_idx.cmp(&num_vars) {
Ordering::Less => W[col_idx],
Ordering::Equal => *u,
Ordering::Greater => X[col_idx - num_vars - 1],
}
};
let mul_row = |row: &RowData| -> F {
self.get_row(row).fold(F::ZERO, |acc, (&val, col_idx)| {
let col_val = get_value(*col_idx);
let val = if val == F::ONE {
col_val
} else if col_val == F::ONE {
val
} else {
val * col_val
};
acc + val
})
.collect_into_vec(sink);
};

let span = tracing::span!(tracing::Level::TRACE, "all_chunks_multiplication");
let _enter = span.enter();
self
.par_iter_rows()
.enumerate()
.chunks(row_chunk_size)
.for_each(|sub_matrix| {
let (init_row_idx, init_row) = sub_matrix[0];
let mut prev_chunk_index = get_chunk(init_row_idx);
let curr_row_index = get_index(init_row_idx);
let mut curr_chunk = chunks[prev_chunk_index].lock().unwrap();

curr_chunk[curr_row_index] = mul_row(init_row);

let span_a = tracing::span!(tracing::Level::TRACE, "chunk_multiplication");
let _enter_b = span_a.enter();
for (row_idx, row) in sub_matrix {
let curr_chunk_index = get_chunk(row_idx);
if prev_chunk_index != curr_chunk_index {
// only unlock the mutex again if required
drop(curr_chunk); // drop the curr_chunk before waiting for the next to avoid race condition
let new_chunk = chunks[curr_chunk_index].lock().unwrap();
curr_chunk = new_chunk;

prev_chunk_index = curr_chunk_index;
}

let curr_row_index = get_index(row_idx);
curr_chunk[curr_row_index] = mul_row(row);
}
});
drop(_enter);
drop(span);

let span_a = tracing::span!(tracing::Level::TRACE, "chunks_mutex_unwrap");
let _enter_a = span_a.enter();
// TODO(sragss): Mutex unwrap takes about 30% of the time due to clone, likely unnecessary.
for chunk in chunks {
let inner_vec = chunk.into_inner().unwrap();
sink.extend(inner_vec.iter());
}
drop(_enter_a);
drop(span_a);
}

/// number of non-zero entries
Expand Down
Loading