Skip to content

Commit

Permalink
Lean: force all arithmetic operations to be on integers (#1088)
Browse files Browse the repository at this point in the history
This attempts to solve 1075 by just doing all of the operations in Int instead of `Nat.
  • Loading branch information
javra authored Feb 28, 2025
1 parent 5bdbdb2 commit 96ede10
Show file tree
Hide file tree
Showing 28 changed files with 109 additions and 111 deletions.
10 changes: 5 additions & 5 deletions lib/arith.sail
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ $include <flow.sail>

// ***** Addition *****

val add_atom = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", lean: "_lean_add", _: "add_int"} : forall 'n 'm.
val add_atom = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", lean: "_lean_addi", _: "add_int"} : forall 'n 'm.
(int('n), int('m)) -> int('n + 'm)

val add_int = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", lean: "_lean_add", _: "add_int"} : (int, int) -> int
val add_int = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", lean: "_lean_addi", _: "add_int"} : (int, int) -> int

overload operator + = {add_atom, add_int}

Expand Down Expand Up @@ -85,10 +85,10 @@ overload negate = {negate_atom, negate_int}

// ***** Multiplication *****

val mult_atom = pure {ocaml: "mult", interpreter: "mult", lem: "integerMult", coq: "Z.mul", lean: "_lean_mul", _: "mult_int"} : forall 'n 'm.
val mult_atom = pure {ocaml: "mult", interpreter: "mult", lem: "integerMult", coq: "Z.mul", lean: "_lean_muli", _: "mult_int"} : forall 'n 'm.
(int('n), int('m)) -> int('n * 'm)

val mult_int = pure {ocaml: "mult", interpreter: "mult", lem: "integerMult", coq: "Z.mul", lean: "_lean_mul", _: "mult_int"} : (int, int) -> int
val mult_int = pure {ocaml: "mult", interpreter: "mult", lem: "integerMult", coq: "Z.mul", lean: "_lean_muli", _: "mult_int"} : (int, int) -> int

overload operator * = {mult_atom, mult_int}

Expand All @@ -111,7 +111,7 @@ $endif
/*!
We have special support for raising values to the power of two. Any Sail expression `2 ^ x` will be compiled to this builtin.
*/
val pow2 = pure {lean : "_lean_pow2", _:"pow2"} : forall 'n. int('n) -> int(2 ^ 'n)
val pow2 = pure {lean : "_lean_pow2i", _:"pow2"} : forall 'n. int('n) -> int(2 ^ 'n)

// ***** Integer shifts *****

Expand Down
27 changes: 11 additions & 16 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ instance : HShiftLeft (BitVec w) Int (BitVec w) where
hShiftLeft b i :=
match i with
| .ofNat n => BitVec.shiftLeft b n
| .negSucc n => BitVec.ushiftRight b n
| .negSucc n => BitVec.ushiftRight b (n + 1)

instance : HShiftRight (BitVec w) Int (BitVec w) where
hShiftRight b i := b <<< (-i)
hShiftRight b i := b <<< (- i)

section Loops

Expand Down Expand Up @@ -596,19 +596,14 @@ instance : HOr (BitVec n) (BitVec m) (BitVec n) where
instance : HXor (BitVec n) (BitVec m) (BitVec n) where
hXor x y := x ^^^ y

instance : HPow Nat Int Nat where
hPow x n := x ^ n.toNat
def Int.zpow (m n : Int) : Int := m ^ n.toNat

instance : HPow Int Int Int where
hPow x n := x ^ n.toNat
infixl:65 " +i " => Int.add
infixl:65 " -i " => Int.sub
infixr:80 " ^i " => Int.pow
infixl:70 " *i " => Int.mul

instance : HSub Nat Nat Int where
hSub m n := (m : Int) - (n : Int)

instance : HPow Nat Int Int where
hPow m z := (m : Int) ^ z

instance : HSub Nat Int Int where
hSub m z := (m : Int) - z

infixl:65 " -i " => HSub.hSub (γ := Int)
macro_rules | `($x +i $y) => `(binop% Int.add $x $y)
macro_rules | `($x -i $y) => `(binop% Int.sub $x $y)
macro_rules | `($x ^i $y) => `(rightact% Int.zpow $x $y)
macro_rules | `($x *i $y) => `(binop% Int.mul $x $y)
3 changes: 3 additions & 0 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,11 @@ let rec doc_implicit_args ?(docs = []) ns ims d_args =
let op_of_id id =
match id with
| Some "_lean_add" -> `Binop "+"
| Some "_lean_addi" -> `Binop "+i"
| Some "_lean_sub" -> `Binop "-"
| Some "_lean_subi" -> `Binop "-i"
| Some "_lean_mul" -> `Binop "*"
| Some "_lean_muli" -> `Binop "*i"
| Some "_lean_div" -> `Binop "/"
| Some "_lean_app" -> `Binop "++"
| Some "_lean_bvand" -> `Binop "&&&"
Expand All @@ -542,6 +544,7 @@ let op_of_id id =
| Some "_lean_shiftl" -> `Binop "<<<"
| Some "_lean_shiftr" -> `Binop ">>>"
| Some "_lean_pow2" -> `Unnop "2 ^"
| Some "_lean_pow2i" -> `Unnop "2 ^i"
| _ -> `NotOp

let unnop_of_id id = match id with Some "_lean_pow2" -> Some "2 ^ " | _ -> None
Expand Down
4 changes: 2 additions & 2 deletions test/lean/SailTinyArm.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -622,14 +622,14 @@ def _shr_int_general (m : Int) (n : Int) : Int :=
/-- Type quantifiers: m : Int, n : Int -/
def fdiv_int (n : Int) (m : Int) : Int :=
if (Bool.and (LT.lt n 0) (GT.gt m 0))
then ((Int.tdiv (n + 1) m) -i 1)
then ((Int.tdiv (n +i 1) m) -i 1)
else if (Bool.and (GT.gt n 0) (LT.lt m 0))
then ((Int.tdiv (n -i 1) m) -i 1)
else (Int.tdiv n m)

/-- Type quantifiers: m : Int, n : Int -/
def fmod_int (n : Int) (m : Int) : Int :=
(n -i (m * (fdiv_int n m)))
(n -i (m *i (fdiv_int n m)))

/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
Expand Down
6 changes: 3 additions & 3 deletions test/lean/append.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def _shr_int_general (m : Int) (n : Int) : Int :=
/-- Type quantifiers: m : Int, n : Int -/
def fdiv_int (n : Int) (m : Int) : Int :=
if (Bool.and (LT.lt n 0) (GT.gt m 0))
then ((Int.tdiv (n + 1) m) -i 1)
then ((Int.tdiv (n +i 1) m) -i 1)
else if (Bool.and (GT.gt n 0) (LT.lt m 0))
then ((Int.tdiv (n -i 1) m) -i 1)
else (Int.tdiv n m)

/-- Type quantifiers: m : Int, n : Int -/
def fmod_int (n : Int) (m : Int) : Int :=
(n -i (m * (fdiv_int n m)))
(n -i (m *i (fdiv_int n m)))

/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
Expand Down Expand Up @@ -130,7 +130,7 @@ def unif_subrange_bits (x : (BitVec 16)) : (BitVec (17 - 10 + 1)) :=

/-- Type quantifiers: i : Nat, i ≥ 0 -/
def unif_vector_subrange (i : Nat) (v : (BitVec (8 * i + 8))) : (BitVec 8) :=
(Sail.BitVec.extractLsb v ((8 * i) + 7) (8 * i))
(Sail.BitVec.extractLsb v ((8 *i i) +i 7) (8 *i i))

def initialize_registers (_ : Unit) : Unit :=
()
Expand Down
4 changes: 2 additions & 2 deletions test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ def _shr_int_general (m : Int) (n : Int) : Int :=
/-- Type quantifiers: m : Int, n : Int -/
def fdiv_int (n : Int) (m : Int) : Int :=
if (Bool.and (LT.lt n 0) (GT.gt m 0))
then ((Int.tdiv (n + 1) m) -i 1)
then ((Int.tdiv (n +i 1) m) -i 1)
else if (Bool.and (GT.gt n 0) (LT.lt m 0))
then ((Int.tdiv (n -i 1) m) -i 1)
else (Int.tdiv n m)

/-- Type quantifiers: m : Int, n : Int -/
def fmod_int (n : Int) (m : Int) : Int :=
(n -i (m * (fdiv_int n m)))
(n -i (m *i (fdiv_int n m)))

/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
Expand Down
4 changes: 2 additions & 2 deletions test/lean/bitvec_operation.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def _shr_int_general (m : Int) (n : Int) : Int :=
/-- Type quantifiers: m : Int, n : Int -/
def fdiv_int (n : Int) (m : Int) : Int :=
if (Bool.and (LT.lt n 0) (GT.gt m 0))
then ((Int.tdiv (n + 1) m) -i 1)
then ((Int.tdiv (n +i 1) m) -i 1)
else if (Bool.and (GT.gt n 0) (LT.lt m 0))
then ((Int.tdiv (n -i 1) m) -i 1)
else (Int.tdiv n m)

/-- Type quantifiers: m : Int, n : Int -/
def fmod_int (n : Int) (m : Int) : Int :=
(n -i (m * (fdiv_int n m)))
(n -i (m *i (fdiv_int n m)))

/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
Expand Down
36 changes: 18 additions & 18 deletions test/lean/early_return.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ def _shr_int_general (m : Int) (n : Int) : Int :=
/-- Type quantifiers: m : Int, n : Int -/
def fdiv_int (n : Int) (m : Int) : Int :=
if (Bool.and (LT.lt n 0) (GT.gt m 0))
then ((Int.tdiv (n + 1) m) -i 1)
then ((Int.tdiv (n +i 1) m) -i 1)
else if (Bool.and (GT.gt n 0) (LT.lt m 0))
then ((Int.tdiv (n -i 1) m) -i 1)
else (Int.tdiv n m)

/-- Type quantifiers: m : Int, n : Int -/
def fmod_int (n : Int) (m : Int) : Int :=
(n -i (m * (fdiv_int n m)))
(n -i (m *i (fdiv_int n m)))

/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
Expand Down Expand Up @@ -154,7 +154,7 @@ def foreach_earlyreturneffect (n : Nat) : SailM Bool := do
if (GT.gt i 5)
then return (early_return
(false : Bool))
else (pure (cont (← writeReg r ((← readReg r) + 1))))))
else (pure (cont (← writeReg r ((← readReg r) +i 1))))))
(pure (GT.gt (← readReg r) n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -169,7 +169,7 @@ def foreach_earlyreturnpure (n : Nat) : Bool := Id.run do
if (GT.gt i 5)
then return (early_return
(false : Bool))
else (cont (res + i))))
else (cont (res +i i))))
(pure (GT.gt res n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -186,7 +186,7 @@ def foreach_inner_earlyreturneffect (n : Nat) : SailM Bool := do
if (GT.gt i 5)
then return (early_return
(false : Bool))
else (pure (cont (← writeReg r ((← readReg r) + 1)))))))
else (pure (cont (← writeReg r ((← readReg r) +i 1)))))))
(pure (GT.gt (← readReg r) n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -205,7 +205,7 @@ def foreach_inner_earlyreturnpure (n : Nat) : Bool := Id.run do
if (GT.gt i 5)
then return (early_return
(false : Bool))
else (cont (res + 1)))))
else (cont (res +i 1)))))
(pure (GT.gt res n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -223,9 +223,9 @@ def foreach_inner_earlyreturneffect_catch (n : Nat) : SailM Bool := do
if (GT.gt i 5)
then return (early_return
(false : Bool))
else (pure (cont (← writeReg r ((← readReg r) + 1))))))
else (pure (cont (← writeReg r ((← readReg r) +i 1))))))
(pure (cont (← do
writeReg r ((← readReg r) * 2))))))
writeReg r ((← readReg r) *i 2))))))
(pure (GT.gt (← readReg r) n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -246,8 +246,8 @@ def foreach_inner_earlyreturnpure_catch (n : Nat) : Bool := Id.run do
if (GT.gt i 5)
then return (early_return
(false : Bool))
else (cont (res + 1))))
(cont (res * 2))))
else (cont (res +i 1))))
(cont (res *i 2))))
(pure (GT.gt res n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -258,7 +258,7 @@ def while_earlyreturneffect (n : Nat) : SailM Bool := do
if (GT.gt (← readReg r) 5)
then return (early_return
(false : Bool))
else (pure (cont (← writeReg r ((← readReg r) + 1))))))
else (pure (cont (← writeReg r ((← readReg r) +i 1))))))
(pure (GT.gt (← readReg r) n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -271,7 +271,7 @@ def while_earlyreturnpure (n : Nat) : Bool := Id.run do
if (GT.gt res 5)
then return (early_return
(false : Bool))
else (cont (res + 1))))
else (cont (res +i 1))))
(pure (GT.gt res n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -284,7 +284,7 @@ def while_inner_earlyreturneffect (n : Nat) : SailM Bool := do
if (GT.gt n 5)
then return (early_return
(false : Bool))
else (pure (cont (← writeReg r ((← readReg r) + 1)))))))
else (pure (cont (← writeReg r ((← readReg r) +i 1)))))))
(pure (GT.gt (← readReg r) n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -299,7 +299,7 @@ def while_inner_earlyreturnpure (n : Nat) : Bool := Id.run do
if (GT.gt n 5)
then return (early_return
(false : Bool))
else (cont (res + 1)))))
else (cont (res +i 1)))))
(pure (GT.gt res n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -313,9 +313,9 @@ def while_inner_earlyreturneffect_catch (n : Nat) : SailM Bool := do
if (GT.gt n 5)
then return (early_return
(false : Bool))
else (pure (cont (← writeReg r ((← readReg r) + 1))))))
else (pure (cont (← writeReg r ((← readReg r) +i 1))))))
(pure (cont (← do
writeReg r ((← readReg r) * 2))))))
writeReg r ((← readReg r) *i 2))))))
(pure (GT.gt (← readReg r) n))

/-- Type quantifiers: n : Nat, 0 ≤ n -/
Expand All @@ -332,8 +332,8 @@ def while_inner_earlyreturnpure_catch (n : Nat) : Bool := Id.run do
if (GT.gt n 5)
then return (early_return
(false : Bool))
else (cont (res + 1))))
(cont (res * 2))))
else (cont (res +i 1))))
(cont (res *i 2))))
(pure (GT.gt res n))

def match_early_return (x : E) : SailM E := do
Expand Down
4 changes: 2 additions & 2 deletions test/lean/enum.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ def _shr_int_general (m : Int) (n : Int) : Int :=
/-- Type quantifiers: m : Int, n : Int -/
def fdiv_int (n : Int) (m : Int) : Int :=
if (Bool.and (LT.lt n 0) (GT.gt m 0))
then ((Int.tdiv (n + 1) m) -i 1)
then ((Int.tdiv (n +i 1) m) -i 1)
else if (Bool.and (GT.gt n 0) (LT.lt m 0))
then ((Int.tdiv (n -i 1) m) -i 1)
else (Int.tdiv n m)

/-- Type quantifiers: m : Int, n : Int -/
def fmod_int (n : Int) (m : Int) : Int :=
(n -i (m * (fdiv_int n m)))
(n -i (m *i (fdiv_int n m)))

/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
Expand Down
4 changes: 2 additions & 2 deletions test/lean/errors.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def _shr_int_general (m : Int) (n : Int) : Int :=
/-- Type quantifiers: m : Int, n : Int -/
def fdiv_int (n : Int) (m : Int) : Int :=
if (Bool.and (LT.lt n 0) (GT.gt m 0))
then ((Int.tdiv (n + 1) m) -i 1)
then ((Int.tdiv (n +i 1) m) -i 1)
else if (Bool.and (GT.gt n 0) (LT.lt m 0))
then ((Int.tdiv (n -i 1) m) -i 1)
else (Int.tdiv n m)

/-- Type quantifiers: m : Int, n : Int -/
def fmod_int (n : Int) (m : Int) : Int :=
(n -i (m * (fdiv_int n m)))
(n -i (m *i (fdiv_int n m)))

/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
Expand Down
8 changes: 4 additions & 4 deletions test/lean/extern.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def _shr_int_general (m : Int) (n : Int) : Int :=
/-- Type quantifiers: m : Int, n : Int -/
def fdiv_int (n : Int) (m : Int) : Int :=
if (Bool.and (LT.lt n 0) (GT.gt m 0))
then ((Int.tdiv (n + 1) m) -i 1)
then ((Int.tdiv (n +i 1) m) -i 1)
else if (Bool.and (GT.gt n 0) (LT.lt m 0))
then ((Int.tdiv (n -i 1) m) -i 1)
else (Int.tdiv n m)

/-- Type quantifiers: m : Int, n : Int -/
def fmod_int (n : Int) (m : Int) : Int :=
(n -i (m * (fdiv_int n m)))
(n -i (m *i (fdiv_int n m)))

/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
Expand Down Expand Up @@ -175,7 +175,7 @@ def sep_backwards_matches (arg_ : String) : SailM Bool := do
| _ => throw Error.Exit

def extern_add (_ : Unit) : Int :=
(5 + 4)
(5 +i 4)

def extern_sub (_ : Unit) : Int :=
(5 -i (-4))
Expand All @@ -187,7 +187,7 @@ def extern_negate (_ : Unit) : Int :=
(Neg.neg 5)

def extern_mult (_ : Unit) : Int :=
(5 * 4)
(5 *i 4)

def extern__shl8 (_ : Unit) : Int :=
(Int.shiftl 8 2)
Expand Down
4 changes: 2 additions & 2 deletions test/lean/extern_bitvec.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def _shr_int_general (m : Int) (n : Int) : Int :=
/-- Type quantifiers: m : Int, n : Int -/
def fdiv_int (n : Int) (m : Int) : Int :=
if (Bool.and (LT.lt n 0) (GT.gt m 0))
then ((Int.tdiv (n + 1) m) -i 1)
then ((Int.tdiv (n +i 1) m) -i 1)
else if (Bool.and (GT.gt n 0) (LT.lt m 0))
then ((Int.tdiv (n -i 1) m) -i 1)
else (Int.tdiv n m)

/-- Type quantifiers: m : Int, n : Int -/
def fmod_int (n : Int) (m : Int) : Int :=
(n -i (m * (fdiv_int n m)))
(n -i (m *i (fdiv_int n m)))

/-- Type quantifiers: k_a : Type -/
def is_none (opt : (Option k_a)) : Bool :=
Expand Down
Loading

0 comments on commit 96ede10

Please sign in to comment.