diff --git a/solver/passes/template_pass.py b/solver/passes/template_pass.py index 9b8a6cb..afd55d3 100644 --- a/solver/passes/template_pass.py +++ b/solver/passes/template_pass.py @@ -1,6 +1,7 @@ from parser.ast import Expr from parser.visitor import Visitor + def run_pass(ast: Expr) -> Expr: # Sample visitor v: TemplateVisitor = TemplateVisitor() diff --git a/solver/passes/z3_pass b/solver/passes/z3_pass.py similarity index 70% rename from solver/passes/z3_pass rename to solver/passes/z3_pass.py index a85de67..59638f2 100644 --- a/solver/passes/z3_pass +++ b/solver/passes/z3_pass.py @@ -1,12 +1,14 @@ from typing import override -from parser.ast import Expr + +import html, re + from parser.ast import AndExpr, OrExpr, Expr, NotExpr, ParenExpr, Var, VarExpr from parser.visitor import Visitor, RetVisitor -import z3 - from parser.lex import Lexer from parser.parse import Parser +import z3 + def run_pass(ast: Expr) -> Expr: v: Z3MappingVisitor = Z3MappingVisitor() @@ -15,7 +17,28 @@ def run_pass(ast: Expr) -> Expr: t: TranslateToZ3 = TranslateToZ3(v.symbols) p: z3.ExprRef = ast.acceptRet(t) simplifiedExpr: z3.ExprRef = z3.simplify(p) + + # Quick hack to force z3 into html mode + # so we can parse the simplified expression + # https://ericpony.github.io/z3py-tutorial/advanced-examples.htm + z3.set_option(html_mode=True) + simplifiedStr: str = str(simplifiedExpr) + simplifiedStr = ( + html.unescape(simplifiedStr) + .replace(chr(8744), "|") + .replace(chr(8743), "&") + .replace(chr(172), "!") + ) + # z3's XOR pretty print does not print cleanly so + # this is a hack to fix that + + # Pattern to match "Xor(A, B)" + pattern = r"Xor\((\w+), (\w+)\)" + # Replacement string using backreferences to capture groups + replacement = r"\1 ^ \2" + # Performing the replacement + simplifiedStr = re.sub(pattern, replacement, simplifiedStr) l: Lexer = Lexer() l.lex(simplifiedStr) @@ -74,11 +97,13 @@ def visitVar(self, va: Var) -> z3.ExprRef: if __name__ == "__main__": - prog: str = "A | B & C" + prog: str = "!(!(B | !C))" l: Lexer = Lexer() l.lex(prog) p: Parser = Parser() ast: Expr = p.parse(l.tokens) - run_pass(ast) + ast = run_pass(ast) + + assert str(ast) == "B | !C"