Skip to content

Commit 362c843

Browse files
committed
Implement min/max for consts
1 parent 27cf557 commit 362c843

File tree

6 files changed

+204
-87
lines changed

6 files changed

+204
-87
lines changed

src/ast.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType};
1212
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1313
pub struct Program<T> {
1414
/// The external constants that the top level const definitions depend upon.
15-
pub const_deps: HashMap<String, HashMap<String, (String, T)>>,
15+
pub const_deps: HashMap<String, HashMap<String, T>>,
1616
/// Top level const definitions.
1717
pub const_defs: HashMap<String, ConstDef>,
1818
/// Top level struct type definitions.

src/check.rs

+14-45
Original file line numberDiff line numberDiff line change
@@ -354,18 +354,18 @@ impl UntypedProgram {
354354
struct_names,
355355
enum_names,
356356
};
357-
let mut const_deps: HashMap<String, HashMap<String, (String, Type)>> = HashMap::new();
357+
let mut const_deps: HashMap<String, HashMap<String, Type>> = HashMap::new();
358358
let mut const_types = HashMap::with_capacity(self.const_defs.len());
359359
let mut const_defs = HashMap::with_capacity(self.const_defs.len());
360360
{
361361
for (const_name, const_def) in self.const_defs.iter() {
362362
fn check_const_expr(
363-
const_name: &String,
363+
value: &ConstExpr,
364364
const_def: &ConstDef,
365365
errors: &mut Vec<Option<TypeError>>,
366-
const_deps: &mut HashMap<String, HashMap<String, (String, Type)>>,
366+
const_deps: &mut HashMap<String, HashMap<String, Type>>,
367367
) {
368-
match &const_def.value {
368+
match value {
369369
ConstExpr::True | ConstExpr::False => {
370370
if const_def.ty != Type::Bool {
371371
let e = TypeErrorEnum::UnexpectedType {
@@ -396,52 +396,21 @@ impl UntypedProgram {
396396
}
397397
}
398398
ConstExpr::ExternalValue { party, identifier } => {
399-
const_deps.entry(party.clone()).or_default().insert(
400-
identifier.clone(),
401-
(const_name.clone(), const_def.ty.clone()),
402-
);
399+
const_deps
400+
.entry(party.clone())
401+
.or_default()
402+
.insert(identifier.clone(), const_def.ty.clone());
403+
}
404+
ConstExpr::Max(args) | ConstExpr::Min(args) => {
405+
for arg in args {
406+
check_const_expr(arg, const_def, errors, const_deps);
407+
}
403408
}
404-
ConstExpr::Max(args) => for arg in args {},
405-
ConstExpr::Min(_) => todo!(),
406409
}
407410
}
408-
check_const_expr(&const_name, &const_def, &mut errors, &mut const_deps);
411+
check_const_expr(&const_def.value, &const_def, &mut errors, &mut const_deps);
409412
const_defs.insert(const_name.clone(), const_def.clone());
410413
const_types.insert(const_name.clone(), const_def.ty.clone());
411-
// TODO: remove the following:
412-
/*match &const_def.value {
413-
ConstExpr::Literal(expr) => {
414-
match expr.type_check(&top_level_defs, &mut env, &mut fns, &defs) {
415-
Ok(mut expr) => {
416-
if let Err(errs) = check_type(&mut expr, &const_def.ty) {
417-
errors.extend(errs);
418-
}
419-
const_defs.insert(
420-
const_name.clone(),
421-
ConstDef {
422-
ty: const_def.ty.clone(),
423-
value: ConstExpr::Literal(expr),
424-
meta: const_def.meta,
425-
},
426-
);
427-
}
428-
Err(errs) => {
429-
for e in errs.into_iter().flatten() {
430-
if let TypeError(TypeErrorEnum::UnknownEnum(p, n), _) = e {
431-
// ignore this error, constant can be provided later during compilation
432-
const_deps
433-
.entry(p)
434-
.or_default()
435-
.insert(n, (const_name.clone(), const_def.ty.clone()));
436-
} else {
437-
errors.push(Some(e));
438-
}
439-
}
440-
}
441-
}
442-
const_types.insert(const_name.clone(), const_def.ty.clone());
443-
}
444-
}*/
445414
}
446415
}
447416
let mut struct_defs = HashMap::with_capacity(self.struct_defs.len());

src/compile.rs

+147-37
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
//! Compiles a [`crate::ast::Program`] to a [`crate::circuit::Circuit`].
22
3-
use std::{cmp::max, collections::HashMap};
3+
use std::{
4+
cmp::{max, min},
5+
collections::HashMap,
6+
};
47

58
use crate::{
69
ast::{
@@ -56,14 +59,26 @@ impl TypedProgram {
5659
) -> Result<(Circuit, &TypedFnDef), CompilerError> {
5760
let mut env = Env::new();
5861
let mut const_sizes = HashMap::new();
62+
let mut consts_unsigned = HashMap::new();
63+
let mut consts_signed = HashMap::new();
5964
for (party, deps) in self.const_deps.iter() {
60-
for (c, (identifier, ty)) in deps {
65+
for (c, ty) in deps {
6166
let Some(party_deps) = consts.get(party) else {
6267
todo!("missing party dep for {party}");
6368
};
6469
let Some(literal) = party_deps.get(c) else {
6570
todo!("missing value {party}::{c}");
6671
};
72+
let identifier = format!("{party}::{c}");
73+
match literal {
74+
Literal::NumUnsigned(n, _) => {
75+
consts_unsigned.insert(identifier.clone(), *n);
76+
}
77+
Literal::NumSigned(n, _) => {
78+
consts_signed.insert(identifier.clone(), *n);
79+
}
80+
_ => {}
81+
}
6782
if literal.is_of_type(self, ty) {
6883
let bits = literal
6984
.as_bits(self, &const_sizes)
@@ -72,7 +87,7 @@ impl TypedProgram {
7287
.collect();
7388
env.let_in_current_scope(identifier.clone(), bits);
7489
if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal {
75-
const_sizes.insert(identifier.clone(), *size as usize);
90+
const_sizes.insert(identifier, *size as usize);
7691
}
7792
} else {
7893
return Err(CompilerError::InvalidLiteralType(
@@ -84,58 +99,153 @@ impl TypedProgram {
8499
}
85100
let mut input_gates = vec![];
86101
let mut wire = 2;
87-
if let Some(fn_def) = self.fn_defs.get(fn_name) {
88-
for param in fn_def.params.iter() {
89-
let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes);
90-
let mut wires = Vec::with_capacity(type_size);
91-
for _ in 0..type_size {
92-
wires.push(wire);
93-
wire += 1;
102+
let Some(fn_def) = self.fn_defs.get(fn_name) else {
103+
return Err(CompilerError::FnNotFound(fn_name.to_string()));
104+
};
105+
for param in fn_def.params.iter() {
106+
let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes);
107+
let mut wires = Vec::with_capacity(type_size);
108+
for _ in 0..type_size {
109+
wires.push(wire);
110+
wire += 1;
111+
}
112+
input_gates.push(type_size);
113+
env.let_in_current_scope(param.name.clone(), wires);
114+
}
115+
fn resolve_const_expr_unsigned(
116+
expr: &ConstExpr,
117+
consts_unsigned: &HashMap<String, u64>,
118+
) -> u64 {
119+
match expr {
120+
ConstExpr::NumUnsigned(n, _) => *n,
121+
ConstExpr::ExternalValue { party, identifier } => *consts_unsigned
122+
.get(&format!("{party}::{identifier}"))
123+
.unwrap(),
124+
ConstExpr::Max(args) => {
125+
let mut result = 0;
126+
for arg in args {
127+
result = max(result, resolve_const_expr_unsigned(arg, consts_unsigned));
128+
}
129+
result
130+
}
131+
ConstExpr::Min(args) => {
132+
let mut result = u64::MAX;
133+
for arg in args {
134+
result = min(result, resolve_const_expr_unsigned(arg, consts_unsigned));
135+
}
136+
result
137+
}
138+
expr => panic!("Not an unsigned const expr: {expr:?}"),
139+
}
140+
}
141+
fn resolve_const_expr_signed(
142+
expr: &ConstExpr,
143+
consts_signed: &HashMap<String, i64>,
144+
) -> i64 {
145+
match expr {
146+
ConstExpr::NumSigned(n, _) => *n,
147+
ConstExpr::ExternalValue { party, identifier } => *consts_signed
148+
.get(&format!("{party}::{identifier}"))
149+
.unwrap(),
150+
ConstExpr::Max(args) => {
151+
let mut result = 0;
152+
for arg in args {
153+
result = max(result, resolve_const_expr_signed(arg, consts_signed));
154+
}
155+
result
156+
}
157+
ConstExpr::Min(args) => {
158+
let mut result = i64::MAX;
159+
for arg in args {
160+
result = min(result, resolve_const_expr_signed(arg, consts_signed));
161+
}
162+
result
163+
}
164+
expr => panic!("Not an unsigned const expr: {expr:?}"),
165+
}
166+
}
167+
for (const_name, const_def) in self.const_defs.iter() {
168+
if let Type::Unsigned(UnsignedNumType::Usize) = const_def.ty {
169+
if let ConstExpr::ExternalValue { party, identifier } = &const_def.value {
170+
let identifier = format!("{party}::{identifier}");
171+
const_sizes.insert(const_name.clone(), *const_sizes.get(&identifier).unwrap());
94172
}
95-
input_gates.push(type_size);
96-
env.let_in_current_scope(param.name.clone(), wires);
173+
let n = resolve_const_expr_unsigned(&const_def.value, &consts_unsigned);
174+
const_sizes.insert(const_name.clone(), n as usize);
97175
}
98-
let mut circuit = CircuitBuilder::new(input_gates, const_sizes);
99-
for (identifier, const_def) in self.const_defs.iter() {
100-
match &const_def.value {
101-
ConstExpr::True => env.let_in_current_scope(identifier.clone(), vec![1]),
102-
ConstExpr::False => env.let_in_current_scope(identifier.clone(), vec![0]),
103-
ConstExpr::NumUnsigned(n, ty) => {
104-
let ty = Type::Unsigned(*ty);
176+
}
177+
let mut circuit = CircuitBuilder::new(input_gates, const_sizes);
178+
for (const_name, const_def) in self.const_defs.iter() {
179+
match &const_def.value {
180+
ConstExpr::True => env.let_in_current_scope(const_name.clone(), vec![1]),
181+
ConstExpr::False => env.let_in_current_scope(const_name.clone(), vec![0]),
182+
ConstExpr::NumUnsigned(n, ty) => {
183+
let ty = Type::Unsigned(*ty);
184+
let mut bits =
185+
Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes()));
186+
unsigned_to_bits(
187+
*n,
188+
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
189+
&mut bits,
190+
);
191+
let bits = bits.into_iter().map(|b| b as usize).collect();
192+
env.let_in_current_scope(const_name.clone(), bits);
193+
}
194+
ConstExpr::NumSigned(n, ty) => {
195+
let ty = Type::Signed(*ty);
196+
let mut bits =
197+
Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes()));
198+
signed_to_bits(
199+
*n,
200+
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
201+
&mut bits,
202+
);
203+
let bits = bits.into_iter().map(|b| b as usize).collect();
204+
env.let_in_current_scope(const_name.clone(), bits);
205+
}
206+
ConstExpr::ExternalValue { party, identifier } => {
207+
let bits = env.get(&format!("{party}::{identifier}")).unwrap();
208+
env.let_in_current_scope(const_name.clone(), bits);
209+
}
210+
expr @ (ConstExpr::Max(_) | ConstExpr::Min(_)) => {
211+
if let Type::Unsigned(_) = const_def.ty {
212+
let result = resolve_const_expr_unsigned(expr, &consts_unsigned);
105213
let mut bits = Vec::with_capacity(
106-
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
214+
const_def
215+
.ty
216+
.size_in_bits_for_defs(self, circuit.const_sizes()),
107217
);
108218
unsigned_to_bits(
109-
*n,
110-
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
219+
result,
220+
const_def
221+
.ty
222+
.size_in_bits_for_defs(self, circuit.const_sizes()),
111223
&mut bits,
112224
);
113225
let bits = bits.into_iter().map(|b| b as usize).collect();
114-
env.let_in_current_scope(identifier.clone(), bits);
115-
}
116-
ConstExpr::NumSigned(n, ty) => {
117-
let ty = Type::Signed(*ty);
226+
env.let_in_current_scope(const_name.clone(), bits);
227+
} else {
228+
let result = resolve_const_expr_signed(expr, &consts_signed);
118229
let mut bits = Vec::with_capacity(
119-
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
230+
const_def
231+
.ty
232+
.size_in_bits_for_defs(self, circuit.const_sizes()),
120233
);
121234
signed_to_bits(
122-
*n,
123-
ty.size_in_bits_for_defs(self, circuit.const_sizes()),
235+
result,
236+
const_def
237+
.ty
238+
.size_in_bits_for_defs(self, circuit.const_sizes()),
124239
&mut bits,
125240
);
126241
let bits = bits.into_iter().map(|b| b as usize).collect();
127-
env.let_in_current_scope(identifier.clone(), bits);
242+
env.let_in_current_scope(const_name.clone(), bits);
128243
}
129-
ConstExpr::ExternalValue { .. } => {}
130-
ConstExpr::Max(_) => todo!("compile max"),
131-
ConstExpr::Min(_) => todo!("compile min"),
132244
}
133245
}
134-
let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit);
135-
Ok((circuit.build(output_gates), fn_def))
136-
} else {
137-
Err(CompilerError::FnNotFound(fn_name.to_string()))
138246
}
247+
let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit);
248+
Ok((circuit.build(output_gates), fn_def))
139249
}
140250
}
141251

src/eval.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,16 @@ impl<'a> Evaluator<'a> {
4444
) -> Self {
4545
let mut const_sizes = HashMap::new();
4646
for (party, deps) in program.const_deps.iter() {
47-
for (c, (identifier, _)) in deps {
47+
for (c, _) in deps {
4848
let Some(party_deps) = consts.get(party) else {
4949
todo!("missing party dep for {party}");
5050
};
5151
let Some(literal) = party_deps.get(c) else {
5252
todo!("missing value {party}::{c}");
5353
};
5454
if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal {
55-
const_sizes.insert(identifier.clone(), *size as usize);
55+
let identifier = format!("{party}::{c}");
56+
const_sizes.insert(identifier, *size as usize);
5657
}
5758
}
5859
}

src/lib.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,16 @@ pub fn compile_with_constants(
122122
let main = main.clone();
123123
let mut const_sizes = HashMap::new();
124124
for (party, deps) in program.const_deps.iter() {
125-
for (c, (identifier, _)) in deps {
125+
for (c, _) in deps {
126126
let Some(party_deps) = consts.get(party) else {
127127
todo!("missing party dep for {party}");
128128
};
129129
let Some(literal) = party_deps.get(c) else {
130130
todo!("missing value {party}::{c}");
131131
};
132132
if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal {
133-
const_sizes.insert(identifier.clone(), *size as usize);
133+
let identifier = format!("{party}::{c}");
134+
const_sizes.insert(identifier, *size as usize);
134135
}
135136
}
136137
}

0 commit comments

Comments
 (0)