Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore support for explicit type signatures on polymorphic functions #2205

Merged
merged 4 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions intTests/test2204/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set -e

$SAW test1.saw
$SAW test2.saw
6 changes: 6 additions & 0 deletions intTests/test2204/test1.saw
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
let f (x : a) : a = x;
let g (x : a) (y : b) : (a, b) = (x, y);

let h (x : a) : m a = do {
return x;
};
4 changes: 4 additions & 0 deletions intTests/test2204/test2.saw
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
rec f (x : a) : b = g x
and g (y : a) : b = h y
and h (z : a) : b = f z
;
12 changes: 12 additions & 0 deletions intTests/test_type_errors/err045.log.good
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Loading file "err045.saw"
err045.saw:2:26-2:27: Type mismatch.
Mismatch of type constructors. Expected: String but got Int
err045.saw:2:17-2:23: The type String arises from this type annotation
err045.saw:2:26-2:27: The type Int arises from the type of this term

Expected: String
Found: Int

within "f" (err045.saw:2:6-2:7)

FAILED
2 changes: 2 additions & 0 deletions intTests/test_type_errors/err045.saw
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// You'd expect this to be a parse error, wouldn't you...
let (f : Int) : String = 3;
171 changes: 140 additions & 31 deletions src/SAWScript/MGU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,93 @@ matches t1 t2 = do
-- }}}


------------------------------------------------------------
-- Inspect for free type variables {{{

-- We want to allow declaring polymorphic functions by introducing
-- type variables in the function header (rather than requiring an
-- explicit forall binding), like Haskell does.
--
-- This means that free type variables in a function header (but not
-- elsewhere) should be accepted, collected, and handed off to
-- generalize for insertion in the resultant type scheme.
--
-- It turns out that because of the way the AST represents functions
-- in let-bindings that this is highly unpleasant to do on the fly
-- while typechecking. So instead extract the free type variables
-- separately.
--
-- A function header comes through like this:
-- Decl _pos <function-name-pattern> Nothing <expr>
--
-- where <expr> is
-- zero or more times, Function _pos <arg-pattern> <expr'>
-- optionally, TSig _pos <expr''> <return-type>
--
-- so we need any free type variables in
-- - <function-name-pattern>
-- - <return-type>
-- - all <arg-pattern>
--
-- On the plus side this will also then work when people write
-- otherwise annoying things like
-- let f (x: a) = \(y: b) -> (a, b)
--
-- We extract the type variables with the position of their
-- initial mention.

-- Get the free type variables found in a Type.
inspectTypeFTVs :: Type -> TI (Map Name Pos)
inspectTypeFTVs ty = case ty of
TyCon _pos _ctor args -> M.unions <$> mapM inspectTypeFTVs args
TyRecord _pos fields -> M.unions <$> traverse inspectTypeFTVs fields
TyUnifyVar _pos _x -> return M.empty
TyVar pos x -> do
tyenv <- TI $ asks tyEnv
case M.lookup x tyenv of
Nothing -> return $ M.singleton x pos
Just _ -> return $ M.empty

-- Get the free type variables found in a Maybe Type.
inspectMaybeTypeFTVs :: Maybe Type -> TI (Map Name Pos)
inspectMaybeTypeFTVs mty = case mty of
Nothing -> return M.empty
Just ty -> inspectTypeFTVs ty

-- Get the free type variables found in a Pattern.
inspectPatternFTVs :: Pattern -> TI (Map Name Pos)
inspectPatternFTVs pat = case pat of
PWild _pos mty -> inspectMaybeTypeFTVs mty
PVar _pos _x mty -> inspectMaybeTypeFTVs mty
PTuple _pos subpats ->
M.unions <$> mapM inspectPatternFTVs subpats

-- Get the free type variables found in a chain of Function Exprs.
-- Also return the body expression found on the inside of the chain
-- for possible further analysis.
inspectLambdaFTVs :: Expr -> TI (Expr, Map Name Pos)
inspectLambdaFTVs e0 = case e0 of
Function _fpos pat e1 -> do
hereFTVs <- inspectPatternFTVs pat
(e1', moreFTVs) <- inspectLambdaFTVs e1
return (e1', M.union hereFTVs moreFTVs)
_ ->
return (e0, M.empty)

-- Get the free type variables found in a Decl.
inspectDeclFTVs :: Decl -> TI (Map Name Pos)
inspectDeclFTVs (Decl _dpos pat _mty e0) = do
nameFTVs <- inspectPatternFTVs pat
(e1, argFTVs) <- inspectLambdaFTVs e0
retFTVs <- case e1 of
TSig _tspos _e2 ty -> inspectTypeFTVs ty
_ -> return M.empty
return $ M.unions [nameFTVs, argFTVs, retFTVs]


-- }}}


------------------------------------------------------------
-- Main recursive pass {{{

Expand Down Expand Up @@ -901,6 +988,13 @@ withDeclGroup :: DeclGroup -> TI a -> TI a
withDeclGroup (NonRecursive d) m = withDecl d m
withDeclGroup (Recursive ds) m = foldr withDecl m ds

-- wrap the action m with some abstract type variables.
withAbstractTyVars :: Map Name Pos -> TI a -> TI a
withAbstractTyVars vars m = do
let insertOne x _pos tyenv = M.insert x AbstractType tyenv
insertAll tyenv = M.foldrWithKey insertOne tyenv vars
TI $ local (\ro -> ro { tyEnv = insertAll $ tyEnv ro }) $ unTI m

--
-- Infer the type for an expression.
--
Expand Down Expand Up @@ -1078,15 +1172,11 @@ inferPattern pat =

-- Check the type of a pattern, by inferring and then unifying the
-- result.
--
-- XXX: it doesn't seem like there's any guarantee that fresh tyvars
-- produced by inferPattern will necessarily be resolved by the
-- unification, and therefore it seems that dropping the possibly
-- updated pattern is a bug.
checkPattern :: LName -> Type -> Pattern -> TI ()
checkPattern :: LName -> Type -> Pattern -> TI Pattern
checkPattern ln t pat =
do (pt, _pat') <- inferPattern pat
do (pt, pat') <- inferPattern pat
unify ln t (getPos pat) pt
return pat'

--
-- statements
Expand Down Expand Up @@ -1343,8 +1433,11 @@ inferSingleStmt ln pos ctx s = do
--
-- (This creates names for any remaining unification vars, so
-- potentially updates the expression.)
generalize :: [OutExpr] -> [Type] -> TI [(OutExpr,Schema)]
generalize es0 ts0 = do
--
-- The "foralls" argument is a set of tyvars that were mentioned
-- explicitly and should be forall-bound.
generalize :: Map Name Pos -> [OutExpr] -> [Type] -> TI [(OutExpr,Schema)]
generalize foralls es0 ts0 = do
-- first, substitute away any resolved unification variables
-- in both the expressions and types.
es <- applyCurrentSubst es0
Expand All @@ -1368,10 +1461,9 @@ generalize es0 ts0 = do
-- (on the left-hand side) of the type environment. Those are
-- already defined.
--
-- There should be no other named variables involved. We don't
-- allow referring to random unbound type names any more, and
-- we don't yet (though probably should, FUTURE) have a way of
-- explicitly forall-binding type variables in declarations.
-- The only other named variables involved should be the set we
-- explicitly intend to be forall-bound as passed in. Insert
-- those, and favor their positions.
--
-- It would be handy for scaling if we didn't have to examine
-- the entire variable environment (on the grounds that there
Expand All @@ -1387,11 +1479,19 @@ generalize es0 ts0 = do
-- table. There is no longer any need for such hackery, and
-- undefined type names are not allowed to appear in the variable
-- environment.
--
-- FUTURE: we end up replacing the user's forall-bound names with
-- generated names, and I'm not sure why. It seems like it
-- shouldn't be possible the way the code is structured. But the
-- type signatures are coming out correct (which they wouldn't if
-- something were seriously wrong) and we aren't inappropriately
-- unifying these vars with each other or with other things, so
-- I'm not going to stress over it right now.

envUnifyVars <- unifyVarsInEnvs
knownNamedVars <- namedVarDefinitions
let is1 = is0 M.\\ envUnifyVars
let bs1 = M.withoutKeys bs0 knownNamedVars
let bs1 = M.union foralls $ M.withoutKeys bs0 knownNamedVars

-- convert to lists
let is2 = M.toList is1
Expand Down Expand Up @@ -1422,12 +1522,14 @@ generalize es0 ts0 = do
-- it from the parser; if there's an explicit type annotation on the
-- declaration that shows up as a type signature in the expression.
inferDecl :: Decl -> TI Decl
inferDecl (Decl pos pat _ e) = do
inferDecl d@(Decl pos pat _ e) = do
let n = patternLName pat
(e',t) <- inferExpr (n, e)
checkPattern n t pat
~[(e1,s)] <- generalize [e'] [t]
return (Decl pos pat (Just s) e1)
foralls <- inspectDeclFTVs d
withAbstractTyVars foralls $ do
(e',t) <- inferExpr (n, e)
pat' <- checkPattern n t pat
~[(e1,s)] <- generalize foralls [e'] [t]
return (Decl pos pat' (Just s) e1)

-- Type inference for a system of mutually recursive declarations.
--
Expand All @@ -1438,23 +1540,30 @@ inferDecl (Decl pos pat _ e) = do
inferRecDecls :: [Decl] -> TI [Decl]
inferRecDecls ds =
do let pats = map dPat ds
pat =
firstPat =
case pats of
p:_ -> p
[] -> panic
"inferRecDecls"
["Empty list of declarations in recursive group"]
(_ts, pats') <- unzip <$> mapM inferPattern pats
(es, ts) <- fmap unzip
$ flip (foldr withPattern) pats'
$ sequence [ inferExpr (patternLName p, e)
| Decl _pos p _ e <- ds
]
sequence_ $ zipWith (checkPattern (patternLName pat)) ts pats'
ess <- generalize es ts
return [ Decl pos p (Just s) e1
| (pos, p, (e1, s)) <- zip3 (map getPos ds) pats ess
]
foralls <- M.unions <$> mapM inspectDeclFTVs ds
withAbstractTyVars foralls $ do
(_ts, pats') <- unzip <$> mapM inferPattern pats
(es, ts) <- fmap unzip
$ flip (foldr withPattern) pats'
$ sequence [ inferExpr (patternLName p, e)
| Decl _pos p _ e <- ds
]

-- pats' has already been checked once, which will have inserted
-- unification vars for any missing types. Running it through
-- again will have no further effect, so we can ignore the
-- theoretically-updated-again patterns returned by checkPattern.
sequence_ $ zipWith (checkPattern (patternLName firstPat)) ts pats'
ess <- generalize foralls es ts
return [ Decl pos p (Just s) e1
| (pos, p, (e1, s)) <- zip3 (map getPos ds) pats' ess
]

-- Type inference for a decl group.
inferDeclGroup :: DeclGroup -> TI DeclGroup
Expand Down
Loading