Skip to content

Commit

Permalink
Lean feat: use match_bv when possible (#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
arthur-adjedj authored Feb 13, 2025
1 parent 1d02546 commit a434a0a
Show file tree
Hide file tree
Showing 31 changed files with 329 additions and 12 deletions.
5 changes: 4 additions & 1 deletion src/bin/dune
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,7 @@
src/gen_lib/sail2_values.lem)
(%{workspace_root}/src/sail_lean_backend/Sail/Sail.lean
as
src/sail_lean_backend/Sail/Sail.lean)))
src/sail_lean_backend/Sail/Sail.lean)
(%{workspace_root}/src/sail_lean_backend/Sail/BitVec.lean
as
src/sail_lean_backend/Sail/BitVec.lean)))
212 changes: 212 additions & 0 deletions src/sail_lean_backend/Sail/BitVec.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/-
Copyright (c) 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Shilpi Goel, Siddharth Bhat
-/

-- Taken from https://github.com/leanprover/LNSym/blob/main/Arm/BitVec.lean

import Lean.Elab.Term
import Lean.Meta.Reduce
import Std.Tactic.BVDecide

open BitVec

/- Bitvector pattern component syntax category, originally written by
Leonardo de Moura. -/
declare_syntax_cat bvpat_comp
syntax num : bvpat_comp
syntax ident (":" num)? : bvpat_comp
syntax "_" ":" num : bvpat_comp

/--
Bitvector pattern syntax category.
Example: [sf:1,0011010000,Rm:5,000000,Rn:5,Rd:5]
-/
declare_syntax_cat bvpat
syntax "[" bvpat_comp,* "]" : bvpat

open Lean

abbrev BVPatComp := TSyntax `bvpat_comp
abbrev BVPat := TSyntax `bvpat

/-- Return the number of bits in a bit-vector component pattern. -/
def BVPatComp.length (c : BVPatComp) : Nat := Id.run do
match c with
| `(bvpat_comp| $n:num) =>
let some str := n.raw.isLit? `num | pure 0
return str.length
| `(bvpat_comp| $_:ident : $n:num) =>
return n.raw.toNat
| `(bvpat_comp| $_:ident ) =>
return 1
| `(bvpat_comp| _ : $n:num) =>
return n.raw.toNat
| _ =>
return 0

/--
If the pattern component is a bitvector literal, convert it into a bit-vector term
denoting it.
-/
def BVPatComp.toBVLit? (c : BVPatComp) : MacroM (Option Term) := do
match c with
| `(bvpat_comp| $n:num) =>
let len := c.length
let some str := n.raw.isLit? `num | Macro.throwErrorAt c "invalid bit-vector literal"
let bs := str.toList
let mut val := 0
for b in bs do
if b = '1' then
val := 2*val + 1
else if b = '0' then
val := 2*val
else
Macro.throwErrorAt c "invalid bit-vector literal, '0'/'1's expected"
let r ← `(BitVec.ofNat $(quote len) $(quote val))
return some r
| _ => return none

/--
If the pattern component is a pattern variable of the form `<id>:<size>` return
`some id`.
-/
def BVPatComp.toBVVar? (c : BVPatComp) : MacroM (Option (TSyntax `ident)) := do
match c with
| `(bvpat_comp| $x:ident $[: $_:num]?) =>
return some x
| _ => return none

def BVPat.getComponents (p : BVPat) : Array BVPatComp :=
match p with
| `(bvpat| [$comp,*]) => comp.getElems.reverse
| _ => #[]

/--
Return the number of bits in a bit-vector pattern.
-/
def BVPat.length (p : BVPat) : Nat := Id.run do
let mut sz := 0
for c in p.getComponents do
sz := sz + c.length
return sz

/--
Return a term that evaluates to `true` if `var` is an instance of the pattern `pat`.
-/
def genBVPatMatchTest (vars : Array Term) (pats : Array BVPat) : MacroM Term := do
if vars.size != pats.size then
Macro.throwError "incorrect number of patterns"
let mut result ← `(true)

for (pat, var) in pats.zip vars do
let mut shift := 0
for c in pat.getComponents do
let len := c.length
if let some bv ← c.toBVLit? then
let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv)
result ← `($result && $test)
shift := shift + len
return result

/--
Given a variable `var` representing a term that matches the pattern `pat`, and a term `rhs`,
return a term of the form
```
let y₁ := var.extract ..
...
let yₙ := var.extract ..
rhs
```
where `yᵢ`s are the pattern variables in `pat`.
-/
def declBVPatVars (vars : Array Term) (pats : Array BVPat) (rhs : Term) : MacroM Term := do
let mut result := rhs
for (pat, var) in pats.zip vars do
let mut shift := 0
for c in pat.getComponents do
let len := c.length
if let some y ← c.toBVVar? then
let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var)
result ← `(let $y := $rhs; $result)
shift := shift + len
return result

/--
Define the `match_bv .. with | bvpat => rhs | _ => rhs`.
The last entry is the `else`-case since we currently do not check whether
the patterns are exhaustive or not.
-/
syntax (name := matchBv) "match_bv " term,+ "with" (atomic("| " bvpat,+) " => " term)* ("| " "_ " " => " term)? : term

open Lean
open Elab
open Term

def checkBVPatLengths (lens : Array (Option Nat)) (pss : Array (Array BVPat)) : TermElabM Unit := do
for (len, i) in lens.zipWithIndex do
let mut patLen := none
for ps in pss do
unless ps.size == lens.size do
throwError "Expected {lens.size} patterns, found {ps.size}"
let p := ps[i]!
let pLen := p.length

-- compare the length to that of the type of the discriminant
if let some pLen' := len then
unless pLen == pLen' do
throwErrorAt p "Exprected pattern of length {pLen}, found {pLen'} instead"

-- compare the lengths of the patterns
if let some pLen' := patLen then
unless pLen == pLen' do
throwErrorAt p "patterns have differrent lengths"
else
patLen := some pLen

-- We use this to gather all the conditions expressing that the
-- previous pattern matches failed. This allows in turn to prove
-- exaustivity of the pattern matching.
abbrev dite_gather {α : Sort u} {old : Prop} (c : Prop) [h : Decidable c]
(t : old ∧ c → α) (e : old ∧ ¬ c → α) (ho : old) : α :=
h.casesOn (λ hc => e (And.intro ho hc)) (λ hc => t (And.intro ho hc))

@[term_elab matchBv]
partial
def elabMatchBv : TermElab := fun stx typ? =>
match stx with
| `(match_bv $[$discrs:term],* with
$[ | $[$pss:bvpat],* => $rhss:term ]*
$[| _ => $rhsElse?:term]?) => do
let xs := discrs

-- try to get the length of the BV to error-out
-- if a pattern has the wrong length
-- TODO: is it the best way to do that?
let lens ← discrs.mapM (fun x => do
let x ← elabTerm x none
let typ ← Meta.inferType x
match_expr typ with
| BitVec n =>
let n ← Meta.reduce n
match n with
| .lit (.natVal n) => return some n
| _ => return none
| _ => return none)

checkBVPatLengths lens pss

let mut result :=
if let some rhsElse := rhsElse? then
`(Function.const _ $rhsElse)
else
`(fun _ => by bv_decide)

for ps in pss.reverse, rhs in rhss.reverse do
let test ← liftMacroM <| genBVPatMatchTest xs ps
let rhs ← liftMacroM <| declBVPatVars xs ps rhs
result ← `(dite_gather $test (Function.const _ $rhs) $result)
let res ← liftMacroM <| `($result True.intro)
elabTerm res typ?
| _ => throwError "invalid syntax"
31 changes: 27 additions & 4 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ let doc_lit (L_aux (lit, l)) =
| L_string s -> utf8string ("\"" ^ lean_escape_string s ^ "\"")
| L_real s -> utf8string s (* TODO test if this is really working *)

let doc_vec_lit (L_aux (lit, _) as l) =
match lit with
| L_zero -> string "0"
| L_one -> string "1"
| _ -> failwith "Unexpected litteral found in vector: " ^^ doc_lit l

let string_of_exp_con (E_aux (e, _)) =
match e with
| E_block _ -> "E_block"
Expand Down Expand Up @@ -362,17 +368,25 @@ let string_of_pat_con (P_aux (p, _)) =
let fixup_match_id (Id_aux (id, l) as id') =
match id with Id id -> Id_aux (Id (match id with "Some" -> "some" | "None" -> "none" | _ -> id), l) | _ -> id'

let rec doc_pat (P_aux (p, (l, annot)) as pat) =
let rec doc_pat ?(in_vector = false) (P_aux (p, (l, annot)) as pat) =
match p with
| P_wild -> underscore
| P_lit lit when in_vector -> doc_vec_lit lit
| P_lit lit -> doc_lit lit
| P_typ (Typ_aux (Typ_id (Id_aux (Id "bit", _)), _), p) when in_vector -> doc_pat p ^^ string ":1"
| P_typ (Typ_aux (Typ_app (Id_aux (Id id, _), [A_aux (A_nexp (Nexp_aux (Nexp_constant i, _)), _)]), _), p)
when in_vector && (id = "bits" || id = "bitvector") ->
doc_pat p ^^ string ":" ^^ doc_big_int i
| P_typ (ptyp, p) -> doc_pat p
| P_id id -> fixup_match_id id |> doc_id_ctor
| P_tuple pats -> separate (string ", ") (List.map doc_pat pats) |> parens
| P_list pats -> separate (string ", ") (List.map doc_pat pats) |> brackets
| P_vector pats -> concat (List.map (doc_pat ~in_vector:true) pats)
| P_vector_concat pats -> separate (string ",") (List.map (doc_pat ~in_vector:true) pats) |> brackets
| P_app (Id_aux (Id "None", _), p) -> string "none"
| P_app (cons, pats) -> doc_id_ctor (fixup_match_id cons) ^^ space ^^ separate_map (string ", ") doc_pat pats
| _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")
| P_as (pat, id) -> doc_pat pat
| _ -> failwith ("Doc Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")

(* Copied from the Coq PP *)
let rebind_cast_pattern_vars pat typ exp =
Expand Down Expand Up @@ -412,6 +426,13 @@ let get_fn_implicits (Typ_aux (t, _)) : bool list =
in
match t with Typ_fn (args, cod) -> List.map arg_implicit args | _ -> []

let rec is_bitvector_pattern (P_aux (pat, _)) =
match pat with P_vector _ | P_vector_concat _ -> true | P_as (pat, _) -> is_bitvector_pattern pat | _ -> false

let match_or_match_bv brs =
if List.exists (function Pat_aux (Pat_exp (pat, _), _) -> is_bitvector_pattern pat | _ -> false) brs then "match_bv "
else "match "

let rec doc_match_clause (as_monadic : bool) ctx (Pat_aux (cl, l)) =
match cl with
| Pat_exp (pat, branch) ->
Expand Down Expand Up @@ -493,8 +514,10 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
wrap_with_pure as_monadic
(braces (space ^^ doc_exp false ctx exp ^^ string " with " ^^ separate (comma ^^ space) args ^^ space))
| E_match (discr, brs) ->
let cases = separate_map hardline (fun br -> doc_match_clause as_monadic ctx br) brs in
string "match " ^^ doc_exp (effectful (effect_of discr)) ctx discr ^^ string " with" ^^ hardline ^^ cases
let cases = separate_map hardline (doc_match_clause as_monadic ctx) brs in
string (match_or_match_bv brs)
^^ doc_exp (effectful (effect_of discr)) ctx discr
^^ string " with" ^^ hardline ^^ cases
| E_assign ((LE_aux (le_act, tannot) as le), e) -> (
match le_act with
| LE_id id | LE_typ (_, id) -> string "writeReg " ^^ doc_id_ctor id ^^ space ^^ doc_exp false ctx e
Expand Down
11 changes: 6 additions & 5 deletions src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ let lean_rewrites =
("move_termination_measures", []);
("instantiate_outcomes", [String_arg "coq"]);
("realize_mappings", []);
("remove_vector_subrange_pats", []);
(* ("remove_vector_subrange_pats", []); *)
("remove_duplicate_valspecs", []);
("toplevel_string_append", []);
("pat_string_append", []);
Expand All @@ -107,8 +107,8 @@ let lean_rewrites =
("tuple_assignments", []);
("vector_concat_assignments", []);
("simple_assignments", []);
("remove_vector_concat", []);
("remove_bitvector_pats", []);
(* ("remove_vector_concat", []); *)
(* ("remove_bitvector_pats", []); *)
(* ("remove_numeral_pats", []); *)
(* ("pattern_literals", [Literal_arg "lem"]); *)
("guarded_pats", []);
Expand All @@ -129,7 +129,7 @@ let lean_rewrites =
(* We need to do the exhaustiveness check before merging, because it may
introduce new wildcard clauses *)
("recheck_defs", []);
("make_cases_exhaustive", []);
(* ("make_cases_exhaustive", []); *)
(* merge funcls before adding the measure argument so that it doesn't
disappear into an internal pattern match *)
("merge_function_clauses", []);
Expand Down Expand Up @@ -185,7 +185,8 @@ let start_lean_output (out_name : string) default_sail_dir =
("cp -r " ^ Filename.quote (sail_dir ^ "/src/sail_lean_backend/Sail") ^ " " ^ Filename.quote lean_src_dir)
in
let main_file = open_out (Filename.concat project_dir (out_name_camel ^ ".lean")) in
output_string main_file ("import " ^ out_name_camel ^ ".Sail.Sail\n\n");
output_string main_file ("import " ^ out_name_camel ^ ".Sail.Sail\n");
output_string main_file ("import " ^ out_name_camel ^ ".Sail.BitVec\n\n");
output_string main_file "open Sail\n\n";
let lakefile = open_out (Filename.concat project_dir "lakefile.toml") in
{ out_name; out_name_camel; sail_dir; main_file; lakefile }
Expand Down
1 change: 1 addition & 0 deletions test/c/hello_world.expected.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Out.Sail.Sail
import Out.Sail.BitVec

open Sail

Expand Down
1 change: 1 addition & 0 deletions test/lean/atom_bool.expected.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Out.Sail.Sail
import Out.Sail.BitVec

open Sail

Expand Down
1 change: 1 addition & 0 deletions test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Out.Sail.Sail
import Out.Sail.BitVec

open Sail

Expand Down
1 change: 1 addition & 0 deletions test/lean/bitvec_operation.expected.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Out.Sail.Sail
import Out.Sail.BitVec

open Sail

Expand Down
1 change: 1 addition & 0 deletions test/lean/enum.expected.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Out.Sail.Sail
import Out.Sail.BitVec

open Sail

Expand Down
5 changes: 3 additions & 2 deletions test/lean/errors.expected.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Out.Sail.Sail
import Out.Sail.BitVec

open Sail

Expand All @@ -15,13 +16,13 @@ instance : Inhabited (RegisterRef RegisterType (BitVec 1)) where
default := .Reg dummy
abbrev SailM := PreSailM RegisterType trivialChoiceSource

/-- Type quantifiers: k_ex824# : Bool -/
/-- Type quantifiers: k_ex809# : Bool -/
def test_exit (b : Bool) : SailM Unit := do
if b
then throw Error.Exit
else (pure ())

/-- Type quantifiers: k_ex826# : Bool -/
/-- Type quantifiers: k_ex811# : Bool -/
def test_assert (b : Bool) : SailM (BitVec 1) := do
assert b "b is false"
(pure 1#1)
Expand Down
1 change: 1 addition & 0 deletions test/lean/extern.expected.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Out.Sail.Sail
import Out.Sail.BitVec

open Sail

Expand Down
1 change: 1 addition & 0 deletions test/lean/extern_bitvec.expected.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Out.Sail.Sail
import Out.Sail.BitVec

open Sail

Expand Down
1 change: 1 addition & 0 deletions test/lean/implicit.expected.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Out.Sail.Sail
import Out.Sail.BitVec

open Sail

Expand Down
Loading

0 comments on commit a434a0a

Please sign in to comment.