Skip to content

Commit

Permalink
Lean: add some definitions and annotations (#1062)
Browse files Browse the repository at this point in the history
  • Loading branch information
ineol authored Feb 24, 2025
1 parent 80049ce commit 07baef1
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 31 deletions.
1 change: 1 addition & 0 deletions lib/smt.sail
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ val abs_int_atom = pure {
interpreter: "abs_int",
lem: "integerAbs",
coq: "Z.abs",
lean: "Int.natAbs",
_: "abs_int"
} : forall 'n. int('n) -> int(abs('n))

Expand Down
95 changes: 65 additions & 30 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,6 @@ import Std.Data.DHashMap
import Std.Data.HashMap
namespace Sail

instance : CoeT Int x Nat where
coe := x.toNat

instance : CoeT (BitVec n) x (BitVec m) where
coe := x.setWidth m

instance : HAdd (BitVec n) (BitVec m) (BitVec n) where
hAdd x y := x + y

instance : HSub (BitVec n) (BitVec m) (BitVec n) where
hSub x y := x - y

instance : HAnd (BitVec n) (BitVec m) (BitVec n) where
hAnd x y := x &&& y

instance : HOr (BitVec n) (BitVec m) (BitVec n) where
hOr x y := x ||| y

instance : HXor (BitVec n) (BitVec m) (BitVec n) where
hXor x y := x ^^^ y

namespace BitVec

def length {w : Nat} (_ : BitVec w) : Nat := w
Expand Down Expand Up @@ -68,10 +47,6 @@ def append' (x : BitVec n) (y : BitVec m) {mn}
(hmn : mn = n + m := by (conv => rhs; simp); try rfl) : BitVec mn :=
(x.append y).cast hmn.symm

def extractLsbUnif {w : Nat} (x : BitVec w) (hi lo : Nat) {w}
(hw : w = hi - lo + 1 := by (conv => rhs; simp [Nat.add_sub_cancel_left]); try rfl) : BitVec w :=
(extractLsb x hi lo).cast hw.symm

def update (x : BitVec m) (n : Nat) (b : BitVec 1) := updateSubrange' x n _ b

def toBin {w : Nat} (x : BitVec w) : String :=
Expand All @@ -91,6 +66,36 @@ instance : Coe (BitVec (1 * n)) (BitVec n) where

end BitVec

def charToHex (c : Char) : BitVec 4 :=
match c.toLower with
| '0' => 0 | '1' => 1 | '2' => 2 | '3' => 3 | '4' => 4
| '5' => 5 | '6' => 6 | '7' => 7 | '8' => 8 | '9' => 9
| 'a' => 10 | 'b' => 11 | 'c' => 12 | 'd' => 13
| 'e' => 14 | 'f' => 15 | _ => 0

def parse_hex_bits (n : Nat) (str : String) : BitVec n :=
if h : n < 4 then BitVec.zero n else
let bv := parse_hex_bits (n-4) (str.drop 1)
let c := str.get! ⟨0⟩ |> charToHex
BitVec.append c bv |>.cast (by omega)

def valid_hex_bits (n : Nat) (str : String) : Bool := str.length = n ∧ str.all fun x =>
x.toLower ∈ ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f']

def shift_bits_left (bv : BitVec n) (sh : BitVec m) : BitVec n :=
bv <<< sh

def shift_bits_right (bv : BitVec n) (sh : BitVec m) : BitVec n :=
bv >>> sh

def shiftl (bv : BitVec n) (m : Nat) : BitVec n :=
bv <<< m

def shiftr (bv : BitVec n) (m : Nat) : BitVec n :=
bv >>> m

def pow2 := (2 ^ ·)

section Regs

variable {Register : Type} {RegisterType : Register → Type} [DecidableEq Register] [Hashable Register]
Expand Down Expand Up @@ -164,7 +169,8 @@ structure SequentialState (RegisterType : Register → Type) (c : ChoiceSource)
choiceState : c.α
mem : Std.HashMap Nat (BitVec 8)
tags : Unit
sail_output : Array String -- TODO: be able to use the IO monad to run
cycleCount : Nat -- Part of the concurrency interface. See `{get_}cycle_count`
sailOutput : Array String -- TODO: be able to use the IO monad to run

inductive RegisterRef (RegisterType : Register → Type) : TypeType where
| Reg (r : Register) : RegisterRef _ (RegisterType r)
Expand Down Expand Up @@ -336,10 +342,16 @@ def read_ram (addr_size data_size : Nat) (_hex_ram addr : BitVec addr_size) : Pr

def sail_barrier (_ : α) : PreSailM RegisterType c ue Unit := pure ()

def cycle_count (_ : Unit) : PreSailM RegisterType c ue Unit :=
modify fun s => { s with cycleCount := s.cycleCount + 1 }

def get_cycle_count (_ : Unit) : PreSailM RegisterType c ue Nat := do
pure (← get).cycleCount

end ConcurrencyInterface

def print_effect (str : String) : PreSailM RegisterType c ue Unit :=
modify fun s ↦ { s with sail_output := s.sail_output.push str }
modify fun s ↦ { s with sailOutput := s.sailOutput.push str }

def print_int_effect (str : String) (n : Int) : PreSailM RegisterType c ue Unit :=
print_effect s!"{str}{n}\n"
Expand All @@ -354,11 +366,11 @@ def main_of_sail_main (initialState : SequentialState RegisterType c) (main : Un
let res := main () |>.run initialState
match res with
| .ok _ s => do
for m in s.sail_output do
for m in s.sailOutput do
IO.print m
return 0
| .error e s => do
for m in s.sail_output do
for m in s.sailOutput do
IO.print m
IO.eprintln s!"Error while running the sail program!: {e.print}"
return 1
Expand All @@ -368,7 +380,7 @@ section Loops
def foreach_' (from' to step : Nat) (vars : Vars) (body : Nat -> Vars -> Vars) : Vars := Id.run do
let mut vars := vars
let step := 1 + (step - 1)
let range := Std.Range.mk from' to step (by omega)
let range := Std.Range.mk from' (to + 1) step (by omega)
for i in range do
vars := body i vars
pure vars
Expand Down Expand Up @@ -482,3 +494,26 @@ instance : HShiftRight (BitVec w) Int (BitVec w) where
hShiftRight b i := b <<< (-i)

end Sail
instance : CoeT Int x Nat where
coe := x.toNat

instance : CoeT (BitVec n) x (BitVec m) where
coe := x.setWidth m

instance : HAdd (BitVec n) (BitVec m) (BitVec n) where
hAdd x y := x + y

instance : HSub (BitVec n) (BitVec m) (BitVec n) where
hSub x y := x - y

instance : HAnd (BitVec n) (BitVec m) (BitVec n) where
hAnd x y := x &&& y

instance : HOr (BitVec n) (BitVec m) (BitVec n) where
hOr x y := x ||| y

instance : HXor (BitVec n) (BitVec m) (BitVec n) where
hXor x y := x ^^^ y

instance : HPow Nat Int Int where
hPow x n := x ^ n.toNat
2 changes: 1 addition & 1 deletion src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ let main_function_stub has_registers =
(separate hardline
[
string "def main (_ : List String) : IO UInt32 := do";
Printf.ksprintf string "main_of_sail_main ⟨default, (), default, default, default⟩ %s" main_call;
Printf.ksprintf string "main_of_sail_main ⟨default, (), default, default, default, default⟩ %s" main_call;
empty;
]
)
Expand Down

0 comments on commit 07baef1

Please sign in to comment.