Skip to content

Commit c4c1fb8

Browse files
authored
chore: add batch_addition code (#272)
* add batch_add code * cleanup: remove duplicated functions * rename diff_stride to binary_stride * remove old batch addition with complex stride pattern * make batch_add public * add initial doc comments * nit: typo
1 parent b47387b commit c4c1fb8

File tree

2 files changed

+244
-0
lines changed

2 files changed

+244
-0
lines changed
+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
use crate::batch_inversion::{batch_inverse, batch_inverse_scratch_pad};
2+
use blstrs::{Fp, G1Affine, G1Projective};
3+
use ff::Field;
4+
use group::Group;
5+
6+
/// Adds two elliptic curve points using the point addition/doubling formula.
7+
///
8+
/// Note: The inversion is precomputed and passed as a parameter.
9+
///
10+
/// This function handles both addition of distinct points and point doubling.
11+
#[inline(always)]
12+
fn point_add_double(p1: G1Affine, p2: G1Affine, inv: &blstrs::Fp) -> G1Affine {
13+
use ff::Field;
14+
15+
let lambda = if p1 == p2 {
16+
p1.x().square().mul3() * inv
17+
} else {
18+
(p2.y() - p1.y()) * inv
19+
};
20+
21+
let x = lambda.square() - p1.x() - p2.x();
22+
let y = lambda * (p1.x() - x) - p1.y();
23+
24+
G1Affine::from_raw_unchecked(x, y, false)
25+
}
26+
27+
/// Chooses between point addition and point doubling based on the input points.
28+
///
29+
/// Note: This does not handle the case where p1 == -p2.
30+
///
31+
/// This case is unlikely for our usecase, and is not trivial
32+
/// to handle.
33+
#[inline(always)]
34+
fn choose_add_or_double(p1: G1Affine, p2: G1Affine) -> Fp {
35+
if p1 == p2 {
36+
p2.y().double()
37+
} else {
38+
p2.x() - p1.x()
39+
}
40+
}
41+
42+
/// This is the threshold to which batching the inversions in affine
43+
/// formula costs more than doing mixed addition.
44+
const BATCH_INVERSE_THRESHOLD: usize = 16;
45+
46+
/// Performs batch addition of elliptic curve points using a binary tree approach with striding.
47+
///
48+
/// This function efficiently adds a large number of points by organizing them into a binary tree
49+
/// and performing batch inversions for the addition formula.
50+
///
51+
// TODO(benedikt): top down balanced tree idea - benedikt
52+
// TODO: search tree for sorted array
53+
pub fn batch_addition_binary_tree_stride(mut points: Vec<G1Affine>) -> G1Projective {
54+
if points.is_empty() {
55+
return G1Projective::identity();
56+
}
57+
58+
let mut new_differences = Vec::with_capacity(points.len());
59+
60+
let mut sum = G1Projective::identity();
61+
62+
while points.len() > BATCH_INVERSE_THRESHOLD {
63+
if points.len() % 2 != 0 {
64+
sum += points
65+
.pop()
66+
.expect("infallible; since points has an odd length");
67+
}
68+
new_differences.clear();
69+
70+
for i in (0..=points.len() - 2).step_by(2) {
71+
let p1 = points[i];
72+
let p2 = points[i + 1];
73+
new_differences.push(choose_add_or_double(p1, p2));
74+
}
75+
76+
batch_inverse(&mut new_differences);
77+
78+
for (i, inv) in (0..=points.len() - 2).step_by(2).zip(&new_differences) {
79+
let p1 = points[i];
80+
let p2 = points[i + 1];
81+
points[i / 2] = point_add_double(p1, p2, inv);
82+
}
83+
84+
// The latter half of the vector is now unused,
85+
// all results are stored in the former half.
86+
points.truncate(new_differences.len())
87+
}
88+
89+
for point in points {
90+
sum += point
91+
}
92+
93+
sum
94+
}
95+
96+
/// Performs multi-batch addition of multiple sets of elliptic curve points.
97+
///
98+
/// This function efficiently adds multiple sets of points amortizing the cost of the
99+
/// inversion over all of the sets, using the same binary tree approach with striding
100+
/// as the single-batch version.
101+
pub fn multi_batch_addition_binary_tree_stride(
102+
mut multi_points: Vec<Vec<G1Affine>>,
103+
) -> Vec<G1Projective> {
104+
let total_num_points: usize = multi_points.iter().map(|p| p.len()).sum();
105+
let mut scratchpad = Vec::with_capacity(total_num_points);
106+
107+
// Find the largest buckets, this will be the bottleneck for the number of iterations
108+
let mut max_bucket_length = 0;
109+
for points in multi_points.iter() {
110+
max_bucket_length = std::cmp::max(max_bucket_length, points.len());
111+
}
112+
113+
// Compute the total number of "unit of work"
114+
// In the single batch addition case this is analogous to
115+
// the batch inversion threshold
116+
#[inline(always)]
117+
fn compute_threshold(points: &[Vec<G1Affine>]) -> usize {
118+
points
119+
.iter()
120+
.map(|p| {
121+
if p.len() % 2 == 0 {
122+
p.len() / 2
123+
} else {
124+
(p.len() - 1) / 2
125+
}
126+
})
127+
.sum()
128+
}
129+
130+
let mut new_differences = Vec::with_capacity(max_bucket_length);
131+
let mut total_amount_of_work = compute_threshold(&multi_points);
132+
133+
let mut sums = vec![G1Projective::identity(); multi_points.len()];
134+
135+
// TODO: total_amount_of_work does not seem to be changing performance that much
136+
while total_amount_of_work > BATCH_INVERSE_THRESHOLD {
137+
// For each point, we check if they are odd and pop off
138+
// one of the points
139+
for (points, sum) in multi_points.iter_mut().zip(sums.iter_mut()) {
140+
// Make the number of points even
141+
if points.len() % 2 != 0 {
142+
*sum += points.pop().unwrap();
143+
}
144+
}
145+
146+
new_differences.clear();
147+
148+
// For each pair of points over all
149+
// vectors, we collect them and put them in the
150+
// inverse array
151+
for points in multi_points.iter() {
152+
if points.len() < 2 {
153+
continue;
154+
}
155+
for i in (0..=points.len() - 2).step_by(2) {
156+
new_differences.push(choose_add_or_double(points[i], points[i + 1]));
157+
}
158+
}
159+
160+
batch_inverse_scratch_pad(&mut new_differences, &mut scratchpad);
161+
162+
let mut new_differences_offset = 0;
163+
164+
for points in multi_points.iter_mut() {
165+
if points.len() < 2 {
166+
continue;
167+
}
168+
for (i, inv) in (0..=points.len() - 2)
169+
.step_by(2)
170+
.zip(&new_differences[new_differences_offset..])
171+
{
172+
let p1 = points[i];
173+
let p2 = points[i + 1];
174+
points[i / 2] = point_add_double(p1, p2, inv);
175+
}
176+
177+
let num_points = points.len() / 2;
178+
// The latter half of the vector is now unused,
179+
// all results are stored in the former half.
180+
points.truncate(num_points);
181+
new_differences_offset += num_points
182+
}
183+
184+
total_amount_of_work = compute_threshold(&multi_points);
185+
}
186+
187+
for (sum, points) in sums.iter_mut().zip(multi_points) {
188+
for point in points {
189+
*sum += point
190+
}
191+
}
192+
193+
sums
194+
}
195+
196+
#[cfg(test)]
197+
mod tests {
198+
199+
use crate::batch_add::{
200+
batch_addition_binary_tree_stride, multi_batch_addition_binary_tree_stride,
201+
};
202+
203+
use blstrs::{G1Affine, G1Projective};
204+
use group::Group;
205+
206+
#[test]
207+
fn test_batch_addition() {
208+
let num_points = 101;
209+
let points: Vec<G1Affine> = (0..num_points)
210+
.map(|_| G1Projective::random(&mut rand::thread_rng()).into())
211+
.collect();
212+
213+
let expected_result: G1Affine = points
214+
.iter()
215+
.fold(G1Projective::identity(), |acc, p| acc + p)
216+
.into();
217+
218+
let got_result = batch_addition_binary_tree_stride(points.clone());
219+
assert_eq!(expected_result, got_result.into());
220+
}
221+
222+
#[test]
223+
fn test_multi_batch_addition_binary_stride() {
224+
let num_points = 99;
225+
let num_sets = 5;
226+
let random_sets_of_points: Vec<Vec<G1Affine>> = (0..num_sets)
227+
.map(|_| {
228+
(0..num_points)
229+
.map(|_| G1Projective::random(&mut rand::thread_rng()).into())
230+
.collect()
231+
})
232+
.collect();
233+
let random_sets_of_points_clone = random_sets_of_points.clone();
234+
235+
let expected_results: Vec<G1Projective> = random_sets_of_points
236+
.into_iter()
237+
.map(|points| batch_addition_binary_tree_stride(points).into())
238+
.collect();
239+
240+
let got_results = multi_batch_addition_binary_tree_stride(random_sets_of_points_clone);
241+
assert_eq!(got_results, expected_results);
242+
}
243+
}

cryptography/bls12_381/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod batch_add;
12
pub mod batch_inversion;
23
pub mod fixed_base_msm;
34
pub mod lincomb;

0 commit comments

Comments
 (0)