Skip to content

Commit 26189a7

Browse files
perf: avoid creating a temporary vec in slice ops mul add (#852)
# Rationale for this change The `mul_add_assign` module has a `mul_add_assign` function that is only used by the `MultilinearExtension` `mul_add` function. When calling `mul_add_assign`, a temporary vector is created with the `slice_ops::slice_cast(self) call. We don't need to create this temporary vector. Avoiding this creation will improve performance and memory allocation. Before, `24.35ms` ![image](https://github.com/user-attachments/assets/02765073-b80c-43bb-9936-f06f99b0e6cc) After, `15.72ms`, `1.55x` improvement ![image](https://github.com/user-attachments/assets/8f27c1b0-7db7-46bb-8024-995278051f98) # What changes are included in this PR? - `mul_add_assign` is updated to not require a temporary vector - The `MultilinearExtension` trait no longer creates a temporary vector when calling `mul_add_assign` - Tests are updated to call `mul_add_assign` correctly # Are these changes tested? Yes
2 parents 1c3ab6c + 5835e37 commit 26189a7

File tree

3 files changed

+26
-31
lines changed

3 files changed

+26
-31
lines changed

crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ where
3939
}
4040

4141
fn mul_add(&self, res: &mut [S], multiplier: &S) {
42-
slice_ops::mul_add_assign(res, *multiplier, &slice_ops::slice_cast(self));
42+
slice_ops::mul_add_assign(res, *multiplier, self);
4343
}
4444

4545
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {

crates/proof-of-sql/src/base/slice_ops/mul_add_assign.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@ use core::ops::{AddAssign, Mul};
33
#[cfg(feature = "rayon")]
44
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
55

6-
/// This operation does `result[i] += multiplier * to_mul_add[i]` for `i` in `0..to_mul_add.len()`.
6+
/// This operation does `result[i] += multiplier * to_mul_add[i]` for `i` in `0..to_mul_add.len()`
7+
/// without creating temporary vectors. Works directly with slice references.
78
///
89
/// # Panics
910
/// Panics if the length of `result` is less than the length of `to_mul_add`.
10-
pub fn mul_add_assign<T, S>(result: &mut [T], multiplier: T, to_mul_add: &[S])
11+
pub fn mul_add_assign<'a, T, S>(result: &mut [T], multiplier: T, to_mul_add: &'a [S])
1112
where
1213
T: Send + Sync + Mul<Output = T> + AddAssign + Copy,
13-
S: Into<T> + Sync + Copy,
14+
&'a S: Into<T>,
15+
S: Sync,
1416
{
1517
assert!(result.len() >= to_mul_add.len(), "The length of result must be greater than or equal to the length of the vector of values to be multiplied and added");
1618
if_rayon!(
1719
result.par_iter_mut().with_min_len(super::MIN_RAYON_LEN),
1820
result.iter_mut()
1921
)
2022
.zip(to_mul_add)
21-
.for_each(|(res_i, &data_i)| {
23+
.for_each(|(res_i, data_i)| {
2224
*res_i += multiplier * data_i.into();
2325
});
2426
}

crates/proof-of-sql/src/base/slice_ops/mul_add_assign_test.rs

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,24 @@ use crate::base::scalar::test_scalar::TestScalar;
33

44
#[test]
55
fn test_mul_add_assign() {
6-
let mut a = vec![1, 2, 3, 4];
6+
let mut a = [1, 2, 3, 4].map(TestScalar::from).to_vec();
77
let b = vec![2, 3, 4, 5];
8-
mul_add_assign(&mut a, 10, &b);
9-
let c = vec![1 + 10 * 2, 2 + 10 * 3, 3 + 10 * 4, 4 + 10 * 5];
8+
mul_add_assign(&mut a, TestScalar::from(10i32), &b);
9+
let c = [1 + 10 * 2, 2 + 10 * 3, 3 + 10 * 4, 4 + 10 * 5]
10+
.map(TestScalar::from)
11+
.to_vec();
1012
assert_eq!(a, c);
1113
}
1214

1315
/// test [`mul_add_assign`] with uneven vectors
1416
#[test]
1517
fn test_mul_add_assign_uneven() {
16-
let mut a = vec![1, 2, 3, 4, 5];
17-
let b = vec![2, 3, 4, 5];
18-
mul_add_assign(&mut a, 10, &b);
19-
let c = vec![1 + 10 * 2, 2 + 10 * 3, 3 + 10 * 4, 4 + 10 * 5, 5];
18+
let mut a = [1, 2, 3, 4, 5].map(TestScalar::from).to_vec();
19+
let b = [2, 3, 4, 5].map(TestScalar::from).to_vec();
20+
mul_add_assign(&mut a, TestScalar::from(10u32), &b);
21+
let c = [1 + 10 * 2, 2 + 10 * 3, 3 + 10 * 4, 4 + 10 * 5, 5]
22+
.map(TestScalar::from)
23+
.to_vec();
2024
assert_eq!(a, c);
2125
}
2226

@@ -26,38 +30,27 @@ fn test_mul_add_assign_uneven() {
2630
expected = "The length of result must be greater than or equal to the length of the vector of values to be multiplied and added"
2731
)]
2832
fn test_mul_add_assign_uneven_panic() {
29-
let mut a = vec![1, 2, 3, 4];
33+
let mut a = [1u32, 2u32, 3u32, 4u32].map(TestScalar::from).to_vec();
3034
let b = vec![2, 3, 4, 5, 6];
31-
mul_add_assign(&mut a, 10, &b);
35+
mul_add_assign(&mut a, TestScalar::from(10u32), &b);
3236
}
3337

3438
/// test [`mul_add_assign`] with `TestScalar`
3539
#[test]
3640
fn test_mul_add_assign_testscalar() {
37-
let mut a = vec![TestScalar::from(1u64), TestScalar::from(2u64)];
38-
let b = vec![TestScalar::from(2u64), TestScalar::from(3u64)];
41+
let mut a = [1, 2].map(TestScalar::from).to_vec();
42+
let b = [2, 3].map(TestScalar::from).to_vec();
3943
mul_add_assign(&mut a, TestScalar::from(10u64), &b);
40-
let c = vec![
41-
TestScalar::from(1u64) + TestScalar::from(10u64) * TestScalar::from(2u64),
42-
TestScalar::from(2u64) + TestScalar::from(10u64) * TestScalar::from(3u64),
43-
];
44+
let c = [1 + 10 * 2, 2 + 10 * 3].map(TestScalar::from).to_vec();
4445
assert_eq!(a, c);
4546
}
4647

4748
/// test [`mul_add_assign`] with uneven `TestScalar`
4849
#[test]
4950
fn test_mul_add_assign_testscalar_uneven() {
50-
let mut a = vec![
51-
TestScalar::from(1u64),
52-
TestScalar::from(2u64),
53-
TestScalar::from(3u64),
54-
];
55-
let b = vec![TestScalar::from(2u64), TestScalar::from(3u64)];
51+
let mut a = [1, 2, 3].map(TestScalar::from).to_vec();
52+
let b = [2, 3].map(TestScalar::from).to_vec();
5653
mul_add_assign(&mut a, TestScalar::from(10u64), &b);
57-
let c = vec![
58-
TestScalar::from(1u64) + TestScalar::from(10u64) * TestScalar::from(2u64),
59-
TestScalar::from(2u64) + TestScalar::from(10u64) * TestScalar::from(3u64),
60-
TestScalar::from(3u64),
61-
];
54+
let c = [1 + 10 * 2, 2 + 10 * 3, 3].map(TestScalar::from).to_vec();
6255
assert_eq!(a, c);
6356
}

0 commit comments

Comments
 (0)