forked from rems-project/sail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsail2_state_monad.lem
289 lines (234 loc) · 11.5 KB
/
sail2_state_monad.lem
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
open import Pervasives_extra
open import Sail2_instr_kinds
open import Sail2_values
(* 'a is result type *)
type memstate = map nat memory_byte
type tagstate = map nat bitU
(* type regstate = map string (vector bitU) *)
type sequential_state 'regs =
<| regstate : 'regs;
memstate : memstate;
tagstate : tagstate |>
val init_state : forall 'regs. 'regs -> sequential_state 'regs
let init_state regs =
<| regstate = regs;
memstate = Map.empty;
tagstate = Map.empty |>
type ex 'e =
| Failure of string
| Throw of 'e
type state_result 'a 'e =
| Value of 'a
| Ex of (ex 'e)
(* State, nondeterminism and exception monad with result value type 'a
and exception type 'e. *)
type monadS 'regs 'a 'e = sequential_state 'regs -> set (state_result 'a 'e * sequential_state 'regs)
val returnS : forall 'regs 'a 'e. 'a -> monadS 'regs 'a 'e
let returnS a s = {(Value a,s)}
val bindS : forall 'regs 'a 'b 'e. monadS 'regs 'a 'e -> ('a -> monadS 'regs 'b 'e) -> monadS 'regs 'b 'e
let bindS m f (s : sequential_state 'regs) =
Set.bigunion (Set.map (function
| (Value a, s') -> f a s'
| (Ex e, s') -> {(Ex e, s')}
end) (m s))
val seqS: forall 'regs 'b 'e. monadS 'regs unit 'e -> monadS 'regs 'b 'e -> monadS 'regs 'b 'e
let seqS m n = bindS m (fun (_ : unit) -> n)
let inline (>>$=) = bindS
let inline (>>$) = seqS
val readS : forall 'regs 'a 'e. (sequential_state 'regs -> 'a) -> monadS 'regs 'a 'e
let readS f = (fun s -> returnS (f s) s)
val updateS : forall 'regs 'e. (sequential_state 'regs -> sequential_state 'regs) -> monadS 'regs unit 'e
let updateS f = (fun s -> returnS () (f s))
val failS : forall 'regs 'a 'e. string -> monadS 'regs 'a 'e
let failS msg s = {(Ex (Failure msg), s)}
val exitS : forall 'regs 'e 'a. unit -> monadS 'regs 'a 'e
let exitS () = failS "exit"
val throwS : forall 'regs 'a 'e. 'e -> monadS 'regs 'a 'e
let throwS e s = {(Ex (Throw e), s)}
val try_catchS : forall 'regs 'a 'e1 'e2. monadS 'regs 'a 'e1 -> ('e1 -> monadS 'regs 'a 'e2) -> monadS 'regs 'a 'e2
let try_catchS m h s =
Set.bigunion (Set.map (function
| (Value a, s') -> returnS a s'
| (Ex (Throw e), s') -> h e s'
| (Ex (Failure msg), s') -> {(Ex (Failure msg), s')}
end) (m s))
val assert_expS : forall 'regs 'e. bool -> string -> monadS 'regs unit 'e
let assert_expS exp msg = if exp then returnS () else failS msg
(* For early return, we abuse exceptions by throwing and catching
the return value. The exception type is "either 'r 'e", where "Right e"
represents a proper exception and "Left r" an early return of value "r". *)
type monadRS 'regs 'a 'r 'e = monadS 'regs 'a (either 'r 'e)
val early_returnS : forall 'regs 'a 'r 'e. 'r -> monadRS 'regs 'a 'r 'e
let early_returnS r = throwS (Left r)
val catch_early_returnS : forall 'regs 'a 'e. monadRS 'regs 'a 'a 'e -> monadS 'regs 'a 'e
let catch_early_returnS m =
try_catchS m
(function
| Left a -> returnS a
| Right e -> throwS e
end)
(* Lift to monad with early return by wrapping exceptions *)
val liftRS : forall 'a 'r 'regs 'e. monadS 'regs 'a 'e -> monadRS 'regs 'a 'r 'e
let liftRS m = try_catchS m (fun e -> throwS (Right e))
(* Catch exceptions in the presence of early returns *)
val try_catchRS : forall 'regs 'a 'r 'e1 'e2. monadRS 'regs 'a 'r 'e1 -> ('e1 -> monadRS 'regs 'a 'r 'e2) -> monadRS 'regs 'a 'r 'e2
let try_catchRS m h =
try_catchS m
(function
| Left r -> throwS (Left r)
| Right e -> h e
end)
val maybe_failS : forall 'regs 'a 'e. string -> maybe 'a -> monadS 'regs 'a 'e
let maybe_failS msg = function
| Just a -> returnS a
| Nothing -> failS msg
end
val chooseS : forall 'regs 'a 'e. SetType 'a => set 'a -> monadS 'regs 'a 'e
let chooseS xs s = Set.map (fun x -> (Value x, s)) xs
val choose_boolS : forall 'regs 'e. unit -> monadS 'regs bool 'e
let choose_boolS () = chooseS {false; true}
val headS : forall 'rv 'a 'e. list 'a -> monadS 'rv 'a 'e
let headS = function
| x :: _ -> returnS x
| [] -> failS "headM"
end
val tailS : forall 'rv 'a 'e. list 'a -> monadS 'rv (list 'a) 'e
let tailS = function
| _ :: xs -> returnS xs
| [] -> failS "tailM"
end
val read_tagS : forall 'regs 'a 'e. Bitvector 'a => 'a -> monadS 'regs bitU 'e
let read_tagS addr =
maybe_failS "nat_of_bv" (nat_of_bv addr) >>$= (fun addr ->
readS (fun s -> fromMaybe B0 (Map.lookup addr s.tagstate)))
(* Read bytes from memory and return in little endian order *)
val get_mem_bytes : forall 'regs. nat -> nat -> sequential_state 'regs -> maybe (list memory_byte * bitU)
let get_mem_bytes addr sz s =
let addrs = genlist (fun n -> addr + n) sz in
let read_byte s addr = Map.lookup addr s.memstate in
let read_tag s addr = Map.findWithDefault addr B0 s.tagstate in
Maybe.map
(fun mem_val -> (mem_val, List.foldl and_bit B1 (List.map (read_tag s) addrs)))
(just_list (List.map (read_byte s) addrs))
val read_memt_bytesS : forall 'regs 'e. read_kind -> nat -> nat -> monadS 'regs (list memory_byte * bitU) 'e
let read_memt_bytesS _ addr sz =
readS (get_mem_bytes addr sz) >>$=
maybe_failS "read_memS"
val read_mem_bytesS : forall 'regs 'e. read_kind -> nat -> nat -> monadS 'regs (list memory_byte) 'e
let read_mem_bytesS rk addr sz =
read_memt_bytesS rk addr sz >>$= (fun (bytes, _) ->
returnS bytes)
val read_memtS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b => read_kind -> 'a -> integer -> monadS 'regs ('b * bitU) 'e
let read_memtS rk a sz =
maybe_failS "nat_of_bv" (nat_of_bv a) >>$= (fun a ->
read_memt_bytesS rk a (nat_of_int sz) >>$= (fun (bytes, tag) ->
maybe_failS "bits_of_mem_bytes" (of_bits (bits_of_mem_bytes bytes)) >>$= (fun mem_val ->
returnS (mem_val, tag))))
val read_memS : forall 'regs 'e 'a 'b 'addrsize. Bitvector 'a, Bitvector 'b => read_kind -> 'addrsize -> 'a -> integer -> monadS 'regs 'b 'e
let read_memS rk addr_size a sz =
read_memtS rk a sz >>$= (fun (bytes, _) ->
returnS bytes)
val excl_resultS : forall 'regs 'e. unit -> monadS 'regs bool 'e
let excl_resultS =
(* TODO: This used to be more deterministic, checking a flag in the state
whether an exclusive load has occurred before. However, this does not
seem very precise; it might be safer to overapproximate the possible
behaviours by always making a nondeterministic choice. *)
choose_boolS
(* Write little-endian list of bytes to given address *)
val put_mem_bytes : forall 'regs. nat -> nat -> list memory_byte -> bitU -> sequential_state 'regs -> sequential_state 'regs
let put_mem_bytes addr sz v tag s =
let addrs = genlist (fun n -> addr + n) sz in
let a_v = List.zip addrs v in
let write_byte mem (addr, v) = Map.insert addr v mem in
let write_tag mem addr = Map.insert addr tag mem in
<| s with memstate = List.foldl write_byte s.memstate a_v;
tagstate = List.foldl write_tag s.tagstate addrs |>
val write_memt_bytesS : forall 'regs 'e. write_kind -> nat -> nat -> list memory_byte -> bitU -> monadS 'regs bool 'e
let write_memt_bytesS _ addr sz v t =
updateS (put_mem_bytes addr sz v t) >>$
returnS true
val write_mem_bytesS : forall 'regs 'e. write_kind -> nat -> nat -> list memory_byte -> monadS 'regs bool 'e
let write_mem_bytesS wk addr sz v = write_memt_bytesS wk addr sz v B0
val write_memtS : forall 'regs 'e 'a 'b. Bitvector 'a, Bitvector 'b =>
write_kind -> 'a -> integer -> 'b -> bitU -> monadS 'regs bool 'e
let write_memtS wk addr sz v t =
match (nat_of_bv addr, mem_bytes_of_bits v) with
| (Just addr, Just v) -> write_memt_bytesS wk addr (nat_of_int sz) v t
| _ -> failS "write_mem"
end
val write_memS : forall 'regs 'e 'a 'b 'addrsize. Bitvector 'a, Bitvector 'b =>
write_kind -> 'addrsize -> 'a -> integer -> 'b -> monadS 'regs bool 'e
let write_memS wk addr_size addr sz v = write_memtS wk addr sz v B0
val read_regS : forall 'regs 'rv 'a 'e. register_ref 'regs 'rv 'a -> monadS 'regs 'a 'e
let read_regS reg = readS (fun s -> reg.read_from s.regstate)
(* TODO
let read_reg_range reg i j state =
let v = slice (get_reg state (name_of_reg reg)) i j in
[(Value (vec_to_bvec v),state)]
let read_reg_bit reg i state =
let v = access (get_reg state (name_of_reg reg)) i in
[(Value v,state)]
let read_reg_field reg regfield =
let (i,j) = register_field_indices reg regfield in
read_reg_range reg i j
let read_reg_bitfield reg regfield =
let (i,_) = register_field_indices reg regfield in
read_reg_bit reg i *)
val read_regvalS : forall 'regs 'rv 'e.
register_accessors 'regs 'rv -> string -> monadS 'regs 'rv 'e
let read_regvalS (read, _) reg =
readS (fun s -> read reg s.regstate) >>$= (function
| Just v -> returnS v
| Nothing -> failS ("read_regvalS " ^ reg)
end)
val write_regvalS : forall 'regs 'rv 'e.
register_accessors 'regs 'rv -> string -> 'rv -> monadS 'regs unit 'e
let write_regvalS (_, write) reg v =
readS (fun s -> write reg v s.regstate) >>$= (function
| Just rs' -> updateS (fun s -> <| s with regstate = rs' |>)
| Nothing -> failS ("write_regvalS " ^ reg)
end)
val write_regS : forall 'regs 'rv 'a 'e. register_ref 'regs 'rv 'a -> 'a -> monadS 'regs unit 'e
let write_regS reg v =
updateS (fun s -> <| s with regstate = reg.write_to v s.regstate |>)
(* TODO
val update_reg : forall 'regs 'rv 'a 'b 'e. register_ref 'regs 'rv 'a -> ('a -> 'b -> 'a) -> 'b -> monadS 'regs unit 'e
let update_reg reg f v state =
let current_value = get_reg state reg in
let new_value = f current_value v in
[(Value (), set_reg state reg new_value)]
let write_reg_field reg regfield = update_reg reg regfield.set_field
val update_reg_range : forall 'regs 'rv 'a 'b. Bitvector 'a, Bitvector 'b => register_ref 'regs 'rv 'a -> integer -> integer -> 'a -> 'b -> 'a
let update_reg_range reg i j reg_val new_val = set_bits (reg.is_inc) reg_val i j (bits_of new_val)
let write_reg_range reg i j = update_reg reg (update_reg_range reg i j)
let update_reg_pos reg i reg_val x = update_list reg.is_inc reg_val i x
let write_reg_pos reg i = update_reg reg (update_reg_pos reg i)
let update_reg_bit reg i reg_val bit = set_bit (reg.is_inc) reg_val i (to_bitU bit)
let write_reg_bit reg i = update_reg reg (update_reg_bit reg i)
let update_reg_field_range regfield i j reg_val new_val =
let current_field_value = regfield.get_field reg_val in
let new_field_value = set_bits (regfield.field_is_inc) current_field_value i j (bits_of new_val) in
regfield.set_field reg_val new_field_value
let write_reg_field_range reg regfield i j = update_reg reg (update_reg_field_range regfield i j)
let update_reg_field_pos regfield i reg_val x =
let current_field_value = regfield.get_field reg_val in
let new_field_value = update_list regfield.field_is_inc current_field_value i x in
regfield.set_field reg_val new_field_value
let write_reg_field_pos reg regfield i = update_reg reg (update_reg_field_pos regfield i)
let update_reg_field_bit regfield i reg_val bit =
let current_field_value = regfield.get_field reg_val in
let new_field_value = set_bit (regfield.field_is_inc) current_field_value i (to_bitU bit) in
regfield.set_field reg_val new_field_value
let write_reg_field_bit reg regfield i = update_reg reg (update_reg_field_bit regfield i)*)
(* TODO Add Show typeclass for value and exception type *)
val show_result : forall 'a 'e. state_result 'a 'e -> string
let show_result = function
| Value _ -> "Value ()"
| Ex (Failure msg) -> "Failure " ^ msg
| Ex (Throw _) -> "Throw"
end
val prerr_results : forall 'a 'e 's. SetType 's => set (state_result 'a 'e * 's) -> unit
let prerr_results rs =
let _ = Set.map (fun (r, _) -> let _ = prerr_endline (show_result r) in ()) rs in
()