Skip to content

Commit

Permalink
typechecker: forall-bind free type variables appearing in function he…
Browse files Browse the repository at this point in the history
…aders.

This restores the ability to write functions of the form
   let f (x: a) (y: a) = (x, y);
which were rejected after #2077 was fixed.

Closes #2204.
  • Loading branch information
sauclovian-g committed Jan 29, 2025
1 parent 0e5c820 commit e583163
Showing 1 changed file with 134 additions and 28 deletions.
162 changes: 134 additions & 28 deletions src/SAWScript/MGU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,91 @@ matches t1 t2 = do
-- }}}


------------------------------------------------------------
-- Inspect for free type variables {{{

-- We want to allow declaring polymorphic functions by introducing
-- type variables in the function header (rather than requiring an
-- explicit forall binding), like Haskell does.
--
-- This means that free type variables in a function header (but not
-- elsewhere) should be accepted, collected, and handed off to
-- generalize for insertion in the resultant type scheme.
--
-- It turns out that because of the way the AST represents functions
-- in let-bindings that this is highly unpleasant to do on the fly
-- while typechecking. So instead extract the free type variables
-- separately.
--
-- A function header comes through like this:
-- Decl _pos <function-name-pattern> Nothing <expr>
--
-- where <expr> is
-- zero or more times, Function _pos <arg-pattern> <expr'>
-- optionally, TSig _pos <expr''> <return-type>
--
-- so we need any free type variables in
-- - <function-name-pattern>
-- - <return-type>
-- - all <arg-pattern>
--
-- On the plus side this will also then work when people write
-- otherwise annoying things like
-- let f (x: a) = \(y: b) -> (a, b)
--
-- We extract the type variables with the position of their
-- initial mention.

-- Get the free type variables found in a Type.
inspectTypeFTVs :: Type -> TI (Map Name Pos)
inspectTypeFTVs ty = case ty of
TyCon _pos _ctor args -> M.unions <$> mapM inspectTypeFTVs args
TyRecord _pos fields -> M.unions <$> traverse inspectTypeFTVs fields
TyUnifyVar _pos _x -> return M.empty
TyVar pos x -> do
tyenv <- TI $ asks tyEnv
case M.lookup x tyenv of
Nothing -> return $ M.singleton x pos
Just _ -> return $ M.empty

-- Get the free type variables found in a Maybe Type.
inspectMaybeTypeFTVs :: Maybe Type -> TI (Map Name Pos)
inspectMaybeTypeFTVs mty = case mty of
Nothing -> return M.empty
Just ty -> inspectTypeFTVs ty

-- Get the free type variables found in a Pattern.
inspectPatternFTVs :: Pattern -> TI (Map Name Pos)
inspectPatternFTVs pat = case pat of
PWild _pos mty -> inspectMaybeTypeFTVs mty
PVar _pos _x mty -> inspectMaybeTypeFTVs mty
PTuple _pos subpats ->
M.unions <$> mapM inspectPatternFTVs subpats

-- Get the free type variables found in a chain of Function Exprs.
inspectLambdaFTVs :: Expr -> TI (Expr, Map Name Pos)
inspectLambdaFTVs e0 = case e0 of
Function _fpos pat e1 -> do
hereFTVs <- inspectPatternFTVs pat
(e1', moreFTVs) <- inspectLambdaFTVs e1
return (e1', M.union hereFTVs moreFTVs)
_ ->
return (e0, M.empty)

-- Get the free type variables found in a Decl.
inspectDeclFTVs :: Decl -> TI (Map Name Pos)
inspectDeclFTVs (Decl dpos pat _mty e0) = do
nameFTVs <- inspectPatternFTVs pat
(e1, argFTVs) <- inspectLambdaFTVs e0
retFTVs <- case e1 of
TSig _tspos e1 ty -> inspectTypeFTVs ty
_ -> return M.empty
return $ M.unions [nameFTVs, argFTVs, retFTVs]


-- }}}


------------------------------------------------------------
-- Main recursive pass {{{

Expand Down Expand Up @@ -901,6 +986,13 @@ withDeclGroup :: DeclGroup -> TI a -> TI a
withDeclGroup (NonRecursive d) m = withDecl d m
withDeclGroup (Recursive ds) m = foldr withDecl m ds

-- wrap the action m with some abstract type variables.
withAbstractTyVars :: Map Name Pos -> TI a -> TI a
withAbstractTyVars vars m = do
let insertOne x _pos tyenv = M.insert x AbstractType tyenv
insertAll tyenv = M.foldrWithKey insertOne tyenv vars
TI $ local (\ro -> ro { tyEnv = insertAll $ tyEnv ro }) $ unTI m

--
-- Infer the type for an expression.
--
Expand Down Expand Up @@ -1339,8 +1431,11 @@ inferSingleStmt ln pos ctx s = do
--
-- (This creates names for any remaining unification vars, so
-- potentially updates the expression.)
generalize :: [OutExpr] -> [Type] -> TI [(OutExpr,Schema)]
generalize es0 ts0 = do
--
-- The "foralls" argument is a set of tyvars that were mentioned
-- explicitly and should be forall-bound.
generalize :: Map Name Pos -> [OutExpr] -> [Type] -> TI [(OutExpr,Schema)]
generalize foralls es0 ts0 = do
-- first, substitute away any resolved unification variables
-- in both the expressions and types.
es <- applyCurrentSubst es0
Expand All @@ -1364,10 +1459,9 @@ generalize es0 ts0 = do
-- (on the left-hand side) of the type environment. Those are
-- already defined.
--
-- There should be no other named variables involved. We don't
-- allow referring to random unbound type names any more, and
-- we don't yet (though probably should, FUTURE) have a way of
-- explicitly forall-binding type variables in declarations.
-- The only other named variables involved should be the set we
-- explicitly intend to be forall-bound as passed in. Insert
-- those, and favor their positions.
--
-- It would be handy for scaling if we didn't have to examine
-- the entire variable environment (on the grounds that there
Expand All @@ -1383,11 +1477,19 @@ generalize es0 ts0 = do
-- table. There is no longer any need for such hackery, and
-- undefined type names are not allowed to appear in the variable
-- environment.
--
-- FUTURE: we end up replacing the user's forall-bound names with
-- generated names, and I'm not sure why. It seems like it
-- shouldn't be possible the way the code is structured. But the
-- type signatures are coming out correct (which they wouldn't if
-- something were seriously wrong) and we aren't inappropriately
-- unifying these vars with each other or with other things, so
-- I'm not going to stress over it right now.

envUnifyVars <- unifyVarsInEnvs
knownNamedVars <- namedVarDefinitions
let is1 = is0 M.\\ envUnifyVars
let bs1 = M.withoutKeys bs0 knownNamedVars
let bs1 = M.union foralls $ M.withoutKeys bs0 knownNamedVars

-- convert to lists
let is2 = M.toList is1
Expand Down Expand Up @@ -1418,12 +1520,14 @@ generalize es0 ts0 = do
-- it from the parser; if there's an explicit type annotation on the
-- declaration that shows up as a type signature in the expression.
inferDecl :: Decl -> TI Decl
inferDecl (Decl pos pat _ e) = do
inferDecl d@(Decl pos pat _ e) = do
let n = patternLName pat
(e',t) <- inferExpr (n, e)
pat' <- checkPattern n t pat
~[(e1,s)] <- generalize [e'] [t]
return (Decl pos pat' (Just s) e1)
foralls <- inspectDeclFTVs d
withAbstractTyVars foralls $ do
(e',t) <- inferExpr (n, e)
pat' <- checkPattern n t pat
~[(e1,s)] <- generalize foralls [e'] [t]
return (Decl pos pat' (Just s) e1)

-- Type inference for a system of mutually recursive declarations.
--
Expand All @@ -1440,22 +1544,24 @@ inferRecDecls ds =
[] -> panic
"inferRecDecls"
["Empty list of declarations in recursive group"]
(_ts, pats') <- unzip <$> mapM inferPattern pats
(es, ts) <- fmap unzip
$ flip (foldr withPattern) pats'
$ sequence [ inferExpr (patternLName p, e)
| Decl _pos p _ e <- ds
]

-- pats' has already been checked once, which will have inserted
-- unification vars for any missing types. Running it through
-- again will have no further effect, so we can ignore the
-- theoretically-updated-again patterns returned by checkPattern.
sequence_ $ zipWith (checkPattern (patternLName firstPat)) ts pats'
ess <- generalize es ts
return [ Decl pos p (Just s) e1
| (pos, p, (e1, s)) <- zip3 (map getPos ds) pats' ess
]
foralls <- M.unions <$> mapM inspectDeclFTVs ds
withAbstractTyVars foralls $ do
(_ts, pats') <- unzip <$> mapM inferPattern pats
(es, ts) <- fmap unzip
$ flip (foldr withPattern) pats'
$ sequence [ inferExpr (patternLName p, e)
| Decl _pos p _ e <- ds
]

-- pats' has already been checked once, which will have inserted
-- unification vars for any missing types. Running it through
-- again will have no further effect, so we can ignore the
-- theoretically-updated-again patterns returned by checkPattern.
sequence_ $ zipWith (checkPattern (patternLName firstPat)) ts pats'
ess <- generalize foralls es ts
return [ Decl pos p (Just s) e1
| (pos, p, (e1, s)) <- zip3 (map getPos ds) pats' ess
]

-- Type inference for a decl group.
inferDeclGroup :: DeclGroup -> TI DeclGroup
Expand Down

0 comments on commit e583163

Please sign in to comment.