diff --git a/Cargo.lock b/Cargo.lock index b28ba7c..95cf4d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -123,7 +123,7 @@ dependencies = [ [[package]] name = "garble_lang" -version = "0.2.0" +version = "0.3.0" dependencies = [ "clap", "quickcheck", diff --git a/Cargo.toml b/Cargo.toml index c0af85c..b35d506 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,19 @@ [package] name = "garble_lang" -version = "0.2.0" +version = "0.3.0" edition = "2021" rust-version = "1.60.0" description = "Turing-Incomplete Programming Language for Multi-Party Computation with Garbled Circuits" repository = "https://github.com/sine-fdn/garble/" license = "MIT" categories = ["command-line-utilities", "compilers"] -keywords = ["programming-language", "secure-computation", "garbled-circuits", "circuit-description", "smpc"] +keywords = [ + "programming-language", + "secure-computation", + "garbled-circuits", + "circuit-description", + "smpc", +] [[bin]] name = "garble" diff --git a/language_tour.md b/language_tour.md index e3e5896..4fba889 100644 --- a/language_tour.md +++ b/language_tour.md @@ -123,7 +123,7 @@ pub fn main(x: i32) -> i32 { } ``` -Garble supports for-each loops as the only looping / recursion construct in the language. For-each loops can only loop over _fixed-size_ arrays. This is by design, as it disallows any form of unbounded recursion and thus enables the Garble compiler to generate fixed circuits consisting only of boolean gates. Garble programs are thus computationally equivalent to [LOOP programs](https://en.wikipedia.org/wiki/LOOP_(programming_language)) and capture the class of _primitive recursive functions_. +Garble supports for-each loops as the only looping / recursion construct in the language. For-each loops can only loop over _fixed-size_ arrays. This is by design, as it disallows any form of unbounded recursion and thus enables the Garble compiler to generate fixed circuits consisting only of boolean gates. Garble programs are thus computationally equivalent to [LOOP programs]() and capture the class of _primitive recursive functions_. ```rust pub fn main(_x: i32) -> i32 { @@ -196,7 +196,7 @@ Panic due to Overflow on line 17:43. Garble will also panic on integer overflows caused by other arithmetic operations (such as subtraction and multiplication), divisions by zero, and out-of-bounds array indexing. -*Circuit logic for panics is always compiled into the final circuit (and includes the line and column number of the code that caused the panic), it is your responsibility to ensure that no sensitive information can be leaked by causing a panic.* +_Circuit logic for panics is always compiled into the final circuit (and includes the line and column number of the code that caused the panic), it is your responsibility to ensure that no sensitive information can be leaked by causing a panic._ ## Collection Types @@ -366,13 +366,37 @@ The patterns are not exhaustive. Missing cases: | } ``` +### Constants + +Garble supports boolean and integer constants, which need to be declared at the top level and must be provided before compilation. This can be helpful for modelling "pseudo-dynamic" collections, i.e. collections whose size is not known during type-checking but will be known before compilation and execution: + +```rust +const MY_CONST: usize = PARTY_0::MY_CONST; + +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +``` + +Garble also supports taking the minimum / maximum of several constants as part of the declaration of a constant, which, for instance, can be useful to set the size of a collection to the size of the biggest collection provided by different parties: + +```rust +const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST); + +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +``` + ## Mental Model of Garble Programs -Garble programs are boolean *circuits* consisting of a graph of logic gates, not a sequentially executed program of instructions on a von Neumann architecture with main memory and CPU. This has deep consequences for the programming style that leads to efficient Garble programs, with programs that would be efficient in "normal" programming languages resulting in highly inefficient circuits and vice versa. +Garble programs are boolean _circuits_ consisting of a graph of logic gates, not a sequentially executed program of instructions on a von Neumann architecture with main memory and CPU. This has deep consequences for the programming style that leads to efficient Garble programs, with programs that would be efficient in "normal" programming languages resulting in highly inefficient circuits and vice versa. One example has already been mentioned: Copying whole arrays in Garble is essentially free, because arrays (and their elements) are just a collection of output wires from a bunch of boolean logic gates. Duplicating these wires does not increase the complexity of the circuit, because no additional logic gates are required. -Replacing the element at a *constant* index in an array with a new value is equally cheap, because the Garble compiler can just duplicate the output wires of all the other elements and only has to use the wires of the replacement element where previously the old element was being used. In contrast, replacing the element at a *non-constant* index (i.e. an index that depends on a runtime value) is a much more expensive operation in a boolean circuit than it would be on a normal computer, because the Garble compiler has to generate a nested multiplexer circuit. +Replacing the element at a _constant_ index in an array with a new value is equally cheap, because the Garble compiler can just duplicate the output wires of all the other elements and only has to use the wires of the replacement element where previously the old element was being used. In contrast, replacing the element at a _non-constant_ index (i.e. an index that depends on a runtime value) is a much more expensive operation in a boolean circuit than it would be on a normal computer, because the Garble compiler has to generate a nested multiplexer circuit. Here's an additional example: Let's assume that you want to implement an MPC function that on each invocation adds a value into a (fixed-size) collection of values, overwriting previous values if the buffer is full. In most languages, this could be easily done using a ring buffer and the same is possible in Garble: @@ -402,4 +426,4 @@ The difference in circuit size is staggering: While the first version (with `i` Such an example might be a bit contrived, since it is possible to infer the inputs of both parties (except for the element that is dropped from the array) from the output of the above function, defeating the purpose of MPC, which is to keep each party's input private. But it does highlight how unintuitive the computational model of pure boolean circuits can be from the perspective of a load-and-store architecture with main memory and CPU. -It can be helpful to think of Garble programs as being executed on a computer with infinite memory, free copying and no garbage collection: Nothing ever goes out of scope, it is therefore trivial to reuse old values. But any form of branching or looping needs to be compiled into a circuit where each possible branch or loop invocation is "unrolled" and requires its own dedicated logic gates. In normal programming languages, looping a few additional times does not increase the program size, but in Garble programs additional gates are necessary. The size of Garble programs therefore reflects the *worst case* algorithm performance: While normal programming languages can return early and will often require much less time in the best or average case than in the worst case, the evaluation of Garble programs will always take constant time, because the full circuit must always be evaluated. +It can be helpful to think of Garble programs as being executed on a computer with infinite memory, free copying and no garbage collection: Nothing ever goes out of scope, it is therefore trivial to reuse old values. But any form of branching or looping needs to be compiled into a circuit where each possible branch or loop invocation is "unrolled" and requires its own dedicated logic gates. In normal programming languages, looping a few additional times does not increase the program size, but in Garble programs additional gates are necessary. The size of Garble programs therefore reflects the _worst case_ algorithm performance: While normal programming languages can return early and will often require much less time in the best or average case than in the worst case, the evaluation of Garble programs will always take constant time, because the full circuit must always be evaluated. diff --git a/src/ast.rs b/src/ast.rs index 3671098..3762420 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -11,6 +11,10 @@ use crate::token::{MetaInfo, SignedNumType, UnsignedNumType}; #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Program { + /// The external constants that the top level const definitions depend upon. + pub const_deps: HashMap>, + /// Top level const definitions. + pub const_defs: HashMap, /// Top level struct type definitions. pub struct_defs: HashMap, /// Top level enum type definitions. @@ -19,6 +23,48 @@ pub struct Program { pub fn_defs: HashMap>, } +/// A top level const definition. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ConstDef { + /// The type of the constant. + pub ty: Type, + /// The value of the constant. + pub value: ConstExpr, + /// The location in the source code. + pub meta: MetaInfo, +} + +/// A constant value, either a literal, a namespaced symbol or an aggregate. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ConstExpr(pub ConstExprEnum, pub MetaInfo); + +/// The different kinds of constant expressions. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ConstExprEnum { + /// Boolean `true`. + True, + /// Boolean `false`. + False, + /// Unsigned integer. + NumUnsigned(u64, UnsignedNumType), + /// Signed integer. + NumSigned(i64, SignedNumType), + /// An external value supplied before compilation. + ExternalValue { + /// The party providing the value. + party: String, + /// The variable name of the value. + identifier: String, + }, + /// The maximum of several constant expressions. + Max(Vec), + /// The minimum of several constant expressions. + Min(Vec), +} + /// A top level struct type definition. #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -137,6 +183,8 @@ pub enum Type { Fn(Vec, Box), /// Array type of a fixed size, containing elements of the specified type. Array(Box, usize), + /// Array type of a fixed size, with the size specified by a constant. + ArrayConst(Box, String), /// Tuple type containing fields of the specified types. Tuple(Vec), /// A struct or an enum, depending on the top level definitions (used only before typechecking). @@ -173,6 +221,13 @@ impl std::fmt::Display for Type { size.fmt(f)?; f.write_str("]") } + Type::ArrayConst(ty, size) => { + f.write_str("[")?; + ty.fmt(f)?; + f.write_str("; ")?; + size.fmt(f)?; + f.write_str("]") + } Type::Tuple(fields) => { f.write_str("(")?; let mut fields = fields.iter(); @@ -279,6 +334,8 @@ pub enum ExprEnum { ArrayLiteral(Vec>), /// Array "repeat expression", which specifies 1 element, to be repeated a number of times. ArrayRepeatLiteral(Box>, usize), + /// Array "repeat expression", with the size specified by a constant. + ArrayRepeatLiteralConst(Box>, String), /// Access of an array at the specified index, returning its element. ArrayAccess(Box>, Box>), /// Tuple literal containing the specified fields. @@ -290,7 +347,7 @@ pub enum ExprEnum { /// Struct literal with the specified fields. StructLiteral(String, Vec<(String, Expr)>), /// Enum literal of the specified variant, possibly with fields. - EnumLiteral(String, Box>), + EnumLiteral(String, String, VariantExprEnum), /// Matching the specified expression with a list of clauses (pattern + expression). Match(Box>, Vec<(Pattern, Expr)>), /// Application of a unary operator. @@ -309,11 +366,6 @@ pub enum ExprEnum { Range((u64, UnsignedNumType), (u64, UnsignedNumType)), } -/// A variant literal, used by [`ExprEnum::EnumLiteral`], with its location in the source code. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct VariantExpr(pub String, pub VariantExprEnum, pub MetaInfo); - /// The different kinds of variant literals. #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/src/check.rs b/src/check.rs index 817ec46..48f41d3 100644 --- a/src/check.rs +++ b/src/check.rs @@ -5,8 +5,9 @@ use std::collections::{HashMap, HashSet}; use crate::{ ast::{ - self, EnumDef, Expr, ExprEnum, Mutability, Op, ParamDef, Pattern, PatternEnum, Stmt, - StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr, VariantExprEnum, + self, ConstDef, ConstExpr, ConstExprEnum, EnumDef, Expr, ExprEnum, Mutability, Op, + ParamDef, Pattern, PatternEnum, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, + VariantExprEnum, }, env::Env, token::{MetaInfo, SignedNumType, UnsignedNumType}, @@ -50,7 +51,7 @@ pub enum TypeErrorEnum { /// The struct constructor is missing the specified field. MissingStructField(String, String), /// No enum declaration with the specified name exists. - UnknownEnum(String), + UnknownEnum(String, String), /// The enum exists, but no variant declaration with the specified name was found. UnknownEnumVariant(String, String), /// No variable or function with the specified name exists in the current scope. @@ -112,6 +113,8 @@ pub enum TypeErrorEnum { PatternsAreNotExhaustive(Vec), /// The expression cannot be matched upon. TypeDoesNotSupportPatternMatching(Type), + /// The specified identifier is not a constant. + ArraySizeNotConst(String), } impl std::fmt::Display for TypeErrorEnum { @@ -137,8 +140,8 @@ impl std::fmt::Display for TypeErrorEnum { TypeErrorEnum::MissingStructField(struct_name, struct_field) => f.write_fmt( format_args!("Field '{struct_field}' is missing for struct '{struct_name}'"), ), - TypeErrorEnum::UnknownEnum(enum_name) => { - f.write_fmt(format_args!("Unknown enum '{enum_name}'")) + TypeErrorEnum::UnknownEnum(enum_name, enum_variant) => { + f.write_fmt(format_args!("Unknown enum '{enum_name}::{enum_variant}'")) } TypeErrorEnum::UnknownEnumVariant(enum_name, enum_variant) => f.write_fmt( format_args!("Unknown enum variant '{enum_name}::{enum_variant}'"), @@ -222,6 +225,9 @@ impl std::fmt::Display for TypeErrorEnum { TypeErrorEnum::TypeDoesNotSupportPatternMatching(ty) => { f.write_fmt(format_args!("Type {ty} does not support pattern matching")) } + TypeErrorEnum::ArraySizeNotConst(identifier) => { + f.write_fmt(format_args!("Array sizes must be constants, but '{identifier}' is a variable")) + } } } } @@ -251,6 +257,10 @@ impl Type { let elem = elem.as_concrete_type(types)?; Type::Array(Box::new(elem), *size) } + Type::ArrayConst(elem, size) => { + let elem = elem.as_concrete_type(types)?; + Type::ArrayConst(Box::new(elem), size.clone()) + } Type::Tuple(fields) => { let mut concrete_fields = Vec::with_capacity(fields.len()); for field in fields.iter() { @@ -277,6 +287,7 @@ impl Type { /// Static top-level definitions of enums and functions. pub struct Defs<'a> { + consts: HashMap<&'a str, &'a Type>, structs: HashMap<&'a str, (Vec<&'a str>, HashMap<&'a str, Type>)>, enums: HashMap<&'a str, HashMap<&'a str, Option>>>, fns: HashMap<&'a str, &'a UntypedFnDef>, @@ -284,14 +295,19 @@ pub struct Defs<'a> { impl<'a> Defs<'a> { pub(crate) fn new( + const_defs: &'a HashMap, struct_defs: &'a HashMap, enum_defs: &'a HashMap, ) -> Self { let mut defs = Self { + consts: HashMap::new(), structs: HashMap::new(), enums: HashMap::new(), fns: HashMap::new(), }; + for (const_name, ty) in const_defs.iter() { + defs.consts.insert(const_name, ty); + } for (struct_name, struct_def) in struct_defs.iter() { let mut field_names = Vec::with_capacity(struct_def.fields.len()); let mut field_types = HashMap::with_capacity(struct_def.fields.len()); @@ -338,6 +354,67 @@ impl UntypedProgram { struct_names, enum_names, }; + let mut const_deps: HashMap> = HashMap::new(); + let mut const_types = HashMap::with_capacity(self.const_defs.len()); + let mut const_defs = HashMap::with_capacity(self.const_defs.len()); + { + for (const_name, const_def) in self.const_defs.iter() { + fn check_const_expr( + value: &ConstExpr, + const_def: &ConstDef, + errors: &mut Vec>, + const_deps: &mut HashMap>, + ) { + let ConstExpr(value, meta) = value; + let meta = *meta; + match value { + ConstExprEnum::True | ConstExprEnum::False => { + if const_def.ty != Type::Bool { + let e = TypeErrorEnum::UnexpectedType { + expected: const_def.ty.clone(), + actual: Type::Bool, + }; + errors.extend(vec![Some(TypeError(e, meta))]); + } + } + ConstExprEnum::NumUnsigned(_, ty) => { + let ty = Type::Unsigned(*ty); + if const_def.ty != ty { + let e = TypeErrorEnum::UnexpectedType { + expected: const_def.ty.clone(), + actual: ty, + }; + errors.extend(vec![Some(TypeError(e, meta))]); + } + } + ConstExprEnum::NumSigned(_, ty) => { + let ty = Type::Signed(*ty); + if const_def.ty != ty { + let e = TypeErrorEnum::UnexpectedType { + expected: const_def.ty.clone(), + actual: ty, + }; + errors.extend(vec![Some(TypeError(e, meta))]); + } + } + ConstExprEnum::ExternalValue { party, identifier } => { + const_deps + .entry(party.clone()) + .or_default() + .insert(identifier.clone(), (const_def.ty.clone(), meta)); + } + ConstExprEnum::Max(args) | ConstExprEnum::Min(args) => { + for arg in args { + check_const_expr(arg, const_def, errors, const_deps); + } + } + } + } + check_const_expr(&const_def.value, const_def, &mut errors, &mut const_deps); + const_defs.insert(const_name.clone(), const_def.clone()); + const_types.insert(const_name.clone(), const_def.ty.clone()); + } + } let mut struct_defs = HashMap::with_capacity(self.struct_defs.len()); for (struct_name, struct_def) in self.struct_defs.iter() { let meta = struct_def.meta; @@ -372,7 +449,7 @@ impl UntypedProgram { enum_defs.insert(enum_name.clone(), EnumDef { variants, meta }); } - let mut untyped_defs = Defs::new(&struct_defs, &enum_defs); + let mut untyped_defs = Defs::new(&const_types, &struct_defs, &enum_defs); let mut checked_fn_defs = TypedFns::new(); for (fn_name, fn_def) in self.fn_defs.iter() { untyped_defs.fns.insert(fn_name, fn_def); @@ -406,6 +483,8 @@ impl UntypedProgram { } if errors.is_empty() { Ok(TypedProgram { + const_deps, + const_defs, struct_defs, enum_defs, fn_defs, @@ -606,7 +685,7 @@ impl UntypedStmt { ast::StmtEnum::ArrayAssign(identifier, index, value) => { match env.get(identifier) { Some((Some(array_ty), Mutability::Mutable)) => { - let (elem_ty, _) = expect_array_type(&array_ty, meta)?; + let elem_ty = expect_array_type(&array_ty, meta)?; let mut index = index.type_check(top_level_defs, env, fns, defs)?; check_type(&mut index, &Type::Unsigned(UnsignedNumType::Usize))?; @@ -635,7 +714,7 @@ impl UntypedStmt { } ast::StmtEnum::ForEachLoop(var, binding, body) => { let binding = binding.type_check(top_level_defs, env, fns, defs)?; - let (elem_ty, _) = expect_array_type(&binding.ty, meta)?; + let elem_ty = expect_array_type(&binding.ty, meta)?; let mut body_typed = Vec::with_capacity(body.len()); env.push(); env.let_in_current_scope(var.clone(), (Some(elem_ty), Mutability::Immutable)); @@ -677,12 +756,15 @@ impl UntypedExpr { Some((None, _mutability)) => { return Err(vec![None]); } - None => { - return Err(vec![Some(TypeError( - TypeErrorEnum::UnknownIdentifier(identifier.clone()), - meta, - ))]); - } + None => match defs.consts.get(identifier.as_str()) { + Some(&ty) => (ExprEnum::Identifier(identifier.clone()), ty.clone()), + None => { + return Err(vec![Some(TypeError( + TypeErrorEnum::UnknownIdentifier(identifier.clone()), + meta, + ))]); + } + }, }, ExprEnum::ArrayLiteral(fields) => { let mut errors = vec![]; @@ -723,10 +805,41 @@ impl UntypedExpr { let ty = Type::Array(Box::new(value.ty.clone()), *size); (ExprEnum::ArrayRepeatLiteral(Box::new(value), *size), ty) } + ExprEnum::ArrayRepeatLiteralConst(value, size) => match env.get(size) { + None => match defs.consts.get(size.as_str()) { + Some(&ty) if ty == &Type::Unsigned(UnsignedNumType::Usize) => { + let value = value.type_check(top_level_defs, env, fns, defs)?; + let ty = Type::ArrayConst(Box::new(value.ty.clone()), size.clone()); + ( + ExprEnum::ArrayRepeatLiteralConst(Box::new(value), size.clone()), + ty, + ) + } + Some(&ty) => { + let e = TypeErrorEnum::UnexpectedType { + expected: Type::Unsigned(UnsignedNumType::Usize), + actual: ty.clone(), + }; + return Err(vec![Some(TypeError(e, meta))]); + } + None => { + return Err(vec![Some(TypeError( + TypeErrorEnum::UnknownIdentifier(size.clone()), + meta, + ))]); + } + }, + Some(_) => { + return Err(vec![Some(TypeError( + TypeErrorEnum::ArraySizeNotConst(size.clone()), + meta, + ))]); + } + }, ExprEnum::ArrayAccess(arr, index) => { let arr = arr.type_check(top_level_defs, env, fns, defs)?; let mut index = index.type_check(top_level_defs, env, fns, defs)?; - let (elem_ty, _) = expect_array_type(&arr.ty, arr.meta)?; + let elem_ty = expect_array_type(&arr.ty, arr.meta)?; check_type(&mut index, &Type::Unsigned(UnsignedNumType::Usize))?; ( ExprEnum::ArrayAccess(Box::new(arr), Box::new(index)), @@ -965,24 +1078,18 @@ impl UntypedExpr { ty, ) } - ExprEnum::EnumLiteral(identifier, variant) => { + ExprEnum::EnumLiteral(identifier, variant_name, variant) => { if let Some(enum_def) = defs.enums.get(identifier.as_str()) { - let VariantExpr(variant_name, variant, meta) = variant.as_ref(); - let meta = *meta; if let Some(types) = enum_def.get(variant_name.as_str()) { match (variant, types) { - (VariantExprEnum::Unit, None) => { - let variant = VariantExpr( - variant_name.to_string(), + (VariantExprEnum::Unit, None) => ( + ExprEnum::EnumLiteral( + identifier.clone(), + variant_name.clone(), VariantExprEnum::Unit, - meta, - ); - let ty = Type::Enum(identifier.clone()); - ( - ExprEnum::EnumLiteral(identifier.clone(), Box::new(variant)), - ty, - ) - } + ), + Type::Enum(identifier.clone()), + ), (VariantExprEnum::Tuple(values), Some(types)) => { if values.len() != types.len() { let e = TypeErrorEnum::UnexpectedEnumVariantArity { @@ -1004,19 +1111,14 @@ impl UntypedExpr { Err(e) => errors.extend(e), } } - let variant = VariantExpr( - variant_name.to_string(), - VariantExprEnum::Tuple(exprs), - meta, - ); - let ty = Type::Enum(identifier.clone()); if errors.is_empty() { ( ExprEnum::EnumLiteral( identifier.clone(), - Box::new(variant), + variant_name.clone(), + VariantExprEnum::Tuple(exprs), ), - ty, + Type::Enum(identifier.clone()), ) } else { return Err(errors); @@ -1039,7 +1141,8 @@ impl UntypedExpr { return Err(vec![Some(TypeError(e, meta))]); } } else { - let e = TypeErrorEnum::UnknownEnum(identifier.clone()); + let e = + TypeErrorEnum::UnknownEnum(identifier.clone(), variant_name.to_string()); return Err(vec![Some(TypeError(e, meta))]); } } @@ -1054,7 +1157,7 @@ impl UntypedExpr { | Type::Tuple(_) | Type::Struct(_) | Type::Enum(_) => {} - Type::Fn(_, _) | Type::Array(_, _) => { + Type::Fn(_, _) | Type::Array(_, _) | Type::ArrayConst(_, _) => { let e = TypeErrorEnum::TypeDoesNotSupportPatternMatching(ty.clone()); return Err(vec![Some(TypeError(e, meta))]); } @@ -1402,7 +1505,7 @@ impl UntypedPattern { return Err(vec![Some(TypeError(e, meta))]); } } else { - let e = TypeErrorEnum::UnknownEnum(enum_name.clone()); + let e = TypeErrorEnum::UnknownEnum(enum_name.clone(), variant_name.to_string()); return Err(vec![Some(TypeError(e, meta))]); } } @@ -1449,6 +1552,7 @@ enum Ctor { Struct(String, Vec<(String, Type)>), Variant(String, String, Option>), Array(Box, usize), + ArrayConst(Box, String), } type PatternStack = Vec; @@ -1517,7 +1621,7 @@ fn specialize(ctor: &Ctor, pattern: &[TypedPattern]) -> Vec { } _ => vec![], }, - Ctor::Array(_, _) => match head_enum { + Ctor::Array(_, _) | Ctor::ArrayConst(_, _) => match head_enum { PatternEnum::Identifier(_) => vec![tail.collect()], _ => vec![], }, @@ -1723,6 +1827,7 @@ fn split_ctor(patterns: &[PatternStack], q: &[TypedPattern], defs: &Defs) -> Vec vec![Ctor::Tuple(fields.clone())] } Type::Array(elem_ty, size) => vec![Ctor::Array(elem_ty.clone(), *size)], + Type::ArrayConst(elem_ty, size) => vec![Ctor::ArrayConst(elem_ty.clone(), size.clone())], Type::Fn(_, _) => { panic!("Type {ty:?} does not support pattern matching") } @@ -1819,6 +1924,14 @@ fn usefulness(patterns: Vec, q: PatternStack, defs: &Defs) -> Vec< meta, ), ), + Ctor::ArrayConst(elem_ty, size) => witness.insert( + 0, + Pattern::typed( + PatternEnum::Identifier("_".to_string()), + Type::ArrayConst(elem_ty.clone(), size.clone()), + meta, + ), + ), } witnesses.push(witness); } @@ -1828,9 +1941,9 @@ fn usefulness(patterns: Vec, q: PatternStack, defs: &Defs) -> Vec< } } -fn expect_array_type(ty: &Type, meta: MetaInfo) -> Result<(Type, usize), TypeErrors> { +fn expect_array_type(ty: &Type, meta: MetaInfo) -> Result { match ty { - Type::Array(elem, size) => Ok((*elem.clone(), *size)), + Type::Array(elem, _) | Type::ArrayConst(elem, _) => Ok(*elem.clone()), _ => Err(vec![Some(TypeError( TypeErrorEnum::ExpectedArrayType(ty.clone()), meta, diff --git a/src/circuit.rs b/src/circuit.rs index ccee5e8..e6a0104 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -269,6 +269,7 @@ pub(crate) struct CircuitBuilder { gates_optimized: usize, gate_counter: usize, panic_gates: PanicResult, + consts: HashMap, } pub(crate) const USIZE_BITS: usize = 32; @@ -393,7 +394,7 @@ impl PanicReason { } impl CircuitBuilder { - pub fn new(input_gates: Vec) -> Self { + pub fn new(input_gates: Vec, consts: HashMap) -> Self { let mut gate_counter = 2; // for const true and false for input_gates_of_party in input_gates.iter() { gate_counter += input_gates_of_party; @@ -407,9 +408,14 @@ impl CircuitBuilder { gates_optimized: 0, gate_counter, panic_gates: PanicResult::ok(), + consts, } } + pub fn const_sizes(&self) -> &HashMap { + &self.consts + } + // Pruning of useless gates (gates that are not part of the output nor used by other gates): fn remove_unused_gates(&mut self, output_gates: Vec) -> Vec { // To find all unused gates, we start at the output gates and recursively mark all their diff --git a/src/compile.rs b/src/compile.rs index fb57fa3..e314d11 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -1,23 +1,61 @@ //! Compiles a [`crate::ast::Program`] to a [`crate::circuit::Circuit`]. -use std::{cmp::max, collections::HashMap}; +use std::{ + cmp::{max, min}, + collections::HashMap, +}; use crate::{ ast::{ - EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef, Type, UnaryOp, - VariantExpr, VariantExprEnum, + ConstExpr, ConstExprEnum, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef, + Type, UnaryOp, VariantExprEnum, }, circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, PanicResult, USIZE_BITS}, env::Env, - token::{SignedNumType, UnsignedNumType}, + literal::Literal, + token::{MetaInfo, SignedNumType, UnsignedNumType}, TypedExpr, TypedFnDef, TypedPattern, TypedProgram, TypedStmt, }; /// An error that occurred during compilation. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum CompilerError { /// The specified function could not be compiled, as it was not found in the program. FnNotFound(String), + /// The provided constant was not of the required type. + InvalidLiteralType(Literal, Type), + /// The constant was declared in the program but not provided during compilation. + MissingConstant(String, String, MetaInfo), +} + +impl PartialOrd for CompilerError { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for CompilerError { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match (self, other) { + (CompilerError::FnNotFound(fn1), CompilerError::FnNotFound(fn2)) => fn1.cmp(fn2), + (CompilerError::FnNotFound(_), _) => std::cmp::Ordering::Less, + (CompilerError::InvalidLiteralType(_, _), CompilerError::FnNotFound(_)) => { + std::cmp::Ordering::Greater + } + ( + CompilerError::InvalidLiteralType(literal1, _), + CompilerError::InvalidLiteralType(literal2, _), + ) => literal1.cmp(literal2), + (CompilerError::InvalidLiteralType(_, _), CompilerError::MissingConstant(_, _, _)) => { + std::cmp::Ordering::Less + } + ( + CompilerError::MissingConstant(_, _, meta1), + CompilerError::MissingConstant(_, _, meta2), + ) => meta1.cmp(meta2), + (CompilerError::MissingConstant(_, _, _), _) => std::cmp::Ordering::Greater, + } + } } impl std::fmt::Display for CompilerError { @@ -26,36 +64,265 @@ impl std::fmt::Display for CompilerError { CompilerError::FnNotFound(fn_name) => f.write_fmt(format_args!( "Could not find any function with name '{fn_name}'" )), + CompilerError::InvalidLiteralType(literal, ty) => { + f.write_fmt(format_args!("The literal is not of type '{ty}': {literal}")) + } + CompilerError::MissingConstant(party, identifier, _) => f.write_fmt(format_args!( + "The constant {party}::{identifier} was declared in the program but never provided" + )), } } } +type CompiledProgram<'a> = (Circuit, &'a TypedFnDef, HashMap); + impl TypedProgram { /// Compiles the (type-checked) program, producing a circuit of gates. /// /// Assumes that the input program has been correctly type-checked and **panics** if /// incompatible types are found that should have been caught by the type-checker. - pub fn compile(&self, fn_name: &str) -> Result<(Circuit, &TypedFnDef), CompilerError> { + pub fn compile(&self, fn_name: &str) -> Result<(Circuit, &TypedFnDef), Vec> { + self.compile_with_constants(fn_name, HashMap::new()) + .map(|(c, f, _)| (c, f)) + } + + /// Compiles the (type-checked) program with provided constants, producing a circuit of gates. + /// + /// Assumes that the input program has been correctly type-checked and **panics** if + /// incompatible types are found that should have been caught by the type-checker. + pub fn compile_with_constants( + &self, + fn_name: &str, + consts: HashMap>, + ) -> Result> { let mut env = Env::new(); + let mut const_sizes = HashMap::new(); + let mut consts_unsigned = HashMap::new(); + let mut consts_signed = HashMap::new(); + + let mut errs = vec![]; + for (party, deps) in self.const_deps.iter() { + for (c, (ty, meta)) in deps { + let Some(party_deps) = consts.get(party) else { + errs.push(CompilerError::MissingConstant( + party.clone(), + c.clone(), + *meta, + )); + continue; + }; + let Some(literal) = party_deps.get(c) else { + errs.push(CompilerError::MissingConstant( + party.clone(), + c.clone(), + *meta, + )); + continue; + }; + let identifier = format!("{party}::{c}"); + match literal { + Literal::NumUnsigned(n, _) => { + consts_unsigned.insert(identifier.clone(), *n); + } + Literal::NumSigned(n, _) => { + consts_signed.insert(identifier.clone(), *n); + } + _ => {} + } + if literal.is_of_type(self, ty) { + if let Literal::NumUnsigned(size, UnsignedNumType::Usize) = literal { + const_sizes.insert(identifier, *size as usize); + } + } + } + } + if !errs.is_empty() { + errs.sort(); + return Err(errs); + } + fn resolve_const_expr_unsigned( + ConstExpr(expr, _): &ConstExpr, + consts_unsigned: &HashMap, + ) -> u64 { + match expr { + ConstExprEnum::NumUnsigned(n, _) => *n, + ConstExprEnum::ExternalValue { party, identifier } => *consts_unsigned + .get(&format!("{party}::{identifier}")) + .unwrap(), + ConstExprEnum::Max(args) => { + let mut result = 0; + for arg in args { + result = max(result, resolve_const_expr_unsigned(arg, consts_unsigned)); + } + result + } + ConstExprEnum::Min(args) => { + let mut result = u64::MAX; + for arg in args { + result = min(result, resolve_const_expr_unsigned(arg, consts_unsigned)); + } + result + } + expr => panic!("Not an unsigned const expr: {expr:?}"), + } + } + fn resolve_const_expr_signed( + ConstExpr(expr, _): &ConstExpr, + consts_signed: &HashMap, + ) -> i64 { + match expr { + ConstExprEnum::NumSigned(n, _) => *n, + ConstExprEnum::ExternalValue { party, identifier } => *consts_signed + .get(&format!("{party}::{identifier}")) + .unwrap(), + ConstExprEnum::Max(args) => { + let mut result = 0; + for arg in args { + result = max(result, resolve_const_expr_signed(arg, consts_signed)); + } + result + } + ConstExprEnum::Min(args) => { + let mut result = i64::MAX; + for arg in args { + result = min(result, resolve_const_expr_signed(arg, consts_signed)); + } + result + } + expr => panic!("Not an unsigned const expr: {expr:?}"), + } + } + for (const_name, const_def) in self.const_defs.iter() { + if let Type::Unsigned(UnsignedNumType::Usize) = const_def.ty { + if let ConstExpr(ConstExprEnum::ExternalValue { party, identifier }, _) = + &const_def.value + { + let identifier = format!("{party}::{identifier}"); + const_sizes.insert(const_name.clone(), *const_sizes.get(&identifier).unwrap()); + } + let n = resolve_const_expr_unsigned(&const_def.value, &consts_unsigned); + const_sizes.insert(const_name.clone(), n as usize); + } + } + + let mut errs = vec![]; + for (party, deps) in self.const_deps.iter() { + for (c, (ty, _)) in deps { + let Some(party_deps) = consts.get(party) else { + continue; + }; + let Some(literal) = party_deps.get(c) else { + continue; + }; + let identifier = format!("{party}::{c}"); + if literal.is_of_type(self, ty) { + let bits = literal + .as_bits(self, &const_sizes) + .iter() + .map(|b| *b as usize) + .collect(); + env.let_in_current_scope(identifier.clone(), bits); + } else { + errs.push(CompilerError::InvalidLiteralType( + literal.clone(), + ty.clone(), + )); + } + } + } + if !errs.is_empty() { + errs.sort(); + return Err(errs); + } let mut input_gates = vec![]; let mut wire = 2; - if let Some(fn_def) = self.fn_defs.get(fn_name) { - for param in fn_def.params.iter() { - let type_size = param.ty.size_in_bits_for_defs(self); - let mut wires = Vec::with_capacity(type_size); - for _ in 0..type_size { - wires.push(wire); - wire += 1; - } - input_gates.push(type_size); - env.let_in_current_scope(param.name.clone(), wires); - } - let mut circuit = CircuitBuilder::new(input_gates); - let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); - Ok((circuit.build(output_gates), fn_def)) - } else { - Err(CompilerError::FnNotFound(fn_name.to_string())) + let Some(fn_def) = self.fn_defs.get(fn_name) else { + return Err(vec![CompilerError::FnNotFound(fn_name.to_string())]); + }; + for param in fn_def.params.iter() { + let type_size = param.ty.size_in_bits_for_defs(self, &const_sizes); + let mut wires = Vec::with_capacity(type_size); + for _ in 0..type_size { + wires.push(wire); + wire += 1; + } + input_gates.push(type_size); + env.let_in_current_scope(param.name.clone(), wires); + } + let mut circuit = CircuitBuilder::new(input_gates, const_sizes.clone()); + for (const_name, const_def) in self.const_defs.iter() { + let ConstExpr(expr, _) = &const_def.value; + match expr { + ConstExprEnum::True => env.let_in_current_scope(const_name.clone(), vec![1]), + ConstExprEnum::False => env.let_in_current_scope(const_name.clone(), vec![0]), + ConstExprEnum::NumUnsigned(n, ty) => { + let ty = Type::Unsigned(*ty); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes())); + unsigned_to_bits( + *n, + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(const_name.clone(), bits); + } + ConstExprEnum::NumSigned(n, ty) => { + let ty = Type::Signed(*ty); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(self, circuit.const_sizes())); + signed_to_bits( + *n, + ty.size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(const_name.clone(), bits); + } + ConstExprEnum::ExternalValue { party, identifier } => { + let bits = env.get(&format!("{party}::{identifier}")).unwrap(); + env.let_in_current_scope(const_name.clone(), bits); + } + ConstExprEnum::Max(_) | ConstExprEnum::Min(_) => { + if let Type::Unsigned(_) = const_def.ty { + let result = + resolve_const_expr_unsigned(&const_def.value, &consts_unsigned); + let mut bits = Vec::with_capacity( + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), + ); + unsigned_to_bits( + result, + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(const_name.clone(), bits); + } else { + let result = resolve_const_expr_signed(&const_def.value, &consts_signed); + let mut bits = Vec::with_capacity( + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), + ); + signed_to_bits( + result, + const_def + .ty + .size_in_bits_for_defs(self, circuit.const_sizes()), + &mut bits, + ); + let bits = bits.into_iter().map(|b| b as usize).collect(); + env.let_in_current_scope(const_name.clone(), bits); + } + } + } } + let output_gates = compile_block(&fn_def.body, self, &mut env, &mut circuit); + Ok((circuit.build(output_gates), fn_def, const_sizes)) } } @@ -99,12 +366,13 @@ impl TypedStmt { vec![] } StmtEnum::ArrayAssign(identifier, index, value) => { - let elem_bits = value.ty.size_in_bits_for_defs(prg); + let elem_bits = value.ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let mut array = env.get(identifier).unwrap(); let size = array.len() / elem_bits; let mut index = index.compile(prg, env, circuit); let value = value.compile(prg, env, circuit); - let index_bits = Type::Unsigned(UnsignedNumType::Usize).size_in_bits_for_defs(prg); + let index_bits = Type::Unsigned(UnsignedNumType::Usize) + .size_in_bits_for_defs(prg, circuit.const_sizes()); extend_to_bits( &mut index, &Type::Unsigned(UnsignedNumType::Usize), @@ -148,7 +416,9 @@ impl TypedStmt { } StmtEnum::ForEachLoop(var, array, body) => { let elem_in_bits = match &array.ty { - Type::Array(elem_ty, _size) => elem_ty.size_in_bits_for_defs(prg), + Type::Array(elem_ty, _) | Type::ArrayConst(elem_ty, _) => { + elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()) + } _ => panic!("Found a non-array value in an array access expr"), }; env.push(); @@ -188,18 +458,29 @@ impl TypedExpr { vec![0] } ExprEnum::NumUnsigned(n, _) => { - let mut bits = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); - unsigned_to_bits(*n, ty.size_in_bits_for_defs(prg), &mut bits); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); + unsigned_to_bits( + *n, + ty.size_in_bits_for_defs(prg, circuit.const_sizes()), + &mut bits, + ); bits.into_iter().map(|b| b as usize).collect() } ExprEnum::NumSigned(n, _) => { - let mut bits = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); - signed_to_bits(*n, ty.size_in_bits_for_defs(prg), &mut bits); + let mut bits = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); + signed_to_bits( + *n, + ty.size_in_bits_for_defs(prg, circuit.const_sizes()), + &mut bits, + ); bits.into_iter().map(|b| b as usize).collect() } ExprEnum::Identifier(s) => env.get(s).unwrap(), ExprEnum::ArrayLiteral(elems) => { - let mut wires = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); + let mut wires = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); for elem in elems { wires.extend(elem.compile(prg, env, circuit)); } @@ -208,23 +489,45 @@ impl TypedExpr { ExprEnum::ArrayRepeatLiteral(elem, size) => { let elem_ty = elem.ty.clone(); let mut elem = elem.compile(prg, env, circuit); - extend_to_bits(&mut elem, &elem_ty, elem_ty.size_in_bits_for_defs(prg)); - let bits = ty.size_in_bits_for_defs(prg); + extend_to_bits( + &mut elem, + &elem_ty, + elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()), + ); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let mut array = Vec::with_capacity(bits); for _ in 0..*size { array.extend_from_slice(&elem); } array } + ExprEnum::ArrayRepeatLiteralConst(elem, size) => { + let size = *circuit.const_sizes().get(size).unwrap(); + let elem_ty = elem.ty.clone(); + let mut elem = elem.compile(prg, env, circuit); + extend_to_bits( + &mut elem, + &elem_ty, + elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()), + ); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); + let mut array = Vec::with_capacity(bits); + for _ in 0..size { + array.extend_from_slice(&elem); + } + array + } ExprEnum::ArrayAccess(array, index) => { let num_elems = match &array.ty { Type::Array(_, size) => *size, + Type::ArrayConst(_, size) => *circuit.const_sizes().get(size).unwrap(), _ => panic!("Found a non-array value in an array access expr"), }; - let elem_bits = ty.size_in_bits_for_defs(prg); + let elem_bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let mut array = array.compile(prg, env, circuit); let mut index = index.compile(prg, env, circuit); - let index_bits = Type::Unsigned(UnsignedNumType::Usize).size_in_bits_for_defs(prg); + let index_bits = Type::Unsigned(UnsignedNumType::Usize) + .size_in_bits_for_defs(prg, circuit.const_sizes()); extend_to_bits( &mut index, &Type::Unsigned(UnsignedNumType::Usize), @@ -267,7 +570,8 @@ impl TypedExpr { } } ExprEnum::TupleLiteral(tuple) => { - let mut wires = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); + let mut wires = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); for value in tuple { wires.extend(value.compile(prg, env, circuit)); } @@ -278,9 +582,12 @@ impl TypedExpr { Type::Tuple(values) => { let mut wires_before = 0; for v in values[0..*index].iter() { - wires_before += v.size_in_bits_for_defs(prg); + wires_before += v.size_in_bits_for_defs(prg, circuit.const_sizes()); } - (wires_before, values[*index].size_in_bits_for_defs(prg)) + ( + wires_before, + values[*index].size_in_bits_for_defs(prg, circuit.const_sizes()), + ) } _ => panic!("Expected a tuple type, but found {:?}", tuple.meta), }; @@ -595,7 +902,7 @@ impl TypedExpr { ExprEnum::Cast(ty, expr) => { let ty_expr = &expr.ty; let mut expr = expr.compile(prg, env, circuit); - let size_after_cast = ty.size_in_bits_for_defs(prg); + let size_after_cast = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); match size_after_cast.cmp(&expr.len()) { std::cmp::Ordering::Equal => expr, @@ -608,7 +915,8 @@ impl TypedExpr { } ExprEnum::Range((from, elem_ty), (to, _)) => { let size = (to - from) as usize; - let elem_bits = Type::Unsigned(*elem_ty).size_in_bits_for_defs(prg); + let elem_bits = + Type::Unsigned(*elem_ty).size_in_bits_for_defs(prg, circuit.const_sizes()); let mut array = Vec::with_capacity(elem_bits * size); for i in *from..*to { for b in (0..elem_bits).rev() { @@ -617,12 +925,11 @@ impl TypedExpr { } array } - ExprEnum::EnumLiteral(identifier, variant) => { + ExprEnum::EnumLiteral(identifier, variant_name, variant) => { let enum_def = prg.enum_defs.get(identifier).unwrap(); let tag_size = enum_tag_size(enum_def); - let max_size = enum_max_size(enum_def, prg); + let max_size = enum_max_size(enum_def, prg, circuit.const_sizes()); let mut wires = vec![0; max_size]; - let VariantExpr(variant_name, variant, _) = variant.as_ref(); let tag_number = enum_tag_number(enum_def, variant_name); for (i, wire) in wires.iter_mut().enumerate().take(tag_size) { *wire = (tag_number >> (tag_size - i - 1)) & 1; @@ -641,7 +948,7 @@ impl TypedExpr { wires } ExprEnum::Match(expr, clauses) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let expr = expr.compile(prg, env, circuit); let mut has_prev_match = 0; let mut muxed_ret_expr = vec![0; bits]; @@ -681,7 +988,8 @@ impl TypedExpr { let struct_def = prg.struct_defs.get(name.as_str()).unwrap(); let mut bits = 0; for (field_name, field_ty) in struct_def.fields.iter() { - let bits_of_field = field_ty.size_in_bits_for_defs(prg); + let bits_of_field = + field_ty.size_in_bits_for_defs(prg, circuit.const_sizes()); if field_name == field { return struct_expr[bits..bits + bits_of_field].to_vec(); } @@ -695,7 +1003,8 @@ impl TypedExpr { ExprEnum::StructLiteral(struct_name, fields) => { let fields: HashMap<_, _> = fields.iter().cloned().collect(); let struct_def = prg.struct_defs.get(struct_name.as_str()).unwrap(); - let mut wires = Vec::with_capacity(ty.size_in_bits_for_defs(prg)); + let mut wires = + Vec::with_capacity(ty.size_in_bits_for_defs(prg, circuit.const_sizes())); for (field_name, _) in struct_def.fields.iter() { let value = fields.get(field_name).unwrap(); wires.extend(value.compile(prg, env, circuit)); @@ -729,7 +1038,7 @@ impl TypedPattern { circuit.push_not(match_expr[0]) } PatternEnum::NumUnsigned(n, _) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let n = unsigned_as_wires(*n, bits); let mut acc = 1; for i in 0..bits { @@ -739,7 +1048,7 @@ impl TypedPattern { acc } PatternEnum::NumSigned(n, _) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let n = signed_as_wires(*n, bits); let mut acc = 1; for i in 0..bits { @@ -749,7 +1058,7 @@ impl TypedPattern { acc } PatternEnum::UnsignedInclusiveRange(min, max, _) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let min = unsigned_as_wires(*min, bits); let max = unsigned_as_wires(*max, bits); let signed = is_signed(ty); @@ -762,7 +1071,7 @@ impl TypedPattern { circuit.push_and(not_lt_min, not_gt_max) } PatternEnum::SignedInclusiveRange(min, max, _) => { - let bits = ty.size_in_bits_for_defs(prg); + let bits = ty.size_in_bits_for_defs(prg, circuit.const_sizes()); let min = signed_as_wires(*min, bits); let max = signed_as_wires(*max, bits); let signed = is_signed(ty); @@ -779,7 +1088,7 @@ impl TypedPattern { let mut w = 0; for field in fields { let Pattern(_, _, field_type) = field; - let field_bits = field_type.size_in_bits_for_defs(prg); + let field_bits = field_type.size_in_bits_for_defs(prg, circuit.const_sizes()); let match_expr = &match_expr[w..w + field_bits]; let is_field_match = field.compile(match_expr, prg, env, circuit); is_match = circuit.push_and(is_match, is_field_match); @@ -794,7 +1103,7 @@ impl TypedPattern { let mut is_match = 1; let mut w = 0; for (field_name, field_type) in struct_def.fields.iter() { - let field_bits = field_type.size_in_bits_for_defs(prg); + let field_bits = field_type.size_in_bits_for_defs(prg, circuit.const_sizes()); if let Some(field_pattern) = fields.get(field_name) { let match_expr = &match_expr[w..w + field_bits]; let is_field_match = field_pattern.compile(match_expr, prg, env, circuit); @@ -829,7 +1138,8 @@ impl TypedPattern { .types() .unwrap_or_default(); for (field, field_type) in fields.iter().zip(field_types) { - let field_bits = field_type.size_in_bits_for_defs(prg); + let field_bits = + field_type.size_in_bits_for_defs(prg, circuit.const_sizes()); let match_expr = &match_expr[w..w + field_bits]; let is_field_match = field.compile(match_expr, prg, env, circuit); is_match = circuit.push_and(is_match, is_field_match); @@ -845,7 +1155,11 @@ impl TypedPattern { } impl Type { - pub(crate) fn size_in_bits_for_defs(&self, prg: &TypedProgram) -> usize { + pub(crate) fn size_in_bits_for_defs( + &self, + prg: &TypedProgram, + const_sizes: &HashMap, + ) -> usize { match self { Type::Bool => 1, Type::Unsigned(UnsignedNumType::Usize) => USIZE_BITS, @@ -853,17 +1167,20 @@ impl Type { Type::Unsigned(UnsignedNumType::U16) | Type::Signed(SignedNumType::I16) => 16, Type::Unsigned(UnsignedNumType::U32) | Type::Signed(SignedNumType::I32) => 32, Type::Unsigned(UnsignedNumType::U64) | Type::Signed(SignedNumType::I64) => 64, - Type::Array(elem, size) => elem.size_in_bits_for_defs(prg) * size, + Type::Array(elem, size) => elem.size_in_bits_for_defs(prg, const_sizes) * size, + Type::ArrayConst(elem, size) => { + elem.size_in_bits_for_defs(prg, const_sizes) * const_sizes.get(size).unwrap() + } Type::Tuple(values) => { let mut size = 0; for v in values { - size += v.size_in_bits_for_defs(prg) + size += v.size_in_bits_for_defs(prg, const_sizes) } size } Type::Fn(_, _) => panic!("Fn types cannot be directly mapped to bits"), - Type::Struct(name) => struct_size(prg.struct_defs.get(name).unwrap(), prg), - Type::Enum(name) => enum_max_size(prg.enum_defs.get(name).unwrap(), prg), + Type::Struct(name) => struct_size(prg.struct_defs.get(name).unwrap(), prg, const_sizes), + Type::Enum(name) => enum_max_size(prg.enum_defs.get(name).unwrap(), prg, const_sizes), Type::UntypedTopLevelDefinition(_, _) => { unreachable!("Untyped top level types should have been typechecked at this point") } @@ -871,10 +1188,14 @@ impl Type { } } -pub(crate) fn struct_size(struct_def: &StructDef, prg: &TypedProgram) -> usize { +pub(crate) fn struct_size( + struct_def: &StructDef, + prg: &TypedProgram, + const_sizes: &HashMap, +) -> usize { let mut total_size = 0; for (_, field_ty) in struct_def.fields.iter() { - total_size += field_ty.size_in_bits_for_defs(prg); + total_size += field_ty.size_in_bits_for_defs(prg, const_sizes); } total_size } @@ -896,12 +1217,16 @@ pub(crate) fn enum_tag_size(enum_def: &EnumDef) -> usize { bits } -pub(crate) fn enum_max_size(enum_def: &EnumDef, prg: &TypedProgram) -> usize { +pub(crate) fn enum_max_size( + enum_def: &EnumDef, + prg: &TypedProgram, + const_sizes: &HashMap, +) -> usize { let mut max = 0; for variant in enum_def.variants.iter() { let mut sum = 0; for field in variant.types().unwrap_or_default() { - sum += field.size_in_bits_for_defs(prg); + sum += field.size_in_bits_for_defs(prg, const_sizes); } if sum > max { max = sum; diff --git a/src/eval.rs b/src/eval.rs index 3a78cc7..8c46a74 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,6 +1,6 @@ //! Evaluates a [`crate::circuit::Circuit`] with inputs supplied by different parties. -use std::fmt::Debug; +use std::{collections::HashMap, fmt::Debug}; use crate::{ ast::Type, @@ -20,16 +20,23 @@ pub struct Evaluator<'a> { /// The compiled circuit. pub circuit: &'a Circuit, inputs: Vec>, + const_sizes: &'a HashMap, } impl<'a> Evaluator<'a> { /// Scans, parses, type-checks and then compiles a program for later evaluation. - pub fn new(program: &'a TypedProgram, main_fn: &'a TypedFnDef, circuit: &'a Circuit) -> Self { + pub fn new( + program: &'a TypedProgram, + main_fn: &'a TypedFnDef, + circuit: &'a Circuit, + const_sizes: &'a HashMap, + ) -> Self { Self { program, main_fn, circuit, inputs: vec![], + const_sizes, } } } @@ -111,6 +118,7 @@ impl<'a> Evaluator<'a> { program: self.program, main_fn: self.main_fn, output, + const_sizes: self.const_sizes.clone(), }) } @@ -183,12 +191,13 @@ impl<'a> Evaluator<'a> { pub fn set_literal(&mut self, literal: Literal) -> Result<(), EvalError> { if self.inputs.len() < self.main_fn.params.len() { let ty = &self.main_fn.params[self.inputs.len()].ty; - if literal.is_of_type(self.program, ty) { + let ty = resolve_const_type(ty, self.const_sizes); + if literal.is_of_type(self.program, &ty) { self.inputs.push(vec![]); self.inputs .last_mut() .unwrap() - .extend(literal.as_bits(self.program)); + .extend(literal.as_bits(self.program, self.const_sizes)); Ok(()) } else { Err(EvalError::InvalidLiteralType(literal, ty.clone())) @@ -202,8 +211,9 @@ impl<'a> Evaluator<'a> { pub fn parse_literal(&mut self, literal: &str) -> Result<(), EvalError> { if self.inputs.len() < self.main_fn.params.len() { let ty = &self.main_fn.params[self.inputs.len()].ty; + let ty = resolve_const_type(ty, self.const_sizes); let parsed = - Literal::parse(self.program, ty, literal).map_err(EvalError::LiteralParseError)?; + Literal::parse(self.program, &ty, literal).map_err(EvalError::LiteralParseError)?; self.set_literal(parsed)?; Ok(()) } else { @@ -212,12 +222,39 @@ impl<'a> Evaluator<'a> { } } +pub(crate) fn resolve_const_type(ty: &Type, const_sizes: &HashMap) -> Type { + match ty { + Type::Fn(params, ret_ty) => Type::Fn( + params + .iter() + .map(|ty| resolve_const_type(ty, const_sizes)) + .collect(), + Box::new(resolve_const_type(ret_ty, const_sizes)), + ), + Type::Array(elem_ty, size) => { + Type::Array(Box::new(resolve_const_type(elem_ty, const_sizes)), *size) + } + Type::ArrayConst(elem_ty, size) => Type::Array( + Box::new(resolve_const_type(elem_ty, const_sizes)), + *const_sizes.get(size).unwrap(), + ), + Type::Tuple(elems) => Type::Tuple( + elems + .iter() + .map(|ty| resolve_const_type(ty, const_sizes)) + .collect(), + ), + ty => ty.clone(), + } +} + /// The encoded result of a circuit evaluation. #[derive(Debug, Clone)] pub struct EvalOutput<'a> { program: &'a TypedProgram, main_fn: &'a TypedFnDef, output: Vec, + const_sizes: HashMap, } impl<'a> TryFrom> for bool { @@ -336,7 +373,7 @@ impl<'a> TryFrom> for Vec { impl<'a> EvalOutput<'a> { fn into_unsigned(self, ty: Type) -> Result { let output = EvalPanic::parse(&self.output)?; - let size = ty.size_in_bits_for_defs(self.program); + let size = ty.size_in_bits_for_defs(self.program, &self.const_sizes); if output.len() == size { let mut n = 0; for (i, output) in output.iter().copied().enumerate() { @@ -353,7 +390,7 @@ impl<'a> EvalOutput<'a> { fn into_signed(self, ty: Type) -> Result { let output = EvalPanic::parse(&self.output)?; - let size = ty.size_in_bits_for_defs(self.program); + let size = ty.size_in_bits_for_defs(self.program, &self.const_sizes); if output.len() == size { let mut n = 0; for (i, output) in output.iter().copied().enumerate() { @@ -376,6 +413,6 @@ impl<'a> EvalOutput<'a> { /// Decodes the evaluated result as a literal (with enums looked up in the program). pub fn into_literal(self) -> Result { let ret_ty = &self.main_fn.ty; - Literal::from_result_bits(self.program, ret_ty, &self.output) + Literal::from_result_bits(self.program, ret_ty, &self.output, &self.const_sizes) } } diff --git a/src/lib.rs b/src/lib.rs index 41d2f8c..bf49175 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,15 +41,18 @@ #![deny(missing_docs)] #![deny(rustdoc::broken_intra_doc_links)] -use ast::{Expr, FnDef, Pattern, Program, Stmt, Type, VariantExpr}; +use ast::{Expr, FnDef, Pattern, Program, Stmt, Type}; use check::TypeError; use circuit::Circuit; use compile::CompilerError; -use eval::{EvalError, Evaluator}; +use eval::{resolve_const_type, EvalError, Evaluator}; use literal::Literal; use parse::ParseError; use scan::{scan, ScanError}; -use std::fmt::{Display, Write as _}; +use std::{ + collections::HashMap, + fmt::{Display, Write as _}, +}; use token::MetaInfo; #[cfg(feature = "serde")] @@ -76,8 +79,6 @@ pub type TypedStmt = Stmt; pub type TypedExpr = Expr; /// [`crate::ast::Pattern`] after typechecking. pub type TypedPattern = Pattern; -/// [`crate::ast::VariantExpr`] after typechecking. -pub type TypedVariantExpr = VariantExpr; pub mod ast; pub mod check; @@ -104,6 +105,25 @@ pub fn compile(prg: &str) -> Result { program, main, circuit, + consts: HashMap::new(), + const_sizes: HashMap::new(), + }) +} + +/// Scans, parses, type-checks and then compiles the `"main"` fn of a program to a boolean circuit. +pub fn compile_with_constants( + prg: &str, + consts: HashMap>, +) -> Result { + let program = check(prg)?; + let (circuit, main, const_sizes) = program.compile_with_constants("main", consts.clone())?; + let main = main.clone(); + Ok(GarbleProgram { + program, + main, + circuit, + consts, + const_sizes, }) } @@ -117,16 +137,20 @@ pub struct GarbleProgram { pub main: TypedFnDef, /// The compilation output, as a circuit of boolean gates. pub circuit: Circuit, + /// The constants used for compiling the circuit. + pub consts: HashMap>, + /// The values of usize constants used for compiling the circuit. + pub const_sizes: HashMap, } /// An input argument for a Garble program and circuit. #[derive(Debug, Clone)] -pub struct GarbleArgument<'a>(Literal, &'a TypedProgram); +pub struct GarbleArgument<'a>(Literal, &'a TypedProgram, &'a HashMap); impl GarbleProgram { /// Returns an evaluator that can be used to run the compiled circuit. pub fn evaluator(&self) -> Evaluator<'_> { - Evaluator::new(&self.program, &self.main, &self.circuit) + Evaluator::new(&self.program, &self.main, &self.circuit, &self.const_sizes) } /// Type-checks and uses the literal as the circuit input argument with the given index. @@ -138,10 +162,11 @@ impl GarbleProgram { let Some(param) = self.main.params.get(arg_index) else { return Err(EvalError::InvalidArgIndex(arg_index)); }; - if !literal.is_of_type(&self.program, ¶m.ty) { - return Err(EvalError::InvalidLiteralType(literal, param.ty.clone())); + let ty = resolve_const_type(¶m.ty, &self.const_sizes); + if !literal.is_of_type(&self.program, &ty) { + return Err(EvalError::InvalidLiteralType(literal, ty)); } - Ok(GarbleArgument(literal, &self.program)) + Ok(GarbleArgument(literal, &self.program, &self.const_sizes)) } /// Tries to parse the string as the circuit input argument with the given index. @@ -155,19 +180,19 @@ impl GarbleProgram { }; let literal = Literal::parse(&self.program, ¶m.ty, literal) .map_err(EvalError::LiteralParseError)?; - Ok(GarbleArgument(literal, &self.program)) + Ok(GarbleArgument(literal, &self.program, &self.const_sizes)) } /// Tries to convert the circuit output back to a Garble literal. pub fn parse_output(&self, bits: &[bool]) -> Result { - Literal::from_result_bits(&self.program, &self.main.ty, bits) + Literal::from_result_bits(&self.program, &self.main.ty, bits, &self.const_sizes) } } impl GarbleArgument<'_> { /// Converts the argument to input bits for the compiled circuit. pub fn as_bits(&self) -> Vec { - self.0.as_bits(self.1) + self.0.as_bits(self.1, self.2) } /// Converts the argument to a Garble literal. @@ -192,7 +217,7 @@ pub enum CompileTimeError { /// Errors originating in the type-checking phase. TypeError(Vec), /// Errors originating in the compilation phase. - CompilerError(CompilerError), + CompilerError(Vec), } /// A generic error that combines compile-time and run-time errors. @@ -224,8 +249,8 @@ impl From> for CompileTimeError { } } -impl From for CompileTimeError { - fn from(e: CompilerError) -> Self { +impl From> for CompileTimeError { + fn from(e: Vec) -> Self { Self::CompilerError(e) } } @@ -285,39 +310,48 @@ impl CompileTimeError { match self { CompileTimeError::ScanErrors(errs) => { for ScanError(e, meta) in errs { - errs_for_display.push(("Scan error", format!("{e}"), *meta)); + errs_for_display.push(("Scan error", format!("{e}"), Some(*meta))); } } CompileTimeError::ParseError(errs) => { for ParseError(e, meta) in errs { - errs_for_display.push(("Parse error", format!("{e}"), *meta)); + errs_for_display.push(("Parse error", format!("{e}"), Some(*meta))); } } CompileTimeError::TypeError(errs) => { for TypeError(e, meta) in errs { - errs_for_display.push(("Type error", format!("{e}"), *meta)); + errs_for_display.push(("Type error", format!("{e}"), Some(*meta))); } } - CompileTimeError::CompilerError(e) => { - let meta = MetaInfo { - start: (0, 0), - end: (0, 0), - }; - errs_for_display.push(("Compiler error", format!("{e}"), meta)) + CompileTimeError::CompilerError(errs) => { + for e in errs { + match e { + CompilerError::MissingConstant(_, _, meta) => { + errs_for_display.push(("Compiler error", format!("{e}"), Some(*meta))) + } + e => errs_for_display.push(("Compiler error", format!("{e}"), None)), + } + } } } let mut msg = "".to_string(); for (err_type, err, meta) in errs_for_display { - writeln!( - msg, - "\n{} on line {}:{}.", - err_type, - meta.start.0 + 1, - meta.start.1 + 1 - ) - .unwrap(); + if let Some(meta) = meta { + writeln!( + msg, + "\n{} on line {}:{}.", + err_type, + meta.start.0 + 1, + meta.start.1 + 1 + ) + .unwrap(); + } else { + writeln!(msg, "\n{}:", err_type).unwrap(); + } writeln!(msg, "{err}:").unwrap(); - msg += &prettify_meta(prg, meta); + if let Some(meta) = meta { + msg += &prettify_meta(prg, meta); + } } msg } diff --git a/src/literal.rs b/src/literal.rs index 02b8941..88f6388 100644 --- a/src/literal.rs +++ b/src/literal.rs @@ -10,7 +10,7 @@ use std::{ use serde::{Deserialize, Serialize}; use crate::{ - ast::{Expr, ExprEnum, Type, Variant, VariantExpr, VariantExprEnum}, + ast::{Expr, ExprEnum, Type, Variant, VariantExprEnum}, check::{check_type, Defs, TopLevelTypes, TypeError, TypedFns}, circuit::EvalPanic, compile::{enum_max_size, enum_tag_number, enum_tag_size, signed_to_bits, unsigned_to_bits}, @@ -18,12 +18,12 @@ use crate::{ eval::EvalError, scan::scan, token::{SignedNumType, UnsignedNumType}, - CompileTimeError, TypedExpr, TypedProgram, TypedVariantExpr, + CompileTimeError, TypedExpr, TypedProgram, }; /// A subset of [`crate::ast::Expr`] that is used as input / output by an /// [`crate::eval::Evaluator`]. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum Literal { /// Literal `true`. @@ -49,7 +49,7 @@ pub enum Literal { } /// A variant literal (either of unit type or containing fields), used by [`Literal::Enum`]. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum VariantLiteral { /// A unit variant, containing no fields. @@ -75,7 +75,12 @@ impl Literal { }; let mut env = Env::new(); let mut fns = TypedFns::new(); - let defs = Defs::new(&checked.struct_defs, &checked.enum_defs); + let const_types = checked + .const_defs + .iter() + .map(|(n, c)| (n.clone(), c.ty.clone())) + .collect(); + let defs = Defs::new(&const_types, &checked.struct_defs, &checked.enum_defs); let mut expr = scan(literal)? .parse_literal()? .type_check(&top_level_defs, &mut env, &mut fns, &defs) @@ -169,9 +174,10 @@ impl Literal { checked: &TypedProgram, ty: &Type, bits: &[bool], + const_sizes: &HashMap, ) -> Result { match EvalPanic::parse(bits) { - Ok(bits) => Literal::from_unwrapped_bits(checked, ty, bits), + Ok(bits) => Literal::from_unwrapped_bits(checked, ty, bits, const_sizes), Err(panic) => Err(EvalError::Panic(panic)), } } @@ -186,6 +192,7 @@ impl Literal { checked: &TypedProgram, ty: &Type, bits: &[bool], + const_sizes: &HashMap, ) -> Result { match ty { Type::Bool => { @@ -203,7 +210,7 @@ impl Literal { } } Type::Unsigned(unsigned_ty) => { - let size = ty.size_in_bits_for_defs(checked); + let size = ty.size_in_bits_for_defs(checked, const_sizes); if bits.len() == size { let mut n = 0; for (i, output) in bits.iter().copied().enumerate() { @@ -218,7 +225,7 @@ impl Literal { } } Type::Signed(signed_ty) => { - let size = ty.size_in_bits_for_defs(checked); + let size = ty.size_in_bits_for_defs(checked, const_sizes); if bits.len() == size { let mut n = 0; for (i, output) in bits.iter().copied().enumerate() { @@ -239,12 +246,34 @@ impl Literal { } } Type::Array(ty, size) => { - let ty_size = ty.size_in_bits_for_defs(checked); + let ty_size = ty.size_in_bits_for_defs(checked, const_sizes); let mut elems = vec![]; let mut i = 0; for _ in 0..*size { let bits = &bits[i..i + ty_size]; - elems.push(Literal::from_unwrapped_bits(checked, ty, bits)?); + elems.push(Literal::from_unwrapped_bits( + checked, + ty, + bits, + const_sizes, + )?); + i += ty_size; + } + Ok(Literal::Array(elems)) + } + Type::ArrayConst(ty, size) => { + let size = const_sizes.get(size).unwrap(); + let ty_size = ty.size_in_bits_for_defs(checked, const_sizes); + let mut elems = vec![]; + let mut i = 0; + for _ in 0..*size { + let bits = &bits[i..i + ty_size]; + elems.push(Literal::from_unwrapped_bits( + checked, + ty, + bits, + const_sizes, + )?); i += ty_size; } Ok(Literal::Array(elems)) @@ -253,9 +282,14 @@ impl Literal { let mut fields = vec![]; let mut i = 0; for ty in field_types { - let ty_size = ty.size_in_bits_for_defs(checked); + let ty_size = ty.size_in_bits_for_defs(checked, const_sizes); let bits = &bits[i..i + ty_size]; - fields.push(Literal::from_unwrapped_bits(checked, ty, bits)?); + fields.push(Literal::from_unwrapped_bits( + checked, + ty, + bits, + const_sizes, + )?); i += ty_size; } Ok(Literal::Tuple(fields)) @@ -265,9 +299,9 @@ impl Literal { let mut i = 0; let struct_def = checked.struct_defs.get(struct_name).unwrap(); for (field_name, ty) in struct_def.fields.iter() { - let ty_size = ty.size_in_bits_for_defs(checked); + let ty_size = ty.size_in_bits_for_defs(checked, const_sizes); let bits = &bits[i..i + ty_size]; - let value = Literal::from_unwrapped_bits(checked, ty, bits)?; + let value = Literal::from_unwrapped_bits(checked, ty, bits, const_sizes)?; fields.push((field_name.clone(), value)); i += ty_size; } @@ -294,10 +328,11 @@ impl Literal { let field = Literal::from_unwrapped_bits( checked, ty, - &bits[i..i + ty.size_in_bits_for_defs(checked)], + &bits[i..i + ty.size_in_bits_for_defs(checked, const_sizes)], + const_sizes, )?; fields.push(field); - i += ty.size_in_bits_for_defs(checked); + i += ty.size_in_bits_for_defs(checked, const_sizes); } let variant = VariantLiteral::Tuple(fields); Ok(Literal::Enum( @@ -316,24 +351,28 @@ impl Literal { } /// Encodes the literal as bits, looking up enum defs in the program. - pub fn as_bits(&self, checked: &TypedProgram) -> Vec { + pub fn as_bits( + &self, + checked: &TypedProgram, + const_sizes: &HashMap, + ) -> Vec { match self { Literal::True => vec![true], Literal::False => vec![false], Literal::NumUnsigned(n, ty) => { - let size = Type::Unsigned(*ty).size_in_bits_for_defs(checked); + let size = Type::Unsigned(*ty).size_in_bits_for_defs(checked, const_sizes); let mut bits = vec![]; unsigned_to_bits(*n, size, &mut bits); bits } Literal::NumSigned(n, ty) => { - let size = Type::Signed(*ty).size_in_bits_for_defs(checked); + let size = Type::Signed(*ty).size_in_bits_for_defs(checked, const_sizes); let mut bits = vec![]; signed_to_bits(*n, size, &mut bits); bits } Literal::ArrayRepeat(elem, size) => { - let elem = elem.as_bits(checked); + let elem = elem.as_bits(checked, const_sizes); let elem_size = elem.len(); let mut bits = vec![false; elem_size * size]; for i in 0..*size { @@ -344,28 +383,28 @@ impl Literal { Literal::Array(elems) => { let mut bits = vec![]; for elem in elems { - bits.extend(elem.as_bits(checked)) + bits.extend(elem.as_bits(checked, const_sizes)) } bits } Literal::Tuple(fields) => { let mut bits = vec![]; for f in fields { - bits.extend(f.as_bits(checked)) + bits.extend(f.as_bits(checked, const_sizes)) } bits } Literal::Struct(_, fields) => { let mut bits = vec![]; for (_, f) in fields { - bits.extend(f.as_bits(checked)) + bits.extend(f.as_bits(checked, const_sizes)) } bits } Literal::Enum(enum_name, variant_name, variant) => { let enum_def = checked.enum_defs.get(enum_name).unwrap(); let tag_size = enum_tag_size(enum_def); - let max_size = enum_max_size(enum_def, checked); + let max_size = enum_max_size(enum_def, checked, const_sizes); let mut wires = vec![false; max_size]; let tag_number = enum_tag_number(enum_def, variant_name); for (i, wire) in wires.iter_mut().enumerate().take(tag_size) { @@ -376,7 +415,7 @@ impl Literal { VariantLiteral::Unit => {} VariantLiteral::Tuple(fields) => { for f in fields { - let f = f.as_bits(checked); + let f = f.as_bits(checked, const_sizes); wires[w..w + f.len()].copy_from_slice(&f); w += f.len(); } @@ -386,7 +425,7 @@ impl Literal { } Literal::Range((min, min_ty), (max, _)) => { let elems: Vec = (*min as usize..*max as usize).collect(); - let elem_size = Type::Unsigned(*min_ty).size_in_bits_for_defs(checked); + let elem_size = Type::Unsigned(*min_ty).size_in_bits_for_defs(checked, const_sizes); let mut bits = Vec::with_capacity(elems.len() * elem_size); for elem in elems { unsigned_to_bits(elem as u64, elem_size, &mut bits); @@ -506,9 +545,14 @@ impl TypedExpr { .map(|(name, value)| (name, value.into_literal())) .collect(), ), - ExprEnum::EnumLiteral(name, variant) => { - let VariantExpr(variant_name, _, _) = &variant.as_ref(); - Literal::Enum(name, variant_name.clone(), variant.into_literal()) + ExprEnum::EnumLiteral(name, variant_name, variant) => { + let variant = match variant { + VariantExprEnum::Unit => VariantLiteral::Unit, + VariantExprEnum::Tuple(fields) => VariantLiteral::Tuple( + fields.into_iter().map(|f| f.into_literal()).collect(), + ), + }; + Literal::Enum(name, variant_name.clone(), variant) } ExprEnum::Range(min, max) => Literal::Range(min, max), _ => unreachable!("This should result in a literal parse error instead"), @@ -516,18 +560,6 @@ impl TypedExpr { } } -impl TypedVariantExpr { - fn into_literal(self) -> VariantLiteral { - let VariantExpr(_, variant, _) = self; - match variant { - VariantExprEnum::Unit => VariantLiteral::Unit, - VariantExprEnum::Tuple(fields) => { - VariantLiteral::Tuple(fields.into_iter().map(|f| f.into_literal()).collect()) - } - } - } -} - impl From for Literal { fn from(b: bool) -> Self { if b { diff --git a/src/main.rs b/src/main.rs index f3c83e2..1b0ae3a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{fs::File, io::Read, path::PathBuf, process::exit}; +use std::{collections::HashMap, fs::File, io::Read, path::PathBuf, process::exit}; use garble_lang::{check, eval::Evaluator, literal::Literal}; @@ -60,8 +60,10 @@ fn run(file: PathBuf, inputs: Vec, function: String) -> Result<(), std:: eprintln!("{}", e.prettify(&prg)); exit(65); }); - let (circuit, main_fn) = program.compile(&function).unwrap_or_else(|e| { - eprintln!("{e}"); + let (circuit, main_fn) = program.compile(&function).unwrap_or_else(|errs| { + for e in errs { + eprintln!("{e}"); + } exit(65); }); @@ -82,7 +84,8 @@ fn run(file: PathBuf, inputs: Vec, function: String) -> Result<(), std:: arguments.push(input); } - let mut evaluator = Evaluator::new(&program, main_fn, &circuit); + let const_sizes = HashMap::new(); + let mut evaluator = Evaluator::new(&program, main_fn, &circuit, &const_sizes); let main_params = &evaluator.main_fn.params; if main_params.len() != arguments.len() { eprintln!( diff --git a/src/parse.rs b/src/parse.rs index 5ca3ad0..be23571 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -4,8 +4,8 @@ use std::{collections::HashMap, iter::Peekable, vec::IntoIter}; use crate::{ ast::{ - EnumDef, Expr, ExprEnum, FnDef, Op, ParamDef, Pattern, PatternEnum, Program, Stmt, - StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExpr, VariantExprEnum, + ConstDef, ConstExpr, ConstExprEnum, EnumDef, Expr, ExprEnum, FnDef, Op, ParamDef, Pattern, + PatternEnum, Program, Stmt, StmtEnum, StructDef, Type, UnaryOp, Variant, VariantExprEnum, }, scan::Tokens, token::{MetaInfo, SignedNumType, Token, TokenEnum, UnsignedNumType}, @@ -20,7 +20,7 @@ pub struct ParseError(pub ParseErrorEnum, pub MetaInfo); /// The different kinds of errors found during parsing. pub enum ParseErrorEnum { - /// The top level definition is not a valid enum or function declaration. + /// The top level definition is not a valid enum/struct/const/fn declaration. InvalidTopLevelDef, /// Arrays of the specified size are not supported. InvalidArraySize, @@ -30,6 +30,8 @@ pub enum ParseErrorEnum { InvalidPattern, /// The literal is not valid. InvalidLiteral, + /// Expected a const expr, but found a non-const or invalid expr. + InvalidConstExpr, /// Expected a type, but found a non-type token. ExpectedType, /// Expected a statement, but found a non-statement token. @@ -48,7 +50,7 @@ impl std::fmt::Display for ParseErrorEnum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ParseErrorEnum::InvalidTopLevelDef => { - f.write_str("Not a valid function or enum definition") + f.write_str("Not a valid top level declaration (struct/enum/const/fn)") } ParseErrorEnum::InvalidArraySize => { let max = usize::MAX; @@ -59,6 +61,7 @@ impl std::fmt::Display for ParseErrorEnum { ParseErrorEnum::InvalidRangeExpr => f.write_str("Invalid range expression"), ParseErrorEnum::InvalidPattern => f.write_str("Invalid pattern"), ParseErrorEnum::InvalidLiteral => f.write_str("Invalid literal"), + ParseErrorEnum::InvalidConstExpr => f.write_str("Invalid const expr"), ParseErrorEnum::ExpectedType => f.write_str("Expected a type"), ParseErrorEnum::ExpectedStmt => f.write_str("Expected a statement"), ParseErrorEnum::ExpectedExpr => f.write_str("Expected an expression"), @@ -115,7 +118,9 @@ impl Parser { TokenEnum::KeywordFn, TokenEnum::KeywordStruct, TokenEnum::KeywordEnum, + TokenEnum::KeywordConst, ]; + let mut const_defs = HashMap::new(); let mut struct_defs = HashMap::new(); let mut enum_defs = HashMap::new(); let mut fn_defs = HashMap::new(); @@ -125,6 +130,14 @@ impl Parser { TokenEnum::KeywordPub if is_pub.is_none() => { is_pub = Some(meta); } + TokenEnum::KeywordConst => { + if let Ok((const_name, const_def)) = self.parse_const_def(meta) { + const_defs.insert(const_name, const_def); + } else { + self.consume_until_one_of(&top_level_keywords); + } + is_pub = None; + } TokenEnum::KeywordStruct => { if let Ok((struct_name, struct_def)) = self.parse_struct_def(meta) { struct_defs.insert(struct_name, struct_def); @@ -158,6 +171,8 @@ impl Parser { } if self.errors.is_empty() { return Ok(Program { + const_deps: HashMap::new(), + const_defs, struct_defs, enum_defs, fn_defs, @@ -166,6 +181,76 @@ impl Parser { Err(self.errors) } + fn parse_const_def(&mut self, start: MetaInfo) -> Result<(String, ConstDef), ()> { + // const keyword was already consumed by the top-level parser + let (identifier, _) = self.expect_identifier()?; + + self.expect(&TokenEnum::Colon)?; + + let (ty, _) = self.parse_type()?; + + self.expect(&TokenEnum::Eq)?; + + let Ok(expr) = self.parse_primary() else { + self.push_error(ParseErrorEnum::InvalidTopLevelDef, start); + return Err(()); + }; + fn parse_const_expr( + expr: UntypedExpr, + ) -> Result> { + match expr.inner { + ExprEnum::True => Ok(ConstExpr(ConstExprEnum::True, expr.meta)), + ExprEnum::False => Ok(ConstExpr(ConstExprEnum::False, expr.meta)), + ExprEnum::NumUnsigned(n, ty) => { + Ok(ConstExpr(ConstExprEnum::NumUnsigned(n, ty), expr.meta)) + } + ExprEnum::NumSigned(n, ty) => { + Ok(ConstExpr(ConstExprEnum::NumSigned(n, ty), expr.meta)) + } + ExprEnum::EnumLiteral(party, identifier, VariantExprEnum::Unit) => Ok(ConstExpr( + ConstExprEnum::ExternalValue { party, identifier }, + expr.meta, + )), + ExprEnum::FnCall(f, args) if f == "max" || f == "min" => { + let mut const_exprs = vec![]; + let mut arg_errs = vec![]; + for arg in args { + match parse_const_expr(arg) { + Ok(value) => { + const_exprs.push(value); + } + Err(errs) => { + arg_errs.extend(errs); + } + } + } + if !arg_errs.is_empty() { + return Err(arg_errs); + } + if f == "max" { + Ok(ConstExpr(ConstExprEnum::Max(const_exprs), expr.meta)) + } else { + Ok(ConstExpr(ConstExprEnum::Min(const_exprs), expr.meta)) + } + } + _ => Err(vec![(ParseErrorEnum::InvalidConstExpr, expr.meta)]), + } + } + match parse_const_expr(expr) { + Ok(value) => { + let end = self.expect(&TokenEnum::Semicolon)?; + let meta = join_meta(start, end); + Ok((identifier, ConstDef { ty, value, meta })) + } + Err(errs) => { + for (e, meta) in errs { + self.push_error(e, meta); + } + Err(()) + } + } + } + fn parse_struct_def(&mut self, start: MetaInfo) -> Result<(String, StructDef), ()> { // struct keyword was already consumed by the top-level parser let (identifier, _) = self.expect_identifier()?; @@ -321,10 +406,7 @@ impl Parser { self.expect(&TokenEnum::Semicolon)?; return Ok(Stmt::new(StmtEnum::Let(pattern, binding), meta)); } else { - self.push_error_for_next(ParseErrorEnum::ExpectedStmt); self.consume_until_one_of(&[TokenEnum::Semicolon]); - self.advance(); - self.parse_expr()?; self.expect(&TokenEnum::Semicolon)?; } } @@ -401,7 +483,8 @@ impl Parser { fn parse_stmts(&mut self) -> Result, ()> { let mut stmts = vec![]; let mut has_error = false; - while !self.peek(&TokenEnum::RightBrace) + while self.tokens.peek().is_some() + && !self.peek(&TokenEnum::RightBrace) && !self.peek(&TokenEnum::Comma) && !self.peek(&TokenEnum::KeywordPub) && !self.peek(&TokenEnum::KeywordFn) @@ -1111,6 +1194,7 @@ impl Parser { _ => { if self.next_matches(&TokenEnum::DoubleColon).is_some() { let (variant_name, variant_meta) = self.expect_identifier()?; + let meta = join_meta(meta, variant_meta); let variant = if self.next_matches(&TokenEnum::LeftParen).is_some() { let mut fields = vec![]; if !self.peek(&TokenEnum::RightParen) { @@ -1132,13 +1216,15 @@ impl Parser { fields.push(child); } } - let end = self.expect(&TokenEnum::RightParen)?; - let variant_meta = join_meta(variant_meta, end); - VariantExpr(variant_name, VariantExprEnum::Tuple(fields), variant_meta) + self.expect(&TokenEnum::RightParen)?; + VariantExprEnum::Tuple(fields) } else { - VariantExpr(variant_name, VariantExprEnum::Unit, variant_meta) + VariantExprEnum::Unit }; - Expr::untyped(ExprEnum::EnumLiteral(identifier, Box::new(variant)), meta) + Expr::untyped( + ExprEnum::EnumLiteral(identifier, variant_name, variant), + meta, + ) } else if self.next_matches(&TokenEnum::LeftBrace).is_some() && self.struct_literals_allowed { @@ -1245,19 +1331,28 @@ impl Parser { }; if self.peek(&TokenEnum::Semicolon) { self.expect(&TokenEnum::Semicolon)?; - let size = if let Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) = - self.tokens.peek() - { - let n = *n; - self.advance(); - n as usize - } else { - self.push_error_for_next(ParseErrorEnum::InvalidArraySize); - return Err(()); - }; - let meta_end = self.expect(&TokenEnum::RightBracket)?; - let meta = join_meta(meta, meta_end); - Expr::untyped(ExprEnum::ArrayRepeatLiteral(Box::new(elem), size), meta) + match self.tokens.peek().cloned() { + Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) => { + self.advance(); + let meta_end = self.expect(&TokenEnum::RightBracket)?; + let meta = join_meta(meta, meta_end); + let size = n as usize; + Expr::untyped(ExprEnum::ArrayRepeatLiteral(Box::new(elem), size), meta) + } + Some(Token(TokenEnum::Identifier(n), _)) => { + self.advance(); + let meta_end = self.expect(&TokenEnum::RightBracket)?; + let meta = join_meta(meta, meta_end); + Expr::untyped( + ExprEnum::ArrayRepeatLiteralConst(Box::new(elem), n), + meta, + ) + } + _ => { + self.push_error_for_next(ParseErrorEnum::InvalidArraySize); + return Err(()); + } + } } else { let mut elems = vec![elem]; while self.next_matches(&TokenEnum::Comma).is_some() { @@ -1308,18 +1403,25 @@ impl Parser { } else if let Some(meta) = self.next_matches(&TokenEnum::LeftBracket) { let (ty, _) = self.parse_type()?; self.expect(&TokenEnum::Semicolon)?; - let size = if let Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) = self.tokens.peek() - { - let n = *n; - self.advance(); - n as usize - } else { - self.push_error_for_next(ParseErrorEnum::InvalidArraySize); - return Err(()); - }; - let meta_end = self.expect(&TokenEnum::RightBracket)?; - let meta = join_meta(meta, meta_end); - Ok((Type::Array(Box::new(ty), size), meta)) + match self.tokens.peek().cloned() { + Some(Token(TokenEnum::ConstantIndexOrSize(n), _)) => { + self.advance(); + let size = n as usize; + let meta_end = self.expect(&TokenEnum::RightBracket)?; + let meta = join_meta(meta, meta_end); + Ok((Type::Array(Box::new(ty), size), meta)) + } + Some(Token(TokenEnum::Identifier(n), _)) => { + self.advance(); + let meta_end = self.expect(&TokenEnum::RightBracket)?; + let meta = join_meta(meta, meta_end); + Ok((Type::ArrayConst(Box::new(ty), n), meta)) + } + _ => { + self.push_error_for_next(ParseErrorEnum::InvalidArraySize); + return Err(()); + } + } } else { let (ty, meta) = self.expect_identifier()?; let ty = match ty.as_str() { diff --git a/src/scan.rs b/src/scan.rs index 5227408..f74dfb7 100644 --- a/src/scan.rs +++ b/src/scan.rs @@ -274,6 +274,7 @@ impl<'a> Scanner<'a> { } let identifier: String = chars.into_iter().collect(); match identifier.as_str() { + "const" => self.push_token(TokenEnum::KeywordConst), "struct" => self.push_token(TokenEnum::KeywordStruct), "enum" => self.push_token(TokenEnum::KeywordEnum), "fn" => self.push_token(TokenEnum::KeywordFn), diff --git a/src/token.rs b/src/token.rs index a7d3d22..cd757ef 100644 --- a/src/token.rs +++ b/src/token.rs @@ -18,10 +18,12 @@ pub enum TokenEnum { UnsignedNum(u64, UnsignedNumType), /// Signed number. SignedNum(i64, SignedNumType), - /// `enum` keyword. - KeywordEnum, + /// `const` keyword. + KeywordConst, /// `struct` keyword. KeywordStruct, + /// `enum` keyword. + KeywordEnum, /// `fn` keyword. KeywordFn, /// `let` keyword. @@ -121,6 +123,7 @@ impl std::fmt::Display for TokenEnum { TokenEnum::ConstantIndexOrSize(num) => f.write_fmt(format_args!("{num}")), TokenEnum::UnsignedNum(num, suffix) => f.write_fmt(format_args!("{num}{suffix}")), TokenEnum::SignedNum(num, suffix) => f.write_fmt(format_args!("{num}{suffix}")), + TokenEnum::KeywordConst => f.write_str("const"), TokenEnum::KeywordStruct => f.write_str("struct"), TokenEnum::KeywordEnum => f.write_str("enum"), TokenEnum::KeywordFn => f.write_str("fn"), @@ -173,7 +176,7 @@ impl std::fmt::Display for TokenEnum { } /// A suffix indicating the explicit unsigned number type of the literal. -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum UnsignedNumType { /// Unsigned integer type used to index arrays, length depends on the host platform. @@ -214,7 +217,7 @@ impl std::fmt::Display for UnsignedNumType { } /// A suffix indicating the explicit signed number type of the literal. -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum SignedNumType { /// 8-bit signed integer type. diff --git a/tests/compile.rs b/tests/compile.rs index 9a9c0f3..3bb0b4c 100644 --- a/tests/compile.rs +++ b/tests/compile.rs @@ -1,4 +1,8 @@ -use garble_lang::{compile, Error}; +use std::collections::HashMap; + +use garble_lang::{ + compile, compile_with_constants, literal::Literal, token::UnsignedNumType, Error, +}; fn pretty_print>(e: E, prg: &str) -> Error { let e: Error = e.into(); @@ -1834,3 +1838,216 @@ pub fn main(_a: i32, _b: i32) -> () { assert_eq!(r.to_string(), "()"); Ok(()) } + +#[test] +fn compile_const() -> Result<(), Error> { + let prg = " +const MY_CONST: u16 = PARTY_0::MY_CONST; +pub fn main(x: u16) -> u16 { + x + MY_CONST +} +"; + let consts = HashMap::from_iter(vec![( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::U16), + )]), + )]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} + +#[test] +fn compile_const_literal() -> Result<(), Error> { + let prg = " +const MY_CONST: u16 = 2u16; +pub fn main(x: u16) -> u16 { + x + MY_CONST +} +"; + let consts = HashMap::from_iter(vec![]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} + +#[test] +fn compile_const_usize() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = PARTY_0::MY_CONST; +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +"; + let consts = HashMap::from_iter(vec![( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + )]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} + +#[test] +fn compile_const_aggregated_max() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST); +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +"; + let consts = HashMap::from_iter(vec![ + ( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(1, UnsignedNumType::Usize), + )]), + ), + ( + "PARTY_1".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + ), + ]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} + +#[test] +fn compile_const_aggregated_min() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = min(PARTY_0::MY_CONST, PARTY_1::MY_CONST); +pub fn main(x: u16) -> u16 { + let array = [2u16; MY_CONST]; + x + array[1] +} +"; + let consts = HashMap::from_iter(vec![ + ( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(3, UnsignedNumType::Usize), + )]), + ), + ( + "PARTY_1".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + ), + ]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.set_u16(255); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!( + u16::try_from(output).map_err(|e| pretty_print(e, prg))?, + 257 + ); + Ok(()) +} + +#[test] +fn compile_const_size_in_fn_param() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST); +pub fn main(array: [u16; MY_CONST]) -> u16 { + array[1] +} +"; + let consts = HashMap::from_iter(vec![ + ( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(1, UnsignedNumType::Usize), + )]), + ), + ( + "PARTY_1".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + ), + ]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.parse_literal("[7u16, 8u16]").unwrap(); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!(u16::try_from(output).map_err(|e| pretty_print(e, prg))?, 8); + Ok(()) +} + +#[test] +fn compile_const_size_for_each_loop() -> Result<(), Error> { + let prg = " +const MY_CONST: usize = max(PARTY_0::MY_CONST, PARTY_1::MY_CONST); +pub fn main(array: [u16; MY_CONST]) -> u16 { + let mut result = 0u16; + for elem in array { + result = result + elem; + } + result +} +"; + let consts = HashMap::from_iter(vec![ + ( + "PARTY_0".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(1, UnsignedNumType::Usize), + )]), + ), + ( + "PARTY_1".to_string(), + HashMap::from_iter(vec![( + "MY_CONST".to_string(), + Literal::NumUnsigned(2, UnsignedNumType::Usize), + )]), + ), + ]); + let compiled = compile_with_constants(prg, consts).map_err(|e| pretty_print(e, prg))?; + let mut eval = compiled.evaluator(); + eval.parse_literal("[7u16, 8u16]").unwrap(); + let output = eval.run().map_err(|e| pretty_print(e, prg))?; + assert_eq!(u16::try_from(output).map_err(|e| pretty_print(e, prg))?, 15); + Ok(()) +} diff --git a/tests/credit_scoring_example.rs b/tests/credit_scoring_example.rs index 493fa50..ac0f940 100644 --- a/tests/credit_scoring_example.rs +++ b/tests/credit_scoring_example.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use garble_lang::{ast::Type, check, circuit::Circuit, eval::Evaluator, literal::Literal, Error}; #[test] @@ -126,7 +128,13 @@ fn credit_scoring_single_run() -> Result<(), Error> { )?; let (compute_score_circuit, compute_score_fn) = typed_prg.compile("compute_score")?; - let mut eval = Evaluator::new(&typed_prg, compute_score_fn, &compute_score_circuit); + let const_sizes = HashMap::new(); + let mut eval = Evaluator::new( + &typed_prg, + compute_score_fn, + &compute_score_circuit, + &const_sizes, + ); eval.set_literal(scoring_algorithm)?; eval.set_literal(user)?; diff --git a/tests/smart_cookie_example.rs b/tests/smart_cookie_example.rs index b274b93..9271a9e 100644 --- a/tests/smart_cookie_example.rs +++ b/tests/smart_cookie_example.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use garble_lang::{ check, circuit::Circuit, @@ -89,7 +91,8 @@ fn smart_cookie_simple_interaction() -> Result<(), Error> { )], ); - let mut init_eval = Evaluator::new(&program, init_fn, &init_circuit); + let const_sizes = HashMap::new(); + let mut init_eval = Evaluator::new(&program, init_fn, &init_circuit, &const_sizes); init_eval .set_literal(website_signing_key.clone()) .map_err(|e| pretty_print(e, smart_cookie))?; @@ -104,8 +107,12 @@ fn smart_cookie_simple_interaction() -> Result<(), Error> { for (i, interest) in interests.iter().enumerate() { println!(" {i}: logging '{interest}'"); - let mut log_interest_eval = - Evaluator::new(&program, log_interest_fn, &log_interest_circuit); + let mut log_interest_eval = Evaluator::new( + &program, + log_interest_fn, + &log_interest_circuit, + &const_sizes, + ); let interest = Literal::Enum( "UserInterest".to_string(), interest.to_string(), @@ -133,7 +140,8 @@ fn smart_cookie_simple_interaction() -> Result<(), Error> { .unwrap(); } - let mut decide_ad_eval = Evaluator::new(&program, decide_ad_fn, &decide_ad_circuit); + let mut decide_ad_eval = + Evaluator::new(&program, decide_ad_fn, &decide_ad_circuit, &const_sizes); decide_ad_eval .set_literal(website_signing_key) .map_err(|e| pretty_print(e, smart_cookie))?;