From 0c9bfd2d4fc30b9a82a955ce46e2692598bc28e4 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 7 May 2024 18:23:53 +0100 Subject: [PATCH 01/22] Implement const defs in parser / type-checker / compiler --- src/ast.rs | 24 +++++++++++ src/check.rs | 102 +++++++++++++++++++++++++++++++++++++++-------- src/compile.rs | 37 +++++++++++++++++ src/lib.rs | 20 +++++++++- src/literal.rs | 7 +++- src/parse.rs | 46 ++++++++++++++++++++- src/scan.rs | 1 + src/token.rs | 7 +++- tests/compile.rs | 32 ++++++++++++++- 9 files changed, 253 insertions(+), 23 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 3671098..6b32809 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -11,6 +11,10 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType}; #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Program { + /// The external constants that the top level const definitions depend upon. + pub const_deps: HashMap>, + /// Top level const definitions. + pub const_defs: HashMap>, /// Top level struct type definitions. pub struct_defs: HashMap, /// Top level enum type definitions. @@ -19,6 +23,26 @@ pub struct Program { pub fn_defs: HashMap>, } +/// A top level const definition. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ConstDef { + /// The type of the constant. + pub ty: Type, + /// The value of the constant. + pub value: ConstExpr, + /// The location in the source code. + pub meta: MetaInfo, +} + +/// A constant value, either a literal or a namespaced symbol. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ConstExpr { + /// A constant value specified as a literal value. + Literal(Expr), +} + /// A top level struct type definition. #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/src/check.rs b/src/check.rs index 817ec46..8f6a240 100644 --- a/src/check.rs +++ b/src/check.rs @@ -5,8 +5,9 @@ use std::collections::{HashMap, HashSet}; use crate::{ ast::{ - self, EnumDef, Expr, ExprEnum, Mutability, Op, ParamDef, Pattern, PatternEnum, Stmt, - StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr, VariantExprEnum, + self, ConstDef, ConstExpr, EnumDef, Expr, ExprEnum, Mutability, Op, ParamDef, Pattern, + PatternEnum, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr, + VariantExprEnum, }, env::Env, token::{MetaInfo, SignedNumType, UnsignedNumType}, @@ -50,7 +51,7 @@ pub enum TypeErrorEnum { /// The struct constructor is missing the specified field. MissingStructField(String, String), /// No enum declaration with the specified name exists. - UnknownEnum(String), + UnknownEnum(String, String), /// The enum exists, but no variant declaration with the specified name was found. UnknownEnumVariant(String, String), /// No variable or function with the specified name exists in the current scope. @@ -137,8 +138,8 @@ impl std::fmt::Display for TypeErrorEnum { TypeErrorEnum::MissingStructField(struct_name, struct_field) => f.write_fmt( format_args!("Field '{struct_field}' is missing for struct '{struct_name}'"), ), - TypeErrorEnum::UnknownEnum(enum_name) => { - f.write_fmt(format_args!("Unknown enum '{enum_name}'")) + TypeErrorEnum::UnknownEnum(enum_name, enum_variant) => { + f.write_fmt(format_args!("Unknown enum '{enum_name}::{enum_variant}'")) } TypeErrorEnum::UnknownEnumVariant(enum_name, enum_variant) => f.write_fmt( format_args!("Unknown enum variant '{enum_name}::{enum_variant}'"), @@ -277,6 +278,7 @@ impl Type { /// Static top-level definitions of enums and functions. pub struct Defs<'a> { + consts: HashMap<&'a str, &'a Type>, structs: HashMap<&'a str, (Vec<&'a str>, HashMap<&'a str, Type>)>, enums: HashMap<&'a str, HashMap<&'a str, Option>>>, fns: HashMap<&'a str, &'a UntypedFnDef>, @@ -284,14 +286,19 @@ pub struct Defs<'a> { impl<'a> Defs<'a> { pub(crate) fn new( + const_defs: &'a HashMap, struct_defs: &'a HashMap, enum_defs: &'a HashMap, ) -> Self { let mut defs = Self { + consts: HashMap::new(), structs: HashMap::new(), enums: HashMap::new(), fns: HashMap::new(), }; + for (const_name, ty) in const_defs.iter() { + defs.consts.insert(const_name, &ty); + } for (struct_name, struct_def) in struct_defs.iter() { let mut field_names = Vec::with_capacity(struct_def.fields.len()); let mut field_types = HashMap::with_capacity(struct_def.fields.len()); @@ -338,6 +345,63 @@ impl UntypedProgram { struct_names, enum_names, }; + let mut const_deps: HashMap> = HashMap::new(); + let mut const_types = HashMap::with_capacity(self.const_defs.len()); + let mut const_defs = HashMap::with_capacity(self.const_defs.len()); + { + let top_level_defs = TopLevelTypes { + struct_names: HashSet::new(), + enum_names: HashSet::new(), + }; + let mut env = Env::new(); + let mut fns = TypedFns { + currently_being_checked: HashSet::new(), + typed: HashMap::new(), + }; + let defs = Defs { + consts: HashMap::new(), + structs: HashMap::new(), + enums: HashMap::new(), + fns: HashMap::new(), + }; + for (const_name, const_def) in self.const_defs.iter() { + match &const_def.value { + ConstExpr::Literal(expr) => { + match expr.type_check(&top_level_defs, &mut env, &mut fns, &defs) { + Ok(mut expr) => { + if let Err(errs) = check_type(&mut expr, &const_def.ty) { + errors.extend(errs); + } + const_defs.insert( + const_name.clone(), + ConstDef { + ty: const_def.ty.clone(), + value: ConstExpr::Literal(expr), + meta: const_def.meta, + }, + ); + } + Err(errs) => { + for e in errs { + if let Some(e) = e { + if let TypeError(TypeErrorEnum::UnknownEnum(p, n), _) = e { + // ignore this error, constant can be provided later during compilation + const_deps.entry(p).or_default().insert( + n, + (const_name.clone(), const_def.ty.clone()), + ); + } else { + errors.push(Some(e)); + } + } + } + } + } + const_types.insert(const_name.clone(), const_def.ty.clone()); + } + } + } + } let mut struct_defs = HashMap::with_capacity(self.struct_defs.len()); for (struct_name, struct_def) in self.struct_defs.iter() { let meta = struct_def.meta; @@ -372,7 +436,7 @@ impl UntypedProgram { enum_defs.insert(enum_name.clone(), EnumDef { variants, meta }); } - let mut untyped_defs = Defs::new(&struct_defs, &enum_defs); + let mut untyped_defs = Defs::new(&const_types, &struct_defs, &enum_defs); let mut checked_fn_defs = TypedFns::new(); for (fn_name, fn_def) in self.fn_defs.iter() { untyped_defs.fns.insert(fn_name, fn_def); @@ -406,6 +470,8 @@ impl UntypedProgram { } if errors.is_empty() { Ok(TypedProgram { + const_deps, + const_defs, struct_defs, enum_defs, fn_defs, @@ -677,12 +743,15 @@ impl UntypedExpr { Some((None, _mutability)) => { return Err(vec![None]); } - None => { - return Err(vec![Some(TypeError( - TypeErrorEnum::UnknownIdentifier(identifier.clone()), - meta, - ))]); - } + None => match defs.consts.get(identifier.as_str()) { + Some(&ty) => (ExprEnum::Identifier(identifier.clone()), ty.clone()), + None => { + return Err(vec![Some(TypeError( + TypeErrorEnum::UnknownIdentifier(identifier.clone()), + meta, + ))]); + } + }, }, ExprEnum::ArrayLiteral(fields) => { let mut errors = vec![]; @@ -966,9 +1035,9 @@ impl UntypedExpr { ) } ExprEnum::EnumLiteral(identifier, variant) => { + let VariantExpr(variant_name, variant, variant_meta) = variant.as_ref(); if let Some(enum_def) = defs.enums.get(identifier.as_str()) { - let VariantExpr(variant_name, variant, meta) = variant.as_ref(); - let meta = *meta; + let meta = *variant_meta; if let Some(types) = enum_def.get(variant_name.as_str()) { match (variant, types) { (VariantExprEnum::Unit, None) => { @@ -1039,7 +1108,8 @@ impl UntypedExpr { return Err(vec![Some(TypeError(e, meta))]); } } else { - let e = TypeErrorEnum::UnknownEnum(identifier.clone()); + let e = + TypeErrorEnum::UnknownEnum(identifier.clone(), variant_name.to_string()); return Err(vec![Some(TypeError(e, meta))]); } } @@ -1402,7 +1472,7 @@ impl UntypedPattern { return Err(vec![Some(TypeError(e, meta))]); } } else { - let e = TypeErrorEnum::UnknownEnum(enum_name.clone()); + let e = TypeErrorEnum::UnknownEnum(enum_name.clone(), variant_name.to_string()); return Err(vec![Some(TypeError(e, meta))]); } } diff --git a/src/compile.rs b/src/compile.rs index fb57fa3..8882acc 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -9,6 +9,7 @@ use crate::{ }, circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, PanicResult, USIZE_BITS}, env::Env, + literal::Literal, token::{SignedNumType, UnsignedNumType}, TypedExpr, TypedFnDef, TypedPattern, TypedProgram, TypedStmt, }; @@ -18,6 +19,8 @@ use crate::{ pub enum CompilerError { /// The specified function could not be compiled, as it was not found in the program. FnNotFound(String), + /// The provided constant was not of the required type. + InvalidLiteralType(Literal, Type), } impl std::fmt::Display for CompilerError { @@ -26,6 +29,9 @@ impl std::fmt::Display for CompilerError { CompilerError::FnNotFound(fn_name) => f.write_fmt(format_args!( "Could not find any function with name '{fn_name}'" )), + CompilerError::InvalidLiteralType(literal, ty) => { + f.write_fmt(format_args!("The literal is not of type '{ty}': {literal}")) + } } } } @@ -36,9 +42,40 @@ impl TypedProgram { /// Assumes that the input program has been correctly type-checked and **panics** if /// incompatible types are found that should have been caught by the type-checker. pub fn compile(&self, fn_name: &str) -> Result<(Circuit, &TypedFnDef), CompilerError> { + self.compile_with_constants(fn_name, HashMap::new()) + } + + /// Compiles the (type-checked) program with provided constants, producing a circuit of gates. + /// + /// Assumes that the input program has been correctly type-checked and **panics** if + /// incompatible types are found that should have been caught by the type-checker. + pub fn compile_with_constants( + &self, + fn_name: &str, + consts: HashMap>, + ) -> Result<(Circuit, &TypedFnDef), CompilerError> { let mut env = Env::new(); let mut input_gates = vec![]; let mut wire = 2; + for (party, deps) in self.const_deps.iter() { + for (c, (identifier, ty)) in deps { + let Some(party_deps) = consts.get(party) else { + todo!("missing party dep for {party}"); + }; + let Some(literal) = party_deps.get(c) else { + todo!("missing value {party}::{c}"); + }; + if literal.is_of_type(self, ty) { + let bits = literal.as_bits(self).iter().map(|b| *b as usize).collect(); + env.let_in_current_scope(identifier.clone(), bits); + } else { + return Err(CompilerError::InvalidLiteralType( + literal.clone(), + ty.clone(), + )); + } + } + } if let Some(fn_def) = self.fn_defs.get(fn_name) { for param in fn_def.params.iter() { let type_size = param.ty.size_in_bits_for_defs(self); diff --git a/src/lib.rs b/src/lib.rs index 41d2f8c..b88bb46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,10 @@ use eval::{EvalError, Evaluator}; use literal::Literal; use parse::ParseError; use scan::{scan, ScanError}; -use std::fmt::{Display, Write as _}; +use std::{ + collections::HashMap, + fmt::{Display, Write as _}, +}; use token::MetaInfo; #[cfg(feature = "serde")] @@ -107,6 +110,21 @@ pub fn compile(prg: &str) -> Result { }) } +/// Scans, parses, type-checks and then compiles the `"main"` fn of a program to a boolean circuit. +pub fn compile_with_constants( + prg: &str, + consts: HashMap>, +) -> Result { + let program = check(prg)?; + let (circuit, main) = program.compile_with_constants("main", consts)?; + let main = main.clone(); + Ok(GarbleProgram { + program, + main, + circuit, + }) +} + /// The result of type-checking and compiling a Garble program. #[derive(Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/src/literal.rs b/src/literal.rs index 02b8941..e942615 100644 --- a/src/literal.rs +++ b/src/literal.rs @@ -75,7 +75,12 @@ impl Literal { }; let mut env = Env::new(); let mut fns = TypedFns::new(); - let defs = Defs::new(&checked.struct_defs, &checked.enum_defs); + let const_types = checked + .const_defs + .iter() + .map(|(n, c)| (n.clone(), c.ty.clone())) + .collect(); + let defs = Defs::new(&const_types, &checked.struct_defs, &checked.enum_defs); let mut expr = scan(literal)? .parse_literal()? .type_check(&top_level_defs, &mut env, &mut fns, &defs) diff --git a/src/parse.rs b/src/parse.rs index 5ca3ad0..2e54b12 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -4,8 +4,8 @@ use std::{collections::HashMap, iter::Peekable, vec::IntoIter}; use crate::{ ast::{ - EnumDef, Expr, ExprEnum, FnDef, Op, ParamDef, Pattern, PatternEnum, Program, Stmt, - StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr, VariantExprEnum, + ConstDef, ConstExpr, EnumDef, Expr, ExprEnum, FnDef, Op, ParamDef, Pattern, PatternEnum, + Program, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr, VariantExprEnum, }, scan::Tokens, token::{MetaInfo, SignedNumType, Token, TokenEnum, UnsignedNumType}, @@ -116,6 +116,7 @@ impl Parser { TokenEnum::KeywordStruct, TokenEnum::KeywordEnum, ]; + let mut const_defs = HashMap::new(); let mut struct_defs = HashMap::new(); let mut enum_defs = HashMap::new(); let mut fn_defs = HashMap::new(); @@ -125,6 +126,14 @@ impl Parser { TokenEnum::KeywordPub if is_pub.is_none() => { is_pub = Some(meta); } + TokenEnum::KeywordConst => { + if let Ok((const_name, const_def)) = self.parse_const_def(meta) { + const_defs.insert(const_name, const_def); + } else { + self.consume_until_one_of(&top_level_keywords); + } + is_pub = None; + } TokenEnum::KeywordStruct => { if let Ok((struct_name, struct_def)) = self.parse_struct_def(meta) { struct_defs.insert(struct_name, struct_def); @@ -158,6 +167,8 @@ impl Parser { } if self.errors.is_empty() { return Ok(Program { + const_deps: HashMap::new(), + const_defs, struct_defs, enum_defs, fn_defs, @@ -166,6 +177,37 @@ impl Parser { Err(self.errors) } + fn parse_const_def(&mut self, start: MetaInfo) -> Result<(String, ConstDef<()>), ()> { + // const keyword was already consumed by the top-level parser + let (identifier, _) = self.expect_identifier()?; + + self.expect(&TokenEnum::Colon)?; + + let (ty, _) = self.parse_type()?; + + self.expect(&TokenEnum::Eq)?; + + let Some(token) = self.tokens.next() else { + self.push_error(ParseErrorEnum::InvalidTopLevelDef, start); + return Err(()); + }; + match self.parse_literal(token, true) { + Ok(literal) => { + let end = self.expect(&TokenEnum::Semicolon)?; + let meta = join_meta(start, end); + Ok(( + identifier, + ConstDef { + ty, + value: ConstExpr::Literal(literal), + meta, + }, + )) + } + Err(_) => todo!("non-literal const def"), + } + } + fn parse_struct_def(&mut self, start: MetaInfo) -> Result<(String, StructDef), ()> { // struct keyword was already consumed by the top-level parser let (identifier, _) = self.expect_identifier()?; diff --git a/src/scan.rs b/src/scan.rs index 5227408..f74dfb7 100644 --- a/src/scan.rs +++ b/src/scan.rs @@ -274,6 +274,7 @@ impl<'a> Scanner<'a> { } let identifier: String = chars.into_iter().collect(); match identifier.as_str() { + "const" => self.push_token(TokenEnum::KeywordConst), "struct" => self.push_token(TokenEnum::KeywordStruct), "enum" => self.push_token(TokenEnum::KeywordEnum), "fn" => self.push_token(TokenEnum::KeywordFn), diff --git a/src/token.rs b/src/token.rs index a7d3d22..f4ceb69 100644 --- a/src/token.rs +++ b/src/token.rs @@ -18,10 +18,12 @@ pub enum TokenEnum { UnsignedNum(u64, UnsignedNumType), /// Signed number. SignedNum(i64, SignedNumType), - /// `enum` keyword. - KeywordEnum, + /// `const` keyword. + KeywordConst, /// `struct` keyword. KeywordStruct, + /// `enum` keyword. + KeywordEnum, /// `fn` keyword. KeywordFn, /// `let` keyword. @@ -121,6 +123,7 @@ impl std::fmt::Display for TokenEnum { TokenEnum::ConstantIndexOrSize(num) => f.write_fmt(format_args!("{num}")), TokenEnum::UnsignedNum(num, suffix) => f.write_fmt(format_args!("{num}{suffix}")), TokenEnum::SignedNum(num, suffix) => f.write_fmt(format_args!("{num}{suffix}")), + TokenEnum::KeywordConst => f.write_str("const"), TokenEnum::KeywordStruct => f.write_str("struct"), TokenEnum::KeywordEnum => f.write_str("enum"), TokenEnum::KeywordFn => f.write_str("fn"), diff --git a/tests/compile.rs b/tests/compile.rs index 9a9c0f3..f37d9fb 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -1,4 +1,8 @@ -use garble_lang::{compile, Error}; +use std::collections::HashMap; + +use garble_lang::{ + compile, compile_with_constants, literal::Literal, token::UnsignedNumType, Error, +}; fn pretty_print>(e: E, prg: &str) -> Error { let e: Error = e.into(); @@ -1834,3 +1838,29 @@ pub fn main(_a: i32, _b: i32) -> () { assert_eq!(r.to_string(), "()"); Ok(()) } + +#[test] +fn compile_consts() -> Result<(), Error> { + let prg = " +const MY_CONST: u16 = PARTY_0::MY_CONST; +pub fn main(x: u16) -> u16 { + x + MY_CONST +} +"; + let consts = HashMap::from_iter(vec![( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::U16), + )]), + )]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} From 641b4c3845d3e41a917ea65ea41795145d1c54aa Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 8 May 2024 12:47:47 +0100 Subject: [PATCH 02/22] Fix endless parser loop for invalid size constants in array literals --- src/parse.rs | 4 +--- tests/compile.rs | 50 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/src/parse.rs b/src/parse.rs index 2e54b12..e2fc724 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -363,10 +363,7 @@ impl Parser { self.expect(&TokenEnum::Semicolon)?; return Ok(Stmt::new(StmtEnum::Let(pattern, binding), meta)); } else { - self.push_error_for_next(ParseErrorEnum::ExpectedStmt); self.consume_until_one_of(&[TokenEnum::Semicolon]); - self.advance(); - self.parse_expr()?; self.expect(&TokenEnum::Semicolon)?; } } @@ -449,6 +446,7 @@ impl Parser { && !self.peek(&TokenEnum::KeywordFn) && !self.peek(&TokenEnum::KeywordStruct) && !self.peek(&TokenEnum::KeywordEnum) + && !self.tokens.peek().is_none() { if let Ok(stmt) = self.parse_stmt() { stmts.push(stmt); diff --git a/tests/compile.rs b/tests/compile.rs index f37d9fb..0f320b4 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -1840,7 +1840,7 @@ pub fn main(_a: i32, _b: i32) -> () { } #[test] -fn compile_consts() -> Result<(), Error> { +fn compile_const() -> Result<(), Error> { let prg = " const MY_CONST: u16 = PARTY_0::MY_CONST; pub fn main(x: u16) -> u16 { @@ -1864,3 +1864,51 @@ pub fn main(x: u16) -> u16 { ); Ok(()) } + +#[test] +fn compile_const_literal() -> Result<(), Error> { + let prg = " +const MY_CONST: u16 = 2u16; +pub fn main(x: u16) -> u16 { + x + MY_CONST +} +"; + let consts = HashMap::from_iter(vec![]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} + +#[test] +fn compile_const_usize() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = PARTY_0::MY_CONST; +pub fn main(x: u16) -> u16 { + // breaks if I use 2usize instead of 2 + let array = [2u16; 2usize]; + x + array[1] +} +"; + let consts = HashMap::from_iter(vec![( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + )]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} From 482c179b9cbfe916f90363ffae694b2e3c7f0530 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Mon, 13 May 2024 18:53:39 +0100 Subject: [PATCH 03/22] Const type checking / compilation --- src/ast.rs | 11 ++++ src/check.rs | 54 ++++++++++++++++- src/circuit.rs | 8 ++- src/compile.rs | 150 ++++++++++++++++++++++++++++++++++------------- src/eval.rs | 44 ++++++++++++-- src/lib.rs | 37 +++++++++--- src/literal.rs | 76 +++++++++++++++++------- src/parse.rs | 66 +++++++++++++-------- tests/compile.rs | 3 +- 9 files changed, 343 insertions(+), 106 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 6b32809..b434d7d 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -161,6 +161,8 @@ pub enum Type { Fn(Vec, Box), /// Array type of a fixed size, containing elements of the specified type. Array(Box, usize), + /// Array type of a fixed size, with the size specified by a constant. + ArrayConst(Box, String), /// Tuple type containing fields of the specified types. Tuple(Vec), /// A struct or an enum, depending on the top level definitions (used only before typechecking). @@ -197,6 +199,13 @@ impl std::fmt::Display for Type { size.fmt(f)?; f.write_str("]") } + Type::ArrayConst(ty, size) => { + f.write_str("[")?; + ty.fmt(f)?; + f.write_str("; ")?; + size.fmt(f)?; + f.write_str("]") + } Type::Tuple(fields) => { f.write_str("(")?; let mut fields = fields.iter(); @@ -303,6 +312,8 @@ pub enum ExprEnum { ArrayLiteral(Vec>), /// Array "repeat expression", which specifies 1 element, to be repeated a number of times. ArrayRepeatLiteral(Box>, usize), + /// Array "repeat expression", with the size specified by a constant. + ArrayRepeatLiteralConst(Box>, String), /// Access of an array at the specified index, returning its element. ArrayAccess(Box>, Box>), /// Tuple literal containing the specified fields. diff --git a/src/check.rs b/src/check.rs index 8f6a240..de4c83a 100644 --- a/src/check.rs +++ b/src/check.rs @@ -113,6 +113,8 @@ pub enum TypeErrorEnum { PatternsAreNotExhaustive(Vec), /// The expression cannot be matched upon. TypeDoesNotSupportPatternMatching(Type), + /// The specified identifier is not a constant. + ArraySizeNotConst(String), } impl std::fmt::Display for TypeErrorEnum { @@ -223,6 +225,9 @@ impl std::fmt::Display for TypeErrorEnum { TypeErrorEnum::TypeDoesNotSupportPatternMatching(ty) => { f.write_fmt(format_args!("Type {ty} does not support pattern matching")) } + TypeErrorEnum::ArraySizeNotConst(identifier) => { + f.write_fmt(format_args!("Array sizes must be constants, but '{identifier}' is a variable")) + } } } } @@ -252,6 +257,10 @@ impl Type { let elem = elem.as_concrete_type(types)?; Type::Array(Box::new(elem), *size) } + Type::ArrayConst(elem, size) => { + let elem = elem.as_concrete_type(types)?; + Type::ArrayConst(Box::new(elem), size.clone()) + } Type::Tuple(fields) => { let mut concrete_fields = Vec::with_capacity(fields.len()); for field in fields.iter() { @@ -792,6 +801,37 @@ impl UntypedExpr { let ty = Type::Array(Box::new(value.ty.clone()), *size); (ExprEnum::ArrayRepeatLiteral(Box::new(value), *size), ty) } + ExprEnum::ArrayRepeatLiteralConst(value, size) => match env.get(size) { + None => match defs.consts.get(size.as_str()) { + Some(&ty) if ty == &Type::Unsigned(UnsignedNumType::Usize) => { + let value = value.type_check(top_level_defs, env, fns, defs)?; + let ty = Type::ArrayConst(Box::new(value.ty.clone()), size.clone()); + ( + ExprEnum::ArrayRepeatLiteralConst(Box::new(value), size.clone()), + ty, + ) + } + Some(&ty) => { + let e = TypeErrorEnum::UnexpectedType { + expected: Type::Unsigned(UnsignedNumType::Usize), + actual: ty.clone(), + }; + return Err(vec![Some(TypeError(e, meta))]); + } + None => { + return Err(vec![Some(TypeError( + TypeErrorEnum::UnknownIdentifier(size.clone()), + meta, + ))]); + } + }, + Some(_) => { + return Err(vec![Some(TypeError( + TypeErrorEnum::ArraySizeNotConst(size.clone()), + meta, + ))]); + } + }, ExprEnum::ArrayAccess(arr, index) => { let arr = arr.type_check(top_level_defs, env, fns, defs)?; let mut index = index.type_check(top_level_defs, env, fns, defs)?; @@ -1124,7 +1164,7 @@ impl UntypedExpr { | Type::Tuple(_) | Type::Struct(_) | Type::Enum(_) => {} - Type::Fn(_, _) | Type::Array(_, _) => { + Type::Fn(_, _) | Type::Array(_, _) | Type::ArrayConst(_, _) => { let e = TypeErrorEnum::TypeDoesNotSupportPatternMatching(ty.clone()); return Err(vec![Some(TypeError(e, meta))]); } @@ -1519,6 +1559,7 @@ enum Ctor { Struct(String, Vec<(String, Type)>), Variant(String, String, Option>), Array(Box, usize), + ArrayConst(Box, String), } type PatternStack = Vec; @@ -1587,7 +1628,7 @@ fn specialize(ctor: &Ctor, pattern: &[TypedPattern]) -> Vec { } _ => vec![], }, - Ctor::Array(_, _) => match head_enum { + Ctor::Array(_, _) | Ctor::ArrayConst(_, _) => match head_enum { PatternEnum::Identifier(_) => vec![tail.collect()], _ => vec![], }, @@ -1793,6 +1834,7 @@ fn split_ctor(patterns: &[PatternStack], q: &[TypedPattern], defs: &Defs) -> Vec vec![Ctor::Tuple(fields.clone())] } Type::Array(elem_ty, size) => vec![Ctor::Array(elem_ty.clone(), *size)], + Type::ArrayConst(elem_ty, size) => vec![Ctor::ArrayConst(elem_ty.clone(), size.clone())], Type::Fn(_, _) => { panic!("Type {ty:?} does not support pattern matching") } @@ -1889,6 +1931,14 @@ fn usefulness(patterns: Vec, q: PatternStack, defs: &Defs) -> Vec< meta, ), ), + Ctor::ArrayConst(elem_ty, size) => witness.insert( + 0, + Pattern::typed( + PatternEnum::Identifier("_".to_string()), + Type::ArrayConst(elem_ty.clone(), size.clone()), + meta, + ), + ), } witnesses.push(witness); } diff --git a/src/circuit.rs b/src/circuit.rs index ccee5e8..e6a0104 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -269,6 +269,7 @@ pub(crate) struct CircuitBuilder { gates_optimized: usize, gate_counter: usize, panic_gates: PanicResult, + consts: HashMap, } pub(crate) const USIZE_BITS: usize = 32; @@ -393,7 +394,7 @@ impl PanicReason { } impl CircuitBuilder { - pub fn new(input_gates: Vec) -> Self { + pub fn new(input_gates: Vec, consts: HashMap) -> Self { let mut gate_counter = 2; // for const true and false for input_gates_of_party in input_gates.iter() { gate_counter += input_gates_of_party; @@ -407,9 +408,14 @@ impl CircuitBuilder { gates_optimized: 0, gate_counter, panic_gates: PanicResult::ok(), + consts, } } + pub fn const_sizes(&self) -> &HashMap { + &self.consts + } + // Pruning of useless gates (gates that are not part of the output nor used by other gates): fn remove_unused_gates(&mut self, output_gates: Vec) -> Vec { // To find all unused gates, we start at the output gates and recursively mark all their diff --git a/src/compile.rs b/src/compile.rs index 8882acc..e6a6129 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -55,8 +55,7 @@ impl TypedProgram { consts: HashMap>, ) -> Result<(Circuit, &TypedFnDef), CompilerError> { let mut env = Env::new(); - let mut input_gates = vec![]; - let mut wire = 2; + let mut const_sizes = HashMap::new(); for (party, deps) in self.const_deps.iter() { for (c, (identifier, ty)) in deps { let Some(party_deps) = consts.get(party) else { @@ -66,8 +65,15 @@ impl TypedProgram { todo!("missing value {party}::{c}"); }; if literal.is_of_type(self, ty) { - let bits = literal.as_bits(self).iter().map(|b| *b as usize).collect(); + let bits = literal + .as_bits(self, &const_sizes) + .iter() + .map(|b| *b as usize) + .collect(); env.let_in_current_scope(identifier.clone(), bits); + if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { + const_sizes.insert(identifier.clone(), *size as usize); + } } else { return Err(CompilerError::InvalidLiteralType( literal.clone(), @@ -76,9 +82,11 @@ impl TypedProgram { } } } + let mut input_gates = vec![]; + let mut wire = 2; if let Some(fn_def) = self.fn_defs.get(fn_name) { for param in fn_def.params.iter() { - let type_size = param.ty.size_in_bits_for_defs(self); + let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes); let mut wires = Vec::with_capacity(type_size); for _ in 0..type_size { wires.push(wire); @@ -87,7 +95,7 @@ impl TypedProgram { input_gates.push(type_size); env.let_in_current_scope(param.name.clone(), wires); } - let mut circuit = CircuitBuilder::new(input_gates); + let mut circuit = CircuitBuilder::new(input_gates, const_sizes); let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); Ok((circuit.build(output_gates), fn_def)) } else { @@ -136,12 +144,13 @@ impl TypedStmt { vec![] } StmtEnum::ArrayAssign(identifier, index, value) => { - let elem_bits = value.ty.size_in_bits_for_defs(prg); + let elem_bits = value.ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let mut array = env.get(identifier).unwrap(); let size = array.len() / elem_bits; let mut index = index.compile(prg, env, circuit); let value = value.compile(prg, env, circuit); - let index_bits = Type::Unsigned(UnsignedNumType::Usize).size_in_bits_for_defs(prg); + let index_bits = Type::Unsigned(UnsignedNumType::Usize) + .size_in_bits_for_defs(prg, circuit.const_sizes()); extend_to_bits( &mut index, &Type::Unsigned(UnsignedNumType::Usize), @@ -185,7 +194,9 @@ impl TypedStmt { } StmtEnum::ForEachLoop(var, array, body) => { let elem_in_bits = match &array.ty { - Type::Array(elem_ty, _size) => elem_ty.size_in_bits_for_defs(prg), + Type::Array(elem_ty, _size) => { + elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()) + } _ => panic!("Found a non-array value in an array access expr"), }; env.push(); @@ -225,18 +236,29 @@ impl TypedExpr { vec![0] } ExprEnum::NumUnsigned(n, _) => { - let mut bits = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); - unsigned_to_bits(*n, ty.size_in_bits_for_defs(prg), &mut bits); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); + unsigned_to_bits( + *n, + ty.size_in_bits_for_defs(prg, circuit.const_sizes()), + &mut bits, + ); bits.into_iter().map(|b| b as usize).collect() } ExprEnum::NumSigned(n, _) => { - let mut bits = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); - signed_to_bits(*n, ty.size_in_bits_for_defs(prg), &mut bits); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); + signed_to_bits( + *n, + ty.size_in_bits_for_defs(prg, circuit.const_sizes()), + &mut bits, + ); bits.into_iter().map(|b| b as usize).collect() } ExprEnum::Identifier(s) => env.get(s).unwrap(), ExprEnum::ArrayLiteral(elems) => { - let mut wires = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); + let mut wires = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); for elem in elems { wires.extend(elem.compile(prg, env, circuit)); } @@ -245,23 +267,44 @@ impl TypedExpr { ExprEnum::ArrayRepeatLiteral(elem, size) => { let elem_ty = elem.ty.clone(); let mut elem = elem.compile(prg, env, circuit); - extend_to_bits(&mut elem, &elem_ty, elem_ty.size_in_bits_for_defs(prg)); - let bits = ty.size_in_bits_for_defs(prg); + extend_to_bits( + &mut elem, + &elem_ty, + elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()), + ); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let mut array = Vec::with_capacity(bits); for _ in 0..*size { array.extend_from_slice(&elem); } array } + ExprEnum::ArrayRepeatLiteralConst(elem, size) => { + let size = *circuit.const_sizes().get(size).unwrap(); + let elem_ty = elem.ty.clone(); + let mut elem = elem.compile(prg, env, circuit); + extend_to_bits( + &mut elem, + &elem_ty, + elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()), + ); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); + let mut array = Vec::with_capacity(bits); + for _ in 0..size { + array.extend_from_slice(&elem); + } + array + } ExprEnum::ArrayAccess(array, index) => { let num_elems = match &array.ty { Type::Array(_, size) => *size, _ => panic!("Found a non-array value in an array access expr"), }; - let elem_bits = ty.size_in_bits_for_defs(prg); + let elem_bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let mut array = array.compile(prg, env, circuit); let mut index = index.compile(prg, env, circuit); - let index_bits = Type::Unsigned(UnsignedNumType::Usize).size_in_bits_for_defs(prg); + let index_bits = Type::Unsigned(UnsignedNumType::Usize) + .size_in_bits_for_defs(prg, circuit.const_sizes()); extend_to_bits( &mut index, &Type::Unsigned(UnsignedNumType::Usize), @@ -304,7 +347,8 @@ impl TypedExpr { } } ExprEnum::TupleLiteral(tuple) => { - let mut wires = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); + let mut wires = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); for value in tuple { wires.extend(value.compile(prg, env, circuit)); } @@ -315,9 +359,12 @@ impl TypedExpr { Type::Tuple(values) => { let mut wires_before = 0; for v in values[0..*index].iter() { - wires_before += v.size_in_bits_for_defs(prg); + wires_before += v.size_in_bits_for_defs(prg, circuit.const_sizes()); } - (wires_before, values[*index].size_in_bits_for_defs(prg)) + ( + wires_before, + values[*index].size_in_bits_for_defs(prg, circuit.const_sizes()), + ) } _ => panic!("Expected a tuple type, but found {:?}", tuple.meta), }; @@ -632,7 +679,7 @@ impl TypedExpr { ExprEnum::Cast(ty, expr) => { let ty_expr = &expr.ty; let mut expr = expr.compile(prg, env, circuit); - let size_after_cast = ty.size_in_bits_for_defs(prg); + let size_after_cast = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); match size_after_cast.cmp(&expr.len()) { std::cmp::Ordering::Equal => expr, @@ -645,7 +692,8 @@ impl TypedExpr { } ExprEnum::Range((from, elem_ty), (to, _)) => { let size = (to - from) as usize; - let elem_bits = Type::Unsigned(*elem_ty).size_in_bits_for_defs(prg); + let elem_bits = + Type::Unsigned(*elem_ty).size_in_bits_for_defs(prg, circuit.const_sizes()); let mut array = Vec::with_capacity(elem_bits * size); for i in *from..*to { for b in (0..elem_bits).rev() { @@ -657,7 +705,7 @@ impl TypedExpr { ExprEnum::EnumLiteral(identifier, variant) => { let enum_def = prg.enum_defs.get(identifier).unwrap(); let tag_size = enum_tag_size(enum_def); - let max_size = enum_max_size(enum_def, prg); + let max_size = enum_max_size(enum_def, prg, circuit.const_sizes()); let mut wires = vec![0; max_size]; let VariantExpr(variant_name, variant, _) = variant.as_ref(); let tag_number = enum_tag_number(enum_def, variant_name); @@ -678,7 +726,7 @@ impl TypedExpr { wires } ExprEnum::Match(expr, clauses) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let expr = expr.compile(prg, env, circuit); let mut has_prev_match = 0; let mut muxed_ret_expr = vec![0; bits]; @@ -718,7 +766,8 @@ impl TypedExpr { let struct_def = prg.struct_defs.get(name.as_str()).unwrap(); let mut bits = 0; for (field_name, field_ty) in struct_def.fields.iter() { - let bits_of_field = field_ty.size_in_bits_for_defs(prg); + let bits_of_field = + field_ty.size_in_bits_for_defs(prg, circuit.const_sizes()); if field_name == field { return struct_expr[bits..bits + bits_of_field].to_vec(); } @@ -732,7 +781,8 @@ impl TypedExpr { ExprEnum::StructLiteral(struct_name, fields) => { let fields: HashMap<_, _> = fields.iter().cloned().collect(); let struct_def = prg.struct_defs.get(struct_name.as_str()).unwrap(); - let mut wires = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); + let mut wires = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); for (field_name, _) in struct_def.fields.iter() { let value = fields.get(field_name).unwrap(); wires.extend(value.compile(prg, env, circuit)); @@ -766,7 +816,7 @@ impl TypedPattern { circuit.push_not(match_expr[0]) } PatternEnum::NumUnsigned(n, _) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let n = unsigned_as_wires(*n, bits); let mut acc = 1; for i in 0..bits { @@ -776,7 +826,7 @@ impl TypedPattern { acc } PatternEnum::NumSigned(n, _) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let n = signed_as_wires(*n, bits); let mut acc = 1; for i in 0..bits { @@ -786,7 +836,7 @@ impl TypedPattern { acc } PatternEnum::UnsignedInclusiveRange(min, max, _) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let min = unsigned_as_wires(*min, bits); let max = unsigned_as_wires(*max, bits); let signed = is_signed(ty); @@ -799,7 +849,7 @@ impl TypedPattern { circuit.push_and(not_lt_min, not_gt_max) } PatternEnum::SignedInclusiveRange(min, max, _) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let min = signed_as_wires(*min, bits); let max = signed_as_wires(*max, bits); let signed = is_signed(ty); @@ -816,7 +866,7 @@ impl TypedPattern { let mut w = 0; for field in fields { let Pattern(_, _, field_type) = field; - let field_bits = field_type.size_in_bits_for_defs(prg); + let field_bits = field_type.size_in_bits_for_defs(prg, circuit.const_sizes()); let match_expr = &match_expr[w..w + field_bits]; let is_field_match = field.compile(match_expr, prg, env, circuit); is_match = circuit.push_and(is_match, is_field_match); @@ -831,7 +881,7 @@ impl TypedPattern { let mut is_match = 1; let mut w = 0; for (field_name, field_type) in struct_def.fields.iter() { - let field_bits = field_type.size_in_bits_for_defs(prg); + let field_bits = field_type.size_in_bits_for_defs(prg, circuit.const_sizes()); if let Some(field_pattern) = fields.get(field_name) { let match_expr = &match_expr[w..w + field_bits]; let is_field_match = field_pattern.compile(match_expr, prg, env, circuit); @@ -866,7 +916,8 @@ impl TypedPattern { .types() .unwrap_or_default(); for (field, field_type) in fields.iter().zip(field_types) { - let field_bits = field_type.size_in_bits_for_defs(prg); + let field_bits = + field_type.size_in_bits_for_defs(prg, circuit.const_sizes()); let match_expr = &match_expr[w..w + field_bits]; let is_field_match = field.compile(match_expr, prg, env, circuit); is_match = circuit.push_and(is_match, is_field_match); @@ -882,7 +933,11 @@ impl TypedPattern { } impl Type { - pub(crate) fn size_in_bits_for_defs(&self, prg: &TypedProgram) -> usize { + pub(crate) fn size_in_bits_for_defs( + &self, + prg: &TypedProgram, + const_sizes: &HashMap, + ) -> usize { match self { Type::Bool => 1, Type::Unsigned(UnsignedNumType::Usize) => USIZE_BITS, @@ -890,17 +945,20 @@ impl Type { Type::Unsigned(UnsignedNumType::U16) | Type::Signed(SignedNumType::I16) => 16, Type::Unsigned(UnsignedNumType::U32) | Type::Signed(SignedNumType::I32) => 32, Type::Unsigned(UnsignedNumType::U64) | Type::Signed(SignedNumType::I64) => 64, - Type::Array(elem, size) => elem.size_in_bits_for_defs(prg) * size, + Type::Array(elem, size) => elem.size_in_bits_for_defs(prg, const_sizes) * size, + Type::ArrayConst(elem, size) => { + elem.size_in_bits_for_defs(prg, const_sizes) * const_sizes.get(size).unwrap() + } Type::Tuple(values) => { let mut size = 0; for v in values { - size += v.size_in_bits_for_defs(prg) + size += v.size_in_bits_for_defs(prg, const_sizes) } size } Type::Fn(_, _) => panic!("Fn types cannot be directly mapped to bits"), - Type::Struct(name) => struct_size(prg.struct_defs.get(name).unwrap(), prg), - Type::Enum(name) => enum_max_size(prg.enum_defs.get(name).unwrap(), prg), + Type::Struct(name) => struct_size(prg.struct_defs.get(name).unwrap(), prg, const_sizes), + Type::Enum(name) => enum_max_size(prg.enum_defs.get(name).unwrap(), prg, const_sizes), Type::UntypedTopLevelDefinition(_, _) => { unreachable!("Untyped top level types should have been typechecked at this point") } @@ -908,10 +966,14 @@ impl Type { } } -pub(crate) fn struct_size(struct_def: &StructDef, prg: &TypedProgram) -> usize { +pub(crate) fn struct_size( + struct_def: &StructDef, + prg: &TypedProgram, + const_sizes: &HashMap, +) -> usize { let mut total_size = 0; for (_, field_ty) in struct_def.fields.iter() { - total_size += field_ty.size_in_bits_for_defs(prg); + total_size += field_ty.size_in_bits_for_defs(prg, const_sizes); } total_size } @@ -933,12 +995,16 @@ pub(crate) fn enum_tag_size(enum_def: &EnumDef) -> usize { bits } -pub(crate) fn enum_max_size(enum_def: &EnumDef, prg: &TypedProgram) -> usize { +pub(crate) fn enum_max_size( + enum_def: &EnumDef, + prg: &TypedProgram, + const_sizes: &HashMap, +) -> usize { let mut max = 0; for variant in enum_def.variants.iter() { let mut sum = 0; for field in variant.types().unwrap_or_default() { - sum += field.size_in_bits_for_defs(prg); + sum += field.size_in_bits_for_defs(prg, const_sizes); } if sum > max { max = sum; diff --git a/src/eval.rs b/src/eval.rs index 3a78cc7..aa3cad3 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,6 +1,6 @@ //! Evaluates a [`crate::circuit::Circuit`] with inputs supplied by different parties. -use std::fmt::Debug; +use std::{collections::HashMap, fmt::Debug}; use crate::{ ast::Type, @@ -20,6 +20,7 @@ pub struct Evaluator<'a> { /// The compiled circuit. pub circuit: &'a Circuit, inputs: Vec>, + const_sizes: HashMap, } impl<'a> Evaluator<'a> { @@ -30,6 +31,37 @@ impl<'a> Evaluator<'a> { main_fn, circuit, inputs: vec![], + const_sizes: HashMap::new(), + } + } + + /// Scans, parses, type-checks and then compiles a program for later evaluation. + pub fn new_with_constants( + program: &'a TypedProgram, + main_fn: &'a TypedFnDef, + circuit: &'a Circuit, + consts: &HashMap>, + ) -> Self { + let mut const_sizes = HashMap::new(); + for (party, deps) in program.const_deps.iter() { + for (c, (identifier, _)) in deps { + let Some(party_deps) = consts.get(party) else { + todo!("missing party dep for {party}"); + }; + let Some(literal) = party_deps.get(c) else { + todo!("missing value {party}::{c}"); + }; + if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { + const_sizes.insert(identifier.clone(), *size as usize); + } + } + } + Self { + program, + main_fn, + circuit, + inputs: vec![], + const_sizes, } } } @@ -111,6 +143,7 @@ impl<'a> Evaluator<'a> { program: self.program, main_fn: self.main_fn, output, + const_sizes: self.const_sizes.clone(), }) } @@ -188,7 +221,7 @@ impl<'a> Evaluator<'a> { self.inputs .last_mut() .unwrap() - .extend(literal.as_bits(self.program)); + .extend(literal.as_bits(self.program, &self.const_sizes)); Ok(()) } else { Err(EvalError::InvalidLiteralType(literal, ty.clone())) @@ -218,6 +251,7 @@ pub struct EvalOutput<'a> { program: &'a TypedProgram, main_fn: &'a TypedFnDef, output: Vec, + const_sizes: HashMap, } impl<'a> TryFrom> for bool { @@ -336,7 +370,7 @@ impl<'a> TryFrom> for Vec { impl<'a> EvalOutput<'a> { fn into_unsigned(self, ty: Type) -> Result { let output = EvalPanic::parse(&self.output)?; - let size = ty.size_in_bits_for_defs(self.program); + let size = ty.size_in_bits_for_defs(self.program, &self.const_sizes); if output.len() == size { let mut n = 0; for (i, output) in output.iter().copied().enumerate() { @@ -353,7 +387,7 @@ impl<'a> EvalOutput<'a> { fn into_signed(self, ty: Type) -> Result { let output = EvalPanic::parse(&self.output)?; - let size = ty.size_in_bits_for_defs(self.program); + let size = ty.size_in_bits_for_defs(self.program, &self.const_sizes); if output.len() == size { let mut n = 0; for (i, output) in output.iter().copied().enumerate() { @@ -376,6 +410,6 @@ impl<'a> EvalOutput<'a> { /// Decodes the evaluated result as a literal (with enums looked up in the program). pub fn into_literal(self) -> Result { let ret_ty = &self.main_fn.ty; - Literal::from_result_bits(self.program, ret_ty, &self.output) + Literal::from_result_bits(self.program, ret_ty, &self.output, &self.const_sizes) } } diff --git a/src/lib.rs b/src/lib.rs index b88bb46..f740ad8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,7 +53,7 @@ use std::{ collections::HashMap, fmt::{Display, Write as _}, }; -use token::MetaInfo; +use token::{MetaInfo, UnsignedNumType}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -107,6 +107,8 @@ pub fn compile(prg: &str) -> Result { program, main, circuit, + consts: HashMap::new(), + const_sizes: HashMap::new(), }) } @@ -116,12 +118,28 @@ pub fn compile_with_constants( consts: HashMap>, ) -> Result { let program = check(prg)?; - let (circuit, main) = program.compile_with_constants("main", consts)?; + let (circuit, main) = program.compile_with_constants("main", consts.clone())?; let main = main.clone(); + let mut const_sizes = HashMap::new(); + for (party, deps) in program.const_deps.iter() { + for (c, (identifier, _)) in deps { + let Some(party_deps) = consts.get(party) else { + todo!("missing party dep for {party}"); + }; + let Some(literal) = party_deps.get(c) else { + todo!("missing value {party}::{c}"); + }; + if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { + const_sizes.insert(identifier.clone(), *size as usize); + } + } + } Ok(GarbleProgram { program, main, circuit, + consts, + const_sizes, }) } @@ -135,16 +153,19 @@ pub struct GarbleProgram { pub main: TypedFnDef, /// The compilation output, as a circuit of boolean gates. pub circuit: Circuit, + /// The constants used for compiling the circuit. + pub consts: HashMap>, + const_sizes: HashMap, } /// An input argument for a Garble program and circuit. #[derive(Debug, Clone)] -pub struct GarbleArgument<'a>(Literal, &'a TypedProgram); +pub struct GarbleArgument<'a>(Literal, &'a TypedProgram, &'a HashMap); impl GarbleProgram { /// Returns an evaluator that can be used to run the compiled circuit. pub fn evaluator(&self) -> Evaluator<'_> { - Evaluator::new(&self.program, &self.main, &self.circuit) + Evaluator::new_with_constants(&self.program, &self.main, &self.circuit, &self.consts) } /// Type-checks and uses the literal as the circuit input argument with the given index. @@ -159,7 +180,7 @@ impl GarbleProgram { if !literal.is_of_type(&self.program, ¶m.ty) { return Err(EvalError::InvalidLiteralType(literal, param.ty.clone())); } - Ok(GarbleArgument(literal, &self.program)) + Ok(GarbleArgument(literal, &self.program, &self.const_sizes)) } /// Tries to parse the string as the circuit input argument with the given index. @@ -173,19 +194,19 @@ impl GarbleProgram { }; let literal = Literal::parse(&self.program, ¶m.ty, literal) .map_err(EvalError::LiteralParseError)?; - Ok(GarbleArgument(literal, &self.program)) + Ok(GarbleArgument(literal, &self.program, &self.const_sizes)) } /// Tries to convert the circuit output back to a Garble literal. pub fn parse_output(&self, bits: &[bool]) -> Result { - Literal::from_result_bits(&self.program, &self.main.ty, bits) + Literal::from_result_bits(&self.program, &self.main.ty, bits, &self.const_sizes) } } impl GarbleArgument<'_> { /// Converts the argument to input bits for the compiled circuit. pub fn as_bits(&self) -> Vec { - self.0.as_bits(self.1) + self.0.as_bits(self.1, self.2) } /// Converts the argument to a Garble literal. diff --git a/src/literal.rs b/src/literal.rs index e942615..6211e88 100644 --- a/src/literal.rs +++ b/src/literal.rs @@ -174,9 +174,10 @@ impl Literal { checked: &TypedProgram, ty: &Type, bits: &[bool], + const_sizes: &HashMap, ) -> Result { match EvalPanic::parse(bits) { - Ok(bits) => Literal::from_unwrapped_bits(checked, ty, bits), + Ok(bits) => Literal::from_unwrapped_bits(checked, ty, bits, const_sizes), Err(panic) => Err(EvalError::Panic(panic)), } } @@ -191,6 +192,7 @@ impl Literal { checked: &TypedProgram, ty: &Type, bits: &[bool], + const_sizes: &HashMap, ) -> Result { match ty { Type::Bool => { @@ -208,7 +210,7 @@ impl Literal { } } Type::Unsigned(unsigned_ty) => { - let size = ty.size_in_bits_for_defs(checked); + let size = ty.size_in_bits_for_defs(checked, const_sizes); if bits.len() == size { let mut n = 0; for (i, output) in bits.iter().copied().enumerate() { @@ -223,7 +225,7 @@ impl Literal { } } Type::Signed(signed_ty) => { - let size = ty.size_in_bits_for_defs(checked); + let size = ty.size_in_bits_for_defs(checked, const_sizes); if bits.len() == size { let mut n = 0; for (i, output) in bits.iter().copied().enumerate() { @@ -244,12 +246,34 @@ impl Literal { } } Type::Array(ty, size) => { - let ty_size = ty.size_in_bits_for_defs(checked); + let ty_size = ty.size_in_bits_for_defs(checked, const_sizes); let mut elems = vec![]; let mut i = 0; for _ in 0..*size { let bits = &bits[i..i + ty_size]; - elems.push(Literal::from_unwrapped_bits(checked, ty, bits)?); + elems.push(Literal::from_unwrapped_bits( + checked, + ty, + bits, + const_sizes, + )?); + i += ty_size; + } + Ok(Literal::Array(elems)) + } + Type::ArrayConst(ty, size) => { + let size = const_sizes.get(size).unwrap(); + let ty_size = ty.size_in_bits_for_defs(checked, const_sizes); + let mut elems = vec![]; + let mut i = 0; + for _ in 0..*size { + let bits = &bits[i..i + ty_size]; + elems.push(Literal::from_unwrapped_bits( + checked, + ty, + bits, + const_sizes, + )?); i += ty_size; } Ok(Literal::Array(elems)) @@ -258,9 +282,14 @@ impl Literal { let mut fields = vec![]; let mut i = 0; for ty in field_types { - let ty_size = ty.size_in_bits_for_defs(checked); + let ty_size = ty.size_in_bits_for_defs(checked, const_sizes); let bits = &bits[i..i + ty_size]; - fields.push(Literal::from_unwrapped_bits(checked, ty, bits)?); + fields.push(Literal::from_unwrapped_bits( + checked, + ty, + bits, + const_sizes, + )?); i += ty_size; } Ok(Literal::Tuple(fields)) @@ -270,9 +299,9 @@ impl Literal { let mut i = 0; let struct_def = checked.struct_defs.get(struct_name).unwrap(); for (field_name, ty) in struct_def.fields.iter() { - let ty_size = ty.size_in_bits_for_defs(checked); + let ty_size = ty.size_in_bits_for_defs(checked, const_sizes); let bits = &bits[i..i + ty_size]; - let value = Literal::from_unwrapped_bits(checked, ty, bits)?; + let value = Literal::from_unwrapped_bits(checked, ty, bits, const_sizes)?; fields.push((field_name.clone(), value)); i += ty_size; } @@ -299,10 +328,11 @@ impl Literal { let field = Literal::from_unwrapped_bits( checked, ty, - &bits[i..i + ty.size_in_bits_for_defs(checked)], + &bits[i..i + ty.size_in_bits_for_defs(checked, const_sizes)], + const_sizes, )?; fields.push(field); - i += ty.size_in_bits_for_defs(checked); + i += ty.size_in_bits_for_defs(checked, const_sizes); } let variant = VariantLiteral::Tuple(fields); Ok(Literal::Enum( @@ -321,24 +351,28 @@ impl Literal { } /// Encodes the literal as bits, looking up enum defs in the program. - pub fn as_bits(&self, checked: &TypedProgram) -> Vec { + pub fn as_bits( + &self, + checked: &TypedProgram, + const_sizes: &HashMap, + ) -> Vec { match self { Literal::True => vec![true], Literal::False => vec![false], Literal::NumUnsigned(n, ty) => { - let size = Type::Unsigned(*ty).size_in_bits_for_defs(checked); + let size = Type::Unsigned(*ty).size_in_bits_for_defs(checked, const_sizes); let mut bits = vec![]; unsigned_to_bits(*n, size, &mut bits); bits } Literal::NumSigned(n, ty) => { - let size = Type::Signed(*ty).size_in_bits_for_defs(checked); + let size = Type::Signed(*ty).size_in_bits_for_defs(checked, const_sizes); let mut bits = vec![]; signed_to_bits(*n, size, &mut bits); bits } Literal::ArrayRepeat(elem, size) => { - let elem = elem.as_bits(checked); + let elem = elem.as_bits(checked, const_sizes); let elem_size = elem.len(); let mut bits = vec![false; elem_size * size]; for i in 0..*size { @@ -349,28 +383,28 @@ impl Literal { Literal::Array(elems) => { let mut bits = vec![]; for elem in elems { - bits.extend(elem.as_bits(checked)) + bits.extend(elem.as_bits(checked, const_sizes)) } bits } Literal::Tuple(fields) => { let mut bits = vec![]; for f in fields { - bits.extend(f.as_bits(checked)) + bits.extend(f.as_bits(checked, const_sizes)) } bits } Literal::Struct(_, fields) => { let mut bits = vec![]; for (_, f) in fields { - bits.extend(f.as_bits(checked)) + bits.extend(f.as_bits(checked, const_sizes)) } bits } Literal::Enum(enum_name, variant_name, variant) => { let enum_def = checked.enum_defs.get(enum_name).unwrap(); let tag_size = enum_tag_size(enum_def); - let max_size = enum_max_size(enum_def, checked); + let max_size = enum_max_size(enum_def, checked, const_sizes); let mut wires = vec![false; max_size]; let tag_number = enum_tag_number(enum_def, variant_name); for (i, wire) in wires.iter_mut().enumerate().take(tag_size) { @@ -381,7 +415,7 @@ impl Literal { VariantLiteral::Unit => {} VariantLiteral::Tuple(fields) => { for f in fields { - let f = f.as_bits(checked); + let f = f.as_bits(checked, const_sizes); wires[w..w + f.len()].copy_from_slice(&f); w += f.len(); } @@ -391,7 +425,7 @@ impl Literal { } Literal::Range((min, min_ty), (max, _)) => { let elems: Vec = (*min as usize..*max as usize).collect(); - let elem_size = Type::Unsigned(*min_ty).size_in_bits_for_defs(checked); + let elem_size = Type::Unsigned(*min_ty).size_in_bits_for_defs(checked, const_sizes); let mut bits = Vec::with_capacity(elems.len() * elem_size); for elem in elems { unsigned_to_bits(elem as u64, elem_size, &mut bits); diff --git a/src/parse.rs b/src/parse.rs index e2fc724..e9de4fa 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1285,19 +1285,28 @@ impl Parser { }; if self.peek(&TokenEnum::Semicolon) { self.expect(&TokenEnum::Semicolon)?; - let size = if let Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) = - self.tokens.peek() - { - let n = *n; - self.advance(); - n as usize - } else { - self.push_error_for_next(ParseErrorEnum::InvalidArraySize); - return Err(()); - }; - let meta_end = self.expect(&TokenEnum::RightBracket)?; - let meta = join_meta(meta, meta_end); - Expr::untyped(ExprEnum::ArrayRepeatLiteral(Box::new(elem), size), meta) + match self.tokens.peek().cloned() { + Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) => { + self.advance(); + let meta_end = self.expect(&TokenEnum::RightBracket)?; + let meta = join_meta(meta, meta_end); + let size = n as usize; + Expr::untyped(ExprEnum::ArrayRepeatLiteral(Box::new(elem), size), meta) + } + Some(Token(TokenEnum::Identifier(n), _)) => { + self.advance(); + let meta_end = self.expect(&TokenEnum::RightBracket)?; + let meta = join_meta(meta, meta_end); + Expr::untyped( + ExprEnum::ArrayRepeatLiteralConst(Box::new(elem), n), + meta, + ) + } + _ => { + self.push_error_for_next(ParseErrorEnum::InvalidArraySize); + return Err(()); + } + } } else { let mut elems = vec![elem]; while self.next_matches(&TokenEnum::Comma).is_some() { @@ -1348,18 +1357,25 @@ impl Parser { } else if let Some(meta) = self.next_matches(&TokenEnum::LeftBracket) { let (ty, _) = self.parse_type()?; self.expect(&TokenEnum::Semicolon)?; - let size = if let Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) = self.tokens.peek() - { - let n = *n; - self.advance(); - n as usize - } else { - self.push_error_for_next(ParseErrorEnum::InvalidArraySize); - return Err(()); - }; - let meta_end = self.expect(&TokenEnum::RightBracket)?; - let meta = join_meta(meta, meta_end); - Ok((Type::Array(Box::new(ty), size), meta)) + match self.tokens.peek().cloned() { + Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) => { + self.advance(); + let size = n as usize; + let meta_end = self.expect(&TokenEnum::RightBracket)?; + let meta = join_meta(meta, meta_end); + Ok((Type::Array(Box::new(ty), size), meta)) + } + Some(Token(TokenEnum::Identifier(n), _)) => { + self.advance(); + let meta_end = self.expect(&TokenEnum::RightBracket)?; + let meta = join_meta(meta, meta_end); + Ok((Type::ArrayConst(Box::new(ty), n), meta)) + } + _ => { + self.push_error_for_next(ParseErrorEnum::InvalidArraySize); + return Err(()); + } + } } else { let (ty, meta) = self.expect_identifier()?; let ty = match ty.as_str() { diff --git a/tests/compile.rs b/tests/compile.rs index 0f320b4..a1ba329 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -1890,8 +1890,7 @@ fn compile_const_usize() -> Result<(), Error> { let prg = " const MY_CONST: usize = PARTY_0::MY_CONST; pub fn main(x: u16) -> u16 { - // breaks if I use 2usize instead of 2 - let array = [2u16; 2usize]; + let array = [2u16; MY_CONST]; x + array[1] } "; From 95caecc8cd531164174dcce8d271f14288211bd9 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 14 May 2024 17:22:25 +0100 Subject: [PATCH 04/22] Fix tests related to consts --- src/check.rs | 32 +++++++++++++++----------------- src/compile.rs | 14 +++++++++++++- src/parse.rs | 4 ++-- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/check.rs b/src/check.rs index de4c83a..92cfa50 100644 --- a/src/check.rs +++ b/src/check.rs @@ -306,7 +306,7 @@ impl<'a> Defs<'a> { fns: HashMap::new(), }; for (const_name, ty) in const_defs.iter() { - defs.consts.insert(const_name, &ty); + defs.consts.insert(const_name, ty); } for (struct_name, struct_def) in struct_defs.iter() { let mut field_names = Vec::with_capacity(struct_def.fields.len()); @@ -391,17 +391,15 @@ impl UntypedProgram { ); } Err(errs) => { - for e in errs { - if let Some(e) = e { - if let TypeError(TypeErrorEnum::UnknownEnum(p, n), _) = e { - // ignore this error, constant can be provided later during compilation - const_deps.entry(p).or_default().insert( - n, - (const_name.clone(), const_def.ty.clone()), - ); - } else { - errors.push(Some(e)); - } + for e in errs.into_iter().flatten() { + if let TypeError(TypeErrorEnum::UnknownEnum(p, n), _) = e { + // ignore this error, constant can be provided later during compilation + const_deps + .entry(p) + .or_default() + .insert(n, (const_name.clone(), const_def.ty.clone())); + } else { + errors.push(Some(e)); } } } @@ -681,7 +679,7 @@ impl UntypedStmt { ast::StmtEnum::ArrayAssign(identifier, index, value) => { match env.get(identifier) { Some((Some(array_ty), Mutability::Mutable)) => { - let (elem_ty, _) = expect_array_type(&array_ty, meta)?; + let elem_ty = expect_array_type(&array_ty, meta)?; let mut index = index.type_check(top_level_defs, env, fns, defs)?; check_type(&mut index, &Type::Unsigned(UnsignedNumType::Usize))?; @@ -710,7 +708,7 @@ impl UntypedStmt { } ast::StmtEnum::ForEachLoop(var, binding, body) => { let binding = binding.type_check(top_level_defs, env, fns, defs)?; - let (elem_ty, _) = expect_array_type(&binding.ty, meta)?; + let elem_ty = expect_array_type(&binding.ty, meta)?; let mut body_typed = Vec::with_capacity(body.len()); env.push(); env.let_in_current_scope(var.clone(), (Some(elem_ty), Mutability::Immutable)); @@ -835,7 +833,7 @@ impl UntypedExpr { ExprEnum::ArrayAccess(arr, index) => { let arr = arr.type_check(top_level_defs, env, fns, defs)?; let mut index = index.type_check(top_level_defs, env, fns, defs)?; - let (elem_ty, _) = expect_array_type(&arr.ty, arr.meta)?; + let elem_ty = expect_array_type(&arr.ty, arr.meta)?; check_type(&mut index, &Type::Unsigned(UnsignedNumType::Usize))?; ( ExprEnum::ArrayAccess(Box::new(arr), Box::new(index)), @@ -1948,9 +1946,9 @@ fn usefulness(patterns: Vec, q: PatternStack, defs: &Defs) -> Vec< } } -fn expect_array_type(ty: &Type, meta: MetaInfo) -> Result<(Type, usize), TypeErrors> { +fn expect_array_type(ty: &Type, meta: MetaInfo) -> Result { match ty { - Type::Array(elem, size) => Ok((*elem.clone(), *size)), + Type::Array(elem, _) | Type::ArrayConst(elem, _) => Ok(*elem.clone()), _ => Err(vec![Some(TypeError( TypeErrorEnum::ExpectedArrayType(ty.clone()), meta, diff --git a/src/compile.rs b/src/compile.rs index e6a6129..21219e2 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -4,7 +4,7 @@ use std::{cmp::max, collections::HashMap}; use crate::{ ast::{ - EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef, Type, UnaryOp, + ConstExpr, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef, Type, UnaryOp, VariantExpr, VariantExprEnum, }, circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, PanicResult, USIZE_BITS}, @@ -96,6 +96,17 @@ impl TypedProgram { env.let_in_current_scope(param.name.clone(), wires); } let mut circuit = CircuitBuilder::new(input_gates, const_sizes); + for (identifier, const_def) in self.const_defs.iter() { + match &const_def.value { + ConstExpr::Literal(expr) => { + if let ExprEnum::EnumLiteral(_, _) = expr.inner { + } else { + let bits = expr.compile(self, &mut env, &mut circuit); + env.let_in_current_scope(identifier.clone(), bits); + } + } + } + } let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); Ok((circuit.build(output_gates), fn_def)) } else { @@ -298,6 +309,7 @@ impl TypedExpr { ExprEnum::ArrayAccess(array, index) => { let num_elems = match &array.ty { Type::Array(_, size) => *size, + Type::ArrayConst(_, size) => *circuit.const_sizes().get(size).unwrap(), _ => panic!("Found a non-array value in an array access expr"), }; let elem_bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); diff --git a/src/parse.rs b/src/parse.rs index e9de4fa..c070af9 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -440,13 +440,13 @@ impl Parser { fn parse_stmts(&mut self) -> Result, ()> { let mut stmts = vec![]; let mut has_error = false; - while !self.peek(&TokenEnum::RightBrace) + while self.tokens.peek().is_some() + && !self.peek(&TokenEnum::RightBrace) && !self.peek(&TokenEnum::Comma) && !self.peek(&TokenEnum::KeywordPub) && !self.peek(&TokenEnum::KeywordFn) && !self.peek(&TokenEnum::KeywordStruct) && !self.peek(&TokenEnum::KeywordEnum) - && !self.tokens.peek().is_none() { if let Ok(stmt) = self.parse_stmt() { stmts.push(stmt); From 27cf5575e89b299d1120f17d298d9c7264d00ac4 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 15 May 2024 16:56:41 +0100 Subject: [PATCH 05/22] Use `ConstExpr` enum instead of reusing `Literal` --- src/ast.rs | 29 +++++++++++++++----- src/check.rs | 69 ++++++++++++++++++++++++++++++++++++------------ src/compile.rs | 36 ++++++++++++++++++++----- src/parse.rs | 69 +++++++++++++++++++++++++++++++++++++----------- tests/compile.rs | 36 +++++++++++++++++++++++++ 5 files changed, 195 insertions(+), 44 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index b434d7d..40b546d 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -14,7 +14,7 @@ pub struct Program { /// The external constants that the top level const definitions depend upon. pub const_deps: HashMap>, /// Top level const definitions. - pub const_defs: HashMap>, + pub const_defs: HashMap, /// Top level struct type definitions. pub struct_defs: HashMap, /// Top level enum type definitions. @@ -26,11 +26,11 @@ pub struct Program { /// A top level const definition. #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct ConstDef { +pub struct ConstDef { /// The type of the constant. pub ty: Type, /// The value of the constant. - pub value: ConstExpr, + pub value: ConstExpr, /// The location in the source code. pub meta: MetaInfo, } @@ -38,9 +38,26 @@ pub struct ConstDef { /// A constant value, either a literal or a namespaced symbol. #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum ConstExpr { - /// A constant value specified as a literal value. - Literal(Expr), +pub enum ConstExpr { + /// Boolean `true`. + True, + /// Boolean `false`. + False, + /// Unsigned integer. + NumUnsigned(u64, UnsignedNumType), + /// Signed integer. + NumSigned(i64, SignedNumType), + /// An external value supplied before compilation. + ExternalValue { + /// The party providing the value. + party: String, + /// The variable name of the value. + identifier: String, + }, + /// The maximum of several constant expressions. + Max(Vec), + /// The minimum of several constant expressions. + Min(Vec), } /// A top level struct type definition. diff --git a/src/check.rs b/src/check.rs index 92cfa50..c0a4ed1 100644 --- a/src/check.rs +++ b/src/check.rs @@ -358,23 +358,58 @@ impl UntypedProgram { let mut const_types = HashMap::with_capacity(self.const_defs.len()); let mut const_defs = HashMap::with_capacity(self.const_defs.len()); { - let top_level_defs = TopLevelTypes { - struct_names: HashSet::new(), - enum_names: HashSet::new(), - }; - let mut env = Env::new(); - let mut fns = TypedFns { - currently_being_checked: HashSet::new(), - typed: HashMap::new(), - }; - let defs = Defs { - consts: HashMap::new(), - structs: HashMap::new(), - enums: HashMap::new(), - fns: HashMap::new(), - }; for (const_name, const_def) in self.const_defs.iter() { - match &const_def.value { + fn check_const_expr( + const_name: &String, + const_def: &ConstDef, + errors: &mut Vec>, + const_deps: &mut HashMap>, + ) { + match &const_def.value { + ConstExpr::True | ConstExpr::False => { + if const_def.ty != Type::Bool { + let e = TypeErrorEnum::UnexpectedType { + expected: const_def.ty.clone(), + actual: Type::Bool, + }; + errors.extend(vec![Some(TypeError(e, const_def.meta))]); + } + } + ConstExpr::NumUnsigned(_, ty) => { + let ty = Type::Unsigned(ty.clone()); + if const_def.ty != ty { + let e = TypeErrorEnum::UnexpectedType { + expected: const_def.ty.clone(), + actual: ty, + }; + errors.extend(vec![Some(TypeError(e, const_def.meta))]); + } + } + ConstExpr::NumSigned(_, ty) => { + let ty = Type::Signed(ty.clone()); + if const_def.ty != ty { + let e = TypeErrorEnum::UnexpectedType { + expected: const_def.ty.clone(), + actual: ty, + }; + errors.extend(vec![Some(TypeError(e, const_def.meta))]); + } + } + ConstExpr::ExternalValue { party, identifier } => { + const_deps.entry(party.clone()).or_default().insert( + identifier.clone(), + (const_name.clone(), const_def.ty.clone()), + ); + } + ConstExpr::Max(args) => for arg in args {}, + ConstExpr::Min(_) => todo!(), + } + } + check_const_expr(&const_name, &const_def, &mut errors, &mut const_deps); + const_defs.insert(const_name.clone(), const_def.clone()); + const_types.insert(const_name.clone(), const_def.ty.clone()); + // TODO: remove the following: + /*match &const_def.value { ConstExpr::Literal(expr) => { match expr.type_check(&top_level_defs, &mut env, &mut fns, &defs) { Ok(mut expr) => { @@ -406,7 +441,7 @@ impl UntypedProgram { } const_types.insert(const_name.clone(), const_def.ty.clone()); } - } + }*/ } } let mut struct_defs = HashMap::with_capacity(self.struct_defs.len()); diff --git a/src/compile.rs b/src/compile.rs index 21219e2..746c0ca 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -98,13 +98,37 @@ impl TypedProgram { let mut circuit = CircuitBuilder::new(input_gates, const_sizes); for (identifier, const_def) in self.const_defs.iter() { match &const_def.value { - ConstExpr::Literal(expr) => { - if let ExprEnum::EnumLiteral(_, _) = expr.inner { - } else { - let bits = expr.compile(self, &mut env, &mut circuit); - env.let_in_current_scope(identifier.clone(), bits); - } + ConstExpr::True => env.let_in_current_scope(identifier.clone(), vec![1]), + ConstExpr::False => env.let_in_current_scope(identifier.clone(), vec![0]), + ConstExpr::NumUnsigned(n, ty) => { + let ty = Type::Unsigned(*ty); + let mut bits = Vec::with_capacity( + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + ); + unsigned_to_bits( + *n, + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(identifier.clone(), bits); + } + ConstExpr::NumSigned(n, ty) => { + let ty = Type::Signed(*ty); + let mut bits = Vec::with_capacity( + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + ); + signed_to_bits( + *n, + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(identifier.clone(), bits); } + ConstExpr::ExternalValue { .. } => {} + ConstExpr::Max(_) => todo!("compile max"), + ConstExpr::Min(_) => todo!("compile min"), } } let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); diff --git a/src/parse.rs b/src/parse.rs index c070af9..ee4f81b 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -20,7 +20,7 @@ pub struct ParseError(pub ParseErrorEnum, pub MetaInfo); /// The different kinds of errors found during parsing. pub enum ParseErrorEnum { - /// The top level definition is not a valid enum or function declaration. + /// The top level definition is not a valid enum/struct/const/fn declaration. InvalidTopLevelDef, /// Arrays of the specified size are not supported. InvalidArraySize, @@ -30,6 +30,8 @@ pub enum ParseErrorEnum { InvalidPattern, /// The literal is not valid. InvalidLiteral, + /// Expected a const expr, but found a non-const or invalid expr. + InvalidConstExpr, /// Expected a type, but found a non-type token. ExpectedType, /// Expected a statement, but found a non-statement token. @@ -48,7 +50,7 @@ impl std::fmt::Display for ParseErrorEnum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ParseErrorEnum::InvalidTopLevelDef => { - f.write_str("Not a valid function or enum definition") + f.write_str("Not a valid top level declaration (struct/enum/const/fn)") } ParseErrorEnum::InvalidArraySize => { let max = usize::MAX; @@ -59,6 +61,7 @@ impl std::fmt::Display for ParseErrorEnum { ParseErrorEnum::InvalidRangeExpr => f.write_str("Invalid range expression"), ParseErrorEnum::InvalidPattern => f.write_str("Invalid pattern"), ParseErrorEnum::InvalidLiteral => f.write_str("Invalid literal"), + ParseErrorEnum::InvalidConstExpr => f.write_str("Invalid const expr"), ParseErrorEnum::ExpectedType => f.write_str("Expected a type"), ParseErrorEnum::ExpectedStmt => f.write_str("Expected a statement"), ParseErrorEnum::ExpectedExpr => f.write_str("Expected an expression"), @@ -177,7 +180,7 @@ impl Parser { Err(self.errors) } - fn parse_const_def(&mut self, start: MetaInfo) -> Result<(String, ConstDef<()>), ()> { + fn parse_const_def(&mut self, start: MetaInfo) -> Result<(String, ConstDef), ()> { // const keyword was already consumed by the top-level parser let (identifier, _) = self.expect_identifier()?; @@ -187,24 +190,60 @@ impl Parser { self.expect(&TokenEnum::Eq)?; - let Some(token) = self.tokens.next() else { + let Ok(expr) = self.parse_primary() else { self.push_error(ParseErrorEnum::InvalidTopLevelDef, start); return Err(()); }; - match self.parse_literal(token, true) { - Ok(literal) => { + fn parse_const_expr( + expr: UntypedExpr, + ) -> Result> { + match expr.inner { + ExprEnum::True => Ok(ConstExpr::True), + ExprEnum::False => Ok(ConstExpr::False), + ExprEnum::NumUnsigned(n, ty) => Ok(ConstExpr::NumUnsigned(n, ty)), + ExprEnum::NumSigned(n, ty) => Ok(ConstExpr::NumSigned(n, ty)), + ExprEnum::EnumLiteral(party, variant) => { + // TODO: check that this is a unit variant + let VariantExpr(identifier, _, _) = *variant; + Ok(ConstExpr::ExternalValue { party, identifier }) + } + ExprEnum::FnCall(f, args) if f == "max" || f == "min" => { + let mut const_exprs = vec![]; + let mut arg_errs = vec![]; + for arg in args { + match parse_const_expr(arg) { + Ok(value) => { + const_exprs.push(value); + } + Err(errs) => { + arg_errs.extend(errs); + } + } + } + if !arg_errs.is_empty() { + return Err(arg_errs); + } + if f == "max" { + Ok(ConstExpr::Max(const_exprs)) + } else { + Ok(ConstExpr::Min(const_exprs)) + } + } + _ => Err(vec![(ParseErrorEnum::InvalidConstExpr, expr.meta)]), + } + } + match parse_const_expr(expr) { + Ok(value) => { let end = self.expect(&TokenEnum::Semicolon)?; let meta = join_meta(start, end); - Ok(( - identifier, - ConstDef { - ty, - value: ConstExpr::Literal(literal), - meta, - }, - )) + Ok((identifier, ConstDef { ty, value, meta })) + } + Err(errs) => { + for (e, meta) in errs { + self.push_error(e, meta); + } + return Err(()); } - Err(_) => todo!("non-literal const def"), } } diff --git a/tests/compile.rs b/tests/compile.rs index a1ba329..91069fa 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -1911,3 +1911,39 @@ pub fn main(x: u16) -> u16 { ); Ok(()) } + +#[test] +fn compile_const_aggregated_max() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST); +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +"; + let consts = HashMap::from_iter(vec![ + ( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(1, UnsignedNumType::Usize), + )]), + ), + ( + "PARTY_1".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + ), + ]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} From 362c84323b24b5ba167f40d1d107e00fc568b457 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 21 May 2024 16:51:16 +0100 Subject: [PATCH 06/22] Implement min/max for consts --- src/ast.rs | 2 +- src/check.rs | 59 ++++----------- src/compile.rs | 184 +++++++++++++++++++++++++++++++++++++---------- src/eval.rs | 5 +- src/lib.rs | 5 +- tests/compile.rs | 36 ++++++++++ 6 files changed, 204 insertions(+), 87 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 40b546d..f14631e 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -12,7 +12,7 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Program { /// The external constants that the top level const definitions depend upon. - pub const_deps: HashMap>, + pub const_deps: HashMap>, /// Top level const definitions. pub const_defs: HashMap, /// Top level struct type definitions. diff --git a/src/check.rs b/src/check.rs index c0a4ed1..52e6eaf 100644 --- a/src/check.rs +++ b/src/check.rs @@ -354,18 +354,18 @@ impl UntypedProgram { struct_names, enum_names, }; - let mut const_deps: HashMap> = HashMap::new(); + let mut const_deps: HashMap> = HashMap::new(); let mut const_types = HashMap::with_capacity(self.const_defs.len()); let mut const_defs = HashMap::with_capacity(self.const_defs.len()); { for (const_name, const_def) in self.const_defs.iter() { fn check_const_expr( - const_name: &String, + value: &ConstExpr, const_def: &ConstDef, errors: &mut Vec>, - const_deps: &mut HashMap>, + const_deps: &mut HashMap>, ) { - match &const_def.value { + match value { ConstExpr::True | ConstExpr::False => { if const_def.ty != Type::Bool { let e = TypeErrorEnum::UnexpectedType { @@ -396,52 +396,21 @@ impl UntypedProgram { } } ConstExpr::ExternalValue { party, identifier } => { - const_deps.entry(party.clone()).or_default().insert( - identifier.clone(), - (const_name.clone(), const_def.ty.clone()), - ); + const_deps + .entry(party.clone()) + .or_default() + .insert(identifier.clone(), const_def.ty.clone()); + } + ConstExpr::Max(args) | ConstExpr::Min(args) => { + for arg in args { + check_const_expr(arg, const_def, errors, const_deps); + } } - ConstExpr::Max(args) => for arg in args {}, - ConstExpr::Min(_) => todo!(), } } - check_const_expr(&const_name, &const_def, &mut errors, &mut const_deps); + check_const_expr(&const_def.value, &const_def, &mut errors, &mut const_deps); const_defs.insert(const_name.clone(), const_def.clone()); const_types.insert(const_name.clone(), const_def.ty.clone()); - // TODO: remove the following: - /*match &const_def.value { - ConstExpr::Literal(expr) => { - match expr.type_check(&top_level_defs, &mut env, &mut fns, &defs) { - Ok(mut expr) => { - if let Err(errs) = check_type(&mut expr, &const_def.ty) { - errors.extend(errs); - } - const_defs.insert( - const_name.clone(), - ConstDef { - ty: const_def.ty.clone(), - value: ConstExpr::Literal(expr), - meta: const_def.meta, - }, - ); - } - Err(errs) => { - for e in errs.into_iter().flatten() { - if let TypeError(TypeErrorEnum::UnknownEnum(p, n), _) = e { - // ignore this error, constant can be provided later during compilation - const_deps - .entry(p) - .or_default() - .insert(n, (const_name.clone(), const_def.ty.clone())); - } else { - errors.push(Some(e)); - } - } - } - } - const_types.insert(const_name.clone(), const_def.ty.clone()); - } - }*/ } } let mut struct_defs = HashMap::with_capacity(self.struct_defs.len()); diff --git a/src/compile.rs b/src/compile.rs index 746c0ca..c76d2ce 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -1,6 +1,9 @@ //! Compiles a [`crate::ast::Program`] to a [`crate::circuit::Circuit`]. -use std::{cmp::max, collections::HashMap}; +use std::{ + cmp::{max, min}, + collections::HashMap, +}; use crate::{ ast::{ @@ -56,14 +59,26 @@ impl TypedProgram { ) -> Result<(Circuit, &TypedFnDef), CompilerError> { let mut env = Env::new(); let mut const_sizes = HashMap::new(); + let mut consts_unsigned = HashMap::new(); + let mut consts_signed = HashMap::new(); for (party, deps) in self.const_deps.iter() { - for (c, (identifier, ty)) in deps { + for (c, ty) in deps { let Some(party_deps) = consts.get(party) else { todo!("missing party dep for {party}"); }; let Some(literal) = party_deps.get(c) else { todo!("missing value {party}::{c}"); }; + let identifier = format!("{party}::{c}"); + match literal { + Literal::NumUnsigned(n, _) => { + consts_unsigned.insert(identifier.clone(), *n); + } + Literal::NumSigned(n, _) => { + consts_signed.insert(identifier.clone(), *n); + } + _ => {} + } if literal.is_of_type(self, ty) { let bits = literal .as_bits(self, &const_sizes) @@ -72,7 +87,7 @@ impl TypedProgram { .collect(); env.let_in_current_scope(identifier.clone(), bits); if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { - const_sizes.insert(identifier.clone(), *size as usize); + const_sizes.insert(identifier, *size as usize); } } else { return Err(CompilerError::InvalidLiteralType( @@ -84,58 +99,153 @@ impl TypedProgram { } let mut input_gates = vec![]; let mut wire = 2; - if let Some(fn_def) = self.fn_defs.get(fn_name) { - for param in fn_def.params.iter() { - let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes); - let mut wires = Vec::with_capacity(type_size); - for _ in 0..type_size { - wires.push(wire); - wire += 1; + let Some(fn_def) = self.fn_defs.get(fn_name) else { + return Err(CompilerError::FnNotFound(fn_name.to_string())); + }; + for param in fn_def.params.iter() { + let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes); + let mut wires = Vec::with_capacity(type_size); + for _ in 0..type_size { + wires.push(wire); + wire += 1; + } + input_gates.push(type_size); + env.let_in_current_scope(param.name.clone(), wires); + } + fn resolve_const_expr_unsigned( + expr: &ConstExpr, + consts_unsigned: &HashMap, + ) -> u64 { + match expr { + ConstExpr::NumUnsigned(n, _) => *n, + ConstExpr::ExternalValue { party, identifier } => *consts_unsigned + .get(&format!("{party}::{identifier}")) + .unwrap(), + ConstExpr::Max(args) => { + let mut result = 0; + for arg in args { + result = max(result, resolve_const_expr_unsigned(arg, consts_unsigned)); + } + result + } + ConstExpr::Min(args) => { + let mut result = u64::MAX; + for arg in args { + result = min(result, resolve_const_expr_unsigned(arg, consts_unsigned)); + } + result + } + expr => panic!("Not an unsigned const expr: {expr:?}"), + } + } + fn resolve_const_expr_signed( + expr: &ConstExpr, + consts_signed: &HashMap, + ) -> i64 { + match expr { + ConstExpr::NumSigned(n, _) => *n, + ConstExpr::ExternalValue { party, identifier } => *consts_signed + .get(&format!("{party}::{identifier}")) + .unwrap(), + ConstExpr::Max(args) => { + let mut result = 0; + for arg in args { + result = max(result, resolve_const_expr_signed(arg, consts_signed)); + } + result + } + ConstExpr::Min(args) => { + let mut result = i64::MAX; + for arg in args { + result = min(result, resolve_const_expr_signed(arg, consts_signed)); + } + result + } + expr => panic!("Not an unsigned const expr: {expr:?}"), + } + } + for (const_name, const_def) in self.const_defs.iter() { + if let Type::Unsigned(UnsignedNumType::Usize) = const_def.ty { + if let ConstExpr::ExternalValue { party, identifier } = &const_def.value { + let identifier = format!("{party}::{identifier}"); + const_sizes.insert(const_name.clone(), *const_sizes.get(&identifier).unwrap()); } - input_gates.push(type_size); - env.let_in_current_scope(param.name.clone(), wires); + let n = resolve_const_expr_unsigned(&const_def.value, &consts_unsigned); + const_sizes.insert(const_name.clone(), n as usize); } - let mut circuit = CircuitBuilder::new(input_gates, const_sizes); - for (identifier, const_def) in self.const_defs.iter() { - match &const_def.value { - ConstExpr::True => env.let_in_current_scope(identifier.clone(), vec![1]), - ConstExpr::False => env.let_in_current_scope(identifier.clone(), vec![0]), - ConstExpr::NumUnsigned(n, ty) => { - let ty = Type::Unsigned(*ty); + } + let mut circuit = CircuitBuilder::new(input_gates, const_sizes); + for (const_name, const_def) in self.const_defs.iter() { + match &const_def.value { + ConstExpr::True => env.let_in_current_scope(const_name.clone(), vec![1]), + ConstExpr::False => env.let_in_current_scope(const_name.clone(), vec![0]), + ConstExpr::NumUnsigned(n, ty) => { + let ty = Type::Unsigned(*ty); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes())); + unsigned_to_bits( + *n, + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(const_name.clone(), bits); + } + ConstExpr::NumSigned(n, ty) => { + let ty = Type::Signed(*ty); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes())); + signed_to_bits( + *n, + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(const_name.clone(), bits); + } + ConstExpr::ExternalValue { party, identifier } => { + let bits = env.get(&format!("{party}::{identifier}")).unwrap(); + env.let_in_current_scope(const_name.clone(), bits); + } + expr @ (ConstExpr::Max(_) | ConstExpr::Min(_)) => { + if let Type::Unsigned(_) = const_def.ty { + let result = resolve_const_expr_unsigned(expr, &consts_unsigned); let mut bits = Vec::with_capacity( - ty.size_in_bits_for_defs(self, circuit.const_sizes()), + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), ); unsigned_to_bits( - *n, - ty.size_in_bits_for_defs(self, circuit.const_sizes()), + result, + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), &mut bits, ); let bits = bits.into_iter().map(|b| b as usize).collect(); - env.let_in_current_scope(identifier.clone(), bits); - } - ConstExpr::NumSigned(n, ty) => { - let ty = Type::Signed(*ty); + env.let_in_current_scope(const_name.clone(), bits); + } else { + let result = resolve_const_expr_signed(expr, &consts_signed); let mut bits = Vec::with_capacity( - ty.size_in_bits_for_defs(self, circuit.const_sizes()), + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), ); signed_to_bits( - *n, - ty.size_in_bits_for_defs(self, circuit.const_sizes()), + result, + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), &mut bits, ); let bits = bits.into_iter().map(|b| b as usize).collect(); - env.let_in_current_scope(identifier.clone(), bits); + env.let_in_current_scope(const_name.clone(), bits); } - ConstExpr::ExternalValue { .. } => {} - ConstExpr::Max(_) => todo!("compile max"), - ConstExpr::Min(_) => todo!("compile min"), } } - let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); - Ok((circuit.build(output_gates), fn_def)) - } else { - Err(CompilerError::FnNotFound(fn_name.to_string())) } + let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); + Ok((circuit.build(output_gates), fn_def)) } } diff --git a/src/eval.rs b/src/eval.rs index aa3cad3..9a1b43c 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -44,7 +44,7 @@ impl<'a> Evaluator<'a> { ) -> Self { let mut const_sizes = HashMap::new(); for (party, deps) in program.const_deps.iter() { - for (c, (identifier, _)) in deps { + for (c, _) in deps { let Some(party_deps) = consts.get(party) else { todo!("missing party dep for {party}"); }; @@ -52,7 +52,8 @@ impl<'a> Evaluator<'a> { todo!("missing value {party}::{c}"); }; if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { - const_sizes.insert(identifier.clone(), *size as usize); + let identifier = format!("{party}::{c}"); + const_sizes.insert(identifier, *size as usize); } } } diff --git a/src/lib.rs b/src/lib.rs index f740ad8..15b684f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,7 +122,7 @@ pub fn compile_with_constants( let main = main.clone(); let mut const_sizes = HashMap::new(); for (party, deps) in program.const_deps.iter() { - for (c, (identifier, _)) in deps { + for (c, _) in deps { let Some(party_deps) = consts.get(party) else { todo!("missing party dep for {party}"); }; @@ -130,7 +130,8 @@ pub fn compile_with_constants( todo!("missing value {party}::{c}"); }; if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { - const_sizes.insert(identifier.clone(), *size as usize); + let identifier = format!("{party}::{c}"); + const_sizes.insert(identifier, *size as usize); } } } diff --git a/tests/compile.rs b/tests/compile.rs index 91069fa..57f8013 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -1947,3 +1947,39 @@ pub fn main(x: u16) -> u16 { ); Ok(()) } + +#[test] +fn compile_const_aggregated_min() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = min(PARTY_0::MY_CONST, PARTY_1::MY_CONST); +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +"; + let consts = HashMap::from_iter(vec![ + ( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(3, UnsignedNumType::Usize), + )]), + ), + ( + "PARTY_1".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + ), + ]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} From fc6e179238f020a1001ae53bc01d220a8270b571 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 22 May 2024 16:21:54 +0100 Subject: [PATCH 07/22] Clean up code for consts --- language_tour.md | 34 ++++++++++++++++++++++---- src/ast.rs | 7 +----- src/check.rs | 42 ++++++++++++--------------------- src/compile.rs | 21 ++++++++++------- src/eval.rs | 34 ++++---------------------- src/lib.rs | 25 ++++---------------- src/literal.rs | 27 ++++++++------------- src/parse.rs | 22 ++++++++--------- tests/credit_scoring_example.rs | 10 +++++++- tests/smart_cookie_example.rs | 16 +++++++++---- 10 files changed, 108 insertions(+), 130 deletions(-) diff --git a/language_tour.md b/language_tour.md index e3e5896..1300496 100644 --- a/language_tour.md +++ b/language_tour.md @@ -123,7 +123,7 @@ pub fn main(x: i32) -> i32 { } ``` -Garble supports for-each loops as the only looping / recursion construct in the language. For-each loops can only loop over _fixed-size_ arrays. This is by design, as it disallows any form of unbounded recursion and thus enables the Garble compiler to generate fixed circuits consisting only of boolean gates. Garble programs are thus computationally equivalent to [LOOP programs](https://en.wikipedia.org/wiki/LOOP_(programming_language)) and capture the class of _primitive recursive functions_. +Garble supports for-each loops as the only looping / recursion construct in the language. For-each loops can only loop over _fixed-size_ arrays. This is by design, as it disallows any form of unbounded recursion and thus enables the Garble compiler to generate fixed circuits consisting only of boolean gates. Garble programs are thus computationally equivalent to [LOOP programs]() and capture the class of _primitive recursive functions_. ```rust pub fn main(_x: i32) -> i32 { @@ -196,7 +196,7 @@ Panic due to Overflow on line 17:43. Garble will also panic on integer overflows caused by other arithmetic operations (such as subtraction and multiplication), divisions by zero, and out-of-bounds array indexing. -*Circuit logic for panics is always compiled into the final circuit (and includes the line and column number of the code that caused the panic), it is your responsibility to ensure that no sensitive information can be leaked by causing a panic.* +_Circuit logic for panics is always compiled into the final circuit (and includes the line and column number of the code that caused the panic), it is your responsibility to ensure that no sensitive information can be leaked by causing a panic._ ## Collection Types @@ -366,13 +366,37 @@ The patterns are not exhaustive. Missing cases: | } ``` +### Constants + +Garble supports boolean and integer constants, which need to be declared at the top level and must be provided before compilation. This can be helpful for modelling "pseudo-dynamic" collections, i.e. collections whose size is not known during type-checking but will be known before compilation and execution: + +```rust +const MY_CONST: usize = PARTY_0::MY_CONST; + +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +``` + +Garble also supports taking the minimum / maximum of several constants as part of the declaration of a constant, which can be useful to set the size of a collection to the size of the biggest collection provided by different parties: + +```rust +const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST); + +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +``` + ## Mental Model of Garble Programs -Garble programs are boolean *circuits* consisting of a graph of logic gates, not a sequentially executed program of instructions on a von Neumann architecture with main memory and CPU. This has deep consequences for the programming style that leads to efficient Garble programs, with programs that would be efficient in "normal" programming languages resulting in highly inefficient circuits and vice versa. +Garble programs are boolean _circuits_ consisting of a graph of logic gates, not a sequentially executed program of instructions on a von Neumann architecture with main memory and CPU. This has deep consequences for the programming style that leads to efficient Garble programs, with programs that would be efficient in "normal" programming languages resulting in highly inefficient circuits and vice versa. One example has already been mentioned: Copying whole arrays in Garble is essentially free, because arrays (and their elements) are just a collection of output wires from a bunch of boolean logic gates. Duplicating these wires does not increase the complexity of the circuit, because no additional logic gates are required. -Replacing the element at a *constant* index in an array with a new value is equally cheap, because the Garble compiler can just duplicate the output wires of all the other elements and only has to use the wires of the replacement element where previously the old element was being used. In contrast, replacing the element at a *non-constant* index (i.e. an index that depends on a runtime value) is a much more expensive operation in a boolean circuit than it would be on a normal computer, because the Garble compiler has to generate a nested multiplexer circuit. +Replacing the element at a _constant_ index in an array with a new value is equally cheap, because the Garble compiler can just duplicate the output wires of all the other elements and only has to use the wires of the replacement element where previously the old element was being used. In contrast, replacing the element at a _non-constant_ index (i.e. an index that depends on a runtime value) is a much more expensive operation in a boolean circuit than it would be on a normal computer, because the Garble compiler has to generate a nested multiplexer circuit. Here's an additional example: Let's assume that you want to implement an MPC function that on each invocation adds a value into a (fixed-size) collection of values, overwriting previous values if the buffer is full. In most languages, this could be easily done using a ring buffer and the same is possible in Garble: @@ -402,4 +426,4 @@ The difference in circuit size is staggering: While the first version (with `i` Such an example might be a bit contrived, since it is possible to infer the inputs of both parties (except for the element that is dropped from the array) from the output of the above function, defeating the purpose of MPC, which is to keep each party's input private. But it does highlight how unintuitive the computational model of pure boolean circuits can be from the perspective of a load-and-store architecture with main memory and CPU. -It can be helpful to think of Garble programs as being executed on a computer with infinite memory, free copying and no garbage collection: Nothing ever goes out of scope, it is therefore trivial to reuse old values. But any form of branching or looping needs to be compiled into a circuit where each possible branch or loop invocation is "unrolled" and requires its own dedicated logic gates. In normal programming languages, looping a few additional times does not increase the program size, but in Garble programs additional gates are necessary. The size of Garble programs therefore reflects the *worst case* algorithm performance: While normal programming languages can return early and will often require much less time in the best or average case than in the worst case, the evaluation of Garble programs will always take constant time, because the full circuit must always be evaluated. +It can be helpful to think of Garble programs as being executed on a computer with infinite memory, free copying and no garbage collection: Nothing ever goes out of scope, it is therefore trivial to reuse old values. But any form of branching or looping needs to be compiled into a circuit where each possible branch or loop invocation is "unrolled" and requires its own dedicated logic gates. In normal programming languages, looping a few additional times does not increase the program size, but in Garble programs additional gates are necessary. The size of Garble programs therefore reflects the _worst case_ algorithm performance: While normal programming languages can return early and will often require much less time in the best or average case than in the worst case, the evaluation of Garble programs will always take constant time, because the full circuit must always be evaluated. diff --git a/src/ast.rs b/src/ast.rs index f14631e..65ef635 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -342,7 +342,7 @@ pub enum ExprEnum { /// Struct literal with the specified fields. StructLiteral(String, Vec<(String, Expr)>), /// Enum literal of the specified variant, possibly with fields. - EnumLiteral(String, Box>), + EnumLiteral(String, String, VariantExprEnum), /// Matching the specified expression with a list of clauses (pattern + expression). Match(Box>, Vec<(Pattern, Expr)>), /// Application of a unary operator. @@ -361,11 +361,6 @@ pub enum ExprEnum { Range((u64, UnsignedNumType), (u64, UnsignedNumType)), } -/// A variant literal, used by [`ExprEnum::EnumLiteral`], with its location in the source code. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct VariantExpr(pub String, pub VariantExprEnum, pub MetaInfo); - /// The different kinds of variant literals. #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/src/check.rs b/src/check.rs index 52e6eaf..2c682ce 100644 --- a/src/check.rs +++ b/src/check.rs @@ -6,8 +6,7 @@ use std::collections::{HashMap, HashSet}; use crate::{ ast::{ self, ConstDef, ConstExpr, EnumDef, Expr, ExprEnum, Mutability, Op, ParamDef, Pattern, - PatternEnum, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr, - VariantExprEnum, + PatternEnum, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExprEnum, }, env::Env, token::{MetaInfo, SignedNumType, UnsignedNumType}, @@ -376,7 +375,7 @@ impl UntypedProgram { } } ConstExpr::NumUnsigned(_, ty) => { - let ty = Type::Unsigned(ty.clone()); + let ty = Type::Unsigned(*ty); if const_def.ty != ty { let e = TypeErrorEnum::UnexpectedType { expected: const_def.ty.clone(), @@ -386,7 +385,7 @@ impl UntypedProgram { } } ConstExpr::NumSigned(_, ty) => { - let ty = Type::Signed(ty.clone()); + let ty = Type::Signed(*ty); if const_def.ty != ty { let e = TypeErrorEnum::UnexpectedType { expected: const_def.ty.clone(), @@ -408,7 +407,7 @@ impl UntypedProgram { } } } - check_const_expr(&const_def.value, &const_def, &mut errors, &mut const_deps); + check_const_expr(&const_def.value, const_def, &mut errors, &mut const_deps); const_defs.insert(const_name.clone(), const_def.clone()); const_types.insert(const_name.clone(), const_def.ty.clone()); } @@ -1076,24 +1075,18 @@ impl UntypedExpr { ty, ) } - ExprEnum::EnumLiteral(identifier, variant) => { - let VariantExpr(variant_name, variant, variant_meta) = variant.as_ref(); + ExprEnum::EnumLiteral(identifier, variant_name, variant) => { if let Some(enum_def) = defs.enums.get(identifier.as_str()) { - let meta = *variant_meta; if let Some(types) = enum_def.get(variant_name.as_str()) { match (variant, types) { - (VariantExprEnum::Unit, None) => { - let variant = VariantExpr( - variant_name.to_string(), + (VariantExprEnum::Unit, None) => ( + ExprEnum::EnumLiteral( + identifier.clone(), + variant_name.clone(), VariantExprEnum::Unit, - meta, - ); - let ty = Type::Enum(identifier.clone()); - ( - ExprEnum::EnumLiteral(identifier.clone(), Box::new(variant)), - ty, - ) - } + ), + Type::Enum(identifier.clone()), + ), (VariantExprEnum::Tuple(values), Some(types)) => { if values.len() != types.len() { let e = TypeErrorEnum::UnexpectedEnumVariantArity { @@ -1115,19 +1108,14 @@ impl UntypedExpr { Err(e) => errors.extend(e), } } - let variant = VariantExpr( - variant_name.to_string(), - VariantExprEnum::Tuple(exprs), - meta, - ); - let ty = Type::Enum(identifier.clone()); if errors.is_empty() { ( ExprEnum::EnumLiteral( identifier.clone(), - Box::new(variant), + variant_name.clone(), + VariantExprEnum::Tuple(exprs), ), - ty, + Type::Enum(identifier.clone()), ) } else { return Err(errors); diff --git a/src/compile.rs b/src/compile.rs index c76d2ce..2ddb009 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -8,7 +8,7 @@ use std::{ use crate::{ ast::{ ConstExpr, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef, Type, UnaryOp, - VariantExpr, VariantExprEnum, + VariantExprEnum, }, circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, PanicResult, USIZE_BITS}, env::Env, @@ -24,6 +24,8 @@ pub enum CompilerError { FnNotFound(String), /// The provided constant was not of the required type. InvalidLiteralType(Literal, Type), + /// The constant was declared in the program but not provided during compilation. + MissingConstant(String, String), } impl std::fmt::Display for CompilerError { @@ -35,6 +37,9 @@ impl std::fmt::Display for CompilerError { CompilerError::InvalidLiteralType(literal, ty) => { f.write_fmt(format_args!("The literal is not of type '{ty}': {literal}")) } + CompilerError::MissingConstant(party, identifier) => f.write_fmt(format_args!( + "The constant {party}::{identifier} was declared in the program but never provided" + )), } } } @@ -46,6 +51,7 @@ impl TypedProgram { /// incompatible types are found that should have been caught by the type-checker. pub fn compile(&self, fn_name: &str) -> Result<(Circuit, &TypedFnDef), CompilerError> { self.compile_with_constants(fn_name, HashMap::new()) + .map(|(c, f, _)| (c, f)) } /// Compiles the (type-checked) program with provided constants, producing a circuit of gates. @@ -56,7 +62,7 @@ impl TypedProgram { &self, fn_name: &str, consts: HashMap>, - ) -> Result<(Circuit, &TypedFnDef), CompilerError> { + ) -> Result<(Circuit, &TypedFnDef, HashMap), CompilerError> { let mut env = Env::new(); let mut const_sizes = HashMap::new(); let mut consts_unsigned = HashMap::new(); @@ -64,10 +70,10 @@ impl TypedProgram { for (party, deps) in self.const_deps.iter() { for (c, ty) in deps { let Some(party_deps) = consts.get(party) else { - todo!("missing party dep for {party}"); + return Err(CompilerError::MissingConstant(party.clone(), c.clone())); }; let Some(literal) = party_deps.get(c) else { - todo!("missing value {party}::{c}"); + return Err(CompilerError::MissingConstant(party.clone(), c.clone())); }; let identifier = format!("{party}::{c}"); match literal { @@ -174,7 +180,7 @@ impl TypedProgram { const_sizes.insert(const_name.clone(), n as usize); } } - let mut circuit = CircuitBuilder::new(input_gates, const_sizes); + let mut circuit = CircuitBuilder::new(input_gates, const_sizes.clone()); for (const_name, const_def) in self.const_defs.iter() { match &const_def.value { ConstExpr::True => env.let_in_current_scope(const_name.clone(), vec![1]), @@ -245,7 +251,7 @@ impl TypedProgram { } } let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); - Ok((circuit.build(output_gates), fn_def)) + Ok((circuit.build(output_gates), fn_def, const_sizes)) } } @@ -848,12 +854,11 @@ impl TypedExpr { } array } - ExprEnum::EnumLiteral(identifier, variant) => { + ExprEnum::EnumLiteral(identifier, variant_name, variant) => { let enum_def = prg.enum_defs.get(identifier).unwrap(); let tag_size = enum_tag_size(enum_def); let max_size = enum_max_size(enum_def, prg, circuit.const_sizes()); let mut wires = vec![0; max_size]; - let VariantExpr(variant_name, variant, _) = variant.as_ref(); let tag_number = enum_tag_number(enum_def, variant_name); for (i, wire) in wires.iter_mut().enumerate().take(tag_size) { *wire = (tag_number >> (tag_size - i - 1)) & 1; diff --git a/src/eval.rs b/src/eval.rs index 9a1b43c..588b37f 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -20,43 +20,17 @@ pub struct Evaluator<'a> { /// The compiled circuit. pub circuit: &'a Circuit, inputs: Vec>, - const_sizes: HashMap, + const_sizes: &'a HashMap, } impl<'a> Evaluator<'a> { /// Scans, parses, type-checks and then compiles a program for later evaluation. - pub fn new(program: &'a TypedProgram, main_fn: &'a TypedFnDef, circuit: &'a Circuit) -> Self { - Self { - program, - main_fn, - circuit, - inputs: vec![], - const_sizes: HashMap::new(), - } - } - - /// Scans, parses, type-checks and then compiles a program for later evaluation. - pub fn new_with_constants( + pub fn new( program: &'a TypedProgram, main_fn: &'a TypedFnDef, circuit: &'a Circuit, - consts: &HashMap>, + const_sizes: &'a HashMap, ) -> Self { - let mut const_sizes = HashMap::new(); - for (party, deps) in program.const_deps.iter() { - for (c, _) in deps { - let Some(party_deps) = consts.get(party) else { - todo!("missing party dep for {party}"); - }; - let Some(literal) = party_deps.get(c) else { - todo!("missing value {party}::{c}"); - }; - if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { - let identifier = format!("{party}::{c}"); - const_sizes.insert(identifier, *size as usize); - } - } - } Self { program, main_fn, @@ -222,7 +196,7 @@ impl<'a> Evaluator<'a> { self.inputs .last_mut() .unwrap() - .extend(literal.as_bits(self.program, &self.const_sizes)); + .extend(literal.as_bits(self.program, self.const_sizes)); Ok(()) } else { Err(EvalError::InvalidLiteralType(literal, ty.clone())) diff --git a/src/lib.rs b/src/lib.rs index 15b684f..c4fc969 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,7 @@ #![deny(missing_docs)] #![deny(rustdoc::broken_intra_doc_links)] -use ast::{Expr, FnDef, Pattern, Program, Stmt, Type, VariantExpr}; +use ast::{Expr, FnDef, Pattern, Program, Stmt, Type}; use check::TypeError; use circuit::Circuit; use compile::CompilerError; @@ -53,7 +53,7 @@ use std::{ collections::HashMap, fmt::{Display, Write as _}, }; -use token::{MetaInfo, UnsignedNumType}; +use token::MetaInfo; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -79,8 +79,6 @@ pub type TypedStmt = Stmt; pub type TypedExpr = Expr; /// [`crate::ast::Pattern`] after typechecking. pub type TypedPattern = Pattern; -/// [`crate::ast::VariantExpr`] after typechecking. -pub type TypedVariantExpr = VariantExpr; pub mod ast; pub mod check; @@ -118,23 +116,8 @@ pub fn compile_with_constants( consts: HashMap>, ) -> Result { let program = check(prg)?; - let (circuit, main) = program.compile_with_constants("main", consts.clone())?; + let (circuit, main, const_sizes) = program.compile_with_constants("main", consts.clone())?; let main = main.clone(); - let mut const_sizes = HashMap::new(); - for (party, deps) in program.const_deps.iter() { - for (c, _) in deps { - let Some(party_deps) = consts.get(party) else { - todo!("missing party dep for {party}"); - }; - let Some(literal) = party_deps.get(c) else { - todo!("missing value {party}::{c}"); - }; - if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { - let identifier = format!("{party}::{c}"); - const_sizes.insert(identifier, *size as usize); - } - } - } Ok(GarbleProgram { program, main, @@ -166,7 +149,7 @@ pub struct GarbleArgument<'a>(Literal, &'a TypedProgram, &'a HashMap Evaluator<'_> { - Evaluator::new_with_constants(&self.program, &self.main, &self.circuit, &self.consts) + Evaluator::new(&self.program, &self.main, &self.circuit, &self.const_sizes) } /// Type-checks and uses the literal as the circuit input argument with the given index. diff --git a/src/literal.rs b/src/literal.rs index 6211e88..c65cdfb 100644 --- a/src/literal.rs +++ b/src/literal.rs @@ -10,7 +10,7 @@ use std::{ use serde::{Deserialize, Serialize}; use crate::{ - ast::{Expr, ExprEnum, Type, Variant, VariantExpr, VariantExprEnum}, + ast::{Expr, ExprEnum, Type, Variant, VariantExprEnum}, check::{check_type, Defs, TopLevelTypes, TypeError, TypedFns}, circuit::EvalPanic, compile::{enum_max_size, enum_tag_number, enum_tag_size, signed_to_bits, unsigned_to_bits}, @@ -18,7 +18,7 @@ use crate::{ eval::EvalError, scan::scan, token::{SignedNumType, UnsignedNumType}, - CompileTimeError, TypedExpr, TypedProgram, TypedVariantExpr, + CompileTimeError, TypedExpr, TypedProgram, }; /// A subset of [`crate::ast::Expr`] that is used as input / output by an @@ -545,9 +545,14 @@ impl TypedExpr { .map(|(name, value)| (name, value.into_literal())) .collect(), ), - ExprEnum::EnumLiteral(name, variant) => { - let VariantExpr(variant_name, _, _) = &variant.as_ref(); - Literal::Enum(name, variant_name.clone(), variant.into_literal()) + ExprEnum::EnumLiteral(name, variant_name, variant) => { + let variant = match variant { + VariantExprEnum::Unit => VariantLiteral::Unit, + VariantExprEnum::Tuple(fields) => VariantLiteral::Tuple( + fields.into_iter().map(|f| f.into_literal()).collect(), + ), + }; + Literal::Enum(name, variant_name.clone(), variant) } ExprEnum::Range(min, max) => Literal::Range(min, max), _ => unreachable!("This should result in a literal parse error instead"), @@ -555,18 +560,6 @@ impl TypedExpr { } } -impl TypedVariantExpr { - fn into_literal(self) -> VariantLiteral { - let VariantExpr(_, variant, _) = self; - match variant { - VariantExprEnum::Unit => VariantLiteral::Unit, - VariantExprEnum::Tuple(fields) => { - VariantLiteral::Tuple(fields.into_iter().map(|f| f.into_literal()).collect()) - } - } - } -} - impl From for Literal { fn from(b: bool) -> Self { if b { diff --git a/src/parse.rs b/src/parse.rs index ee4f81b..ab0ffca 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, iter::Peekable, vec::IntoIter}; use crate::{ ast::{ ConstDef, ConstExpr, EnumDef, Expr, ExprEnum, FnDef, Op, ParamDef, Pattern, PatternEnum, - Program, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr, VariantExprEnum, + Program, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExprEnum, }, scan::Tokens, token::{MetaInfo, SignedNumType, Token, TokenEnum, UnsignedNumType}, @@ -202,9 +202,7 @@ impl Parser { ExprEnum::False => Ok(ConstExpr::False), ExprEnum::NumUnsigned(n, ty) => Ok(ConstExpr::NumUnsigned(n, ty)), ExprEnum::NumSigned(n, ty) => Ok(ConstExpr::NumSigned(n, ty)), - ExprEnum::EnumLiteral(party, variant) => { - // TODO: check that this is a unit variant - let VariantExpr(identifier, _, _) = *variant; + ExprEnum::EnumLiteral(party, identifier, VariantExprEnum::Unit) => { Ok(ConstExpr::ExternalValue { party, identifier }) } ExprEnum::FnCall(f, args) if f == "max" || f == "min" => { @@ -242,7 +240,7 @@ impl Parser { for (e, meta) in errs { self.push_error(e, meta); } - return Err(()); + Err(()) } } } @@ -1189,7 +1187,7 @@ impl Parser { "false" => Expr::untyped(ExprEnum::False, meta), _ => { if self.next_matches(&TokenEnum::DoubleColon).is_some() { - let (variant_name, variant_meta) = self.expect_identifier()?; + let (variant_name, _) = self.expect_identifier()?; let variant = if self.next_matches(&TokenEnum::LeftParen).is_some() { let mut fields = vec![]; if !self.peek(&TokenEnum::RightParen) { @@ -1211,13 +1209,15 @@ impl Parser { fields.push(child); } } - let end = self.expect(&TokenEnum::RightParen)?; - let variant_meta = join_meta(variant_meta, end); - VariantExpr(variant_name, VariantExprEnum::Tuple(fields), variant_meta) + self.expect(&TokenEnum::RightParen)?; + VariantExprEnum::Tuple(fields) } else { - VariantExpr(variant_name, VariantExprEnum::Unit, variant_meta) + VariantExprEnum::Unit }; - Expr::untyped(ExprEnum::EnumLiteral(identifier, Box::new(variant)), meta) + Expr::untyped( + ExprEnum::EnumLiteral(identifier, variant_name, variant), + meta, + ) } else if self.next_matches(&TokenEnum::LeftBrace).is_some() && self.struct_literals_allowed { diff --git a/tests/credit_scoring_example.rs b/tests/credit_scoring_example.rs index 493fa50..ac0f940 100644 --- a/tests/credit_scoring_example.rs +++ b/tests/credit_scoring_example.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use garble_lang::{ast::Type, check, circuit::Circuit, eval::Evaluator, literal::Literal, Error}; #[test] @@ -126,7 +128,13 @@ fn credit_scoring_single_run() -> Result<(), Error> { )?; let (compute_score_circuit, compute_score_fn) = typed_prg.compile("compute_score")?; - let mut eval = Evaluator::new(&typed_prg, compute_score_fn, &compute_score_circuit); + let const_sizes = HashMap::new(); + let mut eval = Evaluator::new( + &typed_prg, + compute_score_fn, + &compute_score_circuit, + &const_sizes, + ); eval.set_literal(scoring_algorithm)?; eval.set_literal(user)?; diff --git a/tests/smart_cookie_example.rs b/tests/smart_cookie_example.rs index b274b93..9271a9e 100644 --- a/tests/smart_cookie_example.rs +++ b/tests/smart_cookie_example.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use garble_lang::{ check, circuit::Circuit, @@ -89,7 +91,8 @@ fn smart_cookie_simple_interaction() -> Result<(), Error> { )], ); - let mut init_eval = Evaluator::new(&program, init_fn, &init_circuit); + let const_sizes = HashMap::new(); + let mut init_eval = Evaluator::new(&program, init_fn, &init_circuit, &const_sizes); init_eval .set_literal(website_signing_key.clone()) .map_err(|e| pretty_print(e, smart_cookie))?; @@ -104,8 +107,12 @@ fn smart_cookie_simple_interaction() -> Result<(), Error> { for (i, interest) in interests.iter().enumerate() { println!(" {i}: logging '{interest}'"); - let mut log_interest_eval = - Evaluator::new(&program, log_interest_fn, &log_interest_circuit); + let mut log_interest_eval = Evaluator::new( + &program, + log_interest_fn, + &log_interest_circuit, + &const_sizes, + ); let interest = Literal::Enum( "UserInterest".to_string(), interest.to_string(), @@ -133,7 +140,8 @@ fn smart_cookie_simple_interaction() -> Result<(), Error> { .unwrap(); } - let mut decide_ad_eval = Evaluator::new(&program, decide_ad_fn, &decide_ad_circuit); + let mut decide_ad_eval = + Evaluator::new(&program, decide_ad_fn, &decide_ad_circuit, &const_sizes); decide_ad_eval .set_literal(website_signing_key) .map_err(|e| pretty_print(e, smart_cookie))?; From 9e1c4e1dd3305f67b82aab90cbeb92909c04563e Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 22 May 2024 16:25:58 +0100 Subject: [PATCH 08/22] Use new `Evaluator::new` signature in `main.rs` --- src/main.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index f3c83e2..b3601ed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{fs::File, io::Read, path::PathBuf, process::exit}; +use std::{collections::HashMap, fs::File, io::Read, path::PathBuf, process::exit}; use garble_lang::{check, eval::Evaluator, literal::Literal}; @@ -82,7 +82,8 @@ fn run(file: PathBuf, inputs: Vec, function: String) -> Result<(), std:: arguments.push(input); } - let mut evaluator = Evaluator::new(&program, main_fn, &circuit); + let const_sizes = HashMap::new(); + let mut evaluator = Evaluator::new(&program, main_fn, &circuit, &const_sizes); let main_params = &evaluator.main_fn.params; if main_params.len() != arguments.len() { eprintln!( From 47a299fe3245b1792d169726ea13886bb48eed76 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 22 May 2024 16:29:14 +0100 Subject: [PATCH 09/22] Bump version to 0.3.0 --- Cargo.lock | 2 +- Cargo.toml | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b28ba7c..95cf4d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -123,7 +123,7 @@ dependencies = [ [[package]] name = "garble_lang" -version = "0.2.0" +version = "0.3.0" dependencies = [ "clap", "quickcheck", diff --git a/Cargo.toml b/Cargo.toml index c0af85c..b35d506 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,19 @@ [package] name = "garble_lang" -version = "0.2.0" +version = "0.3.0" edition = "2021" rust-version = "1.60.0" description = "Turing-Incomplete Programming Language for Multi-Party Computation with Garbled Circuits" repository = "https://github.com/sine-fdn/garble/" license = "MIT" categories = ["command-line-utilities", "compilers"] -keywords = ["programming-language", "secure-computation", "garbled-circuits", "circuit-description", "smpc"] +keywords = [ + "programming-language", + "secure-computation", + "garbled-circuits", + "circuit-description", + "smpc", +] [[bin]] name = "garble" From b388e3832969620c63b33f5cec54b54113959e0c Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Mon, 27 May 2024 18:23:56 +0100 Subject: [PATCH 10/22] Expose `const_sizes` as a pub field of `GarbleProgram` --- src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index c4fc969..0f3b0b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -139,7 +139,8 @@ pub struct GarbleProgram { pub circuit: Circuit, /// The constants used for compiling the circuit. pub consts: HashMap>, - const_sizes: HashMap, + /// The values of usize constants used for compiling the circuit. + pub const_sizes: HashMap, } /// An input argument for a Garble program and circuit. From 147582cc9de003abb77854c606c3bfd591cba49c Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 28 May 2024 21:47:20 +0100 Subject: [PATCH 11/22] Implement const array sizes in fn params --- src/compile.rs | 67 +++++++++++++++++++++++++++++------------------- src/eval.rs | 32 +++++++++++++++++++++-- tests/compile.rs | 32 +++++++++++++++++++++++ 3 files changed, 103 insertions(+), 28 deletions(-) diff --git a/src/compile.rs b/src/compile.rs index 2ddb009..e75bb47 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -67,6 +67,7 @@ impl TypedProgram { let mut const_sizes = HashMap::new(); let mut consts_unsigned = HashMap::new(); let mut consts_signed = HashMap::new(); + for (party, deps) in self.const_deps.iter() { for (c, ty) in deps { let Some(party_deps) = consts.get(party) else { @@ -86,38 +87,12 @@ impl TypedProgram { _ => {} } if literal.is_of_type(self, ty) { - let bits = literal - .as_bits(self, &const_sizes) - .iter() - .map(|b| *b as usize) - .collect(); - env.let_in_current_scope(identifier.clone(), bits); if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { const_sizes.insert(identifier, *size as usize); } - } else { - return Err(CompilerError::InvalidLiteralType( - literal.clone(), - ty.clone(), - )); } } } - let mut input_gates = vec![]; - let mut wire = 2; - let Some(fn_def) = self.fn_defs.get(fn_name) else { - return Err(CompilerError::FnNotFound(fn_name.to_string())); - }; - for param in fn_def.params.iter() { - let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes); - let mut wires = Vec::with_capacity(type_size); - for _ in 0..type_size { - wires.push(wire); - wire += 1; - } - input_gates.push(type_size); - env.let_in_current_scope(param.name.clone(), wires); - } fn resolve_const_expr_unsigned( expr: &ConstExpr, consts_unsigned: &HashMap, @@ -180,6 +155,46 @@ impl TypedProgram { const_sizes.insert(const_name.clone(), n as usize); } } + + for (party, deps) in self.const_deps.iter() { + for (c, ty) in deps { + let Some(party_deps) = consts.get(party) else { + return Err(CompilerError::MissingConstant(party.clone(), c.clone())); + }; + let Some(literal) = party_deps.get(c) else { + return Err(CompilerError::MissingConstant(party.clone(), c.clone())); + }; + let identifier = format!("{party}::{c}"); + if literal.is_of_type(self, ty) { + let bits = literal + .as_bits(self, &const_sizes) + .iter() + .map(|b| *b as usize) + .collect(); + env.let_in_current_scope(identifier.clone(), bits); + } else { + return Err(CompilerError::InvalidLiteralType( + literal.clone(), + ty.clone(), + )); + } + } + } + let mut input_gates = vec![]; + let mut wire = 2; + let Some(fn_def) = self.fn_defs.get(fn_name) else { + return Err(CompilerError::FnNotFound(fn_name.to_string())); + }; + for param in fn_def.params.iter() { + let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes); + let mut wires = Vec::with_capacity(type_size); + for _ in 0..type_size { + wires.push(wire); + wire += 1; + } + input_gates.push(type_size); + env.let_in_current_scope(param.name.clone(), wires); + } let mut circuit = CircuitBuilder::new(input_gates, const_sizes.clone()); for (const_name, const_def) in self.const_defs.iter() { match &const_def.value { diff --git a/src/eval.rs b/src/eval.rs index 588b37f..ba2a9aa 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -191,7 +191,8 @@ impl<'a> Evaluator<'a> { pub fn set_literal(&mut self, literal: Literal) -> Result<(), EvalError> { if self.inputs.len() < self.main_fn.params.len() { let ty = &self.main_fn.params[self.inputs.len()].ty; - if literal.is_of_type(self.program, ty) { + let ty = resolve_const_type(&ty, self.const_sizes); + if literal.is_of_type(self.program, &ty) { self.inputs.push(vec![]); self.inputs .last_mut() @@ -210,8 +211,9 @@ impl<'a> Evaluator<'a> { pub fn parse_literal(&mut self, literal: &str) -> Result<(), EvalError> { if self.inputs.len() < self.main_fn.params.len() { let ty = &self.main_fn.params[self.inputs.len()].ty; + let ty = resolve_const_type(&ty, self.const_sizes); let parsed = - Literal::parse(self.program, ty, literal).map_err(EvalError::LiteralParseError)?; + Literal::parse(self.program, &ty, literal).map_err(EvalError::LiteralParseError)?; self.set_literal(parsed)?; Ok(()) } else { @@ -220,6 +222,32 @@ impl<'a> Evaluator<'a> { } } +fn resolve_const_type(ty: &Type, const_sizes: &HashMap) -> Type { + match ty { + Type::Fn(params, ret_ty) => Type::Fn( + params + .iter() + .map(|ty| resolve_const_type(ty, const_sizes)) + .collect(), + Box::new(resolve_const_type(ret_ty, const_sizes)), + ), + Type::Array(elem_ty, size) => { + Type::Array(Box::new(resolve_const_type(elem_ty, const_sizes)), *size) + } + Type::ArrayConst(elem_ty, size) => Type::Array( + Box::new(resolve_const_type(elem_ty, const_sizes)), + *const_sizes.get(size).unwrap(), + ), + Type::Tuple(elems) => Type::Tuple( + elems + .iter() + .map(|ty| resolve_const_type(ty, const_sizes)) + .collect(), + ), + ty => ty.clone(), + } +} + /// The encoded result of a circuit evaluation. #[derive(Debug, Clone)] pub struct EvalOutput<'a> { diff --git a/tests/compile.rs b/tests/compile.rs index 57f8013..405ef25 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -1983,3 +1983,35 @@ pub fn main(x: u16) -> u16 { ); Ok(()) } + +#[test] +fn compile_const_size_in_fn_param() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST); +pub fn main(array: [u16; MY_CONST]) -> u16 { + array[1] +} +"; + let consts = HashMap::from_iter(vec![ + ( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(1, UnsignedNumType::Usize), + )]), + ), + ( + "PARTY_1".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + ), + ]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.parse_literal("[7u16, 8u16]").unwrap(); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!(u16::try_from(output).map_err(|e| pretty_print(e, prg))?, 8); + Ok(()) +} From 51700467ec5ea2828eebf050aedfabbcf7da44f2 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 29 May 2024 10:43:17 +0100 Subject: [PATCH 12/22] Fix clippy warnings --- src/eval.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/eval.rs b/src/eval.rs index ba2a9aa..0dc7ea4 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -191,7 +191,7 @@ impl<'a> Evaluator<'a> { pub fn set_literal(&mut self, literal: Literal) -> Result<(), EvalError> { if self.inputs.len() < self.main_fn.params.len() { let ty = &self.main_fn.params[self.inputs.len()].ty; - let ty = resolve_const_type(&ty, self.const_sizes); + let ty = resolve_const_type(ty, self.const_sizes); if literal.is_of_type(self.program, &ty) { self.inputs.push(vec![]); self.inputs @@ -211,7 +211,7 @@ impl<'a> Evaluator<'a> { pub fn parse_literal(&mut self, literal: &str) -> Result<(), EvalError> { if self.inputs.len() < self.main_fn.params.len() { let ty = &self.main_fn.params[self.inputs.len()].ty; - let ty = resolve_const_type(&ty, self.const_sizes); + let ty = resolve_const_type(ty, self.const_sizes); let parsed = Literal::parse(self.program, &ty, literal).map_err(EvalError::LiteralParseError)?; self.set_literal(parsed)?; From 18bf814dfadc3dbd3da7720c9a0fd172c1659178 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 29 May 2024 11:23:00 +0100 Subject: [PATCH 13/22] Update language_tour.md slightly --- language_tour.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/language_tour.md b/language_tour.md index 1300496..4fba889 100644 --- a/language_tour.md +++ b/language_tour.md @@ -379,7 +379,7 @@ pub fn main(x: u16) -> u16 { } ``` -Garble also supports taking the minimum / maximum of several constants as part of the declaration of a constant, which can be useful to set the size of a collection to the size of the biggest collection provided by different parties: +Garble also supports taking the minimum / maximum of several constants as part of the declaration of a constant, which, for instance, can be useful to set the size of a collection to the size of the biggest collection provided by different parties: ```rust const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST); From d8c210db5602f63f607c113a0bbb8a34dd1a8bf0 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 29 May 2024 16:59:06 +0100 Subject: [PATCH 14/22] Fix for-each loop for arrays with const size --- src/compile.rs | 2 +- tests/compile.rs | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/compile.rs b/src/compile.rs index e75bb47..801c08d 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -360,7 +360,7 @@ impl TypedStmt { } StmtEnum::ForEachLoop(var, array, body) => { let elem_in_bits = match &array.ty { - Type::Array(elem_ty, _size) => { + Type::Array(elem_ty, _) | Type::ArrayConst(elem_ty, _) => { elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()) } _ => panic!("Found a non-array value in an array access expr"), diff --git a/tests/compile.rs b/tests/compile.rs index 405ef25..3bb0b4c 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -2015,3 +2015,39 @@ pub fn main(array: [u16; MY_CONST]) -> u16 { assert_eq!(u16::try_from(output).map_err(|e| pretty_print(e, prg))?, 8); Ok(()) } + +#[test] +fn compile_const_size_for_each_loop() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST); +pub fn main(array: [u16; MY_CONST]) -> u16 { + let mut result = 0u16; + for elem in array { + result = result + elem; + } + result +} +"; + let consts = HashMap::from_iter(vec![ + ( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(1, UnsignedNumType::Usize), + )]), + ), + ( + "PARTY_1".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + ), + ]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.parse_literal("[7u16, 8u16]").unwrap(); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!(u16::try_from(output).map_err(|e| pretty_print(e, prg))?, 15); + Ok(()) +} From 456067178b5ae683901e64223d7e9549751e3af4 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 29 May 2024 17:29:26 +0100 Subject: [PATCH 15/22] Fix `GarbleProgram::literal_arg` for arrays with const size --- src/eval.rs | 2 +- src/lib.rs | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/eval.rs b/src/eval.rs index 0dc7ea4..8c46a74 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -222,7 +222,7 @@ impl<'a> Evaluator<'a> { } } -fn resolve_const_type(ty: &Type, const_sizes: &HashMap) -> Type { +pub(crate) fn resolve_const_type(ty: &Type, const_sizes: &HashMap) -> Type { match ty { Type::Fn(params, ret_ty) => Type::Fn( params diff --git a/src/lib.rs b/src/lib.rs index 0f3b0b8..5a11f8e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,7 +45,7 @@ use ast::{Expr, FnDef, Pattern, Program, Stmt, Type}; use check::TypeError; use circuit::Circuit; use compile::CompilerError; -use eval::{EvalError, Evaluator}; +use eval::{resolve_const_type, EvalError, Evaluator}; use literal::Literal; use parse::ParseError; use scan::{scan, ScanError}; @@ -162,8 +162,9 @@ impl GarbleProgram { let Some(param) = self.main.params.get(arg_index) else { return Err(EvalError::InvalidArgIndex(arg_index)); }; - if !literal.is_of_type(&self.program, ¶m.ty) { - return Err(EvalError::InvalidLiteralType(literal, param.ty.clone())); + let ty = resolve_const_type(¶m.ty, &self.const_sizes); + if !literal.is_of_type(&self.program, &ty) { + return Err(EvalError::InvalidLiteralType(literal, ty)); } Ok(GarbleArgument(literal, &self.program, &self.const_sizes)) } From 54c99cf5efcad3eeeb0486b14d0ee7fc22be041d Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 4 Jun 2024 17:03:52 +0100 Subject: [PATCH 16/22] Improve parser errors for multiple const defs --- src/parse.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parse.rs b/src/parse.rs index ab0ffca..56aff87 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -118,6 +118,7 @@ impl Parser { TokenEnum::KeywordFn, TokenEnum::KeywordStruct, TokenEnum::KeywordEnum, + TokenEnum::KeywordConst, ]; let mut const_defs = HashMap::new(); let mut struct_defs = HashMap::new(); From 4360f8887157fa09ad8df35b32100896d52b5c18 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 4 Jun 2024 17:24:24 +0100 Subject: [PATCH 17/22] Improve type checking errors for const defs --- src/ast.rs | 9 +++++++-- src/check.rs | 23 +++++++++++++---------- src/compile.rs | 48 ++++++++++++++++++++++++++---------------------- src/parse.rs | 25 +++++++++++++++---------- 4 files changed, 61 insertions(+), 44 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 65ef635..2c4d3ce 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -35,10 +35,15 @@ pub struct ConstDef { pub meta: MetaInfo, } -/// A constant value, either a literal or a namespaced symbol. +/// A constant value, either a literal, a namespaced symbol or an aggregate. #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum ConstExpr { +pub struct ConstExpr(pub ConstExprEnum, pub MetaInfo); + +/// The different kinds of constant expressions. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ConstExprEnum { /// Boolean `true`. True, /// Boolean `false`. diff --git a/src/check.rs b/src/check.rs index 2c682ce..cdfea16 100644 --- a/src/check.rs +++ b/src/check.rs @@ -5,8 +5,9 @@ use std::collections::{HashMap, HashSet}; use crate::{ ast::{ - self, ConstDef, ConstExpr, EnumDef, Expr, ExprEnum, Mutability, Op, ParamDef, Pattern, - PatternEnum, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExprEnum, + self, ConstDef, ConstExpr, ConstExprEnum, EnumDef, Expr, ExprEnum, Mutability, Op, + ParamDef, Pattern, PatternEnum, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, + VariantExprEnum, }, env::Env, token::{MetaInfo, SignedNumType, UnsignedNumType}, @@ -364,43 +365,45 @@ impl UntypedProgram { errors: &mut Vec>, const_deps: &mut HashMap>, ) { + let ConstExpr(value, meta) = value; + let meta = *meta; match value { - ConstExpr::True | ConstExpr::False => { + ConstExprEnum::True | ConstExprEnum::False => { if const_def.ty != Type::Bool { let e = TypeErrorEnum::UnexpectedType { expected: const_def.ty.clone(), actual: Type::Bool, }; - errors.extend(vec![Some(TypeError(e, const_def.meta))]); + errors.extend(vec![Some(TypeError(e, meta))]); } } - ConstExpr::NumUnsigned(_, ty) => { + ConstExprEnum::NumUnsigned(_, ty) => { let ty = Type::Unsigned(*ty); if const_def.ty != ty { let e = TypeErrorEnum::UnexpectedType { expected: const_def.ty.clone(), actual: ty, }; - errors.extend(vec![Some(TypeError(e, const_def.meta))]); + errors.extend(vec![Some(TypeError(e, meta))]); } } - ConstExpr::NumSigned(_, ty) => { + ConstExprEnum::NumSigned(_, ty) => { let ty = Type::Signed(*ty); if const_def.ty != ty { let e = TypeErrorEnum::UnexpectedType { expected: const_def.ty.clone(), actual: ty, }; - errors.extend(vec![Some(TypeError(e, const_def.meta))]); + errors.extend(vec![Some(TypeError(e, meta))]); } } - ConstExpr::ExternalValue { party, identifier } => { + ConstExprEnum::ExternalValue { party, identifier } => { const_deps .entry(party.clone()) .or_default() .insert(identifier.clone(), const_def.ty.clone()); } - ConstExpr::Max(args) | ConstExpr::Min(args) => { + ConstExprEnum::Max(args) | ConstExprEnum::Min(args) => { for arg in args { check_const_expr(arg, const_def, errors, const_deps); } diff --git a/src/compile.rs b/src/compile.rs index 801c08d..0fee313 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -7,8 +7,8 @@ use std::{ use crate::{ ast::{ - ConstExpr, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef, Type, UnaryOp, - VariantExprEnum, + ConstExpr, ConstExprEnum, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef, + Type, UnaryOp, VariantExprEnum, }, circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, PanicResult, USIZE_BITS}, env::Env, @@ -94,22 +94,22 @@ impl TypedProgram { } } fn resolve_const_expr_unsigned( - expr: &ConstExpr, + ConstExpr(expr, _): &ConstExpr, consts_unsigned: &HashMap, ) -> u64 { match expr { - ConstExpr::NumUnsigned(n, _) => *n, - ConstExpr::ExternalValue { party, identifier } => *consts_unsigned + ConstExprEnum::NumUnsigned(n, _) => *n, + ConstExprEnum::ExternalValue { party, identifier } => *consts_unsigned .get(&format!("{party}::{identifier}")) .unwrap(), - ConstExpr::Max(args) => { + ConstExprEnum::Max(args) => { let mut result = 0; for arg in args { result = max(result, resolve_const_expr_unsigned(arg, consts_unsigned)); } result } - ConstExpr::Min(args) => { + ConstExprEnum::Min(args) => { let mut result = u64::MAX; for arg in args { result = min(result, resolve_const_expr_unsigned(arg, consts_unsigned)); @@ -120,22 +120,22 @@ impl TypedProgram { } } fn resolve_const_expr_signed( - expr: &ConstExpr, + ConstExpr(expr, _): &ConstExpr, consts_signed: &HashMap, ) -> i64 { match expr { - ConstExpr::NumSigned(n, _) => *n, - ConstExpr::ExternalValue { party, identifier } => *consts_signed + ConstExprEnum::NumSigned(n, _) => *n, + ConstExprEnum::ExternalValue { party, identifier } => *consts_signed .get(&format!("{party}::{identifier}")) .unwrap(), - ConstExpr::Max(args) => { + ConstExprEnum::Max(args) => { let mut result = 0; for arg in args { result = max(result, resolve_const_expr_signed(arg, consts_signed)); } result } - ConstExpr::Min(args) => { + ConstExprEnum::Min(args) => { let mut result = i64::MAX; for arg in args { result = min(result, resolve_const_expr_signed(arg, consts_signed)); @@ -147,7 +147,9 @@ impl TypedProgram { } for (const_name, const_def) in self.const_defs.iter() { if let Type::Unsigned(UnsignedNumType::Usize) = const_def.ty { - if let ConstExpr::ExternalValue { party, identifier } = &const_def.value { + if let ConstExpr(ConstExprEnum::ExternalValue { party, identifier }, _) = + &const_def.value + { let identifier = format!("{party}::{identifier}"); const_sizes.insert(const_name.clone(), *const_sizes.get(&identifier).unwrap()); } @@ -197,10 +199,11 @@ impl TypedProgram { } let mut circuit = CircuitBuilder::new(input_gates, const_sizes.clone()); for (const_name, const_def) in self.const_defs.iter() { - match &const_def.value { - ConstExpr::True => env.let_in_current_scope(const_name.clone(), vec![1]), - ConstExpr::False => env.let_in_current_scope(const_name.clone(), vec![0]), - ConstExpr::NumUnsigned(n, ty) => { + let ConstExpr(expr, _) = &const_def.value; + match expr { + ConstExprEnum::True => env.let_in_current_scope(const_name.clone(), vec![1]), + ConstExprEnum::False => env.let_in_current_scope(const_name.clone(), vec![0]), + ConstExprEnum::NumUnsigned(n, ty) => { let ty = Type::Unsigned(*ty); let mut bits = Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes())); @@ -212,7 +215,7 @@ impl TypedProgram { let bits = bits.into_iter().map(|b| b as usize).collect(); env.let_in_current_scope(const_name.clone(), bits); } - ConstExpr::NumSigned(n, ty) => { + ConstExprEnum::NumSigned(n, ty) => { let ty = Type::Signed(*ty); let mut bits = Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes())); @@ -224,13 +227,14 @@ impl TypedProgram { let bits = bits.into_iter().map(|b| b as usize).collect(); env.let_in_current_scope(const_name.clone(), bits); } - ConstExpr::ExternalValue { party, identifier } => { + ConstExprEnum::ExternalValue { party, identifier } => { let bits = env.get(&format!("{party}::{identifier}")).unwrap(); env.let_in_current_scope(const_name.clone(), bits); } - expr @ (ConstExpr::Max(_) | ConstExpr::Min(_)) => { + ConstExprEnum::Max(_) | ConstExprEnum::Min(_) => { if let Type::Unsigned(_) = const_def.ty { - let result = resolve_const_expr_unsigned(expr, &consts_unsigned); + let result = + resolve_const_expr_unsigned(&const_def.value, &consts_unsigned); let mut bits = Vec::with_capacity( const_def .ty @@ -246,7 +250,7 @@ impl TypedProgram { let bits = bits.into_iter().map(|b| b as usize).collect(); env.let_in_current_scope(const_name.clone(), bits); } else { - let result = resolve_const_expr_signed(expr, &consts_signed); + let result = resolve_const_expr_signed(&const_def.value, &consts_signed); let mut bits = Vec::with_capacity( const_def .ty diff --git a/src/parse.rs b/src/parse.rs index 56aff87..1c0e216 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -4,8 +4,8 @@ use std::{collections::HashMap, iter::Peekable, vec::IntoIter}; use crate::{ ast::{ - ConstDef, ConstExpr, EnumDef, Expr, ExprEnum, FnDef, Op, ParamDef, Pattern, PatternEnum, - Program, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExprEnum, + ConstDef, ConstExpr, ConstExprEnum, EnumDef, Expr, ExprEnum, FnDef, Op, ParamDef, Pattern, + PatternEnum, Program, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExprEnum, }, scan::Tokens, token::{MetaInfo, SignedNumType, Token, TokenEnum, UnsignedNumType}, @@ -199,13 +199,18 @@ impl Parser { expr: UntypedExpr, ) -> Result> { match expr.inner { - ExprEnum::True => Ok(ConstExpr::True), - ExprEnum::False => Ok(ConstExpr::False), - ExprEnum::NumUnsigned(n, ty) => Ok(ConstExpr::NumUnsigned(n, ty)), - ExprEnum::NumSigned(n, ty) => Ok(ConstExpr::NumSigned(n, ty)), - ExprEnum::EnumLiteral(party, identifier, VariantExprEnum::Unit) => { - Ok(ConstExpr::ExternalValue { party, identifier }) + ExprEnum::True => Ok(ConstExpr(ConstExprEnum::True, expr.meta)), + ExprEnum::False => Ok(ConstExpr(ConstExprEnum::False, expr.meta)), + ExprEnum::NumUnsigned(n, ty) => { + Ok(ConstExpr(ConstExprEnum::NumUnsigned(n, ty), expr.meta)) } + ExprEnum::NumSigned(n, ty) => { + Ok(ConstExpr(ConstExprEnum::NumSigned(n, ty), expr.meta)) + } + ExprEnum::EnumLiteral(party, identifier, VariantExprEnum::Unit) => Ok(ConstExpr( + ConstExprEnum::ExternalValue { party, identifier }, + expr.meta, + )), ExprEnum::FnCall(f, args) if f == "max" || f == "min" => { let mut const_exprs = vec![]; let mut arg_errs = vec![]; @@ -223,9 +228,9 @@ impl Parser { return Err(arg_errs); } if f == "max" { - Ok(ConstExpr::Max(const_exprs)) + Ok(ConstExpr(ConstExprEnum::Max(const_exprs), expr.meta)) } else { - Ok(ConstExpr::Min(const_exprs)) + Ok(ConstExpr(ConstExprEnum::Min(const_exprs), expr.meta)) } } _ => Err(vec![(ParseErrorEnum::InvalidConstExpr, expr.meta)]), From af63a61808a33ad872306bc7bbe28ad71c215336 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 4 Jun 2024 17:43:11 +0100 Subject: [PATCH 18/22] Improve compiler errors for const defs --- src/ast.rs | 4 ++-- src/check.rs | 6 +++--- src/compile.rs | 28 ++++++++++++++++++++-------- src/lib.rs | 18 +++++++++--------- src/parse.rs | 8 ++++++-- 5 files changed, 40 insertions(+), 24 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 2c4d3ce..7bdd385 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,6 +1,6 @@ //! The untyped Abstract Syntax Tree (AST). -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -12,7 +12,7 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Program { /// The external constants that the top level const definitions depend upon. - pub const_deps: HashMap>, + pub const_deps: BTreeMap>, /// Top level const definitions. pub const_defs: HashMap, /// Top level struct type definitions. diff --git a/src/check.rs b/src/check.rs index cdfea16..b8cb5fb 100644 --- a/src/check.rs +++ b/src/check.rs @@ -1,7 +1,7 @@ //! Type-checker, transforming an untyped [`crate::ast::Program`] into a typed //! [`crate::ast::Program`]. -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use crate::{ ast::{ @@ -354,7 +354,7 @@ impl UntypedProgram { struct_names, enum_names, }; - let mut const_deps: HashMap> = HashMap::new(); + let mut const_deps: BTreeMap> = BTreeMap::new(); let mut const_types = HashMap::with_capacity(self.const_defs.len()); let mut const_defs = HashMap::with_capacity(self.const_defs.len()); { @@ -363,7 +363,7 @@ impl UntypedProgram { value: &ConstExpr, const_def: &ConstDef, errors: &mut Vec>, - const_deps: &mut HashMap>, + const_deps: &mut BTreeMap>, ) { let ConstExpr(value, meta) = value; let meta = *meta; diff --git a/src/compile.rs b/src/compile.rs index 0fee313..f207c59 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -44,12 +44,14 @@ impl std::fmt::Display for CompilerError { } } +type CompiledProgram<'a> = (Circuit, &'a TypedFnDef, HashMap); + impl TypedProgram { /// Compiles the (type-checked) program, producing a circuit of gates. /// /// Assumes that the input program has been correctly type-checked and **panics** if /// incompatible types are found that should have been caught by the type-checker. - pub fn compile(&self, fn_name: &str) -> Result<(Circuit, &TypedFnDef), CompilerError> { + pub fn compile(&self, fn_name: &str) -> Result<(Circuit, &TypedFnDef), Vec> { self.compile_with_constants(fn_name, HashMap::new()) .map(|(c, f, _)| (c, f)) } @@ -62,19 +64,22 @@ impl TypedProgram { &self, fn_name: &str, consts: HashMap>, - ) -> Result<(Circuit, &TypedFnDef, HashMap), CompilerError> { + ) -> Result> { let mut env = Env::new(); let mut const_sizes = HashMap::new(); let mut consts_unsigned = HashMap::new(); let mut consts_signed = HashMap::new(); + let mut errs = vec![]; for (party, deps) in self.const_deps.iter() { for (c, ty) in deps { let Some(party_deps) = consts.get(party) else { - return Err(CompilerError::MissingConstant(party.clone(), c.clone())); + errs.push(CompilerError::MissingConstant(party.clone(), c.clone())); + continue; }; let Some(literal) = party_deps.get(c) else { - return Err(CompilerError::MissingConstant(party.clone(), c.clone())); + errs.push(CompilerError::MissingConstant(party.clone(), c.clone())); + continue; }; let identifier = format!("{party}::{c}"); match literal { @@ -93,6 +98,9 @@ impl TypedProgram { } } } + if !errs.is_empty() { + return Err(errs); + } fn resolve_const_expr_unsigned( ConstExpr(expr, _): &ConstExpr, consts_unsigned: &HashMap, @@ -158,13 +166,14 @@ impl TypedProgram { } } + let mut errs = vec![]; for (party, deps) in self.const_deps.iter() { for (c, ty) in deps { let Some(party_deps) = consts.get(party) else { - return Err(CompilerError::MissingConstant(party.clone(), c.clone())); + continue; }; let Some(literal) = party_deps.get(c) else { - return Err(CompilerError::MissingConstant(party.clone(), c.clone())); + continue; }; let identifier = format!("{party}::{c}"); if literal.is_of_type(self, ty) { @@ -175,17 +184,20 @@ impl TypedProgram { .collect(); env.let_in_current_scope(identifier.clone(), bits); } else { - return Err(CompilerError::InvalidLiteralType( + errs.push(CompilerError::InvalidLiteralType( literal.clone(), ty.clone(), )); } } } + if !errs.is_empty() { + return Err(errs); + } let mut input_gates = vec![]; let mut wire = 2; let Some(fn_def) = self.fn_defs.get(fn_name) else { - return Err(CompilerError::FnNotFound(fn_name.to_string())); + return Err(vec![CompilerError::FnNotFound(fn_name.to_string())]); }; for param in fn_def.params.iter() { let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes); diff --git a/src/lib.rs b/src/lib.rs index 5a11f8e..e4bcb09 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -217,7 +217,7 @@ pub enum CompileTimeError { /// Errors originating in the type-checking phase. TypeError(Vec), /// Errors originating in the compilation phase. - CompilerError(CompilerError), + CompilerError(Vec), } /// A generic error that combines compile-time and run-time errors. @@ -249,8 +249,8 @@ impl From> for CompileTimeError { } } -impl From for CompileTimeError { - fn from(e: CompilerError) -> Self { +impl From> for CompileTimeError { + fn from(e: Vec) -> Self { Self::CompilerError(e) } } @@ -323,12 +323,12 @@ impl CompileTimeError { errs_for_display.push(("Type error", format!("{e}"), *meta)); } } - CompileTimeError::CompilerError(e) => { - let meta = MetaInfo { - start: (0, 0), - end: (0, 0), - }; - errs_for_display.push(("Compiler error", format!("{e}"), meta)) + CompileTimeError::CompilerError(errs) => { + let mut pretty = String::new(); + for e in errs { + pretty += &format!("\nCompiler error: {e}"); + } + return pretty; } } let mut msg = "".to_string(); diff --git a/src/parse.rs b/src/parse.rs index 1c0e216..5ef14aa 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1,6 +1,10 @@ //! Parses a stream of [`crate::scan::Tokens`] into an untyped [`crate::ast::Program`]. -use std::{collections::HashMap, iter::Peekable, vec::IntoIter}; +use std::{ + collections::{BTreeMap, HashMap}, + iter::Peekable, + vec::IntoIter, +}; use crate::{ ast::{ @@ -171,7 +175,7 @@ impl Parser { } if self.errors.is_empty() { return Ok(Program { - const_deps: HashMap::new(), + const_deps: BTreeMap::new(), const_defs, struct_defs, enum_defs, From add7d23288e1190c5cd7053551ff67430beaf6e0 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Tue, 4 Jun 2024 20:21:21 +0100 Subject: [PATCH 19/22] Fix main.rs --- src/main.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index b3601ed..1b0ae3a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,8 +60,10 @@ fn run(file: PathBuf, inputs: Vec, function: String) -> Result<(), std:: eprintln!("{}", e.prettify(&prg)); exit(65); }); - let (circuit, main_fn) = program.compile(&function).unwrap_or_else(|e| { - eprintln!("{e}"); + let (circuit, main_fn) = program.compile(&function).unwrap_or_else(|errs| { + for e in errs { + eprintln!("{e}"); + } exit(65); }); From e01cdefbb4196ab1d606ea55c8e6f10c81298d11 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 5 Jun 2024 11:00:10 +0100 Subject: [PATCH 20/22] Improve error messages for missing constants --- src/ast.rs | 2 +- src/check.rs | 6 +++--- src/compile.rs | 22 +++++++++++++++------- src/lib.rs | 39 ++++++++++++++++++++++++--------------- 4 files changed, 43 insertions(+), 26 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 7bdd385..a35c991 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -12,7 +12,7 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Program { /// The external constants that the top level const definitions depend upon. - pub const_deps: BTreeMap>, + pub const_deps: BTreeMap>, /// Top level const definitions. pub const_defs: HashMap, /// Top level struct type definitions. diff --git a/src/check.rs b/src/check.rs index b8cb5fb..35ae966 100644 --- a/src/check.rs +++ b/src/check.rs @@ -354,7 +354,7 @@ impl UntypedProgram { struct_names, enum_names, }; - let mut const_deps: BTreeMap> = BTreeMap::new(); + let mut const_deps: BTreeMap> = BTreeMap::new(); let mut const_types = HashMap::with_capacity(self.const_defs.len()); let mut const_defs = HashMap::with_capacity(self.const_defs.len()); { @@ -363,7 +363,7 @@ impl UntypedProgram { value: &ConstExpr, const_def: &ConstDef, errors: &mut Vec>, - const_deps: &mut BTreeMap>, + const_deps: &mut BTreeMap>, ) { let ConstExpr(value, meta) = value; let meta = *meta; @@ -401,7 +401,7 @@ impl UntypedProgram { const_deps .entry(party.clone()) .or_default() - .insert(identifier.clone(), const_def.ty.clone()); + .insert(identifier.clone(), (const_def.ty.clone(), meta)); } ConstExprEnum::Max(args) | ConstExprEnum::Min(args) => { for arg in args { diff --git a/src/compile.rs b/src/compile.rs index f207c59..9f3ea7f 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -13,7 +13,7 @@ use crate::{ circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, PanicResult, USIZE_BITS}, env::Env, literal::Literal, - token::{SignedNumType, UnsignedNumType}, + token::{MetaInfo, SignedNumType, UnsignedNumType}, TypedExpr, TypedFnDef, TypedPattern, TypedProgram, TypedStmt, }; @@ -25,7 +25,7 @@ pub enum CompilerError { /// The provided constant was not of the required type. InvalidLiteralType(Literal, Type), /// The constant was declared in the program but not provided during compilation. - MissingConstant(String, String), + MissingConstant(String, String, MetaInfo), } impl std::fmt::Display for CompilerError { @@ -37,7 +37,7 @@ impl std::fmt::Display for CompilerError { CompilerError::InvalidLiteralType(literal, ty) => { f.write_fmt(format_args!("The literal is not of type '{ty}': {literal}")) } - CompilerError::MissingConstant(party, identifier) => f.write_fmt(format_args!( + CompilerError::MissingConstant(party, identifier, _) => f.write_fmt(format_args!( "The constant {party}::{identifier} was declared in the program but never provided" )), } @@ -72,13 +72,21 @@ impl TypedProgram { let mut errs = vec![]; for (party, deps) in self.const_deps.iter() { - for (c, ty) in deps { + for (c, (ty, meta)) in deps { let Some(party_deps) = consts.get(party) else { - errs.push(CompilerError::MissingConstant(party.clone(), c.clone())); + errs.push(CompilerError::MissingConstant( + party.clone(), + c.clone(), + *meta, + )); continue; }; let Some(literal) = party_deps.get(c) else { - errs.push(CompilerError::MissingConstant(party.clone(), c.clone())); + errs.push(CompilerError::MissingConstant( + party.clone(), + c.clone(), + *meta, + )); continue; }; let identifier = format!("{party}::{c}"); @@ -168,7 +176,7 @@ impl TypedProgram { let mut errs = vec![]; for (party, deps) in self.const_deps.iter() { - for (c, ty) in deps { + for (c, (ty, _)) in deps { let Some(party_deps) = consts.get(party) else { continue; }; diff --git a/src/lib.rs b/src/lib.rs index e4bcb09..bf49175 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -310,39 +310,48 @@ impl CompileTimeError { match self { CompileTimeError::ScanErrors(errs) => { for ScanError(e, meta) in errs { - errs_for_display.push(("Scan error", format!("{e}"), *meta)); + errs_for_display.push(("Scan error", format!("{e}"), Some(*meta))); } } CompileTimeError::ParseError(errs) => { for ParseError(e, meta) in errs { - errs_for_display.push(("Parse error", format!("{e}"), *meta)); + errs_for_display.push(("Parse error", format!("{e}"), Some(*meta))); } } CompileTimeError::TypeError(errs) => { for TypeError(e, meta) in errs { - errs_for_display.push(("Type error", format!("{e}"), *meta)); + errs_for_display.push(("Type error", format!("{e}"), Some(*meta))); } } CompileTimeError::CompilerError(errs) => { - let mut pretty = String::new(); for e in errs { - pretty += &format!("\nCompiler error: {e}"); + match e { + CompilerError::MissingConstant(_, _, meta) => { + errs_for_display.push(("Compiler error", format!("{e}"), Some(*meta))) + } + e => errs_for_display.push(("Compiler error", format!("{e}"), None)), + } } - return pretty; } } let mut msg = "".to_string(); for (err_type, err, meta) in errs_for_display { - writeln!( - msg, - "\n{} on line {}:{}.", - err_type, - meta.start.0 + 1, - meta.start.1 + 1 - ) - .unwrap(); + if let Some(meta) = meta { + writeln!( + msg, + "\n{} on line {}:{}.", + err_type, + meta.start.0 + 1, + meta.start.1 + 1 + ) + .unwrap(); + } else { + writeln!(msg, "\n{}:", err_type).unwrap(); + } writeln!(msg, "{err}:").unwrap(); - msg += &prettify_meta(prg, meta); + if let Some(meta) = meta { + msg += &prettify_meta(prg, meta); + } } msg } From 9a307fb04db266abeb10f3c2d14c017e4bdd0fe4 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 5 Jun 2024 11:00:38 +0100 Subject: [PATCH 21/22] Fix MetaInfo end bounds for enum literals --- src/parse.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/parse.rs b/src/parse.rs index 5ef14aa..b2173aa 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1197,7 +1197,8 @@ impl Parser { "false" => Expr::untyped(ExprEnum::False, meta), _ => { if self.next_matches(&TokenEnum::DoubleColon).is_some() { - let (variant_name, _) = self.expect_identifier()?; + let (variant_name, variant_meta) = self.expect_identifier()?; + let meta = join_meta(meta, variant_meta); let variant = if self.next_matches(&TokenEnum::LeftParen).is_some() { let mut fields = vec![]; if !self.peek(&TokenEnum::RightParen) { From 58bc1ab2d47e3309e36a02e7970a1c8ca526a545 Mon Sep 17 00:00:00 2001 From: Frederic Kettelhoit Date: Wed, 5 Jun 2024 11:28:42 +0100 Subject: [PATCH 22/22] Order compiler errors by MetaInfo --- src/ast.rs | 4 ++-- src/check.rs | 6 +++--- src/compile.rs | 34 +++++++++++++++++++++++++++++++++- src/literal.rs | 4 ++-- src/parse.rs | 8 ++------ src/token.rs | 4 ++-- 6 files changed, 44 insertions(+), 16 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index a35c991..3762420 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,6 +1,6 @@ //! The untyped Abstract Syntax Tree (AST). -use std::collections::{BTreeMap, HashMap}; +use std::collections::HashMap; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -12,7 +12,7 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Program { /// The external constants that the top level const definitions depend upon. - pub const_deps: BTreeMap>, + pub const_deps: HashMap>, /// Top level const definitions. pub const_defs: HashMap, /// Top level struct type definitions. diff --git a/src/check.rs b/src/check.rs index 35ae966..48f41d3 100644 --- a/src/check.rs +++ b/src/check.rs @@ -1,7 +1,7 @@ //! Type-checker, transforming an untyped [`crate::ast::Program`] into a typed //! [`crate::ast::Program`]. -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{HashMap, HashSet}; use crate::{ ast::{ @@ -354,7 +354,7 @@ impl UntypedProgram { struct_names, enum_names, }; - let mut const_deps: BTreeMap> = BTreeMap::new(); + let mut const_deps: HashMap> = HashMap::new(); let mut const_types = HashMap::with_capacity(self.const_defs.len()); let mut const_defs = HashMap::with_capacity(self.const_defs.len()); { @@ -363,7 +363,7 @@ impl UntypedProgram { value: &ConstExpr, const_def: &ConstDef, errors: &mut Vec>, - const_deps: &mut BTreeMap>, + const_deps: &mut HashMap>, ) { let ConstExpr(value, meta) = value; let meta = *meta; diff --git a/src/compile.rs b/src/compile.rs index 9f3ea7f..e314d11 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -18,7 +18,7 @@ use crate::{ }; /// An error that occurred during compilation. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum CompilerError { /// The specified function could not be compiled, as it was not found in the program. FnNotFound(String), @@ -28,6 +28,36 @@ pub enum CompilerError { MissingConstant(String, String, MetaInfo), } +impl PartialOrd for CompilerError { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for CompilerError { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match (self, other) { + (CompilerError::FnNotFound(fn1), CompilerError::FnNotFound(fn2)) => fn1.cmp(fn2), + (CompilerError::FnNotFound(_), _) => std::cmp::Ordering::Less, + (CompilerError::InvalidLiteralType(_, _), CompilerError::FnNotFound(_)) => { + std::cmp::Ordering::Greater + } + ( + CompilerError::InvalidLiteralType(literal1, _), + CompilerError::InvalidLiteralType(literal2, _), + ) => literal1.cmp(literal2), + (CompilerError::InvalidLiteralType(_, _), CompilerError::MissingConstant(_, _, _)) => { + std::cmp::Ordering::Less + } + ( + CompilerError::MissingConstant(_, _, meta1), + CompilerError::MissingConstant(_, _, meta2), + ) => meta1.cmp(meta2), + (CompilerError::MissingConstant(_, _, _), _) => std::cmp::Ordering::Greater, + } + } +} + impl std::fmt::Display for CompilerError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -107,6 +137,7 @@ impl TypedProgram { } } if !errs.is_empty() { + errs.sort(); return Err(errs); } fn resolve_const_expr_unsigned( @@ -200,6 +231,7 @@ impl TypedProgram { } } if !errs.is_empty() { + errs.sort(); return Err(errs); } let mut input_gates = vec![]; diff --git a/src/literal.rs b/src/literal.rs index c65cdfb..88f6388 100644 --- a/src/literal.rs +++ b/src/literal.rs @@ -23,7 +23,7 @@ use crate::{ /// A subset of [`crate::ast::Expr`] that is used as input / output by an /// [`crate::eval::Evaluator`]. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum Literal { /// Literal `true`. @@ -49,7 +49,7 @@ pub enum Literal { } /// A variant literal (either of unit type or containing fields), used by [`Literal::Enum`]. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum VariantLiteral { /// A unit variant, containing no fields. diff --git a/src/parse.rs b/src/parse.rs index b2173aa..be23571 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1,10 +1,6 @@ //! Parses a stream of [`crate::scan::Tokens`] into an untyped [`crate::ast::Program`]. -use std::{ - collections::{BTreeMap, HashMap}, - iter::Peekable, - vec::IntoIter, -}; +use std::{collections::HashMap, iter::Peekable, vec::IntoIter}; use crate::{ ast::{ @@ -175,7 +171,7 @@ impl Parser { } if self.errors.is_empty() { return Ok(Program { - const_deps: BTreeMap::new(), + const_deps: HashMap::new(), const_defs, struct_defs, enum_defs, diff --git a/src/token.rs b/src/token.rs index f4ceb69..cd757ef 100644 --- a/src/token.rs +++ b/src/token.rs @@ -176,7 +176,7 @@ impl std::fmt::Display for TokenEnum { } /// A suffix indicating the explicit unsigned number type of the literal. -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum UnsignedNumType { /// Unsigned integer type used to index arrays, length depends on the host platform. @@ -217,7 +217,7 @@ impl std::fmt::Display for UnsignedNumType { } /// A suffix indicating the explicit signed number type of the literal. -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum SignedNumType { /// 8-bit signed integer type.