1
- use crate :: evm_context ;
1
+ use crate :: evaluator :: Evaluator ;
2
2
use crate :: execution_position:: ExecutionPositionManager ;
3
- use crate :: smt:: { self , SMTExpr , SMTFormat , SMTSort , SMTStatement , SMTVariable } ;
3
+ use crate :: smt:: { SMTExpr , SMTFormat , SMTSort , SMTStatement , SMTVariable } ;
4
4
use crate :: ssa_tracker:: SSATracker ;
5
+ use crate :: { evm_context, smt} ;
5
6
6
7
use yultsur:: dialect:: { Builtin , Dialect } ;
7
8
use yultsur:: visitor:: ASTVisitor ;
@@ -28,6 +29,7 @@ pub struct Encoder<InstructionsType> {
28
29
ssa_tracker : SSATracker ,
29
30
output : Vec < SMTStatement > ,
30
31
interpreter : InstructionsType ,
32
+ evaluator : Evaluator ,
31
33
loop_unroll : u64 ,
32
34
path_conditions : Vec < SMTExpr > ,
33
35
execution_position : ExecutionPositionManager ,
@@ -56,6 +58,25 @@ pub fn encode_revert_reachable<T: Instructions>(
56
58
encode_with_counterexamples ( & mut encoder, counterexamples)
57
59
}
58
60
61
+ pub fn encode_with_evaluator < T : Instructions > (
62
+ ast : & Block ,
63
+ loop_unroll : u64 ,
64
+ evaluator : Evaluator ,
65
+ ) -> ( String , Evaluator ) {
66
+ let mut encoder = Encoder :: < T > {
67
+ evaluator,
68
+ ..Default :: default ( )
69
+ } ;
70
+ encoder. encode ( ast, loop_unroll) ;
71
+ let query = encoder
72
+ . output
73
+ . iter ( )
74
+ . map ( |s| s. as_smt ( ) )
75
+ . collect :: < Vec < _ > > ( )
76
+ . join ( "\n " ) ;
77
+ ( query, std:: mem:: take ( & mut encoder. evaluator ) )
78
+ }
79
+
59
80
pub fn encode_solc_panic_reachable < T : Instructions > (
60
81
ast : & Block ,
61
82
loop_unroll : u64 ,
@@ -133,13 +154,16 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
133
154
pub fn encode_function (
134
155
& mut self ,
135
156
function : & FunctionDefinition ,
136
- arguments : & [ SMTVariable ] ,
157
+ arguments : & Vec < SMTVariable > ,
137
158
) -> Vec < SMTVariable > {
138
159
assert_eq ! ( function. parameters. len( ) , arguments. len( ) ) ;
139
160
for ( param, arg) in function. parameters . iter ( ) . zip ( arguments) {
140
161
let var = self . ssa_tracker . introduce_variable ( param) ;
141
- self . out ( smt:: define_const ( var, smt:: SMTExpr :: from ( arg. clone ( ) ) ) )
162
+ self . out ( smt:: define_const ( var, smt:: SMTExpr :: from ( arg. clone ( ) ) ) ) ;
163
+ self . evaluator . define_from_variable ( & var, arg) ;
142
164
}
165
+
166
+ let parameters = self . ssa_tracker . to_smt_variables ( & function. parameters ) ;
143
167
self . encode_variable_declaration ( & VariableDeclaration {
144
168
variables : function. returns . clone ( ) ,
145
169
value : None ,
@@ -199,7 +223,14 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
199
223
fn encode_if ( & mut self , expr : & yul:: If ) {
200
224
let cond = self . encode_expression ( & expr. condition ) ;
201
225
assert ! ( cond. len( ) == 1 ) ;
226
+ if self
227
+ . evaluator
228
+ . variable_known_equal ( & cond[ 0 ] , & "0" . to_string ( ) )
229
+ {
230
+ return ;
231
+ }
202
232
let prev_ssa = self . ssa_tracker . copy_current_ssa ( ) ;
233
+ let prev_eval = self . evaluator . copy_state ( ) ;
203
234
204
235
self . push_path_condition ( smt:: neq ( cond[ 0 ] . clone ( ) , 0 ) ) ;
205
236
self . encode_block ( & expr. body ) ;
@@ -208,6 +239,7 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
208
239
let output = self
209
240
. ssa_tracker
210
241
. join_branches ( smt:: eq ( cond[ 0 ] . clone ( ) , 0 ) , prev_ssa) ;
242
+ self . evaluator . join_with_old_state ( prev_eval) ;
211
243
self . out_vec ( output) ;
212
244
}
213
245
@@ -220,6 +252,8 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
220
252
let cond = self . encode_expression ( & for_loop. condition ) ;
221
253
assert ! ( cond. len( ) == 1 ) ;
222
254
let prev_ssa = self . ssa_tracker . copy_current_ssa ( ) ;
255
+ let prev_eval = self . evaluator . copy_state ( ) ;
256
+ // TODO the evaluator does not have path conditions - is that OK?
223
257
224
258
self . push_path_condition ( smt:: neq ( cond[ 0 ] . clone ( ) , 0 ) ) ;
225
259
self . encode_block ( & for_loop. body ) ;
@@ -229,6 +263,7 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
229
263
let output = self
230
264
. ssa_tracker
231
265
. join_branches ( smt:: eq ( cond[ 0 ] . clone ( ) , 0 ) , prev_ssa) ;
266
+ self . evaluator . join_with_old_state ( prev_eval) ;
232
267
self . out_vec ( output) ;
233
268
}
234
269
}
@@ -238,6 +273,7 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
238
273
assert ! ( discriminator. len( ) == 1 ) ;
239
274
let pre_switch_ssa = self . ssa_tracker . copy_current_ssa ( ) ;
240
275
let mut post_switch_ssa = self . ssa_tracker . take_current_ssa ( ) ;
276
+ let prev_eval = self . evaluator . copy_state ( ) ;
241
277
242
278
for Case {
243
279
literal,
@@ -246,6 +282,15 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
246
282
} in & switch. cases
247
283
{
248
284
let is_default = literal. is_none ( ) ;
285
+ // TODO this will always run the default case.
286
+ // We should first go through all case labels and see if we only need to execute a single one.
287
+ if !is_default
288
+ && self
289
+ . evaluator
290
+ . variable_known_unequal ( & discriminator[ 0 ] , & literal. as_ref ( ) . unwrap ( ) . literal )
291
+ {
292
+ continue ;
293
+ }
249
294
250
295
self . ssa_tracker . set_current_ssa ( pre_switch_ssa. clone ( ) ) ;
251
296
@@ -258,15 +303,17 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
258
303
. map ( |case| {
259
304
smt:: eq (
260
305
discriminator[ 0 ] . clone ( ) ,
261
- self . encode_literal_value ( case. literal . as_ref ( ) . unwrap ( ) ) ,
306
+ self . encode_literal_value ( case. literal . as_ref ( ) . unwrap ( ) )
307
+ . unwrap ( ) ,
262
308
)
263
309
} )
264
310
. collect :: < Vec < _ > > ( ) ,
265
311
)
266
312
} else {
267
313
smt:: neq (
268
314
discriminator[ 0 ] . clone ( ) ,
269
- self . encode_literal_value ( literal. as_ref ( ) . unwrap ( ) ) ,
315
+ self . encode_literal_value ( literal. as_ref ( ) . unwrap ( ) )
316
+ . unwrap ( ) ,
270
317
)
271
318
} ;
272
319
@@ -278,9 +325,13 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
278
325
. ssa_tracker
279
326
. join_branches ( skip_condition, post_switch_ssa) ;
280
327
self . out_vec ( output) ;
328
+ // TODO check if this is correct.
329
+ self . evaluator . join_with_old_state ( prev_eval. clone ( ) ) ;
281
330
post_switch_ssa = self . ssa_tracker . take_current_ssa ( ) ;
282
331
post_switch_ssa. retain ( |key, _| pre_switch_ssa. contains_key ( key) ) ;
283
332
}
333
+ // TODO we should actually reset thet state of the evaluator, because
334
+ // we do not know in which branch we ended up
284
335
285
336
self . ssa_tracker . set_current_ssa ( post_switch_ssa) ;
286
337
}
@@ -296,10 +347,12 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
296
347
fn encode_literal ( & mut self , literal : & Literal ) -> SMTVariable {
297
348
let sort = SMTSort :: BV ( 256 ) ;
298
349
let var = self . new_temporary_variable ( sort) ;
299
- self . out ( smt:: define_const (
300
- var. clone ( ) ,
301
- self . encode_literal_value ( literal) ,
302
- ) ) ;
350
+ self . evaluator . define_from_literal ( & var, literal) ;
351
+ if let Some ( value) = self . encode_literal_value ( literal) {
352
+ self . out ( smt:: define_const ( var. clone ( ) , value) ) ;
353
+ } else {
354
+ self . out ( smt:: declare_const ( var. clone ( ) ) )
355
+ }
303
356
var
304
357
}
305
358
@@ -308,7 +361,7 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
308
361
}
309
362
310
363
fn encode_function_call ( & mut self , call : & FunctionCall ) -> Vec < SMTVariable > {
311
- let arguments = call
364
+ let arguments: Vec < SMTVariable > = call
312
365
. arguments
313
366
. iter ( )
314
367
. rev ( )
@@ -343,6 +396,15 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
343
396
. map ( |_i| self . new_temporary_variable ( SMTSort :: BV ( 256 ) ) )
344
397
. collect ( ) ;
345
398
399
+ // TODO call evaluator first or interpreter first?
400
+ self . evaluator
401
+ . builtin_call ( builtin, & arguments, & return_vars) ;
402
+ if builtin. name == "call" {
403
+ // if let Some((ast, calldata)) = self.evaluator.is_call_to_knon_contract(arguments) {
404
+
405
+ // // TODO
406
+ // }
407
+ }
346
408
let result = self . interpreter . encode_builtin_call (
347
409
builtin,
348
410
arguments,
@@ -355,7 +417,13 @@ impl<InstructionsType: Instructions> Encoder<InstructionsType> {
355
417
}
356
418
IdentifierID :: Reference ( id) => {
357
419
let fun_def = self . function_definitions [ & id] . clone ( ) ;
358
- self . encode_function ( & fun_def, & arguments)
420
+ let function_vars = self . encode_function ( & fun_def, & arguments) ;
421
+ assert ! ( arguments. len( ) == function_vars. parameters. len( ) ) ;
422
+ arguments
423
+ . into_iter ( )
424
+ . zip ( function_vars. parameters )
425
+ . for_each ( |( arg, param) | self . out ( smt:: assert ( smt:: eq ( arg, param) ) ) ) ;
426
+ function_vars. returns
359
427
}
360
428
_ => panic ! (
361
429
"Unexpected reference in function call: {:?}" ,
@@ -393,21 +461,27 @@ impl<T> Encoder<T> {
393
461
394
462
for ( v, val) in variables. iter ( ) . zip ( values. into_iter ( ) ) {
395
463
let var = self . ssa_tracker . allocate_new_ssa_index ( v) ;
464
+ self . evaluator . define_from_variable ( & var, & val) ;
396
465
self . out ( smt:: define_const ( var, val. into ( ) ) ) ;
397
466
}
398
467
}
399
468
400
- fn encode_literal_value ( & self , literal : & Literal ) -> SMTExpr {
401
- if literal. literal . starts_with ( "0x" ) {
402
- smt:: literal ( format ! ( "{:0>64}" , & literal. literal[ 2 ..] ) , SMTSort :: BV ( 256 ) )
469
+ fn encode_literal_value ( & self , literal : & Literal ) -> Option < SMTExpr > {
470
+ if let Some ( hex) = literal. literal . strip_prefix ( "0x" ) {
471
+ Some ( smt:: literal ( format ! ( "{:0>64}" , hex) , SMTSort :: BV ( 256 ) ) )
472
+ } else if let Some ( string) = literal. literal . strip_prefix ( "\" " ) {
473
+ assert ! ( string. len( ) >= 2 && string. chars( ) . last( ) . unwrap( ) == '"' ) ;
474
+ // This is usually only used for references to data objects,
475
+ // so we do not encode it.
476
+ None
403
477
} else {
404
- smt:: literal (
478
+ Some ( smt:: literal (
405
479
format ! (
406
480
"{:064X}" ,
407
481
literal. literal. parse:: <num_bigint:: BigUint >( ) . unwrap( )
408
482
) ,
409
483
SMTSort :: BV ( 256 ) ,
410
- )
484
+ ) )
411
485
}
412
486
}
413
487
0 commit comments