Skip to content


typechecker: forall-bind free type variables appearing in function he…
Browse files Browse the repository at this point in the history

This restores the ability to write functions of the form
   let f (x: a) (y: a) = (x, y);
which were rejected after #2077 was fixed.

Closes #2204.
  • Loading branch information
sauclovian-g committed Jan 29, 2025
1 parent 0e5c820 commit e583163
Showing 1 changed file with 134 additions and 28 deletions.
162 changes: 134 additions & 28 deletions src/SAWScript/MGU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,91 @@ matches t1 t2 = do
-- }}}

-- Inspect for free type variables {{{

-- We want to allow declaring polymorphic functions by introducing
-- type variables in the function header (rather than requiring an
-- explicit forall binding), like Haskell does.
-- This means that free type variables in a function header (but not
-- elsewhere) should be accepted, collected, and handed off to
-- generalize for insertion in the resultant type scheme.
-- It turns out that because of the way the AST represents functions
-- in let-bindings that this is highly unpleasant to do on the fly
-- while typechecking. So instead extract the free type variables
-- separately.
-- A function header comes through like this:
-- Decl _pos <function-name-pattern> Nothing <expr>
-- where <expr> is
-- zero or more times, Function _pos <arg-pattern> <expr'>
-- optionally, TSig _pos <expr''> <return-type>
-- so we need any free type variables in
-- - <function-name-pattern>
-- - <return-type>
-- - all <arg-pattern>
-- On the plus side this will also then work when people write
-- otherwise annoying things like
-- let f (x: a) = \(y: b) -> (a, b)
-- We extract the type variables with the position of their
-- initial mention.

-- Get the free type variables found in a Type.
inspectTypeFTVs :: Type -> TI (Map Name Pos)
inspectTypeFTVs ty = case ty of
TyCon _pos _ctor args -> M.unions <$> mapM inspectTypeFTVs args
TyRecord _pos fields -> M.unions <$> traverse inspectTypeFTVs fields
TyUnifyVar _pos _x -> return M.empty
TyVar pos x -> do
tyenv <- TI $ asks tyEnv
case M.lookup x tyenv of
Nothing -> return $ M.singleton x pos
Just _ -> return $ M.empty

-- Get the free type variables found in a Maybe Type.
inspectMaybeTypeFTVs :: Maybe Type -> TI (Map Name Pos)
inspectMaybeTypeFTVs mty = case mty of
Nothing -> return M.empty
Just ty -> inspectTypeFTVs ty

-- Get the free type variables found in a Pattern.
inspectPatternFTVs :: Pattern -> TI (Map Name Pos)
inspectPatternFTVs pat = case pat of
PWild _pos mty -> inspectMaybeTypeFTVs mty
PVar _pos _x mty -> inspectMaybeTypeFTVs mty
PTuple _pos subpats ->
M.unions <$> mapM inspectPatternFTVs subpats

-- Get the free type variables found in a chain of Function Exprs.
inspectLambdaFTVs :: Expr -> TI (Expr, Map Name Pos)
inspectLambdaFTVs e0 = case e0 of
Function _fpos pat e1 -> do
hereFTVs <- inspectPatternFTVs pat
(e1', moreFTVs) <- inspectLambdaFTVs e1
return (e1', M.union hereFTVs moreFTVs)
_ ->
return (e0, M.empty)

-- Get the free type variables found in a Decl.
inspectDeclFTVs :: Decl -> TI (Map Name Pos)
inspectDeclFTVs (Decl dpos pat _mty e0) = do
nameFTVs <- inspectPatternFTVs pat
(e1, argFTVs) <- inspectLambdaFTVs e0
retFTVs <- case e1 of
TSig _tspos e1 ty -> inspectTypeFTVs ty
_ -> return M.empty
return $ M.unions [nameFTVs, argFTVs, retFTVs]

-- }}}

-- Main recursive pass {{{

Expand Down Expand Up @@ -901,6 +986,13 @@ withDeclGroup :: DeclGroup -> TI a -> TI a
withDeclGroup (NonRecursive d) m = withDecl d m
withDeclGroup (Recursive ds) m = foldr withDecl m ds

-- wrap the action m with some abstract type variables.
withAbstractTyVars :: Map Name Pos -> TI a -> TI a
withAbstractTyVars vars m = do
let insertOne x _pos tyenv = M.insert x AbstractType tyenv
insertAll tyenv = M.foldrWithKey insertOne tyenv vars
TI $ local (\ro -> ro { tyEnv = insertAll $ tyEnv ro }) $ unTI m

-- Infer the type for an expression.
Expand Down Expand Up @@ -1339,8 +1431,11 @@ inferSingleStmt ln pos ctx s = do
-- (This creates names for any remaining unification vars, so
-- potentially updates the expression.)
generalize :: [OutExpr] -> [Type] -> TI [(OutExpr,Schema)]
generalize es0 ts0 = do
-- The "foralls" argument is a set of tyvars that were mentioned
-- explicitly and should be forall-bound.
generalize :: Map Name Pos -> [OutExpr] -> [Type] -> TI [(OutExpr,Schema)]
generalize foralls es0 ts0 = do
-- first, substitute away any resolved unification variables
-- in both the expressions and types.
es <- applyCurrentSubst es0
Expand All @@ -1364,10 +1459,9 @@ generalize es0 ts0 = do
-- (on the left-hand side) of the type environment. Those are
-- already defined.
-- There should be no other named variables involved. We don't
-- allow referring to random unbound type names any more, and
-- we don't yet (though probably should, FUTURE) have a way of
-- explicitly forall-binding type variables in declarations.
-- The only other named variables involved should be the set we
-- explicitly intend to be forall-bound as passed in. Insert
-- those, and favor their positions.
-- It would be handy for scaling if we didn't have to examine
-- the entire variable environment (on the grounds that there
Expand All @@ -1383,11 +1477,19 @@ generalize es0 ts0 = do
-- table. There is no longer any need for such hackery, and
-- undefined type names are not allowed to appear in the variable
-- environment.
-- FUTURE: we end up replacing the user's forall-bound names with
-- generated names, and I'm not sure why. It seems like it
-- shouldn't be possible the way the code is structured. But the
-- type signatures are coming out correct (which they wouldn't if
-- something were seriously wrong) and we aren't inappropriately
-- unifying these vars with each other or with other things, so
-- I'm not going to stress over it right now.

envUnifyVars <- unifyVarsInEnvs
knownNamedVars <- namedVarDefinitions
let is1 = is0 M.\\ envUnifyVars
let bs1 = M.withoutKeys bs0 knownNamedVars
let bs1 = M.union foralls $ M.withoutKeys bs0 knownNamedVars

-- convert to lists
let is2 = M.toList is1
Expand Down Expand Up @@ -1418,12 +1520,14 @@ generalize es0 ts0 = do
-- it from the parser; if there's an explicit type annotation on the
-- declaration that shows up as a type signature in the expression.
inferDecl :: Decl -> TI Decl
inferDecl (Decl pos pat _ e) = do
inferDecl d@(Decl pos pat _ e) = do
let n = patternLName pat
(e',t) <- inferExpr (n, e)
pat' <- checkPattern n t pat
~[(e1,s)] <- generalize [e'] [t]
return (Decl pos pat' (Just s) e1)
foralls <- inspectDeclFTVs d
withAbstractTyVars foralls $ do
(e',t) <- inferExpr (n, e)
pat' <- checkPattern n t pat
~[(e1,s)] <- generalize foralls [e'] [t]
return (Decl pos pat' (Just s) e1)

-- Type inference for a system of mutually recursive declarations.
Expand All @@ -1440,22 +1544,24 @@ inferRecDecls ds =
[] -> panic
["Empty list of declarations in recursive group"]
(_ts, pats') <- unzip <$> mapM inferPattern pats
(es, ts) <- fmap unzip
$ flip (foldr withPattern) pats'
$ sequence [ inferExpr (patternLName p, e)
| Decl _pos p _ e <- ds

-- pats' has already been checked once, which will have inserted
-- unification vars for any missing types. Running it through
-- again will have no further effect, so we can ignore the
-- theoretically-updated-again patterns returned by checkPattern.
sequence_ $ zipWith (checkPattern (patternLName firstPat)) ts pats'
ess <- generalize es ts
return [ Decl pos p (Just s) e1
| (pos, p, (e1, s)) <- zip3 (map getPos ds) pats' ess
foralls <- M.unions <$> mapM inspectDeclFTVs ds
withAbstractTyVars foralls $ do
(_ts, pats') <- unzip <$> mapM inferPattern pats
(es, ts) <- fmap unzip
$ flip (foldr withPattern) pats'
$ sequence [ inferExpr (patternLName p, e)
| Decl _pos p _ e <- ds

-- pats' has already been checked once, which will have inserted
-- unification vars for any missing types. Running it through
-- again will have no further effect, so we can ignore the
-- theoretically-updated-again patterns returned by checkPattern.
sequence_ $ zipWith (checkPattern (patternLName firstPat)) ts pats'
ess <- generalize foralls es ts
return [ Decl pos p (Just s) e1
| (pos, p, (e1, s)) <- zip3 (map getPos ds) pats' ess

-- Type inference for a decl group.
inferDeclGroup :: DeclGroup -> TI DeclGroup
Expand Down

0 comments on commit e583163

Please sign in to comment.