1
1
//! Compiles a [`crate::ast::Program`] to a [`crate::circuit::Circuit`].
2
2
3
- use std:: { cmp:: max, collections:: HashMap } ;
3
+ use std:: {
4
+ cmp:: { max, min} ,
5
+ collections:: HashMap ,
6
+ } ;
4
7
5
8
use crate :: {
6
9
ast:: {
@@ -56,14 +59,26 @@ impl TypedProgram {
56
59
) -> Result < ( Circuit , & TypedFnDef ) , CompilerError > {
57
60
let mut env = Env :: new ( ) ;
58
61
let mut const_sizes = HashMap :: new ( ) ;
62
+ let mut consts_unsigned = HashMap :: new ( ) ;
63
+ let mut consts_signed = HashMap :: new ( ) ;
59
64
for ( party, deps) in self . const_deps . iter ( ) {
60
- for ( c, ( identifier , ty ) ) in deps {
65
+ for ( c, ty ) in deps {
61
66
let Some ( party_deps) = consts. get ( party) else {
62
67
todo ! ( "missing party dep for {party}" ) ;
63
68
} ;
64
69
let Some ( literal) = party_deps. get ( c) else {
65
70
todo ! ( "missing value {party}::{c}" ) ;
66
71
} ;
72
+ let identifier = format ! ( "{party}::{c}" ) ;
73
+ match literal {
74
+ Literal :: NumUnsigned ( n, _) => {
75
+ consts_unsigned. insert ( identifier. clone ( ) , * n) ;
76
+ }
77
+ Literal :: NumSigned ( n, _) => {
78
+ consts_signed. insert ( identifier. clone ( ) , * n) ;
79
+ }
80
+ _ => { }
81
+ }
67
82
if literal. is_of_type ( self , ty) {
68
83
let bits = literal
69
84
. as_bits ( self , & const_sizes)
@@ -72,7 +87,7 @@ impl TypedProgram {
72
87
. collect ( ) ;
73
88
env. let_in_current_scope ( identifier. clone ( ) , bits) ;
74
89
if let Literal :: NumUnsigned ( size, UnsignedNumType :: Usize ) = literal {
75
- const_sizes. insert ( identifier. clone ( ) , * size as usize ) ;
90
+ const_sizes. insert ( identifier, * size as usize ) ;
76
91
}
77
92
} else {
78
93
return Err ( CompilerError :: InvalidLiteralType (
@@ -84,58 +99,153 @@ impl TypedProgram {
84
99
}
85
100
let mut input_gates = vec ! [ ] ;
86
101
let mut wire = 2 ;
87
- if let Some ( fn_def) = self . fn_defs . get ( fn_name) {
88
- for param in fn_def. params . iter ( ) {
89
- let type_size = param. ty . size_in_bits_for_defs ( self , & const_sizes) ;
90
- let mut wires = Vec :: with_capacity ( type_size) ;
91
- for _ in 0 ..type_size {
92
- wires. push ( wire) ;
93
- wire += 1 ;
102
+ let Some ( fn_def) = self . fn_defs . get ( fn_name) else {
103
+ return Err ( CompilerError :: FnNotFound ( fn_name. to_string ( ) ) ) ;
104
+ } ;
105
+ for param in fn_def. params . iter ( ) {
106
+ let type_size = param. ty . size_in_bits_for_defs ( self , & const_sizes) ;
107
+ let mut wires = Vec :: with_capacity ( type_size) ;
108
+ for _ in 0 ..type_size {
109
+ wires. push ( wire) ;
110
+ wire += 1 ;
111
+ }
112
+ input_gates. push ( type_size) ;
113
+ env. let_in_current_scope ( param. name . clone ( ) , wires) ;
114
+ }
115
+ fn resolve_const_expr_unsigned (
116
+ expr : & ConstExpr ,
117
+ consts_unsigned : & HashMap < String , u64 > ,
118
+ ) -> u64 {
119
+ match expr {
120
+ ConstExpr :: NumUnsigned ( n, _) => * n,
121
+ ConstExpr :: ExternalValue { party, identifier } => * consts_unsigned
122
+ . get ( & format ! ( "{party}::{identifier}" ) )
123
+ . unwrap ( ) ,
124
+ ConstExpr :: Max ( args) => {
125
+ let mut result = 0 ;
126
+ for arg in args {
127
+ result = max ( result, resolve_const_expr_unsigned ( arg, consts_unsigned) ) ;
128
+ }
129
+ result
130
+ }
131
+ ConstExpr :: Min ( args) => {
132
+ let mut result = u64:: MAX ;
133
+ for arg in args {
134
+ result = min ( result, resolve_const_expr_unsigned ( arg, consts_unsigned) ) ;
135
+ }
136
+ result
137
+ }
138
+ expr => panic ! ( "Not an unsigned const expr: {expr:?}" ) ,
139
+ }
140
+ }
141
+ fn resolve_const_expr_signed (
142
+ expr : & ConstExpr ,
143
+ consts_signed : & HashMap < String , i64 > ,
144
+ ) -> i64 {
145
+ match expr {
146
+ ConstExpr :: NumSigned ( n, _) => * n,
147
+ ConstExpr :: ExternalValue { party, identifier } => * consts_signed
148
+ . get ( & format ! ( "{party}::{identifier}" ) )
149
+ . unwrap ( ) ,
150
+ ConstExpr :: Max ( args) => {
151
+ let mut result = 0 ;
152
+ for arg in args {
153
+ result = max ( result, resolve_const_expr_signed ( arg, consts_signed) ) ;
154
+ }
155
+ result
156
+ }
157
+ ConstExpr :: Min ( args) => {
158
+ let mut result = i64:: MAX ;
159
+ for arg in args {
160
+ result = min ( result, resolve_const_expr_signed ( arg, consts_signed) ) ;
161
+ }
162
+ result
163
+ }
164
+ expr => panic ! ( "Not an unsigned const expr: {expr:?}" ) ,
165
+ }
166
+ }
167
+ for ( const_name, const_def) in self . const_defs . iter ( ) {
168
+ if let Type :: Unsigned ( UnsignedNumType :: Usize ) = const_def. ty {
169
+ if let ConstExpr :: ExternalValue { party, identifier } = & const_def. value {
170
+ let identifier = format ! ( "{party}::{identifier}" ) ;
171
+ const_sizes. insert ( const_name. clone ( ) , * const_sizes. get ( & identifier) . unwrap ( ) ) ;
94
172
}
95
- input_gates . push ( type_size ) ;
96
- env . let_in_current_scope ( param . name . clone ( ) , wires ) ;
173
+ let n = resolve_const_expr_unsigned ( & const_def . value , & consts_unsigned ) ;
174
+ const_sizes . insert ( const_name . clone ( ) , n as usize ) ;
97
175
}
98
- let mut circuit = CircuitBuilder :: new ( input_gates, const_sizes) ;
99
- for ( identifier, const_def) in self . const_defs . iter ( ) {
100
- match & const_def. value {
101
- ConstExpr :: True => env. let_in_current_scope ( identifier. clone ( ) , vec ! [ 1 ] ) ,
102
- ConstExpr :: False => env. let_in_current_scope ( identifier. clone ( ) , vec ! [ 0 ] ) ,
103
- ConstExpr :: NumUnsigned ( n, ty) => {
104
- let ty = Type :: Unsigned ( * ty) ;
176
+ }
177
+ let mut circuit = CircuitBuilder :: new ( input_gates, const_sizes) ;
178
+ for ( const_name, const_def) in self . const_defs . iter ( ) {
179
+ match & const_def. value {
180
+ ConstExpr :: True => env. let_in_current_scope ( const_name. clone ( ) , vec ! [ 1 ] ) ,
181
+ ConstExpr :: False => env. let_in_current_scope ( const_name. clone ( ) , vec ! [ 0 ] ) ,
182
+ ConstExpr :: NumUnsigned ( n, ty) => {
183
+ let ty = Type :: Unsigned ( * ty) ;
184
+ let mut bits =
185
+ Vec :: with_capacity ( ty. size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ) ;
186
+ unsigned_to_bits (
187
+ * n,
188
+ ty. size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
189
+ & mut bits,
190
+ ) ;
191
+ let bits = bits. into_iter ( ) . map ( |b| b as usize ) . collect ( ) ;
192
+ env. let_in_current_scope ( const_name. clone ( ) , bits) ;
193
+ }
194
+ ConstExpr :: NumSigned ( n, ty) => {
195
+ let ty = Type :: Signed ( * ty) ;
196
+ let mut bits =
197
+ Vec :: with_capacity ( ty. size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ) ;
198
+ signed_to_bits (
199
+ * n,
200
+ ty. size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
201
+ & mut bits,
202
+ ) ;
203
+ let bits = bits. into_iter ( ) . map ( |b| b as usize ) . collect ( ) ;
204
+ env. let_in_current_scope ( const_name. clone ( ) , bits) ;
205
+ }
206
+ ConstExpr :: ExternalValue { party, identifier } => {
207
+ let bits = env. get ( & format ! ( "{party}::{identifier}" ) ) . unwrap ( ) ;
208
+ env. let_in_current_scope ( const_name. clone ( ) , bits) ;
209
+ }
210
+ expr @ ( ConstExpr :: Max ( _) | ConstExpr :: Min ( _) ) => {
211
+ if let Type :: Unsigned ( _) = const_def. ty {
212
+ let result = resolve_const_expr_unsigned ( expr, & consts_unsigned) ;
105
213
let mut bits = Vec :: with_capacity (
106
- ty. size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
214
+ const_def
215
+ . ty
216
+ . size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
107
217
) ;
108
218
unsigned_to_bits (
109
- * n,
110
- ty. size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
219
+ result,
220
+ const_def
221
+ . ty
222
+ . size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
111
223
& mut bits,
112
224
) ;
113
225
let bits = bits. into_iter ( ) . map ( |b| b as usize ) . collect ( ) ;
114
- env. let_in_current_scope ( identifier. clone ( ) , bits) ;
115
- }
116
- ConstExpr :: NumSigned ( n, ty) => {
117
- let ty = Type :: Signed ( * ty) ;
226
+ env. let_in_current_scope ( const_name. clone ( ) , bits) ;
227
+ } else {
228
+ let result = resolve_const_expr_signed ( expr, & consts_signed) ;
118
229
let mut bits = Vec :: with_capacity (
119
- ty. size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
230
+ const_def
231
+ . ty
232
+ . size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
120
233
) ;
121
234
signed_to_bits (
122
- * n,
123
- ty. size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
235
+ result,
236
+ const_def
237
+ . ty
238
+ . size_in_bits_for_defs ( self , circuit. const_sizes ( ) ) ,
124
239
& mut bits,
125
240
) ;
126
241
let bits = bits. into_iter ( ) . map ( |b| b as usize ) . collect ( ) ;
127
- env. let_in_current_scope ( identifier . clone ( ) , bits) ;
242
+ env. let_in_current_scope ( const_name . clone ( ) , bits) ;
128
243
}
129
- ConstExpr :: ExternalValue { .. } => { }
130
- ConstExpr :: Max ( _) => todo ! ( "compile max" ) ,
131
- ConstExpr :: Min ( _) => todo ! ( "compile min" ) ,
132
244
}
133
245
}
134
- let output_gates = compile_block ( & fn_def. body , self , & mut env, & mut circuit) ;
135
- Ok ( ( circuit. build ( output_gates) , fn_def) )
136
- } else {
137
- Err ( CompilerError :: FnNotFound ( fn_name. to_string ( ) ) )
138
246
}
247
+ let output_gates = compile_block ( & fn_def. body , self , & mut env, & mut circuit) ;
248
+ Ok ( ( circuit. build ( output_gates) , fn_def) )
139
249
}
140
250
}
141
251
0 commit comments