Skip to content

Commit

Permalink
Lean: handling foreach Loops (#990)
Browse files Browse the repository at this point in the history
Co-authored-by: Léo Stefanesco <leo.lveb@gmail.com>
  • Loading branch information
lfrenot and ineol authored Feb 13, 2025
1 parent fb7249f commit f2432d8
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 68 deletions.
31 changes: 31 additions & 0 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,37 @@ def main_of_sail_main (initialState : SequentialState RegisterType c) (main : Un
| .error e _ => do
IO.println s!"Error while running the sail program!: {e.print}"


section Loops

def foreach_' (from' to step : Nat) (vars : Vars) (body : Nat -> Vars -> Vars) : Vars := Id.run do
let mut vars := vars
let step := 1 + (step - 1)
let range := Std.Range.mk from' to step (by omega)
for i in range do
vars := body i vars
pure vars

def foreach_ (from' to step : Nat) (vars : Vars) (body : Nat -> Vars -> Vars) : Vars :=
if from' < to
then foreach_' from' to step vars body
else foreach_' to from' step vars body

def foreach_M' (from' to step : Nat) (vars : Vars) (body : Nat -> Vars -> PreSailM RegisterType c Vars) : PreSailM RegisterType c Vars := do
let mut vars := vars
let step := 1 + (step - 1)
let range := Std.Range.mk from' to step (by omega)
for i in range do
vars ← body i vars
pure vars

def foreach_M (from' to step : Nat) (vars : Vars) (body : Nat -> Vars -> PreSailM RegisterType c Vars) : PreSailM RegisterType c Vars :=
if from' < to
then foreach_M' from' to step vars body
else foreach_M' to from' step vars body

end Loops

end Regs

namespace BitVec
Expand Down
81 changes: 77 additions & 4 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ let rec doc_pat ?(in_vector = false) (P_aux (p, (l, annot)) as pat) =
| P_vector_concat pats -> separate (string ",") (List.map (doc_pat ~in_vector:true) pats) |> brackets
| P_app (Id_aux (Id "None", _), p) -> string "none"
| P_app (cons, pats) -> doc_id_ctor (fixup_match_id cons) ^^ space ^^ separate_map (string ", ") doc_pat pats
| P_var (p, _) -> doc_pat p
| P_as (pat, id) -> doc_pat pat
| _ -> failwith ("Doc Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")

Expand Down Expand Up @@ -495,6 +496,77 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
let d_id = string "some" in
let d_args = List.map d_of_arg args in
nest 2 (parens (flow (break 1) (d_id :: d_args)))
| E_app (Id_aux (Id "foreach#", _), args) -> begin
let doc_loop_var (E_aux (e, (l, _)) as exp) =
match e with
| E_id id ->
let id_pp = doc_id_ctor id in
let typ = typ_of exp in
(id_pp, id_pp)
| E_lit (L_aux (L_unit, _)) -> (string "()", underscore)
| _ -> raise (Reporting.err_unreachable l __POS__ ("Bad expression for variable in loop: " ^ string_of_exp exp))
in
let make_loop_vars extra_binders varstuple =
match varstuple with
| E_aux (E_tuple vs, _) ->
let vs = List.map doc_loop_var vs in
let mkpp f vs = separate (string ", ") (List.map f vs) in
let tup_pp = mkpp (fun (pp, _) -> pp) vs in
let match_pp = mkpp (fun (_, pp) -> pp) vs in
(parens tup_pp, separate space ((string "λ" :: extra_binders) @ [parens match_pp; string "=>"]))
| _ ->
let exp_pp, match_pp = doc_loop_var varstuple in
(exp_pp, separate space ((string "λ" :: extra_binders) @ [match_pp; string "=>"]))
in
match args with
| [from_exp; to_exp; step_exp; ord_exp; vartuple; body] ->
let loopvar, body =
match body with
| E_aux
( E_if
( _,
E_aux
( E_let
( LB_aux
( LB_val
( ( P_aux (P_typ (_, P_aux (P_var (P_aux (P_id id, _), _), _)), _)
| P_aux (P_var (P_aux (P_id id, _), _), _)
| P_aux (P_id id, _) ),
_
),
_
),
body
),
_
),
_
),
_
) ->
(id, body)
| _ -> raise (Reporting.err_unreachable l __POS__ ("Unable to find loop variable in " ^ string_of_exp body))
in
let effects = effectful (effect_of body) in
let combinator = if as_monadic && effects then "foreach_M" else "foreach_" in
let body_ctxt = add_single_kid_id_rename ctx loopvar (mk_kid ("loop_" ^ string_of_id loopvar)) in
let from_exp_pp, to_exp_pp, step_exp_pp =
(doc_exp false ctx from_exp, doc_exp false ctx to_exp, doc_exp false ctx step_exp)
in
(* The body has the right type for deciding whether a proof is necessary *)
(* let vartuple_retyped = check_exp env (strip_exp vartuple) (typ_of body) in *)
let vartuple_pp, body_lambda = make_loop_vars [doc_id_ctor loopvar] vartuple in
let body_lambda = if effects then body_lambda ^^ string " do" else body_lambda in
(* TODO: this should probably be construct_dep_pairs, but we would need
to change it to use the updated context. *)
let body_pp = doc_exp as_monadic body_ctxt body in
parens
((prefix 2 1)
((separate space) [string combinator; from_exp_pp; to_exp_pp; step_exp_pp; vartuple_pp])
(parens (prefix 2 1 (group body_lambda) body_pp))
)
| _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator")
end
| E_app (f, args) ->
let _, f_typ = Env.get_val_spec f env in
let implicits = get_fn_implicits f_typ in
Expand Down Expand Up @@ -522,13 +594,14 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
| Some (_, Some typ) -> doc_pat lpat ^^ space ^^ colon ^^ space ^^ doc_typ ctx typ
| _ -> doc_pat lpat
in
let pp_let_line_f l = group (nest 2 (flow (break 1) l)) in
let pp_let_line =
if effectful (effect_of lexp) then
if is_unit (typ_of lexp) && is_anonymous_pat lpat then [doc_exp true ctx lexp]
else [separate space [string "let"; id_typ; string ""]; doc_exp true ctx lexp]
else [separate space [string "let"; id_typ; coloneq]; doc_exp false ctx lexp]
if is_unit (typ_of lexp) && is_anonymous_pat lpat then doc_exp true ctx lexp
else pp_let_line_f [separate space [string "let"; id_typ; string " do"]; doc_exp true ctx lexp]
else pp_let_line_f [separate space [string "let"; id_typ; coloneq]; doc_exp false ctx lexp]
in
group (nest 2 (flow (break 1) pp_let_line)) ^^ hardline ^^ doc_exp as_monadic ctx e
pp_let_line ^^ hardline ^^ doc_exp as_monadic ctx e
| E_internal_return e -> doc_exp false ctx e (* ??? *)
| E_struct fexps ->
let args = List.map d_of_field fexps in
Expand Down
4 changes: 3 additions & 1 deletion src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,12 @@ let lean_rewrites =
("attach_effects", []);
("remove_blocks", []);
("attach_effects", []);
(*("letbind_effects", []);*)
(* ("letbind_effects", []); *)
("remove_e_assign", []);
(* ^^^^ replace loops by dummy function calls *)
("attach_effects", []);
("internal_lets", []);
(* ^^^^ transforms var into let *)
("remove_superfluous_letbinds", []);
("remove_superfluous_returns", []);
("bit_lists_to_lits", []);
Expand Down
Loading

0 comments on commit f2432d8

Please sign in to comment.