Skip to content

Commit aa01552

Browse files
committed
Implement bitonic sort network for join compilation
1 parent 5b9c5be commit aa01552

File tree

5 files changed

+111
-5
lines changed

5 files changed

+111
-5
lines changed

src/ast.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ pub enum StmtEnum<T> {
279279
/// Binds an identifier to each value of an array expr, evaluating the body.
280280
ForEachLoop(String, Expr<T>, Vec<Stmt<T>>),
281281
/// Binds an identifier to each joined row of two tables, evaluating the body.
282-
JoinLoop(String, (Expr<T>, Expr<T>), Vec<Stmt<T>>),
282+
JoinLoop(String, T, (Expr<T>, Expr<T>), Vec<Stmt<T>>),
283283
/// An expression (all expressions are statements, but not all statements expressions).
284284
Expr(Expr<T>),
285285
}

src/check.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ impl UntypedStmt {
762762
meta,
763763
))]);
764764
}
765+
let join_ty = tuple_a[0].clone();
765766
let elem_ty = Type::Tuple(vec![elem_ty_a, elem_ty_b]);
766767
let mut body_typed = Vec::with_capacity(body.len());
767768
env.push();
@@ -771,7 +772,7 @@ impl UntypedStmt {
771772
}
772773
env.pop();
773774
Ok(Stmt::new(
774-
StmtEnum::JoinLoop(var.clone(), (a, b), body_typed),
775+
StmtEnum::JoinLoop(var.clone(), join_ty, (a, b), body_typed),
775776
meta,
776777
))
777778
}
@@ -791,7 +792,7 @@ impl UntypedStmt {
791792
))
792793
}
793794
},
794-
ast::StmtEnum::JoinLoop(_, _, _) => {
795+
ast::StmtEnum::JoinLoop(_, _, _, _) => {
795796
unreachable!("Untyped expressions should never be join loops")
796797
}
797798
}

src/circuit.rs

+33
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,39 @@ impl CircuitBuilder {
993993
}
994994
(acc_lt, acc_gt)
995995
}
996+
997+
pub fn push_condswap(
998+
&mut self,
999+
s: GateIndex,
1000+
x: GateIndex,
1001+
y: GateIndex,
1002+
) -> (GateIndex, GateIndex) {
1003+
if x == y {
1004+
return (x, y);
1005+
}
1006+
let x_xor_y = self.push_xor(x, y);
1007+
let swap = self.push_and(x_xor_y, s);
1008+
let x_swapped = self.push_xor(x, swap);
1009+
let y_swapped = self.push_xor(y, swap);
1010+
(x_swapped, y_swapped)
1011+
}
1012+
1013+
pub fn push_sorter(
1014+
&mut self,
1015+
bits: usize,
1016+
x: &[GateIndex],
1017+
y: &[GateIndex],
1018+
) -> (Vec<GateIndex>, Vec<GateIndex>) {
1019+
let (_, gt) = self.push_comparator_circuit(bits, x, false, y, false);
1020+
let mut min = vec![];
1021+
let mut max = vec![];
1022+
for (x, y) in x.iter().zip(y.iter()) {
1023+
let (a, b) = self.push_condswap(gt, *x, *y);
1024+
min.push(a);
1025+
max.push(b);
1026+
}
1027+
(min, max)
1028+
}
9961029
}
9971030

9981031
fn unsigned_as_usize_bits(n: u64) -> [usize; USIZE_BITS] {

src/compile.rs

+63-1
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,69 @@ impl TypedStmt {
437437
env.pop();
438438
vec![]
439439
}
440-
StmtEnum::JoinLoop(_, _, _) => {
440+
StmtEnum::JoinLoop(var, join_ty, (a, b), body) => {
441+
let (elem_bits_a, num_elems_a) = match &a.ty {
442+
Type::Array(elem_ty, size) => (
443+
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()),
444+
*size,
445+
),
446+
Type::ArrayConst(elem_ty, size) => (
447+
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()),
448+
*circuit.const_sizes().get(size).unwrap(),
449+
),
450+
_ => panic!("Found a non-array value in an array access expr"),
451+
};
452+
let (elem_bits_b, num_elems_b) = match &b.ty {
453+
Type::Array(elem_ty, size) => (
454+
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()),
455+
*size,
456+
),
457+
Type::ArrayConst(elem_ty, size) => (
458+
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()),
459+
*circuit.const_sizes().get(size).unwrap(),
460+
),
461+
_ => panic!("Found a non-array value in an array access expr"),
462+
};
463+
let max_elem_bits = max(elem_bits_a, elem_bits_b);
464+
let num_elems = num_elems_a + num_elems_b;
465+
let join_ty_size = join_ty.size_in_bits_for_defs(prg, circuit.const_sizes());
466+
let a = a.compile(prg, env, circuit);
467+
let b = b.compile(prg, env, circuit);
468+
let mut bitonic = vec![];
469+
for i in 0..num_elems_a {
470+
let mut v = a[i * elem_bits_a..(i + 1) * elem_bits_a].to_vec();
471+
for _ in 0..(max_elem_bits - elem_bits_a) {
472+
v.push(0);
473+
}
474+
bitonic.push(v);
475+
}
476+
for i in (0..num_elems_b).rev() {
477+
let mut v = b[i * elem_bits_b..(i + 1) * elem_bits_b].to_vec();
478+
for _ in 0..(max_elem_bits - elem_bits_b) {
479+
v.push(0);
480+
}
481+
bitonic.push(v);
482+
}
483+
let mut offset = num_elems / 2;
484+
while offset > 0 {
485+
let mut result = vec![];
486+
for _ in 0..num_elems {
487+
result.push(vec![]);
488+
}
489+
let rounds = num_elems / 2 / offset;
490+
for r in 0..rounds {
491+
for i in 0..offset {
492+
let i = i + r * offset * 2;
493+
let x = &bitonic[i];
494+
let y = &bitonic[i + offset];
495+
let (min, max) = circuit.push_sorter(join_ty_size, x, y);
496+
result[i] = min;
497+
result[i + offset] = max;
498+
}
499+
}
500+
offset /= 2;
501+
bitonic = result;
502+
}
441503
todo!("compile join loop")
442504
}
443505
}

tests/compile.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -2055,7 +2055,7 @@ pub fn main(array: [u16; MY_CONST]) -> u16 {
20552055
#[test]
20562056
fn compile_join_fn() -> Result<(), Error> {
20572057
let prg = "
2058-
pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 3]) -> u16 {
2058+
pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 4]) -> u16 {
20592059
let mut result = 0u16;
20602060
for row in join(rows1, rows2) {
20612061
let ((_, field1), (_, field2, field3)) = row;
@@ -2091,6 +2091,11 @@ pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 3]) -> u16
20912091
Literal::NumUnsigned(117, UnsignedNumType::U8),
20922092
Literal::NumUnsigned(120, UnsignedNumType::U8),
20932093
]);
2094+
let id_xxx = Literal::Array(vec![
2095+
Literal::NumUnsigned(120, UnsignedNumType::U8),
2096+
Literal::NumUnsigned(120, UnsignedNumType::U8),
2097+
Literal::NumUnsigned(120, UnsignedNumType::U8),
2098+
]);
20942099
eval.set_literal(Literal::Array(vec![
20952100
Literal::Tuple(vec![
20962101
id_aaa.clone(),
@@ -2126,6 +2131,11 @@ pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 3]) -> u16
21262131
Literal::NumUnsigned(8, UnsignedNumType::U16),
21272132
Literal::NumUnsigned(9, UnsignedNumType::U16),
21282133
]),
2134+
Literal::Tuple(vec![
2135+
id_xxx.clone(),
2136+
Literal::NumUnsigned(10, UnsignedNumType::U16),
2137+
Literal::NumUnsigned(11, UnsignedNumType::U16),
2138+
]),
21292139
]))
21302140
.unwrap();
21312141
let output = eval.run().map_err(|e| pretty_print(e, prg))?;

0 commit comments

Comments
 (0)