Skip to content

Commit

Permalink
Lean: Small fixes for the RISC-V model
Browse files Browse the repository at this point in the history
  • Loading branch information
lfrenot committed Feb 18, 2025
1 parent c62e643 commit 63981b1
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 32 deletions.
2 changes: 1 addition & 1 deletion lib/flow.sail
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ therefore be included in just about every Sail specification.
val eq_unit = pure { lean : "BEq.beq", _ : "eq_unit" } : (unit, unit) -> bool(true)
function eq_unit(_, _) = true

val eq_bit = pure { lem : "eq", lean : "Eq", _ : "eq_bit" } : (bit, bit) -> bool
val eq_bit = pure { lem : "eq", lean : "BEq.beq", _ : "eq_bit" } : (bit, bit) -> bool

val not_bool = pure {coq: "negb", lean: "Bool.not", _: "not"} : forall ('p : Bool). bool('p) -> bool(not('p))
/* NB: There are special cases in Sail for effectful uses of and_bool and
Expand Down
2 changes: 1 addition & 1 deletion lib/vector.sail
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ val bitvector_update = pure {
interpreter: "update",
lem: "update_vec_dec",
coq: "update_vec_dec",
lean: "bitvectorUpdate",
lean: "BitVec.update",
_: "vector_update"
} : forall 'n 'm, 0 <= 'm < 'n. (bits('n), int('m), bit) -> bits('n)
$else
Expand Down
5 changes: 2 additions & 3 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def append' (x : BitVec n) (y : BitVec m) {mn}
(hmn : mn = n + m := by (conv => rhs; dsimp); try rfl) : BitVec mn :=
hmn ▸ x.append y

def update (x : BitVec m) (n : Nat) (b : BitVec 1) := updateSubrange' x n _ b

def toBin {w : Nat} (x : BitVec w) : String :=
List.asString (List.map (fun c => if c then '1' else '0') (List.ofFn (BitVec.getMsb' x)))

Expand Down Expand Up @@ -199,8 +201,6 @@ def reg_deref (reg_ref : @RegisterRef Register RegisterType α) : PreSailM Regis

def vectorAccess [Inhabited α] (v : Vector α m) (n : Nat) := v[n]!

def bitvectorUpdate (v : BitVec m) (n : Nat) (b : Bool) := v[n]! = b

def vectorUpdate (v : Vector α m) (n : Nat) (a : α) := v.set! n a

def assert (p : Bool) (s : String) : PreSailM RegisterType c ue Unit :=
Expand Down Expand Up @@ -326,7 +326,6 @@ def main_of_sail_main (initialState : SequentialState RegisterType c) (main : Un
| .error e _ => do
IO.println s!"Error while running the sail program!: {e.print}"


section Loops

def foreach_' (from' to step : Nat) (vars : Vars) (body : Nat -> Vars -> Vars) : Vars := Id.run do
Expand Down
63 changes: 36 additions & 27 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ let rec doc_match_clause (as_monadic : bool) ctx (Pat_aux (cl, l)) =

and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
let env = env_of_tannot annot in
let d_of_arg arg =
let d_of_arg ctx arg =
let arg_monadic = effectful (effect_of arg) in
let wrap = match arg with E_aux (E_let _, _) | E_aux (E_internal_plet _, _) -> parens | _ -> fun x -> x in
wrap_with_left_arrow arg_monadic (wrap (doc_exp arg_monadic ctx arg))
Expand All @@ -515,11 +515,13 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
if Env.is_register id env then wrap_with_left_arrow (not as_monadic) (string "readReg " ^^ doc_id_ctor id)
else wrap_with_pure as_monadic (doc_id_ctor id)
| E_lit l -> wrap_with_pure as_monadic (doc_lit l)
| E_app (Id_aux (Id "None", _), _) -> string "none"
| E_app (Id_aux (Id "None", _), _) -> wrap_with_pure as_monadic (string "none")
| E_app (Id_aux (Id "Some", _), args) ->
let d_id = string "some" in
let d_args = List.map d_of_arg args in
nest 2 (parens (flow (break 1) (d_id :: d_args)))
wrap_with_pure as_monadic
(let d_id = string "some" in
let d_args = List.map (d_of_arg ctx) args in
nest 2 (parens (flow (break 1) (d_id :: d_args)))
)
| E_app (Id_aux (Id "foreach#", _), args) -> begin
let doc_loop_var (E_aux (e, (l, _)) as exp) =
match e with
Expand Down Expand Up @@ -572,7 +574,14 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
| _ -> raise (Reporting.err_unreachable l __POS__ ("Unable to find loop variable in " ^ string_of_exp body))
in
let effects = effectful (effect_of body) in
let combinator = if as_monadic && effects then "foreach_M" else "foreach_" in
let early_return = has_early_return body in
let combinator, catch, as_monadic =
match (as_monadic && effects, early_return) with
| true, true -> ("foreach_ME", string "catchEarlyReturn", true)
| true, false -> ("foreach_M", empty, true)
| false, true -> ("foreach_E", string "catchEarlyReturnPure", false)
| false, false -> ("foreach_", empty, false)
in
let body_ctxt = add_single_kid_id_rename ctx loopvar (mk_kid ("loop_" ^ string_of_id loopvar)) in
let from_exp_pp, to_exp_pp, step_exp_pp =
(doc_exp false ctx from_exp, doc_exp false ctx to_exp, doc_exp false ctx step_exp)
Expand All @@ -584,11 +593,9 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
(* TODO: this should probably be construct_dep_pairs, but we would need
to change it to use the updated context. *)
let body_pp = doc_exp as_monadic body_ctxt body in
parens
((prefix 2 1)
((separate space) [string combinator; from_exp_pp; to_exp_pp; step_exp_pp; vartuple_pp])
(parens (prefix 2 1 (group body_lambda) body_pp))
)
let loop_head = flow (break 1) [string combinator; from_exp_pp; to_exp_pp; step_exp_pp; vartuple_pp] in
let full_loop = (prefix 2 1) loop_head (parens (prefix 2 1 (group body_lambda) body_pp)) in
if early_return then flow (break 1) [catch; parens full_loop] else full_loop
| _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator")
end
| E_app (f, args) ->
Expand All @@ -598,34 +605,35 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
if Env.is_extern f env "lean" then string (Env.get_extern f env "lean")
else doc_exp false ctx (E_aux (E_id f, (l, annot)))
in
let d_args = List.map d_of_arg args in
let d_args = List.map (d_of_arg ctx) args in
let d_args = List.map snd (List.filter (fun x -> not (fst x)) (List.combine implicits d_args)) in
let fn_monadic = not (Effects.function_is_pure f ctx.global.effect_info) in
nest 2
(wrap_with_left_arrow ((not as_monadic) && fn_monadic)
(wrap_with_pure (as_monadic && not fn_monadic) (parens (flow (break 1) (d_id :: d_args))))
)
| E_vector vals ->
string "#v" ^^ wrap_with_pure as_monadic (brackets (nest 2 (flow (comma ^^ break 1) (List.map d_of_arg vals))))
string "#v"
^^ wrap_with_pure as_monadic (brackets (nest 2 (flow (comma ^^ break 1) (List.map (d_of_arg ctx) vals))))
| E_typ (typ, e) ->
if effectful (effect_of e) then
parens (separate space [doc_exp as_monadic ctx e; colon; string "SailM"; doc_typ ctx typ])
if effectful (effect_of e) then doc_exp as_monadic ctx e
else wrap_with_pure as_monadic (parens (separate space [doc_exp false ctx e; colon; doc_typ ctx typ]))
| E_tuple es -> wrap_with_pure as_monadic (parens (separate_map (comma ^^ space) d_of_arg es))
| E_internal_plet (lpat, lexp, e) | E_let (LB_aux (LB_val (lpat, lexp), _), e) ->
| E_tuple es -> wrap_with_pure as_monadic (parens (separate_map (comma ^^ space) (d_of_arg ctx) es))
| E_let (LB_aux (LB_val (lpat, lexp), _), e') | E_internal_plet (lpat, lexp, e') ->
let arrow = match e with E_let _ -> string "" | _ -> string "← do" in
let id_typ =
match pat_is_plain_binder env lpat with
| Some (_, Some typ) -> doc_pat lpat ^^ space ^^ colon ^^ space ^^ doc_typ ctx typ
| _ -> doc_pat lpat
in
let pp_let_line_f l = group (nest 2 (flow (break 1) l)) in
let pp_let_line =
if effectful (effect_of lexp) then
if effectful (effect_of lexp) || has_early_return lexp then
if is_unit (typ_of lexp) && is_anonymous_pat lpat then doc_exp true ctx lexp
else pp_let_line_f [separate space [string "let"; id_typ; string "← do"]; doc_exp true ctx lexp]
else pp_let_line_f [separate space [string "let"; id_typ; arrow]; doc_exp true ctx lexp]
else pp_let_line_f [separate space [string "let"; id_typ; coloneq]; doc_exp false ctx lexp]
in
pp_let_line ^^ hardline ^^ doc_exp as_monadic ctx e
pp_let_line ^^ hardline ^^ doc_exp as_monadic ctx e'
| E_internal_return e -> doc_exp false ctx e (* ??? *)
| E_struct fexps ->
let args = List.map d_of_field fexps in
Expand All @@ -641,12 +649,13 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
| E_match (discr, brs) ->
let cases = separate_map hardline (doc_match_clause as_monadic ctx) brs in
string (match_or_match_bv brs) ^^ doc_exp false ctx discr ^^ string " with" ^^ hardline ^^ cases
| E_assign ((LE_aux (le_act, tannot) as le), e) -> (
match le_act with
| LE_id id | LE_typ (_, id) -> string "writeReg " ^^ doc_id_ctor id ^^ space ^^ doc_exp false ctx e
| LE_deref e' -> string "writeRegRef " ^^ doc_exp false ctx e' ^^ space ^^ doc_exp false ctx e
| _ -> failwith ("assign " ^ string_of_lexp le ^ "not implemented yet")
)
| E_assign ((LE_aux (le_act, tannot) as le), e) ->
wrap_with_left_arrow (not as_monadic)
( match le_act with
| LE_id id | LE_typ (_, id) -> string "writeReg " ^^ doc_id_ctor id ^^ space ^^ doc_exp false ctx e
| LE_deref e' -> string "writeRegRef " ^^ doc_exp false ctx e' ^^ space ^^ doc_exp false ctx e
| _ -> failwith ("assign " ^ string_of_lexp le ^ "not implemented yet")
)
| E_if (i, t, e) ->
let statements_monadic = as_monadic || effectful (effect_of t) || effectful (effect_of e) in
nest 2 (string "if" ^^ space ^^ nest 1 (doc_exp false ctx i))
Expand All @@ -664,7 +673,7 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
^^ parens (doc_exp false ctx e)
^^ space
^^ parens (string "fun the_exception => " ^^ hardline ^^ cases)
| E_assert (e1, e2) -> string "assert " ^^ d_of_arg e1 ^^ space ^^ d_of_arg e2
| E_assert (e1, e2) -> string "assert " ^^ d_of_arg ctx e1 ^^ space ^^ d_of_arg ctx e2
| E_list es -> brackets (separate_map comma_sp (doc_exp as_monadic ctx) es)
| E_cons (hd_e, tl_e) -> parens (separate space [doc_exp false ctx hd_e; string "::"; doc_exp false ctx tl_e])
| _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.")
Expand Down

0 comments on commit 63981b1

Please sign in to comment.