Skip to content

Commit 0c9bfd2

Browse files
committed
Implement const defs in parser / type-checker / compiler
1 parent 090f1d8 commit 0c9bfd2

File tree

9 files changed

+253
-23
lines changed

9 files changed

+253
-23
lines changed

src/ast.rs

+24
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType};
1111
#[derive(Debug, Clone, PartialEq, Eq)]
1212
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1313
pub struct Program<T> {
14+
/// The external constants that the top level const definitions depend upon.
15+
pub const_deps: HashMap<String, HashMap<String, (String, T)>>,
16+
/// Top level const definitions.
17+
pub const_defs: HashMap<String, ConstDef<T>>,
1418
/// Top level struct type definitions.
1519
pub struct_defs: HashMap<String, StructDef>,
1620
/// Top level enum type definitions.
@@ -19,6 +23,26 @@ pub struct Program<T> {
1923
pub fn_defs: HashMap<String, FnDef<T>>,
2024
}
2125

26+
/// A top level const definition.
27+
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
28+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29+
pub struct ConstDef<T> {
30+
/// The type of the constant.
31+
pub ty: Type,
32+
/// The value of the constant.
33+
pub value: ConstExpr<T>,
34+
/// The location in the source code.
35+
pub meta: MetaInfo,
36+
}
37+
38+
/// A constant value, either a literal or a namespaced symbol.
39+
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
40+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
41+
pub enum ConstExpr<T> {
42+
/// A constant value specified as a literal value.
43+
Literal(Expr<T>),
44+
}
45+
2246
/// A top level struct type definition.
2347
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
2448
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]

src/check.rs

+86-16
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ use std::collections::{HashMap, HashSet};
55

66
use crate::{
77
ast::{
8-
self, EnumDef, Expr, ExprEnum, Mutability, Op, ParamDef, Pattern, PatternEnum, Stmt,
9-
StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr, VariantExprEnum,
8+
self, ConstDef, ConstExpr, EnumDef, Expr, ExprEnum, Mutability, Op, ParamDef, Pattern,
9+
PatternEnum, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr,
10+
VariantExprEnum,
1011
},
1112
env::Env,
1213
token::{MetaInfo, SignedNumType, UnsignedNumType},
@@ -50,7 +51,7 @@ pub enum TypeErrorEnum {
5051
/// The struct constructor is missing the specified field.
5152
MissingStructField(String, String),
5253
/// No enum declaration with the specified name exists.
53-
UnknownEnum(String),
54+
UnknownEnum(String, String),
5455
/// The enum exists, but no variant declaration with the specified name was found.
5556
UnknownEnumVariant(String, String),
5657
/// No variable or function with the specified name exists in the current scope.
@@ -137,8 +138,8 @@ impl std::fmt::Display for TypeErrorEnum {
137138
TypeErrorEnum::MissingStructField(struct_name, struct_field) => f.write_fmt(
138139
format_args!("Field '{struct_field}' is missing for struct '{struct_name}'"),
139140
),
140-
TypeErrorEnum::UnknownEnum(enum_name) => {
141-
f.write_fmt(format_args!("Unknown enum '{enum_name}'"))
141+
TypeErrorEnum::UnknownEnum(enum_name, enum_variant) => {
142+
f.write_fmt(format_args!("Unknown enum '{enum_name}::{enum_variant}'"))
142143
}
143144
TypeErrorEnum::UnknownEnumVariant(enum_name, enum_variant) => f.write_fmt(
144145
format_args!("Unknown enum variant '{enum_name}::{enum_variant}'"),
@@ -277,21 +278,27 @@ impl Type {
277278

278279
/// Static top-level definitions of enums and functions.
279280
pub struct Defs<'a> {
281+
consts: HashMap<&'a str, &'a Type>,
280282
structs: HashMap<&'a str, (Vec<&'a str>, HashMap<&'a str, Type>)>,
281283
enums: HashMap<&'a str, HashMap<&'a str, Option<Vec<Type>>>>,
282284
fns: HashMap<&'a str, &'a UntypedFnDef>,
283285
}
284286

285287
impl<'a> Defs<'a> {
286288
pub(crate) fn new(
289+
const_defs: &'a HashMap<String, Type>,
287290
struct_defs: &'a HashMap<String, StructDef>,
288291
enum_defs: &'a HashMap<String, EnumDef>,
289292
) -> Self {
290293
let mut defs = Self {
294+
consts: HashMap::new(),
291295
structs: HashMap::new(),
292296
enums: HashMap::new(),
293297
fns: HashMap::new(),
294298
};
299+
for (const_name, ty) in const_defs.iter() {
300+
defs.consts.insert(const_name, &ty);
301+
}
295302
for (struct_name, struct_def) in struct_defs.iter() {
296303
let mut field_names = Vec::with_capacity(struct_def.fields.len());
297304
let mut field_types = HashMap::with_capacity(struct_def.fields.len());
@@ -338,6 +345,63 @@ impl UntypedProgram {
338345
struct_names,
339346
enum_names,
340347
};
348+
let mut const_deps: HashMap<String, HashMap<String, (String, Type)>> = HashMap::new();
349+
let mut const_types = HashMap::with_capacity(self.const_defs.len());
350+
let mut const_defs = HashMap::with_capacity(self.const_defs.len());
351+
{
352+
let top_level_defs = TopLevelTypes {
353+
struct_names: HashSet::new(),
354+
enum_names: HashSet::new(),
355+
};
356+
let mut env = Env::new();
357+
let mut fns = TypedFns {
358+
currently_being_checked: HashSet::new(),
359+
typed: HashMap::new(),
360+
};
361+
let defs = Defs {
362+
consts: HashMap::new(),
363+
structs: HashMap::new(),
364+
enums: HashMap::new(),
365+
fns: HashMap::new(),
366+
};
367+
for (const_name, const_def) in self.const_defs.iter() {
368+
match &const_def.value {
369+
ConstExpr::Literal(expr) => {
370+
match expr.type_check(&top_level_defs, &mut env, &mut fns, &defs) {
371+
Ok(mut expr) => {
372+
if let Err(errs) = check_type(&mut expr, &const_def.ty) {
373+
errors.extend(errs);
374+
}
375+
const_defs.insert(
376+
const_name.clone(),
377+
ConstDef {
378+
ty: const_def.ty.clone(),
379+
value: ConstExpr::Literal(expr),
380+
meta: const_def.meta,
381+
},
382+
);
383+
}
384+
Err(errs) => {
385+
for e in errs {
386+
if let Some(e) = e {
387+
if let TypeError(TypeErrorEnum::UnknownEnum(p, n), _) = e {
388+
// ignore this error, constant can be provided later during compilation
389+
const_deps.entry(p).or_default().insert(
390+
n,
391+
(const_name.clone(), const_def.ty.clone()),
392+
);
393+
} else {
394+
errors.push(Some(e));
395+
}
396+
}
397+
}
398+
}
399+
}
400+
const_types.insert(const_name.clone(), const_def.ty.clone());
401+
}
402+
}
403+
}
404+
}
341405
let mut struct_defs = HashMap::with_capacity(self.struct_defs.len());
342406
for (struct_name, struct_def) in self.struct_defs.iter() {
343407
let meta = struct_def.meta;
@@ -372,7 +436,7 @@ impl UntypedProgram {
372436
enum_defs.insert(enum_name.clone(), EnumDef { variants, meta });
373437
}
374438

375-
let mut untyped_defs = Defs::new(&struct_defs, &enum_defs);
439+
let mut untyped_defs = Defs::new(&const_types, &struct_defs, &enum_defs);
376440
let mut checked_fn_defs = TypedFns::new();
377441
for (fn_name, fn_def) in self.fn_defs.iter() {
378442
untyped_defs.fns.insert(fn_name, fn_def);
@@ -406,6 +470,8 @@ impl UntypedProgram {
406470
}
407471
if errors.is_empty() {
408472
Ok(TypedProgram {
473+
const_deps,
474+
const_defs,
409475
struct_defs,
410476
enum_defs,
411477
fn_defs,
@@ -677,12 +743,15 @@ impl UntypedExpr {
677743
Some((None, _mutability)) => {
678744
return Err(vec![None]);
679745
}
680-
None => {
681-
return Err(vec![Some(TypeError(
682-
TypeErrorEnum::UnknownIdentifier(identifier.clone()),
683-
meta,
684-
))]);
685-
}
746+
None => match defs.consts.get(identifier.as_str()) {
747+
Some(&ty) => (ExprEnum::Identifier(identifier.clone()), ty.clone()),
748+
None => {
749+
return Err(vec![Some(TypeError(
750+
TypeErrorEnum::UnknownIdentifier(identifier.clone()),
751+
meta,
752+
))]);
753+
}
754+
},
686755
},
687756
ExprEnum::ArrayLiteral(fields) => {
688757
let mut errors = vec![];
@@ -966,9 +1035,9 @@ impl UntypedExpr {
9661035
)
9671036
}
9681037
ExprEnum::EnumLiteral(identifier, variant) => {
1038+
let VariantExpr(variant_name, variant, variant_meta) = variant.as_ref();
9691039
if let Some(enum_def) = defs.enums.get(identifier.as_str()) {
970-
let VariantExpr(variant_name, variant, meta) = variant.as_ref();
971-
let meta = *meta;
1040+
let meta = *variant_meta;
9721041
if let Some(types) = enum_def.get(variant_name.as_str()) {
9731042
match (variant, types) {
9741043
(VariantExprEnum::Unit, None) => {
@@ -1039,7 +1108,8 @@ impl UntypedExpr {
10391108
return Err(vec![Some(TypeError(e, meta))]);
10401109
}
10411110
} else {
1042-
let e = TypeErrorEnum::UnknownEnum(identifier.clone());
1111+
let e =
1112+
TypeErrorEnum::UnknownEnum(identifier.clone(), variant_name.to_string());
10431113
return Err(vec![Some(TypeError(e, meta))]);
10441114
}
10451115
}
@@ -1402,7 +1472,7 @@ impl UntypedPattern {
14021472
return Err(vec![Some(TypeError(e, meta))]);
14031473
}
14041474
} else {
1405-
let e = TypeErrorEnum::UnknownEnum(enum_name.clone());
1475+
let e = TypeErrorEnum::UnknownEnum(enum_name.clone(), variant_name.to_string());
14061476
return Err(vec![Some(TypeError(e, meta))]);
14071477
}
14081478
}

src/compile.rs

+37
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::{
99
},
1010
circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, PanicResult, USIZE_BITS},
1111
env::Env,
12+
literal::Literal,
1213
token::{SignedNumType, UnsignedNumType},
1314
TypedExpr, TypedFnDef, TypedPattern, TypedProgram, TypedStmt,
1415
};
@@ -18,6 +19,8 @@ use crate::{
1819
pub enum CompilerError {
1920
/// The specified function could not be compiled, as it was not found in the program.
2021
FnNotFound(String),
22+
/// The provided constant was not of the required type.
23+
InvalidLiteralType(Literal, Type),
2124
}
2225

2326
impl std::fmt::Display for CompilerError {
@@ -26,6 +29,9 @@ impl std::fmt::Display for CompilerError {
2629
CompilerError::FnNotFound(fn_name) => f.write_fmt(format_args!(
2730
"Could not find any function with name '{fn_name}'"
2831
)),
32+
CompilerError::InvalidLiteralType(literal, ty) => {
33+
f.write_fmt(format_args!("The literal is not of type '{ty}': {literal}"))
34+
}
2935
}
3036
}
3137
}
@@ -36,9 +42,40 @@ impl TypedProgram {
3642
/// Assumes that the input program has been correctly type-checked and **panics** if
3743
/// incompatible types are found that should have been caught by the type-checker.
3844
pub fn compile(&self, fn_name: &str) -> Result<(Circuit, &TypedFnDef), CompilerError> {
45+
self.compile_with_constants(fn_name, HashMap::new())
46+
}
47+
48+
/// Compiles the (type-checked) program with provided constants, producing a circuit of gates.
49+
///
50+
/// Assumes that the input program has been correctly type-checked and **panics** if
51+
/// incompatible types are found that should have been caught by the type-checker.
52+
pub fn compile_with_constants(
53+
&self,
54+
fn_name: &str,
55+
consts: HashMap<String, HashMap<String, Literal>>,
56+
) -> Result<(Circuit, &TypedFnDef), CompilerError> {
3957
let mut env = Env::new();
4058
let mut input_gates = vec![];
4159
let mut wire = 2;
60+
for (party, deps) in self.const_deps.iter() {
61+
for (c, (identifier, ty)) in deps {
62+
let Some(party_deps) = consts.get(party) else {
63+
todo!("missing party dep for {party}");
64+
};
65+
let Some(literal) = party_deps.get(c) else {
66+
todo!("missing value {party}::{c}");
67+
};
68+
if literal.is_of_type(self, ty) {
69+
let bits = literal.as_bits(self).iter().map(|b| *b as usize).collect();
70+
env.let_in_current_scope(identifier.clone(), bits);
71+
} else {
72+
return Err(CompilerError::InvalidLiteralType(
73+
literal.clone(),
74+
ty.clone(),
75+
));
76+
}
77+
}
78+
}
4279
if let Some(fn_def) = self.fn_defs.get(fn_name) {
4380
for param in fn_def.params.iter() {
4481
let type_size = param.ty.size_in_bits_for_defs(self);

src/lib.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ use eval::{EvalError, Evaluator};
4949
use literal::Literal;
5050
use parse::ParseError;
5151
use scan::{scan, ScanError};
52-
use std::fmt::{Display, Write as _};
52+
use std::{
53+
collections::HashMap,
54+
fmt::{Display, Write as _},
55+
};
5356
use token::MetaInfo;
5457

5558
#[cfg(feature = "serde")]
@@ -107,6 +110,21 @@ pub fn compile(prg: &str) -> Result<GarbleProgram, Error> {
107110
})
108111
}
109112

113+
/// Scans, parses, type-checks and then compiles the `"main"` fn of a program to a boolean circuit.
114+
pub fn compile_with_constants(
115+
prg: &str,
116+
consts: HashMap<String, HashMap<String, Literal>>,
117+
) -> Result<GarbleProgram, Error> {
118+
let program = check(prg)?;
119+
let (circuit, main) = program.compile_with_constants("main", consts)?;
120+
let main = main.clone();
121+
Ok(GarbleProgram {
122+
program,
123+
main,
124+
circuit,
125+
})
126+
}
127+
110128
/// The result of type-checking and compiling a Garble program.
111129
#[derive(Debug, Clone)]
112130
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]

src/literal.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@ impl Literal {
7575
};
7676
let mut env = Env::new();
7777
let mut fns = TypedFns::new();
78-
let defs = Defs::new(&checked.struct_defs, &checked.enum_defs);
78+
let const_types = checked
79+
.const_defs
80+
.iter()
81+
.map(|(n, c)| (n.clone(), c.ty.clone()))
82+
.collect();
83+
let defs = Defs::new(&const_types, &checked.struct_defs, &checked.enum_defs);
7984
let mut expr = scan(literal)?
8085
.parse_literal()?
8186
.type_check(&top_level_defs, &mut env, &mut fns, &defs)

0 commit comments

Comments
 (0)