Skip to content

Commit

Permalink
Typechecking for case expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
yav committed Jan 17, 2024
1 parent e19a4de commit 659030a
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 10 deletions.
5 changes: 5 additions & 0 deletions src/Cryptol/IR/FreeVars.hs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ instance FreeVars Expr where
ESel e _ -> freeVars e
ESet ty e _ v -> freeVars ty <> freeVars [e,v]
EIf e1 e2 e3 -> freeVars [e1,e2,e3]
ECase e as d -> freeVars e <> freeVars (Map.elems as)
<> maybe mempty freeVars d
EComp t1 t2 e mss -> freeVars [t1,t2] <> rmVals (defs mss) (freeVars e)
<> mconcat (map foldFree mss)
EVar x -> mempty { valDeps = Set.singleton x }
Expand All @@ -124,6 +126,9 @@ instance FreeVars Expr where
foldFree = foldr updateFree mempty
updateFree x rest = freeVars x <> rmVals (defs x) rest

instance FreeVars CaseAlt where
freeVars (CaseAlt xs e) = foldr (rmVal . fst) (freeVars e) xs

instance FreeVars Match where
freeVars m = case m of
From _ t1 t2 e -> freeVars t1 <> freeVars t2 <> freeVars e
Expand Down
9 changes: 9 additions & 0 deletions src/Cryptol/IR/TraverseNames.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ instance TraverseNames Expr where
EIf e1 e2 e3 -> EIf <$> traverseNamesIP e1
<*> traverseNamesIP e2
<*> traverseNamesIP e3
ECase e as d -> ECase <$> traverseNamesIP e
<*> traverse traverseNamesIP as
<*> traverse traverseNamesIP d

EComp t1 t2 e mss -> EComp <$> traverseNamesIP t1
<*> traverseNamesIP t2
Expand All @@ -82,6 +85,11 @@ instance TraverseNames Expr where
EPropGuards gs t -> EPropGuards <$> traverse doG gs <*> traverseNamesIP t
where doG (xs, e) = (,) <$> traverseNamesIP xs <*> traverseNamesIP e

instance TraverseNames CaseAlt where
traverseNamesIP (CaseAlt xs e) =
CaseAlt <$> traverse doPair xs <*> traverseNamesIP e
where doPair (x,y) = (,) <$> traverseNamesIP x <*> traverseNamesIP y

instance TraverseNames Match where
traverseNamesIP mat =
case mat of
Expand Down Expand Up @@ -168,6 +176,7 @@ instance TraverseNames TypeSource where
GeneratorOfListComp -> pure src
TypeErrorPlaceHolder -> pure src
CasedExpression -> pure src
ConPat -> pure src

instance TraverseNames ArgDescr where
traverseNamesIP arg = mk <$> traverseNamesIP (argDescrFun arg)
Expand Down
5 changes: 5 additions & 0 deletions src/Cryptol/Transform/MonoValues.hs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ rewE rews = go
ESel e s -> ESel <$> go e <*> return s
ESet ty e s v -> ESet ty <$> go e <*> return s <*> go v
EIf e1 e2 e3 -> EIf <$> go e1 <*> go e2 <*> go e3
ECase e as d -> ECase <$> go e <*> traverse (rewCase rews) as
<*> traverse (rewCase rews) d

EComp len t e mss -> EComp len t <$> go e <*> mapM (mapM (rewM rews)) mss
EVar _ -> return expr
Expand All @@ -204,6 +206,9 @@ rewE rews = go
EPropGuards guards ty -> EPropGuards <$> (\(props, e) -> (,) <$> pure props <*> go e) `traverse` guards <*> pure ty


rewCase :: RewMap -> CaseAlt -> M CaseAlt
rewCase rew (CaseAlt xs e) = CaseAlt xs <$> rewE rew e

rewM :: RewMap -> Match -> M Match
rewM rews ma =
case ma of
Expand Down
5 changes: 5 additions & 0 deletions src/Cryptol/Transform/Specialize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ specializeExpr expr =
ESel e s -> ESel <$> specializeExpr e <*> pure s
ESet ty e s v -> ESet ty <$> specializeExpr e <*> pure s <*> specializeExpr v
EIf e1 e2 e3 -> EIf <$> specializeExpr e1 <*> specializeExpr e2 <*> specializeExpr e3
ECase e as d -> ECase <$> specializeExpr e
<*> traverse specializeCaseAlt as
<*> traverse specializeCaseAlt d
EComp len t e mss -> EComp len t <$> specializeExpr e <*> traverse (traverse specializeMatch) mss
-- Bindings within list comprehensions always have monomorphic types.
EVar {} -> specializeConst expr
Expand Down Expand Up @@ -117,6 +120,8 @@ specializeExpr expr =
pm <- liftSpecT getPrimMap
pure $ eError pm ty "no constraint guard was satisfied"

specializeCaseAlt :: CaseAlt -> SpecM CaseAlt
specializeCaseAlt (CaseAlt xs e) = CaseAlt xs <$> specializeExpr e

specializeMatch :: Match -> SpecM Match
specializeMatch (From qn l t e) = From qn l t <$> specializeExpr e
Expand Down
23 changes: 22 additions & 1 deletion src/Cryptol/TypeCheck/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ data Expr = EList [Expr] Type -- ^ List value (with type of elements)
-- The included type gives the type of the record being updated

| EIf Expr Expr Expr -- ^ If-then-else
| ECase Expr (Map Ident CaseAlt) (Maybe CaseAlt)
-- ^ Case expression. The keys are the name of constructors
-- `Nothing` for default case, the expresssions are what to
-- do if the constructor matches. If the constructor binds
-- variables, then then the expr should be `EAbs`
| EComp Type Type Expr [[Match]]
-- ^ List comprehensions
-- The types cache the length of the
Expand Down Expand Up @@ -223,6 +228,10 @@ data Expr = EList [Expr] Type -- ^ List value (with type of elements)

deriving (Show, Generic, NFData)

-- | Used for case expressions. Similar to a lambda, the variables
-- are bound by the value examined in the case.
data CaseAlt = CaseAlt [(Name,Type)] Expr
deriving (Show, Generic, NFData)

data Match = From Name Type Type Expr
-- ^ Type arguments are the length and element
Expand Down Expand Up @@ -308,9 +317,17 @@ instance PP (WithNames Expr) where
$ sep [ text "if" <+> ppW e1
, text "then" <+> ppW e2
, text "else" <+> ppW e3 ]
ECase e arms dflt ->
optParens (prec > 0) $
vcat [ "case" <+> pp e <+> "of"
, indent 2 (vcat ppArms $$ ppDflt)
]
where
ppArms = [ pp i <+> pp c | (i,c) <- reverse (Map.toList arms) ]
ppDflt = maybe mempty pp dflt

EComp _ _ e mss -> let arm ms = text "|" <+> commaSep (map ppW ms)
in brackets $ ppW e <+> (align (vcat (map arm mss)))
in brackets $ ppW e <+> align (vcat (map arm mss))

EVar x -> ppPrefixName x

Expand Down Expand Up @@ -358,6 +375,10 @@ instance PP (WithNames Expr) where
ppW x = ppWithNames nm x
ppWP x = ppWithNamesPrec nm x

instance PP CaseAlt where
ppPrec _ (CaseAlt xs e) = hsep (map ppV xs) <+> "->" <+> pp e
where ppV (x,t) = parens (pp x <.> ":" <+> pp t)

ppLam :: NameMap -> Int -> [TParam] -> [Prop] -> [(Name,Type)] -> Expr -> Doc
ppLam nm prec [] [] [] e = nest 2 (ppWithNamesPrec nm prec e)
ppLam nm prec ts ps xs e =
Expand Down
16 changes: 16 additions & 0 deletions src/Cryptol/TypeCheck/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ data Error = KindMismatch (Maybe TypeSource) Kind Kind
-- 1) Number of parameters we have,
-- 2) Number of parameters we need.

| OverlappingPat (Maybe Ident) [Range]
-- ^ Overlapping patterns in a case

| TemporaryError Doc
-- ^ This is for errors that don't fit other cateogories.
-- We should not use it much, and is generally to be used
Expand Down Expand Up @@ -214,6 +217,7 @@ errorImportance err =
TypeMismatch {} -> 8
SchemaMismatch {} -> 7
InvalidConPat {} -> 7
OverlappingPat {} -> 3
RecursiveType {} -> 7
NotForAll {} -> 6
TypeVariableEscaped {} -> 5
Expand Down Expand Up @@ -288,6 +292,7 @@ instance TVars Error where
SchemaMismatch i !$ (apSubst su t1) !$ (apSubst su t2)
TypeMismatch src pa t1 t2 -> TypeMismatch src pa !$ (apSubst su t1) !$ (apSubst su t2)
InvalidConPat {} -> err
OverlappingPat {} -> err
RecursiveType src pa t1 t2 -> RecursiveType src pa !$ (apSubst su t1) !$ (apSubst su t2)
UnsolvedGoals gs -> UnsolvedGoals !$ apSubst su gs
UnsolvableGoals gs -> UnsolvableGoals !$ apSubst su gs
Expand Down Expand Up @@ -339,6 +344,7 @@ instance FVS Error where
SchemaMismatch _ t1 t2 -> fvs (t1,t2)
TypeMismatch _ _ t1 t2 -> fvs (t1,t2)
InvalidConPat {} -> Set.empty
OverlappingPat {} -> Set.empty
RecursiveType _ _ t1 t2 -> fvs (t1,t2)
UnsolvedGoals gs -> fvs gs
UnsolvableGoals gs -> fvs gs
Expand Down Expand Up @@ -485,6 +491,15 @@ instance PP (WithNames Error) where
, "but there are" <+> int have <.> "."
]

OverlappingPat mbCon rs ->
addTVarsDescsAfter names err $
nested ("Overlapping choices for" <+> what <.> ":") $
vcat [ "Pattern at" <+> pp r | r <- rs ]
where
what = case mbCon of
Just i -> "constructor" <+> pp i
Nothing -> "default case"

UnsolvableGoals gs -> explainUnsolvable names gs

UnsolvedGoals gs
Expand Down Expand Up @@ -590,6 +605,7 @@ instance PP (WithNames Error) where
NSValue -> "value"
NSType -> "type"
NSModule -> "module"
NSConstructor -> "constructor"

FunctorInstanceBadBacktick bad ->
case bad of
Expand Down
38 changes: 29 additions & 9 deletions src/Cryptol/TypeCheck/Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -479,38 +479,56 @@ checkE expr tGoal =
P.ECase e as ->
do et <- newType CasedExpression KType
alts <- forM as \a -> checkCaseAlt a et tGoal
undefined
rng <- curRange
e1 <- checkE e (WithSource et CasedExpression (Just rng))
let mp1 = Map.fromListWith (++) [ (x,[(r,y)]) | (r,x,y) <- alts ]

forM_ (Map.toList mp1) \(mb,cs) ->
case cs of
[_] -> pure ()
_ -> recordError (OverlappingPat mb [ r | (r,_) <- cs ])

-- check that we only have 1 default
let dflt = do (_,k) : _ <- Map.lookup Nothing mp1
pure k

let arms = Map.fromList [ (i,a) | (_,Just i, a) <- alts ]
pure (ECase e1 arms dflt)

P.EParens e -> checkE e tGoal


checkCaseAlt ::
P.CaseAlt Name -> Type -> TypeWithSource -> InferM ()
P.CaseAlt Name -> Type -> TypeWithSource ->
InferM (Range, Maybe Ident, CaseAlt)
checkCaseAlt (P.CaseAlt pat e) srcT resT =
case pat of
P.PCon c ps ->
inRange (srcRange c) $
do (tArgs,pArgs,fTs,cresT) <- instantiatePCon (thing c)
do (_tArgs,_pArgs,fTs,cresT) <- instantiatePCon (thing c)
-- XXX: should we store these somewhere?

let have = length ps
need = length fTs
unless (have == need) (recordError (InvalidConPat have need))
let scresT = WithSource
{ twsType = cresT
let expect = WithSource
{ twsType = srcT
, twsRange = Just (srcRange c)
, twsSource = CasedExpression -- or make a new one?
, twsSource = ConPat
}
newGoals CtExactType =<< unify scresT srcT
newGoals CtExactType =<< unify expect cresT
let xs = [ (thing x, Located rng t)
| (v,t) <- zip ps fTs
, let x = isPVar v
, let rng = srcRange c
]
e1 <- withMonoTypes (Map.fromList xs) (checkE e resT)
undefined
pure (srcRange c, Just (nameIdent (thing c)), mkAlt xs e1)

P.PVar x ->
do let xty = (thing x, Located (srcRange x) srcT)
e1 <- withMonoType xty (checkE e resT)
undefined
pure (srcRange x, Nothing, mkAlt [xty] e1)

_ -> panic "checkCaseAlt" ["Unexpected pattern"]
where
Expand All @@ -519,6 +537,8 @@ checkCaseAlt (P.CaseAlt pat e) srcT resT =
P.PVar x -> x
_ -> panic "checkCaseAlt" ["Nested pattern is not PVar"]

mkAlt xs = CaseAlt [ (x, thing t) | (x,t) <- xs ]

checkRecUpd ::
Maybe (P.Expr Name) -> [ P.UpdField Name ] -> TypeWithSource -> InferM Expr
checkRecUpd mb fs tGoal =
Expand Down
4 changes: 4 additions & 0 deletions src/Cryptol/TypeCheck/ModuleBacktickInstance.hs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ instance RewVal Expr where
ESel e l -> ESel (rew e) l
ESet t e1 s e2 -> ESet (rewType t) (rew e1) s (rew e2)
EIf e1 e2 e3 -> EIf (rew e1) (rew e2) (rew e3)
ECase e as d -> ECase (rew e) (rew <$> as) (rew <$> d)
EComp t1 t2 e mss -> EComp (rewType t1) (rewType t2) (rew e) (rew mss)
EVar x -> tryVarApp
case Map.lookup x (pSubst ?vparams) of
Expand Down Expand Up @@ -388,6 +389,9 @@ instance RewVal Expr where
in evs
_ -> orElse

instance RewVal CaseAlt where
rew (CaseAlt xs e) = CaseAlt xs (rew e)


instance RewVal DeclGroup where
rew dg =
Expand Down
10 changes: 10 additions & 0 deletions src/Cryptol/TypeCheck/Parseable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ module Cryptol.TypeCheck.Parseable
) where

import Data.Void
import qualified Data.Map as Map
import Prettyprinter

import Cryptol.TypeCheck.AST
Expand Down Expand Up @@ -52,6 +53,9 @@ instance ShowParseable Expr where
showParseable e <+> showParseable s
<+> showParseable v)
showParseable (EIf c t f) = parens (text "EIf" <+> showParseable c $$ showParseable t $$ showParseable f)
showParseable (ECase e as d) =
parens (text "ECase" <+> showParseable e $$ showParseable (Map.toList as)
$$ showParseable d)
showParseable (EComp _ _ e mss) = parens (text "EComp" $$ showParseable e $$ showParseable mss)
showParseable (EVar n) = parens (text "EVar" <+> showParseable n)
showParseable (EApp fe ae) = parens (text "EApp" $$ showParseable fe $$ showParseable ae)
Expand Down Expand Up @@ -89,6 +93,12 @@ instance ShowParseable Match where
showParseable (From n _ _ e) = parens (text "From" <+> showParseable n <+> showParseable e)
showParseable (Let d) = parens (text "MLet" <+> showParseable d)


instance ShowParseable CaseAlt where
showParseable (CaseAlt xs e) =
parens (text "CaseAlt" <+> showParseable xs <+> showParseable e)


instance ShowParseable Decl where
showParseable d = parens (text "Decl" <+> showParseable (dName d)
$$ showParseable (dDefinition d))
Expand Down
3 changes: 3 additions & 0 deletions src/Cryptol/TypeCheck/Sanity.hs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ exprSchema expr =

return $ tMono t1

-- XXX
-- ECase e as d ->

EComp len t e mss ->
do checkTypeIs KNum len
checkTypeIs KType t
Expand Down
7 changes: 7 additions & 0 deletions src/Cryptol/TypeCheck/Subst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,20 @@ instance TVars Expr where
ESel e s -> ESel !$ (go e) .$ s
EComp len t e mss -> EComp !$ (apSubst su len) !$ (apSubst su t) !$ (go e) !$ (apSubst su mss)
EIf e1 e2 e3 -> EIf !$ (go e1) !$ (go e2) !$ (go e3)
ECase e as d -> ECase !$ go e !$ (apSubst su <$> as)
!$ (apSubst su <$> d)

EWhere e ds -> EWhere !$ (go e) !$ (apSubst su ds)

EPropGuards guards ty -> EPropGuards
!$ (\(props, e) -> (apSubst su `fmap'` props, go e)) `fmap'` guards
!$ apSubst su ty

instance TVars CaseAlt where
apSubst su (CaseAlt xs e) = CaseAlt !$ [(x,apSubst su t) | (x,t) <- xs]
!$ apSubst su e
-- XXX: not as strict as the rest

instance TVars Match where
apSubst su (From x len t e) = From x !$ (apSubst su len) !$ (apSubst su t) !$ (apSubst su e)
apSubst su (Let b) = Let !$ (apSubst su b)
Expand Down
3 changes: 3 additions & 0 deletions src/Cryptol/TypeCheck/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ data TypeSource = TVFromModParam Name -- ^ Name of module parameter
| TypeFromUserAnnotation
| GeneratorOfListComp
| CasedExpression
| ConPat
| TypeErrorPlaceHolder
deriving (Show, Generic, NFData)

Expand Down Expand Up @@ -1316,6 +1317,7 @@ pickTVarName k src uni =
TypeFromUserAnnotation -> "user"
TypeErrorPlaceHolder -> "err"
CasedExpression -> "case"
ConPat -> "conp"
where
sh a = show (pp a)
using a = mk (sh a)
Expand Down Expand Up @@ -1366,6 +1368,7 @@ instance PP TypeSource where
FunApp -> "function call"
TypeErrorPlaceHolder -> "type error place-holder"
CasedExpression -> "cased expression"
ConPat -> "constructor pattern"

instance PP ModParamNames where
ppPrec _ ps =
Expand Down
Loading

0 comments on commit 659030a

Please sign in to comment.