Skip to content

Commit

Permalink
Merge pull request #2205 from GaloisInc/2204-polymorphic-signatures
Browse files Browse the repository at this point in the history
Restore support for explicit type signatures on polymorphic functions
  • Loading branch information
sauclovian-g authored Jan 29, 2025
2 parents 3baa726 + 4d87626 commit 98ce44f
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 31 deletions.
4 changes: 4 additions & 0 deletions intTests/test2204/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set -e

$SAW test1.saw
$SAW test2.saw
6 changes: 6 additions & 0 deletions intTests/test2204/test1.saw
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
let f (x : a) : a = x;
let g (x : a) (y : b) : (a, b) = (x, y);

let h (x : a) : m a = do {
return x;
};
4 changes: 4 additions & 0 deletions intTests/test2204/test2.saw
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
rec f (x : a) : b = g x
and g (y : a) : b = h y
and h (z : a) : b = f z
;
12 changes: 12 additions & 0 deletions intTests/test_type_errors/err045.log.good
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Loading file "err045.saw"
err045.saw:2:26-2:27: Type mismatch.
Mismatch of type constructors. Expected: String but got Int
err045.saw:2:17-2:23: The type String arises from this type annotation
err045.saw:2:26-2:27: The type Int arises from the type of this term

Expected: String
Found: Int

within "f" (err045.saw:2:6-2:7)

FAILED
2 changes: 2 additions & 0 deletions intTests/test_type_errors/err045.saw
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// You'd expect this to be a parse error, wouldn't you...
let (f : Int) : String = 3;
171 changes: 140 additions & 31 deletions src/SAWScript/MGU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,93 @@ matches t1 t2 = do
-- }}}


------------------------------------------------------------
-- Inspect for free type variables {{{

-- We want to allow declaring polymorphic functions by introducing
-- type variables in the function header (rather than requiring an
-- explicit forall binding), like Haskell does.
--
-- This means that free type variables in a function header (but not
-- elsewhere) should be accepted, collected, and handed off to
-- generalize for insertion in the resultant type scheme.
--
-- It turns out that because of the way the AST represents functions
-- in let-bindings that this is highly unpleasant to do on the fly
-- while typechecking. So instead extract the free type variables
-- separately.
--
-- A function header comes through like this:
-- Decl _pos <function-name-pattern> Nothing <expr>
--
-- where <expr> is
-- zero or more times, Function _pos <arg-pattern> <expr'>
-- optionally, TSig _pos <expr''> <return-type>
--
-- so we need any free type variables in
-- - <function-name-pattern>
-- - <return-type>
-- - all <arg-pattern>
--
-- On the plus side this will also then work when people write
-- otherwise annoying things like
-- let f (x: a) = \(y: b) -> (a, b)
--
-- We extract the type variables with the position of their
-- initial mention.

-- Get the free type variables found in a Type.
inspectTypeFTVs :: Type -> TI (Map Name Pos)
inspectTypeFTVs ty = case ty of
TyCon _pos _ctor args -> M.unions <$> mapM inspectTypeFTVs args
TyRecord _pos fields -> M.unions <$> traverse inspectTypeFTVs fields
TyUnifyVar _pos _x -> return M.empty
TyVar pos x -> do
tyenv <- TI $ asks tyEnv
case M.lookup x tyenv of
Nothing -> return $ M.singleton x pos
Just _ -> return $ M.empty

-- Get the free type variables found in a Maybe Type.
inspectMaybeTypeFTVs :: Maybe Type -> TI (Map Name Pos)
inspectMaybeTypeFTVs mty = case mty of
Nothing -> return M.empty
Just ty -> inspectTypeFTVs ty

-- Get the free type variables found in a Pattern.
inspectPatternFTVs :: Pattern -> TI (Map Name Pos)
inspectPatternFTVs pat = case pat of
PWild _pos mty -> inspectMaybeTypeFTVs mty
PVar _pos _x mty -> inspectMaybeTypeFTVs mty
PTuple _pos subpats ->
M.unions <$> mapM inspectPatternFTVs subpats

-- Get the free type variables found in a chain of Function Exprs.
-- Also return the body expression found on the inside of the chain
-- for possible further analysis.
inspectLambdaFTVs :: Expr -> TI (Expr, Map Name Pos)
inspectLambdaFTVs e0 = case e0 of
Function _fpos pat e1 -> do
hereFTVs <- inspectPatternFTVs pat
(e1', moreFTVs) <- inspectLambdaFTVs e1
return (e1', M.union hereFTVs moreFTVs)
_ ->
return (e0, M.empty)

-- Get the free type variables found in a Decl.
inspectDeclFTVs :: Decl -> TI (Map Name Pos)
inspectDeclFTVs (Decl _dpos pat _mty e0) = do
nameFTVs <- inspectPatternFTVs pat
(e1, argFTVs) <- inspectLambdaFTVs e0
retFTVs <- case e1 of
TSig _tspos _e2 ty -> inspectTypeFTVs ty
_ -> return M.empty
return $ M.unions [nameFTVs, argFTVs, retFTVs]


-- }}}


------------------------------------------------------------
-- Main recursive pass {{{

Expand Down Expand Up @@ -901,6 +988,13 @@ withDeclGroup :: DeclGroup -> TI a -> TI a
withDeclGroup (NonRecursive d) m = withDecl d m
withDeclGroup (Recursive ds) m = foldr withDecl m ds

-- wrap the action m with some abstract type variables.
withAbstractTyVars :: Map Name Pos -> TI a -> TI a
withAbstractTyVars vars m = do
let insertOne x _pos tyenv = M.insert x AbstractType tyenv
insertAll tyenv = M.foldrWithKey insertOne tyenv vars
TI $ local (\ro -> ro { tyEnv = insertAll $ tyEnv ro }) $ unTI m

--
-- Infer the type for an expression.
--
Expand Down Expand Up @@ -1078,15 +1172,11 @@ inferPattern pat =

-- Check the type of a pattern, by inferring and then unifying the
-- result.
--
-- XXX: it doesn't seem like there's any guarantee that fresh tyvars
-- produced by inferPattern will necessarily be resolved by the
-- unification, and therefore it seems that dropping the possibly
-- updated pattern is a bug.
checkPattern :: LName -> Type -> Pattern -> TI ()
checkPattern :: LName -> Type -> Pattern -> TI Pattern
checkPattern ln t pat =
do (pt, _pat') <- inferPattern pat
do (pt, pat') <- inferPattern pat
unify ln t (getPos pat) pt
return pat'

--
-- statements
Expand Down Expand Up @@ -1343,8 +1433,11 @@ inferSingleStmt ln pos ctx s = do
--
-- (This creates names for any remaining unification vars, so
-- potentially updates the expression.)
generalize :: [OutExpr] -> [Type] -> TI [(OutExpr,Schema)]
generalize es0 ts0 = do
--
-- The "foralls" argument is a set of tyvars that were mentioned
-- explicitly and should be forall-bound.
generalize :: Map Name Pos -> [OutExpr] -> [Type] -> TI [(OutExpr,Schema)]
generalize foralls es0 ts0 = do
-- first, substitute away any resolved unification variables
-- in both the expressions and types.
es <- applyCurrentSubst es0
Expand All @@ -1368,10 +1461,9 @@ generalize es0 ts0 = do
-- (on the left-hand side) of the type environment. Those are
-- already defined.
--
-- There should be no other named variables involved. We don't
-- allow referring to random unbound type names any more, and
-- we don't yet (though probably should, FUTURE) have a way of
-- explicitly forall-binding type variables in declarations.
-- The only other named variables involved should be the set we
-- explicitly intend to be forall-bound as passed in. Insert
-- those, and favor their positions.
--
-- It would be handy for scaling if we didn't have to examine
-- the entire variable environment (on the grounds that there
Expand All @@ -1387,11 +1479,19 @@ generalize es0 ts0 = do
-- table. There is no longer any need for such hackery, and
-- undefined type names are not allowed to appear in the variable
-- environment.
--
-- FUTURE: we end up replacing the user's forall-bound names with
-- generated names, and I'm not sure why. It seems like it
-- shouldn't be possible the way the code is structured. But the
-- type signatures are coming out correct (which they wouldn't if
-- something were seriously wrong) and we aren't inappropriately
-- unifying these vars with each other or with other things, so
-- I'm not going to stress over it right now.

envUnifyVars <- unifyVarsInEnvs
knownNamedVars <- namedVarDefinitions
let is1 = is0 M.\\ envUnifyVars
let bs1 = M.withoutKeys bs0 knownNamedVars
let bs1 = M.union foralls $ M.withoutKeys bs0 knownNamedVars

-- convert to lists
let is2 = M.toList is1
Expand Down Expand Up @@ -1422,12 +1522,14 @@ generalize es0 ts0 = do
-- it from the parser; if there's an explicit type annotation on the
-- declaration that shows up as a type signature in the expression.
inferDecl :: Decl -> TI Decl
inferDecl (Decl pos pat _ e) = do
inferDecl d@(Decl pos pat _ e) = do
let n = patternLName pat
(e',t) <- inferExpr (n, e)
checkPattern n t pat
~[(e1,s)] <- generalize [e'] [t]
return (Decl pos pat (Just s) e1)
foralls <- inspectDeclFTVs d
withAbstractTyVars foralls $ do
(e',t) <- inferExpr (n, e)
pat' <- checkPattern n t pat
~[(e1,s)] <- generalize foralls [e'] [t]
return (Decl pos pat' (Just s) e1)

-- Type inference for a system of mutually recursive declarations.
--
Expand All @@ -1438,23 +1540,30 @@ inferDecl (Decl pos pat _ e) = do
inferRecDecls :: [Decl] -> TI [Decl]
inferRecDecls ds =
do let pats = map dPat ds
pat =
firstPat =
case pats of
p:_ -> p
[] -> panic
"inferRecDecls"
["Empty list of declarations in recursive group"]
(_ts, pats') <- unzip <$> mapM inferPattern pats
(es, ts) <- fmap unzip
$ flip (foldr withPattern) pats'
$ sequence [ inferExpr (patternLName p, e)
| Decl _pos p _ e <- ds
]
sequence_ $ zipWith (checkPattern (patternLName pat)) ts pats'
ess <- generalize es ts
return [ Decl pos p (Just s) e1
| (pos, p, (e1, s)) <- zip3 (map getPos ds) pats ess
]
foralls <- M.unions <$> mapM inspectDeclFTVs ds
withAbstractTyVars foralls $ do
(_ts, pats') <- unzip <$> mapM inferPattern pats
(es, ts) <- fmap unzip
$ flip (foldr withPattern) pats'
$ sequence [ inferExpr (patternLName p, e)
| Decl _pos p _ e <- ds
]

-- pats' has already been checked once, which will have inserted
-- unification vars for any missing types. Running it through
-- again will have no further effect, so we can ignore the
-- theoretically-updated-again patterns returned by checkPattern.
sequence_ $ zipWith (checkPattern (patternLName firstPat)) ts pats'
ess <- generalize foralls es ts
return [ Decl pos p (Just s) e1
| (pos, p, (e1, s)) <- zip3 (map getPos ds) pats' ess
]

-- Type inference for a decl group.
inferDeclGroup :: DeclGroup -> TI DeclGroup
Expand Down

0 comments on commit 98ce44f

Please sign in to comment.