Skip to content

Commit

Permalink
Rewrite CompM combinators in Prelude.sawcore to use new tuple types.
Browse files Browse the repository at this point in the history
  • Loading branch information
Brian Huffman committed Mar 18, 2022
1 parent 0b43f8f commit aec37ef
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions saw-core/prelude/Prelude.sawcore
Original file line number Diff line number Diff line change
Expand Up @@ -2085,15 +2085,24 @@ composeM : (a b c: sort 0) -> (a -> CompM b) -> (b -> CompM c) -> a -> CompM c;
composeM a b c f g x = bindM b c (f x) g;

-- Tuple a type onto the input and output types of a monadic function
tupleCompMFunBoth : (a b c: sort 0) -> (a -> CompM b) -> #(c, a) -> CompM #(c, b);
tupleCompMFunBoth a b c f =
\ (x : #(c, a)) ->
bindM b #(c, b) (f x.1) (\ (y:b) -> returnM #(c, b) (x.0, y));

-- Tuple a valu onto the output of a monadic function
tupleCompMFunOut : (a b c: sort 0) -> c -> (a -> CompM b) -> (a -> CompM #(c, b));
tupleCompMFunOut a b c x f =
\ (y:a) -> bindM b #(c, b) (f y) (\ (z:b) -> returnM #(c, b) (x, z));
tupleCompMFunBoth :
(a b : TypeList) ->
(c : sort 0) ->
(Tuple a -> CompM (Tuple b)) ->
Tuple (TypeCons c a) -> CompM (Tuple (TypeCons c b));
tupleCompMFunBoth a b c f x =
bindM (Tuple b) (Tuple (TypeCons c b)) (f (tailTuple c a x))
(\ (y : Tuple b) -> returnM (Tuple (TypeCons c b)) (consTuple c b (headTuple c a x) y));

-- Tuple a value onto the output of a monadic function
tupleCompMFunOut :
(a : sort 0) ->
(b : TypeList) ->
(c : sort 0) ->
c -> (a -> CompM (Tuple b)) -> (a -> CompM (Tuple (TypeCons c b)));
tupleCompMFunOut a b c x f y =
bindM (Tuple b) (Tuple (TypeCons c b)) (f y)
(\ (z : Tuple b) -> returnM (Tuple (TypeCons c b)) (consTuple c b x z));

-- Map a monadic function across a vector
mapM : (a :sort 0) -> (b : isort 0) -> (a -> CompM b) -> (n : Nat) -> Vec n a -> CompM (Vec n b);
Expand Down Expand Up @@ -2258,16 +2267,21 @@ lrtPi lrts b =
(\ (lrt:LetRecType) (_:LetRecTypes) (rest:sort 0) -> lrtToType lrt -> rest)
lrts;

-- Build the product type (lrtToType lrt1, ..., lrtToType lrtn) from the
-- Build the type list [lrtToType lrt1, ..., lrtToType lrtn] from the
-- LetRecTypes list [lrt1, ..., lrtn]
lrtTupleType : LetRecTypes -> sort 0;
lrtTupleType lrts =
lrtTypeList : LetRecTypes -> TypeList;
lrtTypeList lrts =
LetRecTypes#rec
(\ (lrts:LetRecTypes) -> sort 0)
#()
(\ (lrt:LetRecType) (_:LetRecTypes) (rest:sort 0) -> #(lrtToType lrt, rest))
(\ (lrts:LetRecTypes) -> TypeList)
TypeNil
(\ (lrt:LetRecType) (_:LetRecTypes) (rest:TypeList) -> TypeCons (lrtToType lrt) rest)
lrts;

-- Build the product type (lrtToType lrt1, ..., lrtToType lrtn) from the
-- LetRecTypes list [lrt1, ..., lrtn]
lrtTupleType : LetRecTypes -> sort 0;
lrtTupleType lrts = Tuple (lrtTypeList lrts);

-- NOTE: the following are needed to define letRecM instead of making it a
-- primitive, which we are keeping commented here in case that is needed
{-
Expand Down Expand Up @@ -2349,7 +2363,7 @@ letRecM1 : (a b c : sort 0) -> ((a -> CompM b) -> (a -> CompM b)) ->
letRecM1 a b c fn body =
letRecM
(LRT_Cons (LRT_Fun a (\ (_:a) -> LRT_Ret b)) LRT_Nil) c
(\ (f:a -> CompM b) -> (fn f, ()))
(\ (f:a -> CompM b) -> consTuple (a -> CompM b) TypeNil (fn f) ())
(\ (f:a -> CompM b) -> body f);

-- A single-argument fixed-point function
Expand All @@ -2359,7 +2373,7 @@ fixM : (a:sort 0) -> (b:a -> sort 0) ->
fixM a b f x =
letRecM (LRT_Cons (LRT_Fun a (\ (y:a) -> LRT_Ret (b y))) LRT_Nil)
(b x)
(\ (g: (y:a) -> CompM (b y)) -> (f g, ()))
(\ (g: (y:a) -> CompM (b y)) -> consTuple ((y:a) -> CompM (b y)) TypeNil (f g) ())
(\ (g: (y:a) -> CompM (b y)) -> g x);


Expand Down Expand Up @@ -2454,7 +2468,10 @@ multiFixM : (lrts:LetRecTypes) -> lrtPi lrts (lrtTupleType lrts) ->
multiArgFixM : (lrt:LetRecType) -> (lrtToType lrt -> lrtToType lrt) ->
lrtToType lrt;
multiArgFixM lrt F =
(multiFixM (LRT_Cons lrt LRT_Nil) (\ (f:lrtToType lrt) -> (F f, ()))).0;
(multiFixM
(LRT_Cons lrt LRT_Nil)
(\ (f:lrtToType lrt) -> consTuple (lrtToType lrt) TypeNil (F f) ())
).0;


-- Test computations
Expand Down Expand Up @@ -2516,7 +2533,7 @@ test_fun6 x =
(Vec 64 Bool)
(\ (f1:(Vec 64 Bool -> CompM (Vec 64 Bool)))
(f2:(Vec 64 Bool -> CompM (Vec 64 Bool))) ->
(f2, (f1, ())))
(f2, f1))
(\ (f1:(Vec 64 Bool -> CompM (Vec 64 Bool)))
(f2:(Vec 64 Bool -> CompM (Vec 64 Bool))) ->
f1 x);
Expand Down

0 comments on commit aec37ef

Please sign in to comment.