Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Consolidate monad definitions to extlib #270

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions examples/IntroductionSolutions.v
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ From Coq Require Import
From ExtLib Require Import
Monad
Traversable
Data.List.
Data.List
Data.Monads.StateMonad.

From ITree Require Import
Simple.
Expand All @@ -30,6 +31,8 @@ Import ListNotations.
Import ITreeNotations.
Import MonadNotation.
Open Scope monad_scope.

Existing Instance Monad_stateT.
(* end hide *)

(** * Events *)
Expand Down Expand Up @@ -75,25 +78,25 @@ Definition write_one : itree ioE unit :=
- [void1] is the empty event (so the resulting ITree can trigger
no event). *)

Compute Monads.stateT (list nat) (itree void1) unit.
Compute stateT (list nat) (itree void1) unit.
Print void1.

Definition handle_io
: forall R, ioE R -> Monads.stateT (list nat) (itree void1) R
:= fun R e log =>
match e with
| Input => ret (log, [0])
| Output o => ret (log ++ o, tt)
end.
: forall R, ioE R -> stateT (list nat) (itree void1) R
:= fun R e => mkStateT (fun log =>
match e in ioE R return itree void1 (R * list nat) with
| Input => ret ([0], log)
| Output o => ret (tt, log ++ o)
end).

(** [interp] lifts any handler into an _interpreter_, of type
[forall R, itree ioE R -> M R]. *)
Definition interp_io
: forall R, itree ioE R -> itree void1 (list nat * R)
:= fun R t => Monads.run_stateT (interp handle_io t) [].
: forall R, itree ioE R -> itree void1 (R * list nat)
:= fun R t => runStateT (interp handle_io t) [].

(** We can now interpret [write_one]. *)
Definition interpreted_write_one : itree void1 (list nat * unit)
Definition interpreted_write_one : itree void1 (unit * list nat)
:= interp_io _ write_one.

(** Intuitively, [interp_io] replaces every [ITree.trigger] in the
Expand Down
2 changes: 1 addition & 1 deletion extra/Dijkstra/DelaySpecMonad.v
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Delimit Scope delayspec_scope with delayspec.
Notation "a ∈ b" := (proj1_sig (A := _ -> Prop) b a) (at level 70) : delayspec_scope.
Notation "a ∋ b" := (proj1_sig (A := _ -> Prop) a b) (at level 70, only parsing) : delayspec_scope.

Definition Delay (A : Type) := itree void1 A.
Definition Delay := itree void1.

#[global] Instance EqMDelay : Eq1 Delay := @ITreeMonad.Eq1_ITree void1.
#[global] Instance MonadDelay : Monad Delay := @Monad_itree void1.
Expand Down
59 changes: 30 additions & 29 deletions extra/Dijkstra/StateDelaySpec.v
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
From ExtLib Require Import
Data.List
Data.Monads.StateMonad
Structures.Monad.

From Paco Require Import paco.
Expand Down Expand Up @@ -42,21 +43,21 @@ Section StateDelaySpec.

Definition StateDelayObs := EffectObsStateT St DelaySpec Delay.

Definition StateDelayMonadMorph := MonadMorphimStateT St DelaySpec Delay.
Definition StateDelayMonadMorph := MonadMorphismStateT St DelaySpec Delay.

Definition PrePost A : Type := (Delay (St * A) -> Prop ) * (St -> Prop).
Definition PrePost A : Type := (Delay (A * St) -> Prop ) * (St -> Prop).

Definition PrePostRef {A : Type} (m : StateDelay A) (pp : PrePost A) : Prop :=
let '(post,pre) := pp in
forall s, pre s -> post (m s).
forall s, pre s -> post (runStateT m s).

Program Definition encode {A : Type} (pp : PrePost A) : StateDelaySpec A :=
let '(post,pre) := pp in
fun s p => pre s /\ (forall r, post r -> p r).
mkStateT (fun s p => pre s /\ (forall r, post r -> p r)).

Definition verify_cond {A : Type} := DijkstraProp StateDelay StateDelaySpec StateDelayObs A.

Lemma encode_correct : forall (A : Type) (pre : St -> Prop) (post : Delay (St * A) -> Prop)
Lemma encode_correct : forall (A : Type) (pre : St -> Prop) (post : Delay (A * St) -> Prop)
(m : StateDelay A),
resp_eutt post -> (PrePostRef m (post,pre) <-> verify_cond (encode (post,pre)) m).
Proof.
Expand All @@ -65,28 +66,28 @@ Section StateDelaySpec.
- repeat red. simpl. intros. destruct p as [p Hp]. simpl in H1. destruct H1 as [Hpre Himp].
auto.
- repeat red in H0. simpl in H0.
set (exist _ post H) as p. enough ((m s) ∈ p); auto.
set (exist _ post H) as p. enough ((runStateT m s) ∈ p); auto.
apply H0. auto.
Qed.

Definition PrePostPair A : Type := PrePost A * PrePost A.

Definition PrePostPairRef {A : Type} (pppp : PrePostPair A) (m : StateDelay A) :=
let '((post0, pre0), (post1, pre1)) := pppp in
forall s, (pre0 s -> post0 (m s)) /\ (pre1 s -> post1 (m s)) .
forall s, (pre0 s -> post0 (runStateT m s)) /\ (pre1 s -> post1 (runStateT m s)) .

Program Definition encode_pair {A : Type} (pppp : PrePostPair A) : StateDelaySpec A:=
let '((post0, pre0), (post1, pre1)) := pppp in
fun s (p : DelaySpecInput (St * A)) =>
(pre0 s /\ (forall r, post0 r -> p r)) \/ (pre1 s /\ forall r, post1 r -> p r).
mkStateT (fun s (p : DelaySpecInput (A * St)) =>
(pre0 s /\ (forall r, post0 r -> p r)) \/ (pre1 s /\ forall r, post1 r -> p r)).
Next Obligation.
destruct H0 as [H0 | H1].
- destruct H0 as [Hp Hr]. left. auto.
- destruct H1 as [Hp Hr]. right. auto.
Qed.

Lemma encode_pair_correct : forall (A : Type) (pre0 pre1 : St -> Prop)
(post0 post1 : Delay (St * A) -> Prop ) (m : StateDelay A),
(post0 post1 : Delay (A * St) -> Prop ) (m : StateDelay A),
let pp : PrePostPair A := ((post0,pre0),(post1,pre1)) in
resp_eutt post0 -> resp_eutt post1 ->
(PrePostPairRef pp m <-> verify_cond (encode_pair pp) m).
Expand All @@ -97,20 +98,20 @@ Section StateDelaySpec.
destruct H2 as [ [Hs Hp] | [Hs Hp] ]; simpl in *; auto.
- repeat red in H1. simpl in *.
split; intros.
+ set (exist _ post0 H) as p. enough ((m s) ∈ p ); auto.
+ set (exist _ post0 H) as p. enough ((runStateT m s) ∈ p ); auto.
apply H1. left. split; auto.
+ set (exist _ post1 H0) as p. enough ((m s) ∈ p ); auto.
+ set (exist _ post1 H0) as p. enough ((runStateT m s) ∈ p ); auto.
apply H1. right. split; auto.
Qed.

Definition PrePostList A : Type := list (PrePost A).

Definition PrePostListRef {A : Type} (ppl : PrePostList A) (m : StateDelay A) :=
forall s, List.Forall (fun pp : PrePost A=> let (post,pre) := pp in pre s -> post (m s) ) ppl.
forall s, List.Forall (fun pp : PrePost A=> let (post,pre) := pp in pre s -> post (runStateT m s) ) ppl.

Program Definition encode_list {A : Type} (ppl : PrePostList A) : StateDelaySpec A :=
fun s (p : DelaySpecInput (St * A) ) =>
List.Exists (fun pp : PrePost A => let (post,pre) := pp in pre s /\ forall r, post r -> p r) ppl.
mkStateT (fun s (p : DelaySpecInput (A * St) ) =>
List.Exists (fun pp : PrePost A => let (post,pre) := pp in pre s /\ forall r, post r -> p r) ppl).
Next Obligation.
induction H0; eauto.
destruct x as [post pre]. destruct H0 as [Hs Hr]. left. auto.
Expand All @@ -129,7 +130,7 @@ Section StateDelaySpec.
+ destruct a as [post pre].
inversion H1; subst.
* destruct H3. auto.
assert ((pre s -> post (m s)) ); auto.
assert ((pre s -> post (runStateT m s)) ); auto.
intros. inversion Hrefine; subst; auto.
* apply IHppl; auto.
-- inversion H; auto.
Expand All @@ -142,24 +143,24 @@ Section StateDelaySpec.
{ inversion H. auto. }
set (exist _ post Heutt) as p. specialize (Henc p) as Hencp.
constructor; intros.
+ enough ((m s) ∈ p ); auto. apply Hencp.
+ enough ((runStateT m s) ∈ p ); auto. apply Hencp.
left. split; auto.
+ apply IHppl; auto.
* inversion H. auto.
* clear IHppl. intros. apply H0. eauto.
Qed.

Definition DynPrePost A : Type := (St -> Prop) * (St -> Delay (St * A) -> Prop).
Definition DynPrePost A : Type := (St -> Prop) * (St -> Delay (A * St) -> Prop).

Definition DynPrePostRef {A : Type} (pp : DynPrePost A) (m : StateDelay A) :=
let (pre,post) := pp in
forall s, pre s -> post s (m s).
forall s, pre s -> post s (runStateT m s).

Program Definition encode_dyn {A : Type} (pp : DynPrePost A) : StateDelaySpec A :=
let (pre,post) := pp in
fun s p => pre s /\ forall r, post s r -> p r.
mkStateT (fun s p => pre s /\ forall r, post s r -> p r).

Lemma encode_dyn_correct : forall (A : Type) (pre : St -> Prop) (post : St -> Delay (St * A) -> Prop ) (m : StateDelay A),
Lemma encode_dyn_correct : forall (A : Type) (pre : St -> Prop) (post : St -> Delay (A * St) -> Prop ) (m : StateDelay A),
(forall s, resp_eutt (post s)) -> (DynPrePostRef (pre,post) m <-> verify_cond (encode_dyn (pre,post) ) m).
Proof.
intros. unfold verify_cond, DijkstraProp. split; intros.
Expand All @@ -174,7 +175,7 @@ Section StateDelaySpec.
Forall (fun pp => DynPrePostRef pp m) ppl.

Program Definition encode_list_dyn {A : Type} (ppl : list (DynPrePost A)) : StateDelaySpec A :=
fun s p => List.Exists (fun pp : DynPrePost A => let (pre,post) := pp in pre s /\ forall r, post s r -> p r ) ppl.
mkStateT (fun s p => List.Exists (fun pp : DynPrePost A => let (pre,post) := pp in pre s /\ forall r, post s r -> p r ) ppl).
Next Obligation.
induction H0; eauto. left. destruct x as [pre post]. destruct H0 as [Hs Hr].
split; auto.
Expand All @@ -192,7 +193,7 @@ Section StateDelaySpec.
+ destruct a as [pre post].
inversion H1; subst.
* destruct H2.
assert ((pre s -> post s (m s)) ); auto.
assert ((pre s -> post s (runStateT m s)) ); auto.
intros. inversion Hrefine; subst; auto.
* apply IHppl; auto.
-- inversion H; auto.
Expand All @@ -204,7 +205,7 @@ Section StateDelaySpec.
{ inversion H. auto. }
constructor; intros.
+ red. intros. set (exist _ (post s) (Heutt s)) as p.
specialize (H0 s p). enough ((m s) ∈ p); auto. apply H0.
specialize (H0 s p). enough ((runStateT m s) ∈ p); auto. apply H0.
left. split; auto.
+ apply IHppl; auto.
* inversion H. auto.
Expand All @@ -213,19 +214,19 @@ Section StateDelaySpec.
Qed.

Lemma combine_prepost_aux : forall (A B : Type) (pre1 pre2 : St -> Prop)
(post1 : Delay (St * A) -> Prop ) (post2 : Delay (St * B) -> Prop)
(post1 : Delay (A * St) -> Prop ) (post2 : Delay (B * St) -> Prop)
(m : StateDelay A) (f : A -> StateDelay B),
verify_cond (encode (post1,pre1) ) m ->
(forall (a : A) (s : St), (* this condition is not exactly what i want*)
post1 (Ret (s,a) ) -> post2 (f a s) ) ->
post1 (Ret (a,s) ) -> post2 (runStateT (f a) s) ) ->
(post1 ITree.spin -> post2 ITree.spin) ->
resp_eutt post1 ->
verify_cond (encode (post2, pre1) ) (bind m f).
Proof.
intros. repeat red in H. repeat red. intros.
destruct p as [p Hp]. simpl in *.
destruct H3.
destruct (eutt_reta_or_div (m s)); basic_solve.
destruct (eutt_reta_or_div (runStateT m s)); basic_solve.
- destruct a as [s' a].
cbn in H5. rewrite <- H5, bind_ret_l; cbn. apply H4, H0. rewrite H5.
apply (H s (exist _ post1 H2)); auto.
Expand All @@ -234,10 +235,10 @@ Section StateDelaySpec.
Qed.

Lemma combine_prepost : forall (A B : Type) (pre1 pre2 : St -> Prop)
(post1 : Delay (St * A) -> Prop ) (post2 : Delay (St * B) -> Prop)
(post1 : Delay (A * St) -> Prop ) (post2 : Delay (B * St) -> Prop)
(m : StateDelay A) (f : A -> StateDelay B),
verify_cond (encode (post1,pre1) ) m ->
(forall a s, post1 (Ret (s,a)) -> pre2 s) ->
(forall a s, post1 (Ret (a,s)) -> pre2 s) ->
(forall a, verify_cond (encode (post2,pre2) ) (f a) ) ->
(post1 ITree.spin -> post2 ITree.spin) ->
resp_eutt post1 ->
Expand Down
40 changes: 23 additions & 17 deletions extra/Dijkstra/StateIOTrace.v
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ From ExtLib Require Import
Data.String
Structures.Monad
Core.RelDec
Data.Map.FMapAList.
Data.Map.FMapAList
Data.Monads.StateMonad.

From Paco Require Import paco.

Expand Down Expand Up @@ -55,17 +56,17 @@ Definition SIOSpecEq := StateSpecTEq env (TraceSpec IO).

Definition SIOObs := EffectObsStateT env (TraceSpec IO) (itree IO).

Definition SIOMorph :=MonadMorphimStateT env (TraceSpec IO) (itree IO).
Definition SIOMorph :=MonadMorphismStateT env (TraceSpec IO) (itree IO).

Definition verify_cond {A : Type} := DijkstraProp (stateT env (itree IO)) StateIOSpec SIOObs A.

(*Predicate on initial state and initial log*)
Definition StateIOSpecPre : Type := env -> ev_list IO -> Prop.
(*Predicate on final log and possible return value*)
Definition StateIOSpecPost (A : Type) : Type := itrace IO (env * A) -> Prop.
Definition StateIOSpecPost (A : Type) : Type := itrace IO (A * env) -> Prop.

Program Definition encode {A} (pre : StateIOSpecPre) (post : StateIOSpecPost A) : StateIOSpec A :=
fun s log p => pre s log /\ (forall tr, post tr -> p tr).
mkStateT (fun s log p => pre s log /\ (forall tr, post tr -> p tr)).


Section PrintMults.
Expand Down Expand Up @@ -114,13 +115,12 @@ Section PrintMults.
alist_add _ V v s.

Definition handleIOStateE (A : Type) (ev : (StateE +' IO) A) : stateT env (itree IO) A :=
fun s =>
match ev with
| inl1 ev' =>
match ev' with
| GetE V => Ret (s, lookup_default V 0 s)
| PutE V v => Ret (Maps.add V v s, tt) end
| inr1 ev' => Vis ev' (fun x => Ret (s,x) )
| GetE V => mkStateT (fun s => Ret (lookup_default V 0 s, s))
| PutE V v => mkStateT (fun s => Ret (tt, Maps.add V v s)) end
| inr1 ev' => mkStateT (fun s => Vis ev' (fun x => Ret (x,s)))
end.

Ltac unf_res := unfold resum, ReSum_id, id_, Id_IFun in *.
Expand Down Expand Up @@ -174,6 +174,8 @@ Section PrintMults.
let H' := fresh H in
match type of H with ?P -> _ => assert (H' : P); try (specialize (H H'); clear H') end.

Arguments interp_state : simpl never.

Lemma print_mults_sats_spec :
verify_cond (encode print_mults_pre print_mults_post) (interp_state handleIOStateE print_mults).
Proof.
Expand Down Expand Up @@ -208,11 +210,11 @@ Section PrintMults.
assert (RAnsRef IO unit nat (evans nat Read n) tt Read n); auto with itree.
apply H6 in H. pclearbot. auto.
}
clear Href ev. subst. rewrite bind_ret_l in H. simpl in *. rewrite interp_state_bind in H.
rewrite interp_state_trigger in H. simpl in *. rewrite bind_ret_l in H.
simpl in *.
clear Href ev. subst. rewrite bind_ret_l in H. cbn in *. rewrite interp_state_bind in H.
rewrite interp_state_trigger in H. cbn in *. rewrite bind_ret_l in H.
cbn in *.
specialize (@interp_state_iter' (StateE +' IO) ) as Hiter.
unfold state_eq in Hiter. rewrite Hiter in H. clear Hiter.
rewrite Hiter in H; eauto. clear Hiter.

remember (Maps.add X n s) as si.
assert (si = alist_add RelDec_string X n s); try (subst; auto; fail).
Expand Down Expand Up @@ -240,7 +242,8 @@ Section PrintMults.

(*This block shows how to proceed through the loop body*)
rename H0 into H.
unfold Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree in H.
unfold Basics.iter, MonadIter_stateT, Basics.iter, MonadIter_itree in H.
cbn in H.
rewrite unfold_iter in H.
match type of H with _ ⊑ ITree.bind _ ?k0 => remember k0 as k end.

Expand All @@ -255,7 +258,9 @@ Section PrintMults.
setoid_rewrite bind_ret_l in H.
unf_res.
punfold H. red in H. cbn in *.
dependent induction H.
remember (interp_state handleIOStateE) as h eqn:Eh.
revert Eh.
dependent induction H; intros Eh.
2:{ rewrite <- x. constructor; auto. eapply IHruttF; eauto; reflexivity. }
inversion H; ddestruction; subst; ddestruction; try contradiction.
subst. specialize (H0 tt tt).
Expand All @@ -270,7 +275,8 @@ Section PrintMults.
remember (lookup_default Y 0 si) as m.
eapply CIH with (Maps.add Y (n + m) si); try apply lookup_eq.
2: { rewrite lookup_neq; subst; auto. }
rewrite tau_eutt in Hk1. setoid_rewrite bind_trigger in Hk1.
rewrite tau_eutt in Hk1.
setoid_rewrite bind_trigger in Hk1.
setoid_rewrite interp_state_vis in Hk1. cbn in *.
rewrite bind_ret_l in Hk1. rewrite tau_eutt in Hk1.
setoid_rewrite bind_vis in Hk1.
Expand All @@ -280,11 +286,11 @@ Section PrintMults.
rewrite interp_state_ret in Hk1. rewrite bind_ret_l in Hk1.
cbn in *.
rewrite tau_eutt in Hk1.
unfold Basics.iter, MonadIter_stateT0, Basics.iter, MonadIter_itree.
unfold Basics.iter, MonadIter_itree.
match goal with
H : _ ⊑ ITree.iter _ (?s1, _) |- _ ⊑ ITree.iter _ (?s2, _) =>
enough (Hseq : s2 = s1) end; try rewrite Hseq; auto.
subst. rewrite Nat.add_comm. auto.
Qed.
Qed.

End PrintMults.
Loading