Skip to content

Commit

Permalink
Adds several Lean backends for built-in SAIL functions (mostly arithm…
Browse files Browse the repository at this point in the history
…etic and equality) (#954)

* Adds support for all primitives in arith.sail and flow.sail
* Adds support for eq_anything
  • Loading branch information
benjaminselfridge authored Feb 11, 2025
1 parent c9a9da8 commit b689188
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 32 deletions.
27 changes: 19 additions & 8 deletions lib/arith.sail
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ $include <flow.sail>
val add_atom = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", lean: "HAdd.hAdd", _: "add_int"} : forall 'n 'm.
(int('n), int('m)) -> int('n + 'm)

val add_int = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", lean: "Int.add", _: "add_int"} : (int, int) -> int
val add_int = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", lean: "HAdd.hAdd", _: "add_int"} : (int, int) -> int

overload operator + = {add_atom, add_int}

Expand All @@ -63,7 +63,7 @@ overload operator + = {add_atom, add_int}
val sub_atom = pure {ocaml: "sub_int", interpreter: "sub_int", lem: "integerMinus", coq: "Z.sub", lean: "HSub.hSub", _: "sub_int"} : forall 'n 'm.
(int('n), int('m)) -> int('n - 'm)

val sub_int = pure {ocaml: "sub_int", interpreter: "sub_int", lem: "integerMinus", coq: "Z.sub", lean: "Int.sub", _: "sub_int"} : (int, int) -> int
val sub_int = pure {ocaml: "sub_int", interpreter: "sub_int", lem: "integerMinus", coq: "Z.sub", lean: "HSub.hSub", _: "sub_int"} : (int, int) -> int

overload operator - = {sub_atom, sub_int}

Expand All @@ -77,18 +77,18 @@ val sub_nat = pure {

// ***** Negation *****

val negate_atom = pure {ocaml: "negate", interpreter: "negate", lem: "integerNegate", coq: "Z.opp", _: "neg_int"} : forall 'n. int('n) -> int(- 'n)
val negate_atom = pure {ocaml: "negate", interpreter: "negate", lem: "integerNegate", coq: "Z.opp", lean: "Neg.neg", _: "neg_int"} : forall 'n. int('n) -> int(- 'n)

val negate_int = pure {ocaml: "negate", interpreter: "negate", lem: "integerNegate", coq: "Z.opp", lean: "Int.neg", _: "neg_int"} : int -> int
val negate_int = pure {ocaml: "negate", interpreter: "negate", lem: "integerNegate", coq: "Z.opp", lean: "Neg.neg", _: "neg_int"} : int -> int

overload negate = {negate_atom, negate_int}

// ***** Multiplication *****

val mult_atom = pure {ocaml: "mult", interpreter: "mult", lem: "integerMult", coq: "Z.mul", _: "mult_int"} : forall 'n 'm.
val mult_atom = pure {ocaml: "mult", interpreter: "mult", lem: "integerMult", coq: "Z.mul", lean: "HMul.hMul", _: "mult_int"} : forall 'n 'm.
(int('n), int('m)) -> int('n * 'm)

val mult_int = pure {ocaml: "mult", interpreter: "mult", lem: "integerMult", coq: "Z.mul", lean: "Int.mul", _: "mult_int"} : (int, int) -> int
val mult_int = pure {ocaml: "mult", interpreter: "mult", lem: "integerMult", coq: "Z.mul", lean: "HMul.hMul", _: "mult_int"} : (int, int) -> int

overload operator * = {mult_atom, mult_int}

Expand Down Expand Up @@ -195,15 +195,26 @@ val abs_int_plain = pure {
interpreter: "abs_int",
lem: "integerAbs",
coq: "Z.abs",
lean: "Sail.Int.intAbs",
_: "abs_int"
} : int -> int

overload abs_int = {abs_int_plain}

val max_int = pure {lem: "integerMax", coq: "Z.max", _: "max_int"} : forall 'x 'y.
val max_int = pure {
lem: "integerMax",
coq: "Z.max",
lean: "Max.max",
_: "max_int"
} : forall 'x 'y.
(int('x), int('y)) -> {'z, ('x >= 'y & 'z == 'x) | ('x < 'y & 'z == 'y). int('z)}

val min_int = pure {lem: "integerMin", coq: "Z.min", _: "min_int"} : forall 'x 'y.
val min_int = pure {
lem: "integerMin",
coq: "Z.min",
lean: "Min.min",
_: "min_int"
} : forall 'x 'y.
(int('x), int('y)) -> {'z, ('x < 'y & 'z == 'x) | ('x >= 'y & 'z == 'y). int('z)}

$endif
15 changes: 7 additions & 8 deletions lib/flow.sail
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,12 @@ therefore be included in just about every Sail specification.

*/

val eq_unit : (unit, unit) -> bool(true)
val eq_unit = pure { lean : "Eq", _ : "eq_unit" } : (unit, unit) -> bool(true)
function eq_unit(_, _) = true

val eq_bit = pure { lem : "eq", lean : "Eq", _ : "eq_bit" } : (bit, bit) -> bool

function eq_unit(_, _) = true

val not_bool = pure {coq: "negb", _: "not"} : forall ('p : Bool). bool('p) -> bool(not('p))
val not_bool = pure {coq: "negb", lean: "Bool.not", _: "not"} : forall ('p : Bool). bool('p) -> bool(not('p))
/* NB: There are special cases in Sail for effectful uses of and_bool and
or_bool that are not shown here. */

Expand All @@ -81,10 +80,10 @@ function neq_int (x, y) = not_bool(eq_int(x, y))
val neq_bool : (bool, bool) -> bool
function neq_bool (x, y) = not_bool(eq_bool(x, y))

val lteq_int = pure {coq: "Z.leb", _:"lteq"} : forall 'n 'm. (int('n), int('m)) -> bool('n <= 'm)
val gteq_int = pure {coq: "Z.geb", _:"gteq"} : forall 'n 'm. (int('n), int('m)) -> bool('n >= 'm)
val lt_int = pure {coq: "Z.ltb", _:"lt"} : forall 'n 'm. (int('n), int('m)) -> bool('n < 'm)
val gt_int = pure {coq: "Z.gtb", _:"gt"} : forall 'n 'm. (int('n), int('m)) -> bool('n > 'm)
val lteq_int = pure {coq: "Z.leb", lean: "LE.le", _:"lteq"} : forall 'n 'm. (int('n), int('m)) -> bool('n <= 'm)
val gteq_int = pure {coq: "Z.geb", lean: "GE.ge",_:"gteq"} : forall 'n 'm. (int('n), int('m)) -> bool('n >= 'm)
val lt_int = pure {coq: "Z.ltb", lean: "LT.lt", _:"lt"} : forall 'n 'm. (int('n), int('m)) -> bool('n < 'm)
val gt_int = pure {coq: "Z.gtb", lean: "GT.gt", _:"gt"} : forall 'n 'm. (int('n), int('m)) -> bool('n > 'm)

overload operator == = {eq_int, eq_bit, eq_bool, eq_unit}
overload operator != = {neq_int, neq_bool}
Expand Down
4 changes: 2 additions & 2 deletions lib/generic_equality.sail
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ $define _GENERIC_EQUALITY

$include <flow.sail>

val eq_anything = pure {ocaml: "(fun (x, y) -> x = y)", lem: "eq", coq: "generic_eq", _: "eq_anything"} : forall ('a : Type). ('a, 'a) -> bool
val eq_anything = pure {ocaml: "(fun (x, y) -> x = y)", lem: "eq", coq: "generic_eq", lean: "BEq.beq", _: "eq_anything"} : forall ('a : Type). ('a, 'a) -> bool

overload operator == = {eq_anything}

val neq_anything = pure {lem: "neq", coq: "generic_neq"} : forall ('a : Type). ('a, 'a) -> bool
val neq_anything = pure {lem: "neq", lean: "bneq", coq: "generic_neq"} : forall ('a : Type). ('a, 'a) -> bool

function neq_anything(x, y) = not_bool(eq_anything(x, y))

Expand Down
6 changes: 6 additions & 0 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,10 @@ def addInt {w : Nat} (x : BitVec w) (i : Int) : BitVec w :=
x + BitVec.ofInt w i

end BitVec

namespace Int

def intAbs (x : Int) : Int := Int.ofNat (Int.natAbs x)

end Int
end Sail
53 changes: 45 additions & 8 deletions test/lean/extern.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@ open Sail
abbrev SailM := PreSailM PEmpty.elim trivialChoiceSource

def extern_add (_ : Unit) : Int :=
(Int.add 5 4)
(HAdd.hAdd 5 4)

def extern_sub (_ : Unit) : Int :=
(Int.sub 5 (-4))
(HSub.hSub 5 (-4))

def extern_sub_nat (_ : Unit) : Nat :=
(HSub.hSub 5 4)

def extern_negate (_ : Unit) : Int :=
(Neg.neg 5)

def extern_mult (_ : Unit) : Int :=
(HMul.hMul 5 4)

def extern_tdiv (_ : Unit) : Int :=
(Int.tdiv 5 4)
Expand All @@ -19,11 +28,24 @@ def extern_tmod (_ : Unit) : Int :=
def extern_tmod_positive (_ : Unit) : Int :=
(Int.tmod 5 4)

def extern_negate (_ : Unit) : Int :=
(Int.neg (-5))
def extern_max (_ : Unit) : Int :=
(Max.max 5 4)

def extern_mult (_ : Unit) : Int :=
(Int.mul 5 (-4))
def extern_min (_ : Unit) : Int :=
(Min.min 5 4)

def extern_abs_int_plain (_ : Unit) : Int :=
let x : Int := (-5)
(Sail.Int.intAbs x)

def extern_eq_unit (_ : Unit) : Bool :=
(Eq () ())

def extern_eq_bit (_ : Unit) : Bool :=
(Eq 0#1 1#1)

def extern_not (_ : Unit) : Bool :=
(Bool.not true)

def extern_and (_ : Unit) : Bool :=
(Bool.and true false)
Expand All @@ -37,8 +59,23 @@ def extern_or (_ : Unit) : Bool :=
def extern_eq_bool (_ : Unit) : Bool :=
(Eq true false)

def extern_eq_bit (_ : Unit) : Bool :=
(Eq 0#1 1#1)
def extern_eq_int (_ : Unit) : Bool :=
(Eq 5 4)

def extern_lteq_int (_ : Unit) : Bool :=
(LE.le 5 4)

def extern_gteq_int (_ : Unit) : Bool :=
(GE.ge 5 4)

def extern_lt_int (_ : Unit) : Bool :=
(LT.lt 5 4)

def extern_gt_int (_ : Unit) : Bool :=
(GT.gt 5 4)

def extern_eq_anything (_ : Unit) : Bool :=
(BEq.beq true true)

def initialize_registers (_ : Unit) : Unit :=
()
Expand Down
67 changes: 61 additions & 6 deletions test/lean/extern.sail
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
default Order dec

$include <prelude.sail>
$include <generic_equality.sail>

/* Testing arith.sail */

function extern_add() -> int = {
return add_int(5, 4)
Expand All @@ -10,6 +13,18 @@ function extern_sub() -> int = {
return sub_int(5, -4)
}

function extern_sub_nat() -> nat = {
return sub_nat(5, 4)
}

function extern_negate() -> int = {
return negate_int(5)
}

function extern_mult() -> int = {
return mult_int(5, 4)
}

function extern_tdiv() -> int = {
return tdiv_int(5, 4)
}
Expand All @@ -22,12 +37,31 @@ function extern_tmod_positive() -> int = {
return _tmod_int_positive(5, 4)
}

function extern_negate() -> int = {
return negate_int(-5)
function extern_max() -> int = {
return max_int(5, 4)
}

function extern_mult() -> int = {
return mult_int(5, -4)
function extern_min() -> int = {
return min_int(5, 4)
}

function extern_abs_int_plain() -> int = {
let x: int = -5;
return abs_int_plain(x)
}

/* Testing flow.sail */

function extern_eq_unit() -> bool = {
return eq_unit((),())
}

function extern_eq_bit() -> bool = {
return eq_bit(bitzero, bitone)
}

function extern_not() -> bool = {
return not_bool(true)
}

function extern_and() -> bool = {
Expand All @@ -46,7 +80,28 @@ function extern_eq_bool() -> bool = {
return eq_bool(true, false)
}

function extern_eq_bit() -> bool = {
return eq_bit(bitzero, bitone)
function extern_eq_int() -> bool = {
return eq_int(5, 4)
}

function extern_lteq_int() -> bool = {
return lteq_int(5, 4)
}

function extern_gteq_int() -> bool = {
return gteq_int(5, 4)
}

function extern_lt_int() -> bool = {
return lt_int(5, 4)
}

function extern_gt_int() -> bool = {
return gt_int(5, 4)
}

/* Testing generic_equality.sail */

function extern_eq_anything() -> bool = {
return eq_anything(true, true)
}

0 comments on commit b689188

Please sign in to comment.