diff --git a/brat/Brat/Checker.hs b/brat/Brat/Checker.hs index 28af8cbc..8a0ecd94 100644 --- a/brat/Brat/Checker.hs +++ b/brat/Brat/Checker.hs @@ -25,6 +25,7 @@ import Prelude hiding (filter) import Brat.Checker.Helpers import Brat.Checker.Monad import Brat.Checker.Quantity +import Brat.Checker.SolveHoles (typeEq) import Brat.Checker.SolvePatterns (argProblems, argProblemsWithLeftovers, solve) import Brat.Checker.Types import Brat.Constructors @@ -667,7 +668,13 @@ check' (Of n e) ((), unders) = case ?my of (elems, unders, rightUnders) <- getVecs len unders pure ((tgt, el):elems, (tgt, ty):unders, rightUnders) getVecs _ unders = pure ([], [], unders) - +check' Hope ((), (NamedPort hope _, ty):unders) = case (?my, ty) of + (Braty, Left _k) -> do + fc <- req AskFC + req (ANewHope hope fc) + pure (((), ()), ((), unders)) + (Braty, Right _ty) -> typeErr "Can only infer kinded things with !" + (Kerny, _) -> typeErr "Won't infer kernel typed !" check' tm _ = error $ "check' " ++ show tm @@ -1133,7 +1140,7 @@ run :: VEnv -> Namespace -> Checking a -> Either Error (a, ([TypedHole], Store, Graph)) -run ve initStore ns m = +run ve initStore ns m = do let ctx = Ctx { globalVEnv = ve , store = initStore -- TODO: fill with default constructors @@ -1141,5 +1148,19 @@ run ve initStore ns m = , kconstructors = kernelConstructors , typeConstructors = defaultTypeConstructors , aliasTable = M.empty - } in - (\(a,ctx,(holes, graph)) -> (a, (holes, store ctx, graph))) <$> handler (localNS ns m) ctx mempty + , hopes = M.empty + } + (a,ctx,(holes, graph)) <- handler (localNS ns m) ctx mempty + let tyMap = typeMap $ store ctx + -- If the `hopes` set has any remaining holes with kind Nat, we need to abort. + -- Even though we didn't need them for typechecking problems, our runtime + -- behaviour depends on the values of the holes, which we can't account for. + case M.toList $ M.filterWithKey (\e _ -> isNatKinded tyMap (InEnd e)) (hopes ctx) of + [] -> pure (a, (holes, store ctx, graph)) + -- Just use the FC of the first hole while we don't have the capacity to + -- show multiple error locations + hs@((_,fc):_) -> Left $ Err (Just fc) (RemainingNatHopes (show . fst <$> hs)) + where + isNatKinded tyMap e = case tyMap M.! e of + EndType Braty (Left Nat) -> True + _ -> False diff --git a/brat/Brat/Checker/Helpers.hs b/brat/Brat/Checker/Helpers.hs index 0c2ad01a..2d3b8419 100644 --- a/brat/Brat/Checker/Helpers.hs +++ b/brat/Brat/Checker/Helpers.hs @@ -37,9 +37,11 @@ import Hasochism import Util (log2) import Control.Monad (when) +import Control.Monad.State.Lazy (StateT(..), runStateT) import Control.Monad.Freer (req) import Data.Bifunctor import Data.Foldable (foldrM) +import Data.List (partition) import Data.Type.Equality (TestEquality(..), (:~:)(..)) import qualified Data.Map as M import Prelude hiding (last) @@ -100,35 +102,29 @@ pullPortsRow :: Show ty => [PortName] -> [(NamedPort e, ty)] -> Checking [(NamedPort e, ty)] -pullPortsRow = pullPorts portName showRow +pullPortsRow = pullPorts (portName . fst) showRow pullPortsSig :: Show ty => [PortName] -> [(PortName, ty)] -> Checking [(PortName, ty)] -pullPortsSig = pullPorts id showSig +pullPortsSig = pullPorts fst showSig -pullPorts :: forall a ty. Show ty - => (a -> PortName) -- A way to get a port name for each element - -> ([(a, ty)] -> String) -- A way to print the list +pullPorts :: forall a ty + . (a -> PortName) -- A way to get a port name for each element + -> ([a] -> String) -- A way to print the list -> [PortName] -- Things to pull to the front - -> [(a, ty)] -- The list to rearrange - -> Checking [(a, ty)] -pullPorts _ _ [] types = pure types -pullPorts toPort showFn (p:ports) types = do - (x, types) <- pull1Port p types - (x:) <$> pullPorts toPort showFn ports types + -> [a] -- The list to rearrange + -> Checking [a] +pullPorts toPort showFn to_pull types = + -- the "state" here is the things still available to be pulled + uncurry (++) <$> runStateT (mapM pull1Port to_pull) types where - pull1Port :: PortName - -> [(a, ty)] - -> Checking ((a, ty), [(a, ty)]) - pull1Port p [] = fail $ "Port not found: " ++ p ++ " in " ++ showFn types - pull1Port p (x@(a,_):xs) - | p == toPort a - = if p `elem` (toPort . fst <$> xs) - then err (AmbiguousPortPull p (showFn (x:xs))) - else pure (x, xs) - | otherwise = second (x:) <$> pull1Port p xs + pull1Port :: PortName -> StateT [a] Checking a + pull1Port p = StateT $ \available -> case partition ((== p) . toPort) available of + ([], _) -> err $ BadPortPull p (showFn available) + ([found], remaining) -> pure (found, remaining) + (_, _) -> err $ AmbiguousPortPull p (showFn available) ensureEmpty :: Show ty => String -> [(NamedPort e, ty)] -> Checking () ensureEmpty _ [] = pure () @@ -506,3 +502,8 @@ runArith (NumValue upl grol) Pow (NumValue upr gror) -- 2^(2^k * upr) + 2^(2^k * upr) * (full(2^(k + k') * mono)) = pure $ NumValue (upl ^ upr) (StrictMonoFun (StrictMono (l * upr) (Full (StrictMono (k + k') mono)))) runArith _ _ _ = Nothing + +buildConst :: SimpleTerm -> Val Z -> Checking Src +buildConst tm ty = do + (_, _, [(out,_)], _) <- next "buildConst" (Const tm) (S0, Some (Zy :* S0)) R0 (RPr ("value", ty) R0) + pure out diff --git a/brat/Brat/Checker/Monad.hs b/brat/Brat/Checker/Monad.hs index e993711b..baecc9e6 100644 --- a/brat/Brat/Checker/Monad.hs +++ b/brat/Brat/Checker/Monad.hs @@ -50,12 +50,15 @@ data CtxEnv = CtxEnv , locals :: VEnv } +type Hopes = M.Map InPort FC + data Context = Ctx { globalVEnv :: VEnv , store :: Store , constructors :: ConstructorMap Brat , kconstructors :: ConstructorMap Kernel , typeConstructors :: M.Map (Mode, QualName) [(PortName, TypeKind)] , aliasTable :: M.Map QualName Alias + , hopes :: Hopes } data CheckingSig ty where @@ -89,6 +92,9 @@ data CheckingSig ty where AskVEnv :: CheckingSig CtxEnv Declare :: End -> Modey m -> BinderType m -> CheckingSig () Define :: End -> Val Z -> CheckingSig () + ANewHope :: InPort -> FC -> CheckingSig () + AskHopes :: CheckingSig Hopes + RemoveHope :: InPort -> CheckingSig () localAlias :: (QualName, Alias) -> Checking v -> Checking v localAlias _ (Ret v) = Ret v @@ -267,6 +273,15 @@ handler (Req s k) ctx g M.lookup tycon tbl handler (k args) ctx g + ANewHope e fc -> handler (k ()) (ctx { hopes = M.insert e fc (hopes ctx) }) g + + AskHopes -> handler (k (hopes ctx)) ctx g + + RemoveHope e -> let hset = hopes ctx in + if M.member e hset + then handler (k ()) (ctx { hopes = M.delete e hset }) g + else Left (dumbErr (InternalError ("Trying to remove unknown Hope: " ++ show e))) + type Checking = Free CheckingSig instance Semigroup a => Semigroup (Checking a) where diff --git a/brat/Brat/Checker/SolveHoles.hs b/brat/Brat/Checker/SolveHoles.hs new file mode 100644 index 00000000..fd74afa7 --- /dev/null +++ b/brat/Brat/Checker/SolveHoles.hs @@ -0,0 +1,169 @@ +module Brat.Checker.SolveHoles (typeEq) where + +import Brat.Checker.Helpers (buildConst) +import Brat.Checker.Monad +import Brat.Checker.Types (kindForMode) +import Brat.Error (ErrorMsg(..)) +import Brat.Eval +import Brat.Syntax.CircuitProperties (eqProps) +import Brat.Syntax.Common +import Brat.Syntax.Simple (SimpleTerm(..)) +import Brat.Syntax.Value +import Control.Monad.Freer +import Bwd +import Hasochism + +import Control.Monad (when) +import Data.Bifunctor (second) +import Data.Foldable (sequenceA_) +import Data.Functor +import Data.Maybe (mapMaybe) +import qualified Data.Map as M +import Data.Type.Equality (TestEquality(..), (:~:)(..)) + +-- Demand that two closed values are equal, we're allowed to solve variables in the +-- hope set to make this true. +-- Raises a user error if the vals cannot be made equal. +typeEq :: String -- String representation of the term for error reporting + -> TypeKind -- The kind we're comparing at + -> Val Z -- Expected + -> Val Z -- Actual + -> Checking () +typeEq str = typeEq' str (Zy :* S0 :* S0) + + +-- Internal version of typeEq with environment for non-closed values +typeEq' :: String -- String representation of the term for error reporting + -> (Ny :* Stack Z TypeKind :* Stack Z Sem) n + -> TypeKind -- The kind we're comparing at + -> Val n -- Expected + -> Val n -- Actual + -> Checking () +typeEq' str stuff@(_ny :* _ks :* sems) k exp act = do + hopes <- req AskHopes + exp <- sem sems exp + act <- sem sems act + typeEqEta str stuff hopes k exp act + +isNumVar :: Sem -> Maybe SVar +isNumVar (SNum (NumValue 0 (StrictMonoFun (StrictMono 0 (Linear v))))) = Just v +isNumVar _ = Nothing + +-- Presumes that the hope set and the two `Sem`s are up to date. +typeEqEta :: String -- String representation of the term for error reporting + -> (Ny :* Stack Z TypeKind :* Stack Z Sem) n + -> Hopes -- A map from the hope set to corresponding FCs + -> TypeKind -- The kind we're comparing at + -> Sem -- Expected + -> Sem -- Actual + -> Checking () +typeEqEta tm (lvy :* kz :* sems) hopes (TypeFor m ((_, k):ks)) exp act = do + -- Higher kinded things + let nextSem = semLvl lvy + let xz = B0 :< nextSem + exp <- applySem exp xz + act <- applySem act xz + typeEqEta tm (Sy lvy :* (kz :<< k) :* (sems :<< nextSem)) hopes (TypeFor m ks) exp act +-- Not higher kinded - check for flex terms +-- (We don't solve under binders for now, so we only consider Zy here) +-- 1. "easy" flex cases +typeEqEta _tm (Zy :* _ks :* _sems) hopes k (SApp (SPar (InEnd e)) B0) act + | M.member e hopes = solveHope k e act +typeEqEta _tm (Zy :* _ks :* _sems) hopes k exp (SApp (SPar (InEnd e)) B0) + | M.member e hopes = solveHope k e exp +typeEqEta _ (Zy :* _ :* _) hopes Nat exp act + | Just (SPar (InEnd e)) <- isNumVar exp, M.member e hopes = solveHope Nat e act + | Just (SPar (InEnd e)) <- isNumVar act, M.member e hopes = solveHope Nat e exp +-- 2. harder cases, neither is in the hope set, so we can't define it ourselves +typeEqEta tm stuff@(ny :* _ks :* _sems) hopes k exp act = do + exp <- quote ny exp + act <- quote ny act + let ends = mapMaybe getEnd [exp,act] + -- sanity check: we've already dealt with either end being in the hopeset + when (or [M.member ie hopes | InEnd ie <- ends]) $ typeErr "ends were in hopeset" + case ends of + [] -> typeEqRigid tm stuff k exp act -- easyish, both rigid i.e. already defined + -- variables are trivially the same, even if undefined, but the values may + -- be different! E.g. X =? 1 + X + [_, _] | exp == act -> pure () + -- TODO: Once we have scheduling, we must wait for one or the other to become more defined, rather than failing + _ -> err (TypeMismatch tm (show exp) (show act)) + where + getEnd (VApp (VPar e) _) = Just e + getEnd (VNum n) = getNumVar n + getEnd _ = Nothing + +-- This will update the `hopes`, potentially invalidating things that have +-- been eval'd. +-- The Sem is closed, for now. +solveHope :: TypeKind -> InPort -> Sem -> Checking () +solveHope k hope v = quote Zy v >>= \v -> case doesntOccur (InEnd hope) v of + Right () -> do + defineEnd (InEnd hope) v + dangling <- case (k, v) of + (Nat, VNum _v) -> err $ Unimplemented "Nat hope solving" [] + (Nat, _) -> err $ InternalError "Head of Nat wasn't a VNum" + _ -> buildConst Unit TUnit + req (Wire (end dangling, kindType k, hope)) + req (RemoveHope hope) + Left msg -> case v of + VApp (VPar (InEnd end)) B0 | hope == end -> pure () + -- TODO: Not all occurrences are toxic. The end could be in an argument + -- to a hoping variable which isn't used. + -- E.g. h1 = h2 h1 - this is valid if h2 is the identity, or ignores h1. + _ -> err msg + +typeEqs :: String -> (Ny :* Stack Z TypeKind :* Stack Z Sem) n -> [TypeKind] -> [Val n] -> [Val n] -> Checking () +typeEqs _ _ [] [] [] = pure () +typeEqs tm stuff (k:ks) (exp:exps) (act:acts) = typeEqs tm stuff ks exps acts <* typeEq' tm stuff k exp act +typeEqs _ _ _ _ _ = typeErr "arity mismatch" + +typeEqRow :: Modey m + -> String -- The term we complain about in errors + -> (Ny :* Stack Z TypeKind :* Stack Z Sem) lv -- Next available level, the kinds of existing levels + -> Ro m lv top0 + -> Ro m lv top1 + -> Either ErrorMsg (Some ((Ny :* Stack Z TypeKind :* Stack Z Sem) -- The new stack of kinds and fresh level + :* ((:~:) top0 :* (:~:) top1)) -- Proofs both input rows have same length (quantified over by Some) + ,[Checking ()] -- subproblems to run in parallel + ) +typeEqRow _ _ stuff R0 R0 = pure (Some (stuff :* (Refl :* Refl)), []) +typeEqRow m tm stuff (RPr (_,ty1) ro1) (RPr (_,ty2) ro2) = typeEqRow m tm stuff ro1 ro2 <&> second + ((:) (typeEq' tm stuff (kindForMode m) ty1 ty2)) +typeEqRow m tm (ny :* kz :* semz) (REx (_,k1) ro1) (REx (_,k2) ro2) | k1 == k2 = typeEqRow m tm (Sy ny :* (kz :<< k1) :* (semz :<< semLvl ny)) ro1 ro2 +typeEqRow _ _ _ _ _ = Left $ TypeErr "Mismatched rows" + +-- Calls to typeEqRigid *must* start with rigid types to ensure termination +typeEqRigid :: String -- String representation of the term for error reporting + -> (Ny :* Stack Z TypeKind :* Stack Z Sem) n + -> TypeKind -- The kind we're comparing at + -> Val n -- Expected + -> Val n -- Actual + -> Checking () +typeEqRigid tm (_ :* _ :* semz) Nat exp act = do + -- TODO: What if there's hope in the numbers? + exp <- sem semz exp + act <- sem semz act + if getNum exp == getNum act + then pure () + else err $ TypeMismatch tm (show exp) (show act) +typeEqRigid tm stuff@(_ :* kz :* _) (TypeFor m []) (VApp f args) (VApp f' args') | f == f' = + svKind f >>= \case + TypeFor m' ks | m == m' -> typeEqs tm stuff (snd <$> ks) (args <>> []) (args' <>> []) + -- pattern should always match + _ -> err $ InternalError "quote gave a surprising result" + where + svKind (VPar e) = kindOf (VPar e) + svKind (VInx n) = pure $ proj kz n +typeEqRigid tm lvkz (TypeFor m []) (VCon c args) (VCon c' args') | c == c' = + req (TLup (m, c)) >>= \case + Just ks -> typeEqs tm lvkz (snd <$> ks) args args' + Nothing -> err $ TypeErr $ "Type constructor " ++ show c + ++ " undefined " ++ " at kind " ++ show (TypeFor m []) +typeEqRigid tm lvkz (Star []) (VFun m0 (FunTy ps0 ins0 outs0)) (VFun m1 (FunTy ps1 ins1 outs1)) + | Just Refl <- testEquality m0 m1 + , eqProps m0 ps0 ps1 = do + probs :: [Checking ()] <- throwLeft $ typeEqRow m0 tm lvkz ins0 ins1 >>= \case -- this is in Either ErrorMsg + (Some (lvkz :* (Refl :* Refl)), ps1) -> typeEqRow m0 tm lvkz outs0 outs1 <&> (ps1++) . snd + sequenceA_ probs -- uses Applicative (unlike sequence_ which uses Monad), hence parallelized +typeEqRigid tm _ _ v0 v1 = err $ TypeMismatch tm (show v0) (show v1) diff --git a/brat/Brat/Checker/SolvePatterns.hs b/brat/Brat/Checker/SolvePatterns.hs index 368c47aa..48e3261f 100644 --- a/brat/Brat/Checker/SolvePatterns.hs +++ b/brat/Brat/Checker/SolvePatterns.hs @@ -19,7 +19,6 @@ import Hasochism import Control.Monad (unless) import Data.Bifunctor (first) -import Data.Foldable (for_, traverse_) import qualified Data.Map as M import Data.Maybe (fromJust) import Data.Type.Equality ((:~:)(..), testEquality) @@ -80,7 +79,7 @@ solve my ((src, Lit tm):p) = do (Braty, Left Nat) | Num n <- tm -> do unless (n >= 0) $ typeErr "Negative Nat kind" - unifyNum (nConstant (fromIntegral n)) (nVar (VPar (ExEnd (end src)))) + unifyNum (nConstant (fromIntegral n)) (nVar (VPar (toEnd src))) (Braty, Right ty) -> do throwLeft (simpleCheck Braty ty tm) _ -> typeErr $ "Literal " ++ show tm ++ " isn't valid at this type" @@ -99,7 +98,7 @@ solve my ((src, PCon c abs):p) = do -- Special case for 0, so that we can call `unifyNum` instead of pattern -- matching using what's returned from `natConstructors` PrefixName [] "zero" -> do - unifyNum (nVar (VPar (ExEnd (end src)))) nZero + unifyNum (nVar (VPar (toEnd src))) nZero p <- argProblems [] (normaliseAbstractor abs) p (tests, sol) <- solve my p pure ((src, PrimLitTest (Num 0)):tests, sol) @@ -110,7 +109,8 @@ solve my ((src, PCon c abs):p) = do (REx ("inner", Nat) R0) unifyNum (nVar (VPar (ExEnd (end src)))) - (relationToInner (nVar (VPar (ExEnd (end dangling))))) + (relationToInner (nVar (VPar (toEnd dangling)))) + -- TODO also do wiring corresponding to relationToInner p <- argProblems [dangling] (normaliseAbstractor abs) p (tests, sol) <- solve my p -- When we get @-patterns, we shouldn't drop this anymore @@ -201,39 +201,6 @@ instantiateMeta e val = do defineEnd e val --- Be conservative, fail if in doubt. Not dangerous like being wrong while succeeding --- We can have bogus failures here because we're not normalising under lambdas --- N.B. the value argument is normalised. -doesntOccur :: End -> Val n -> Either ErrorMsg () -doesntOccur e (VNum nv) = for_ (getNumVar nv) (collision e) - where - getNumVar :: NumVal (VVar n) -> Maybe End - getNumVar (NumValue _ (StrictMonoFun (StrictMono _ mono))) = case mono of - Linear v -> case v of - VPar e -> Just e - _ -> Nothing - Full sm -> getNumVar (numValue sm) - getNumVar _ = Nothing -doesntOccur e (VApp var args) = case var of - VPar e' -> collision e e' *> traverse_ (doesntOccur e) args - _ -> pure () -doesntOccur e (VCon _ args) = traverse_ (doesntOccur e) args -doesntOccur e (VLam body) = doesntOccur e body -doesntOccur e (VFun my (FunTy _ ins outs)) = case my of - Braty -> doesntOccurRo my e ins *> doesntOccurRo my e outs - Kerny -> doesntOccurRo my e ins *> doesntOccurRo my e outs -doesntOccur e (VSum my rows) = traverse_ (\(Some ro) -> doesntOccurRo my e ro) rows - -collision :: End -> End -> Either ErrorMsg () -collision e v | e == v = Left . UnificationError $ - show e ++ " is cyclic" - | otherwise = pure () - -doesntOccurRo :: Modey m -> End -> Ro m i j -> Either ErrorMsg () -doesntOccurRo _ _ R0 = pure () -doesntOccurRo my e (RPr (_, ty) ro) = doesntOccur e ty *> doesntOccurRo my e ro -doesntOccurRo Braty e (REx _ ro) = doesntOccurRo Braty e ro - -- Need to keep track of which way we're solving - which side is known/unknown -- Things which are dynamically unknown must be Tgts - information flows from Srcs -- ...But we don't need to do any wiring here, right? @@ -375,7 +342,7 @@ argProblems srcs na p = argProblemsWithLeftovers srcs na p >>= \case _ -> err $ UnificationError "Pattern doesn't match expected length for constructor args" argProblemsWithLeftovers :: [Src] -> NormalisedAbstractor -> Problem -> Checking (Problem, [Src]) -argProblemsWithLeftovers srcs (NA (APull ps abs)) p = pullPorts portName show ps (map (, ()) srcs) >>= \srcs -> argProblemsWithLeftovers (fst <$> srcs) (NA abs) p +argProblemsWithLeftovers srcs (NA (APull ps abs)) p = pullPorts portName show ps srcs >>= \srcs -> argProblemsWithLeftovers srcs (NA abs) p argProblemsWithLeftovers (src:srcs) na p | Just (pat, na) <- unconsNA na = first ((src, pat):) <$> argProblemsWithLeftovers srcs na p argProblemsWithLeftovers srcs (NA AEmpty) p = pure (p, srcs) argProblemsWithLeftovers [] abst _ = err $ NothingToBind (show abst) diff --git a/brat/Brat/Compile/Hugr.hs b/brat/Brat/Compile/Hugr.hs index eeb1faea..477572c5 100644 --- a/brat/Brat/Compile/Hugr.hs +++ b/brat/Brat/Compile/Hugr.hs @@ -137,11 +137,11 @@ runCheckingInCompile (Req _ _) = error "Compile monad found a command it can't h -- To be called on top-level signatures which are already Inx-closed, but not -- necessarily normalised. -compileSig :: Modey m -> FunTy m Z -> Compile PolyFuncType +compileSig :: Modey m -> FunTy m Z -> Compile ([HugrType], [HugrType]) compileSig my cty = do runCheckingInCompile (evalFunTy S0 my cty) <&> compileFunTy -compileFunTy (FunTy _ ss ts) = PolyFuncType [] (FunctionType (compileRo ss) (compileRo ts)) +compileFunTy (FunTy _ ss ts) = (compileRo ss, compileRo ts) compileRo :: Ro m i j -- The Ro that we're processing -> [HugrType] -- The hugr type of the row @@ -165,15 +165,12 @@ compileType ty@(TCons _ _) = htTuple (tuple ty) tuple TNil = [] tuple ty = error $ "Found " ++ show ty ++ " in supposed tuple type" compileType TNil = htTuple [] -compileType (VSum my ros) = case my of - Braty -> error "Todo: compileTypeWorker for BRAT" - Kerny -> HTSum (SG (GeneralSum $ map (\(Some ro) -> compileRo ro) ros)) compileType (TVec el _) = hugrList (compileType el) compileType (TList el) = hugrList (compileType el) -- All variables are of kind `TypeFor m xs`, we already checked in `kindCheckRow` compileType (VApp _ _) = htTuple [] -- VFun is already evaluated here, so we don't need to call `compileSig` -compileType (VFun _ cty) = HTFunc $ compileFunTy cty +compileType (VFun _ cty) = let (ins, outs) = compileFunTy cty in HTFunc (PolyFuncType [] (FunctionType ins outs bratExts)) compileType ty = error $ "todo: compile type " ++ show ty compileGraphTypes :: Traversable t => t (Val Z) -> Compile (t HugrType) @@ -201,7 +198,7 @@ compileConst parent tm ty = do constId <- addNode "Const" (OpConst (ConstOp parent (valFromSimple tm))) loadId <- case ty of HTFunc poly@(PolyFuncType [] _) -> - addNode "LoadFunction" (OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly]))) + addNode "LoadFunction" (OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly] []))) HTFunc (PolyFuncType _ _) -> error "Trying to compile function with type args" _ -> addNode "LoadConst" (OpLoadConstant (LoadConstantOp parent ty)) addEdge (Port constId 0, Port loadId 0) @@ -212,13 +209,13 @@ compileArithNode parent op TNat = addNode (show op ++ "_Nat") $ OpCustom $ case Add -> binaryIntOp parent "iadd" Sub -> binaryIntOp parent "isub" Mul-> binaryIntOp parent "imul" - Div -> intOp parent "idiv_u" (FunctionType [hugrInt, hugrInt] [hugrInt]) [TANat intWidth, TANat intWidth] + Div -> intOp parent "idiv_u" [hugrInt, hugrInt] [hugrInt] [TANat intWidth, TANat intWidth] Pow -> error "TODO: Pow" -- Not defined in extension compileArithNode parent op TInt = addNode (show op ++ "_Int") $ OpCustom $ case op of Add -> binaryIntOp parent "iadd" Sub -> binaryIntOp parent "isub" Mul-> binaryIntOp parent "imul" - Div -> intOp parent "idiv_u" (FunctionType [hugrInt, hugrInt] [hugrInt]) [TANat intWidth, TANat intWidth] + Div -> intOp parent "idiv_u" [hugrInt, hugrInt] [hugrInt] [TANat intWidth, TANat intWidth] Pow -> error "TODO: Pow" -- Not defined in extension compileArithNode parent op TFloat = addNode (show op ++ "_Float") $ OpCustom $ case op of Add -> binaryFloatOp parent "fadd" @@ -275,7 +272,7 @@ compileClauses parent ins ((matchData, rhs) :| clauses) = do (ns, _) <- gets bratGraph -- RHS has to be a box, so it must have a function type outTys <- case nodeOuts (ns M.! rhs) of - [(_, VFun my cty)] -> compileSig my cty >>= (\(FunctionType _ outs) -> pure outs) . body + [(_, VFun my cty)] -> compileSig my cty >>= (\(_, outs) -> pure outs) _ -> error "Expected 1 kernel function type from rhs" -- Compile the match: testResult is the port holding the dynamic match result @@ -295,13 +292,13 @@ compileClauses parent ins ((matchData, rhs) :| clauses) = do didntMatch outTys parent ins = case nonEmpty clauses of Just clauses -> compileClauses parent ins clauses -- If there are no more clauses left to test, then the Hugr panics - Nothing -> let sig = FunctionType (snd <$> ins) outTys in - addNodeWithInputs "Panic" (OpCustom (CustomOp parent "brat" "panic" sig [])) ins outTys + Nothing -> let sig = FunctionType (snd <$> ins) outTys ["BRAT"] in + addNodeWithInputs "Panic" (OpCustom (CustomOp parent "BRAT" "panic" sig [])) ins outTys didMatch :: [HugrType] -> NodeId -> [TypedPort] -> Compile [TypedPort] didMatch outTys parent ins = gets bratGraph >>= \(ns,_) -> case ns M.! rhs of BratNode (Box _venv src tgt) _ _ -> do - dfgId <- addNode "DidMatch_DFG" (OpDFG (DFG parent (FunctionType (snd <$> ins) outTys))) + dfgId <- addNode "DidMatch_DFG" (OpDFG (DFG parent (FunctionType (snd <$> ins) outTys bratExts))) compileBox (src, tgt) dfgId for_ (zip (fst <$> ins) (Port dfgId <$> [0..])) addEdge pure $ zip (Port dfgId <$> [0..]) outTys @@ -337,13 +334,13 @@ compileWithInputs parent name = gets compiled >>= (\case let (funcDef, extra_call) = decls M.! name nod <- if extra_call then addNode ("direct_call(" ++ show funcDef ++ ")") - (OpCall (CallOp parent (FunctionType [] hTys))) + (OpCall (CallOp parent (FunctionType [] hTys bratExts))) -- We are loading idNode as a value (not an Eval'd thing), and it is a FuncDef directly -- corresponding to a Brat TLD (not that produces said TLD when eval'd) else case hTys of [HTFunc poly@(PolyFuncType [] _)] -> addNode ("load_thunk(" ++ show funcDef ++ ")") - (OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly]))) + (OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly] []))) [HTFunc (PolyFuncType args _)] -> error $ unwords ["Unexpected type args to" ,show funcDef ++ ":" ,show args @@ -376,7 +373,7 @@ compileWithInputs parent name = gets compiled >>= (\case Splice (Ex outNode _) -> default_edges <$> do ins <- compilePorts ins outs <- compilePorts outs - let sig = FunctionType ins outs + let sig = FunctionType ins outs bratExts case hasPrefix ["checking", "globals", "prim"] outNode of -- If we're evaling a Prim, we add it directly into the kernel graph Just suffix -> do @@ -401,11 +398,12 @@ compileWithInputs parent name = gets compiled >>= (\case let n = ext ++ ('_':op) let [] = ins let [(_, VFun Braty cty)] = outs - box_sig@(FunctionType inputTys outputTys) <- body <$> compileSig Braty cty - ((Port loadConst _, _ty), ()) <- compileConstDfg parent n box_sig $ \dfg_id -> do - ins <- addNodeWithInputs ("Inputs" ++ n) (OpIn (InputNode dfg_id inputTys)) [] inputTys - outs <- addNodeWithInputs n (OpCustom (CustomOp dfg_id ext op box_sig [])) ins outputTys - addNodeWithInputs ("Outputs" ++ n) (OpOut (OutputNode dfg_id outputTys)) outs [] + boxSig@(inputTys, outputTys) <- compileSig Braty cty + let boxFunTy = FunctionType inputTys outputTys bratExts + ((Port loadConst _, _ty), ()) <- compileConstDfg parent n boxSig $ \dfgId -> do + ins <- addNodeWithInputs ("Inputs" ++ n) (OpIn (InputNode dfgId inputTys)) [] inputTys + outs <- addNodeWithInputs n (OpCustom (CustomOp dfgId ext op boxFunTy [])) ins outputTys + addNodeWithInputs ("Outputs" ++ n) (OpOut (OutputNode dfgId outputTys)) outs [] pure () pure $ default_edges loadConst @@ -419,13 +417,13 @@ compileWithInputs parent name = gets compiled >>= (\case -- Callee is a Prim node, insert Hugr Op; first look up outNode in the BRAT graph to get the Prim data Just suffix -> default_edges <$> case M.lookup outNode ns of Just (BratNode (Prim (ext,op)) _ _) -> do - addNode (show suffix) (OpCustom (CustomOp parent ext op (FunctionType ins outs) [])) + addNode (show suffix) (OpCustom (CustomOp parent ext op (FunctionType ins outs [ext]) [])) x -> error $ "Expected a Prim node but got " ++ show x Nothing -> case hasPrefix ["checking", "globals"] outNode of -- Callee is a user-defined global def that, since it does not require an "extra" call, can be turned from IndirectCall to direct. Just _ | (funcDef, False) <- fromJust (M.lookup outNode decls) -> do callerId <- addNode ("direct_call(" ++ show funcDef ++ ")") - (OpCall (CallOp parent (FunctionType ins outs))) + (OpCall (CallOp parent (FunctionType ins outs bratExts))) -- Add the static edge from the FuncDefn node to the port *after* -- all of the dynamic arguments to the Call node. -- This is because in hugr, static edges (like the graph arg to a @@ -437,7 +435,7 @@ compileWithInputs parent name = gets compiled >>= (\case _ -> compileWithInputs parent outNode >>= \case Just calleeId -> do callerId <- addNode ("indirect_call(" ++ show calleeId ++ ")") - (OpCallIndirect (CallIndirectOp parent (FunctionType ins outs))) + (OpCallIndirect (CallIndirectOp parent (FunctionType ins outs bratExts {-[]-}))) -- for an IndirectCall, the callee (thunk, function value) is the *first* -- Hugr input. So move all the others along, and add that extra edge. pure $ Just (callerId, 1, [(Port calleeId outPort, 0)]) @@ -468,11 +466,11 @@ compileWithInputs parent name = gets compiled >>= (\case case outs of [(_, VCon tycon _)] -> do outs <- compilePorts outs - compileConstructor parent tycon c (FunctionType ins outs) + compileConstructor parent tycon c (FunctionType ins outs ["BRAT"]) PatternMatch cs -> default_edges <$> do ins <- compilePorts ins outs <- compilePorts outs - dfgId <- addNode "DidMatch_DFG" (OpDFG (DFG parent (FunctionType ins outs))) + dfgId <- addNode "DidMatch_DFG" (OpDFG (DFG parent (FunctionType ins outs bratExts))) inputNode <- addNode "PatternMatch.Input" (OpIn (InputNode dfgId ins)) ccOuts <- compileClauses dfgId (zip (Port inputNode <$> [0..]) ins) cs addNodeWithInputs "PatternMatch.Output" (OpOut (OutputNode dfgId (snd <$> ccOuts))) ccOuts [] @@ -483,7 +481,7 @@ compileWithInputs parent name = gets compiled >>= (\case ins <- compilePorts ins let [_, elemTy] = ins outs <- compilePorts outs - let sig = FunctionType ins outs + let sig = FunctionType ins outs bratExts addNode "Replicate" (OpCustom (CustomOp parent "BRAT" "Replicate" sig [TAType elemTy])) x -> error $ show x ++ " should have been compiled outside of compileNode" @@ -516,26 +514,28 @@ getOutPort parent p@(Ex srcNode srcPort) = do -- Execute a compilation (which takes a DFG parent) in a nested monad; -- produce a Const node containing the resulting Hugr, and a LoadConstant, -- and return the latter. -compileConstDfg :: NodeId -> String -> FunctionType -> (NodeId -> Compile a) -> Compile (TypedPort, a) -compileConstDfg parent desc box_sig contents = do +compileConstDfg :: NodeId -> String -> ([HugrType], [HugrType]) -> (NodeId -> Compile a) -> Compile (TypedPort, a) +compileConstDfg parent desc (inTys, outTys) contents = do st <- gets store g <- gets bratGraph -- First, we fork off a new namespace - (res, cs) <- desc -! do + ((funTy, a), cs) <- desc -! do ns <- gets nameSupply pure $ flip runState (emptyCS g ns st) $ do -- make a DFG node at the root. We can't use `addNode` since the -- DFG needs itself as parent dfg_id <- freshNode ("Box_" ++ show desc) - addOp (OpDFG $ DFG dfg_id box_sig) dfg_id - contents dfg_id + a <- contents dfg_id + let funTy = FunctionType inTys outTys bratExts + addOp (OpDFG $ DFG dfg_id funTy) dfg_id + pure (funTy, a) let nestedHugr = renameAndSortHugr (nodes cs) (edges cs) - let ht = HTFunc $ PolyFuncType [] box_sig + let ht = HTFunc $ PolyFuncType [] funTy constNode <- addNode ("ConstTemplate_" ++ desc) (OpConst (ConstOp parent (HVFunction nestedHugr))) lcPort <- head <$> addNodeWithInputs ("LoadTemplate_" ++ desc) (OpLoadConstant (LoadConstantOp parent ht)) [(Port constNode 0, ht)] [ht] - pure (lcPort, res) + pure (lcPort, a) -- Brat computations may capture some local variables. Thus, we need -- to lambda-lift, producing (as results) a Partial node and a list of @@ -549,12 +549,12 @@ compileBratBox parent name (venv, src, tgt) cty = do parmTys <- compileGraphTypes (map (binderToValue Braty . snd) params) -- Create a FuncDefn for the lambda that takes the params as first inputs - (FunctionType inputTys outputTys) <- body <$> compileSig Braty cty + (inputTys, outputTys) <- compileSig Braty cty let allInputTys = parmTys ++ inputTys - let box_sig = FunctionType allInputTys outputTys + let boxInnerSig = FunctionType allInputTys outputTys bratExts - (templatePort, _) <- compileConstDfg parent ("BB" ++ show name) box_sig $ \dfg_id -> do - src_id <- addNode ("LiftedCapturesInputs" ++ show name) (OpIn (InputNode dfg_id allInputTys)) + (templatePort, _) <- compileConstDfg parent ("BB" ++ show name) (allInputTys, outputTys) $ \dfgId -> do + src_id <- addNode ("LiftedCapturesInputs" ++ show name) (OpIn (InputNode dfgId allInputTys)) -- Now map ports in the BRAT Graph to their Hugr equivalents. -- Each captured value is read from an element of src_id, starting from 0 let lifted = [(src, Port src_id i) | ((src, _ty), i) <- zip params [0..]] @@ -563,10 +563,10 @@ compileBratBox parent name (venv, src, tgt) cty = do st <- get put $ st {liftedOutPorts = M.fromList lifted} -- no need to return any holes - compileWithInputs dfg_id tgt + compileWithInputs dfgId tgt -- Finally, we add a `Partial` node to supply the captured params. - partialNode <- addNode "Partial" (OpCustom $ partialOp parent box_sig (length params)) + partialNode <- addNode "Partial" (OpCustom $ partialOp parent boxInnerSig (length params)) addEdge (fst templatePort, Port partialNode 0) edge_srcs <- for (map fst params) $ getOutPort parent pure (partialNode, zip (map fromJust edge_srcs) [1..]) @@ -577,9 +577,9 @@ compileKernBox parent name contents cty = do -- compile kernel nodes only into a Hugr with "Holes" -- when we see a Splice, we'll record the func-port onto a list -- return a Hugr with holes - box_sig <- body <$> compileSig Kerny cty - let box_ty = HTFunc $ PolyFuncType [] box_sig - (templatePort, holelist) <- compileConstDfg parent ("KB" ++ show name) box_sig $ \dfg_id -> do + boxInnerSig@(inTys, outTys) <- compileSig Kerny cty + let boxTy = HTFunc $ PolyFuncType [] (FunctionType inTys outTys bratExts) + (templatePort, holelist) <- compileConstDfg parent ("KB" ++ show name) boxInnerSig $ \dfg_id -> do contents dfg_id gets holes @@ -591,11 +591,11 @@ compileKernBox parent name contents cty = do ins <- compilePorts ins outs <- compilePorts outs kernel_src <- compileWithInputs parent kernel_src <&> fromJust - pure (Port kernel_src port, HTFunc (PolyFuncType [] (FunctionType ins outs)))) + pure (Port kernel_src port, HTFunc (PolyFuncType [] (FunctionType ins outs bratExts)))) -- Add a substitute node to fill the holes in the template let hole_sigs = [ body poly | (_, HTFunc poly) <- hole_ports ] - head <$> addNodeWithInputs ("subst_" ++ show name) (OpCustom (substOp parent box_sig hole_sigs)) (templatePort : hole_ports) [box_ty] + head <$> addNodeWithInputs ("subst_" ++ show name) (OpCustom (substOp parent (FunctionType inTys outTys bratExts) hole_sigs)) (templatePort : hole_ports) [boxTy] -- We get a bunch of TypedPorts which are associated with Srcs in the BRAT graph. @@ -739,7 +739,7 @@ makeConditional parent discrim otherInputs cases = do outId <- addNode ("Output" ++ name) (OpOut (OutputNode caseId outTys)) for_ (zip (fst <$> outs) (Port outId <$> [0..])) addEdge - addOp (OpCase (ix, Case parent (FunctionType tys outTys))) caseId + addOp (OpCase (ix, Case parent (FunctionType tys outTys bratExts))) caseId pure outTys allRowsEqual :: [[HugrType]] -> Bool @@ -753,7 +753,7 @@ compilePrimTest :: NodeId -> Compile TypedPort compilePrimTest parent (port, ty) (PrimCtorTest c tycon unpackingNode outputs) = do let sumOut = HTSum (SG (GeneralSum [[ty], snd <$> outputs])) - let sig = FunctionType [ty] [sumOut] + let sig = FunctionType [ty] [sumOut] ["BRAT"] testId <- addNode ("PrimCtorTest " ++ show c) (OpCustom (CustomOp parent @@ -771,7 +771,7 @@ compilePrimTest parent port@(_, ty) (PrimLitTest tm) = do [(Port constId 0, ty)] [ty] -- Connect to a test node let sumOut = HTSum (SG (GeneralSum [[ty], []])) - let sig = FunctionType [ty, ty] [sumOut] + let sig = FunctionType [ty, ty] [sumOut] ["BRAT"] head <$> addNodeWithInputs ("PrimLitTest " ++ show tm) (OpCustom (CustomOp parent "BRAT" ("PrimLitTest::" ++ show ty) sig [])) [port, loadPort] @@ -786,7 +786,7 @@ undoPrimTest :: NodeId -> PrimTest HugrType -- The test to undo -> Compile TypedPort undoPrimTest parent inPorts outTy (PrimCtorTest c tycon _ _) = do - let sig = FunctionType (snd <$> inPorts) [outTy] + let sig = FunctionType (snd <$> inPorts) [outTy] ["BRAT"] head <$> addNodeWithInputs ("UndoCtorTest " ++ show c) (constructorOp parent tycon c sig) @@ -833,8 +833,8 @@ compileModule venv = do [(Ex input 0, _)] | Just (BratNode (Box _ src tgt) _ outs) <- M.lookup input ns -> case outs of [(_, VFun Braty cty)] -> do - sig <- compileSig Braty cty - pure (sig, False, compileBox (src, tgt)) + (inTys, outTys) <- compileSig Braty cty + pure (PolyFuncType [] (FunctionType inTys outTys bratExts), False, compileBox (src, tgt)) [(_, VFun Kerny cty)] -> do -- We're compiling, e.g. -- f :: { Qubit -o Qubit } @@ -842,7 +842,7 @@ compileModule venv = do -- Although this looks like a constant kernel, we'll have to compile the -- computation that produces this constant. We do so by making a FuncDefn -- that takes no arguments and produces the constant kernel graph value. - thunkTy <- HTFunc <$> compileSig Kerny cty + thunkTy <- HTFunc . PolyFuncType [] . (\(ins, outs) -> FunctionType ins outs bratExts) <$> compileSig Kerny cty pure (funcReturning [thunkTy], True, \parent -> withIO parent thunkTy $ compileKernBox parent input (compileBox (src, tgt)) cty) _ -> error "Box should have exactly one output of Thunk type" @@ -873,7 +873,7 @@ compileModule venv = do pure id funcReturning :: [HugrType] -> PolyFuncType - funcReturning outs = PolyFuncType [] (FunctionType [] outs) + funcReturning outs = PolyFuncType [] (FunctionType [] outs bratExts) compileNoun :: [HugrType] -> [OutPort] -> NodeId -> Compile () compileNoun outs srcPorts parent = do diff --git a/brat/Brat/Elaborator.hs b/brat/Brat/Elaborator.hs index 5e867a38..df65a0b8 100644 --- a/brat/Brat/Elaborator.hs +++ b/brat/Brat/Elaborator.hs @@ -91,6 +91,7 @@ elaborate (WC fc x) = do elaborate' :: Flat -> Either Error SomeRaw' elaborate' (FVar x) = pure $ SomeRaw' (RVar x) +elaborate' FHope = pure $ SomeRaw' RHope elaborate' (FArith op a b) = do (SomeRaw a) <- elaborate a (SomeRaw b) <- elaborate b diff --git a/brat/Brat/Error.hs b/brat/Brat/Error.hs index 32cea48a..dbccdbb1 100644 --- a/brat/Brat/Error.hs +++ b/brat/Brat/Error.hs @@ -1,5 +1,6 @@ module Brat.Error (ParseError(..) ,LengthConstraintF(..), LengthConstraint + ,BracketErrMsg(..) ,ErrorMsg(..) ,Error(..), showError ,SrcErr(..) @@ -9,6 +10,8 @@ module Brat.Error (ParseError(..) ) where import Brat.FC +import Data.Bracket +import Brat.Syntax.Port (PortName) import Data.List (intercalate) import System.Exit @@ -25,6 +28,28 @@ instance Show a => Show (LengthConstraintF a) where type LengthConstraint = LengthConstraintF Int +data BracketErrMsg + = EOFInBracket BracketType -- FC in enclosing `Err` should point to the open bracket + -- FC here is opening; closing FC in the enclosing `Err` + | OpenCloseMismatch (FC, BracketType) BracketType + | UnexpectedClose BracketType + +instance Show BracketErrMsg where + show (EOFInBracket b) = "File ended before this " ++ showOpen b ++ " was closed" + show (OpenCloseMismatch (openFC, bOpen) bClose) = unwords ["This" + ,showClose bClose + ,"doesn't match the" + ,showOpen bOpen + ,"at" + ,show openFC + ] + show (UnexpectedClose b) = unwords ["There is no" + ,showOpen b + ,"for this" + ,showClose b + ,"to close" + ] + data ErrorMsg = TypeErr String -- Term, Expected type, Actual type @@ -60,8 +85,8 @@ data ErrorMsg | FileNotFound String [String] | SymbolNotFound String String | InternalError String - | AmbiguousPortPull String String - | BadPortPull String + | AmbiguousPortPull PortName String + | BadPortPull PortName String | VConNotFound String | TyConNotFound String String | MatchingOnTypes @@ -82,6 +107,8 @@ data ErrorMsg -- The argument is the row of unused connectors | ThunkLeftOvers String | ThunkLeftUnders String + | BracketErr BracketErrMsg + | RemainingNatHopes [String] instance Show ErrorMsg where show (TypeErr x) = "Type error: " ++ x @@ -139,7 +166,7 @@ instance Show ErrorMsg where show (SymbolNotFound s i) = "Symbol `" ++ s ++ "` not found in `" ++ i ++ "`" show (InternalError x) = "Internal error: " ++ x show (AmbiguousPortPull p row) = "Port " ++ p ++ " is ambiguous in " ++ row - show (BadPortPull x) = "Port " ++ x ++ " can't be pulled because it depends on a previous port" + show (BadPortPull p row) = "Port not found: " ++ p ++ " in " ++ row show (VConNotFound x) = "Value constructor not recognised: " ++ x show (TyConNotFound ty v) = show v ++ " is not a valid constructor for type " ++ ty show MatchingOnTypes = "Trying to pattern match on a type" @@ -165,7 +192,8 @@ instance Show ErrorMsg where show UnreachableBranch = "Branch cannot be reached" show (ThunkLeftOvers overs) = "Expected function to address all inputs, but " ++ overs ++ " wasn't used" show (ThunkLeftUnders unders) = "Expected function to return additional values of type: " ++ unders - + show (BracketErr msg) = show msg + show (RemainingNatHopes hs) = unlines ("Expected to work out values for these holes:":((" " ++) <$> hs)) data Error = Err { fc :: Maybe FC , msg :: ErrorMsg diff --git a/brat/Brat/Eval.hs b/brat/Brat/Eval.hs index ad362f3b..82d19225 100644 --- a/brat/Brat/Eval.hs +++ b/brat/Brat/Eval.hs @@ -4,14 +4,20 @@ module Brat.Eval (EvMode(..) ,ValPat(..) ,NumPat(..) ,apply + ,applySem ,eval ,sem + ,semLvl + ,doesntOccur ,evalFunTy ,eqTest + ,getNum ,kindEq + ,kindOf ,kindType ,numVal - ,typeEq + ,quote + ,getNumVar ) where import Brat.Checker.Monad @@ -19,8 +25,8 @@ import Brat.Checker.Types (EndType(..), kindForMode) import Brat.Error (ErrorMsg(..)) import Brat.QualName (plain) import Brat.Syntax.CircuitProperties (eqProps) -import Brat.Syntax.Value import Brat.Syntax.Common +import Brat.Syntax.Value import Control.Monad.Freer (req) import Bwd import Hasochism @@ -30,6 +36,7 @@ import Data.Bifunctor (second) import Data.Functor import Data.Kind (Type) import Data.Type.Equality (TestEquality(..), (:~:)(..)) +import Data.Foldable (traverse_) kindType :: TypeKind -> Val Z kindType Nat = TNat @@ -82,7 +89,6 @@ sem ga (VApp f vz) = do f <- semVar ga f vz <- traverse (sem ga) vz applySem f vz -sem ga (VSum my ts) = pure $ SSum my ga ts semVar :: Stack Z Sem n -> VVar n -> Checking Sem semVar vz (VInx inx) = pure $ proj vz inx @@ -119,10 +125,6 @@ quote lvy (SLam stk body) = do VLam <$> quote (Sy lvy) body quote lvy (SFun my ga cty) = VFun my <$> quoteFunTy lvy my ga cty quote lvy (SApp f vz) = VApp (quoteVar lvy f) <$> traverse (quote lvy) vz -quote lvy (SSum my ga ts) = VSum my <$> traverse quoteVariant ts - where - quoteVariant (Some ro) = quoteRo my ga ro lvy >>= \case - (_, Some (ro :* _)) -> pure (Some ro) quoteFunTy :: Ny lv -> Modey m -> Stack Z Sem n -> FunTy m n -> Checking (FunTy m lv) quoteFunTy lvy my ga (FunTy ps ins outs) = quoteRo my ga ins lvy >>= \case @@ -192,14 +194,8 @@ kindOf (VPar e) = req (TypeOf e) >>= \case Kerny -> show ty kindOf (VInx n) = case n of {} --- We should have made sure that the two values share the given kind -typeEq :: String -- String representation of the term for error reporting - -> TypeKind -- The kind we're comparing at - -> Val Z -- Expected - -> Val Z -- Actual - -> Checking () -typeEq str k exp act = eqTest str k exp act >>= throwLeft - +-------- for SolvePatterns usage: not allowed to solve hopes, +-- and if pattern insoluble, it's not a type error (it's a "pattern match case unreachable") eqTest :: String -- String representation of the term for error reporting -> TypeKind -- The kind we're comparing at -> Val Z -- Expected @@ -263,10 +259,7 @@ eqWorker tm lvkz (TypeFor _ []) (SSum m0 stk0 rs0) (SSum m1 stk1 rs1) Just rs -> traverse eqVariant rs <&> sequence_ where eqVariant (Some r0, Some r1) = eqRowTest m0 tm lvkz (stk0,r0) (stk1,r1) <&> dropRight -eqWorker tm _ _ s0 s1 = do - v0 <- quote Zy s0 - v1 <- quote Zy s1 - pure . Left $ TypeMismatch tm (show v0) (show v1) +eqWorker tm _ _ v0 v1 = pure . Left $ TypeMismatch tm (show v0) (show v1) -- Type rows have bot0,bot1 dangling de Bruijn indices, which we instantiate with -- de Bruijn levels. As we go under binders in these rows, we add to the scope's @@ -307,3 +300,35 @@ eqTests tm lvkz = go Left e -> pure $ Left e go _ us vs = pure . Left . TypeErr $ "Arity mismatch in type constructor arguments:\n " ++ show us ++ "\n " ++ show vs + +getNumVar :: NumVal (VVar n) -> Maybe End +getNumVar (NumValue _ (StrictMonoFun (StrictMono _ mono))) = case mono of + Linear v -> case v of + VPar e -> Just e + _ -> Nothing + Full sm -> getNumVar (numValue sm) +getNumVar _ = Nothing + +-- Be conservative, fail if in doubt. Not dangerous like being wrong while succeeding +-- We can have bogus failures here because we're not normalising under lambdas +-- N.B. the value argument is normalised. +doesntOccur :: End -> Val n -> Either ErrorMsg () +doesntOccur e (VNum nv) = traverse_ (collision e) (getNumVar nv) +doesntOccur e (VApp var args) = case var of + VPar e' -> collision e e' *> traverse_ (doesntOccur e) args + _ -> pure () +doesntOccur e (VCon _ args) = traverse_ (doesntOccur e) args +doesntOccur e (VLam body) = doesntOccur e body +doesntOccur e (VFun my (FunTy _ ins outs)) = case my of + Braty -> doesntOccurRo my e ins *> doesntOccurRo my e outs + Kerny -> doesntOccurRo my e ins *> doesntOccurRo my e outs + +collision :: End -> End -> Either ErrorMsg () +collision e v | e == v = Left . UnificationError $ + show e ++ " is cyclic" + | otherwise = pure () + +doesntOccurRo :: Modey m -> End -> Ro m i j -> Either ErrorMsg () +doesntOccurRo _ _ R0 = pure () +doesntOccurRo my e (RPr (_, ty) ro) = doesntOccur e ty *> doesntOccurRo my e ro +doesntOccurRo Braty e (REx _ ro) = doesntOccurRo Braty e ro diff --git a/brat/Brat/FC.hs b/brat/Brat/FC.hs index 958df5d1..669b9506 100644 --- a/brat/Brat/FC.hs +++ b/brat/Brat/FC.hs @@ -35,3 +35,9 @@ fcOf (WC fc _) = fc -- TODO: Remove this dummyFC :: a -> WC a dummyFC = WC (FC (Pos 0 0) (Pos 0 0)) + +spanFC :: FC -> FC -> FC +spanFC afc bfc = FC (start afc) (end bfc) + +spanFCOf :: WC a -> WC b -> FC +spanFCOf (WC afc _) (WC bfc _) = spanFC afc bfc diff --git a/brat/Brat/Lexer/Bracketed.hs b/brat/Brat/Lexer/Bracketed.hs new file mode 100644 index 00000000..d668c1ed --- /dev/null +++ b/brat/Brat/Lexer/Bracketed.hs @@ -0,0 +1,98 @@ +module Brat.Lexer.Bracketed (BToken(..), brackets) where + +import Data.Bracket +import Brat.Error (BracketErrMsg(..), Error(Err), ErrorMsg(..)) +import Brat.FC +import Brat.Lexer.Token + +import Data.List.NonEmpty (NonEmpty(..)) +import Data.Bifunctor (first) +import Text.Megaparsec (PosState(..), SourcePos(..), TraversableStream(..), VisualStream(..)) +import Text.Megaparsec.Pos (mkPos) + +data OpenClose = Open BracketType | Close BracketType + +openClose :: Tok -> Maybe OpenClose +openClose LParen = Just (Open Paren) +openClose LSquare = Just (Open Square) +openClose LBrace = Just (Open Brace) +openClose RParen = Just (Close Paren) +openClose RSquare = Just (Close Square) +openClose RBrace = Just (Close Brace) +openClose _ = Nothing + +-- Well bracketed tokens +data BToken + = Bracketed FC BracketType [BToken] + | FlatTok Token + deriving (Eq, Ord) + +btokLen :: BToken -> Int +btokLen (FlatTok tok) = length (show tok) +btokLen (Bracketed _ _ bs) = sum (btokLen <$> bs) + 2 + +instance Show BToken where + show (FlatTok t) = show t + show (Bracketed _ b ts) = showOpen b ++ concatMap show ts ++ showClose b + +instance VisualStream [BToken] where + showTokens _ = concatMap show + tokensLength _ = sum . fmap btokLen + +instance TraversableStream [BToken] where + reachOffsetNoLine i pos = let fileName = sourceName (pstateSourcePos pos) + (Pos line col, rest) = skipChars (i - pstateOffset pos + 1) (pstateInput pos) + in pos + { pstateInput = rest + , pstateOffset = max (pstateOffset pos) i + , pstateSourcePos = SourcePos fileName (mkPos line) (mkPos col) + } + where + skipChars :: Int -> [BToken] -> (Pos, [BToken]) + skipChars 0 inp@(Bracketed fc _ _:_) = (start fc, inp) + skipChars 0 inp@(FlatTok t:_) = (start (fc t), inp) + skipChars i ((Bracketed fc b bts):rest) = + let Pos closeLine closeCol = end fc + closeFC = FC (Pos closeLine (closeCol - 1)) (Pos closeLine closeCol) + in skipChars (i - 1) (bts ++ [FlatTok (Token closeFC (closeTok b))] ++ rest) + skipChars i (FlatTok t:rest) + | i >= tokenLen t = skipChars (i - tokenLen t) rest + | otherwise = (start (fc t), FlatTok t:rest) + + closeTok Paren = RParen + closeTok Square = RSquare + closeTok Brace = RBrace + +eofErr :: (FC, BracketType) -> Error +eofErr (fc, b) = Err (Just fc) (BracketErr (EOFInBracket b)) + +openCloseMismatchErr :: (FC, BracketType) -> (FC, BracketType) -> Error +openCloseMismatchErr open (fcClose, bClose) + = Err (Just fcClose) (BracketErr (OpenCloseMismatch open bClose)) + +unexpectedCloseErr :: (FC, BracketType) -> Error +unexpectedCloseErr (fc, b) = Err (Just fc) (BracketErr (UnexpectedClose b)) + +brackets :: [Token] -> Either Error [BToken] +brackets ts = helper ts >>= \case + (res, Nothing) -> pure res + (_, Just (b, t:|_)) -> Left $ unexpectedCloseErr (fc t, b) + where + -- Given a list of tokens, either + -- (success) return [BToken] consisting of the prefix of the input [Token] in which all opened brackets are closed, + -- and any remaining [Token] beginning with a closer that does not match any opener in the input + -- (either Nothing = no remaining tokens; or tokens with the BracketType that the first token closes) + -- (failure) return an error, if a bracket opened in the input, is either not closed (EOF) or does not match the closer + helper :: [Token] -> Either Error ([BToken], Maybe (BracketType, NonEmpty Token)) + helper [] = pure ([], Nothing) + helper (t:ts) = case openClose (_tok t) of + Just (Open b) -> let openFC = fc t in helper ts >>= \case + (_, Nothing) -> Left $ eofErr (fc t, b) + (within, Just (b', r :| rs)) -> + let closeFC = fc r + enclosingFC = spanFC openFC closeFC + in if b == b' + then first (Bracketed enclosingFC b within:) <$> helper rs + else Left $ openCloseMismatchErr (openFC, b) (closeFC, b') + Just (Close b) -> pure ([], Just (b, t :| ts)) -- return closer for caller + Nothing -> first (FlatTok t:) <$> helper ts diff --git a/brat/Brat/Lexer/Flat.hs b/brat/Brat/Lexer/Flat.hs index 1f41d5ba..511e70d1 100644 --- a/brat/Brat/Lexer/Flat.hs +++ b/brat/Brat/Lexer/Flat.hs @@ -54,8 +54,8 @@ tok = try (char '(' $> LParen) <|> try (char ')' $> RParen) <|> try (char '{' $> LBrace) <|> try (char '}' $> RBrace) - <|> try (char '[' $> LBracket) - <|> try (char ']' $> RBracket) + <|> try (char '[' $> LSquare) + <|> try (char ']' $> RSquare) <|> try (Underscore <$ string "_") <|> try (Quoted <$> (char '"' *> printChar `manyTill` char '"')) <|> try (FloatLit <$> float) @@ -87,6 +87,7 @@ tok = try (char '(' $> LParen) <|> try (string "-" $> Minus) <|> try (string "$" $> Dollar) <|> try (string "|" $> Pipe) + <|> try (string "!" $> Bang) <|> try (K <$> try keyword) <|> try qualified <|> Ident <$> ident diff --git a/brat/Brat/Lexer/Token.hs b/brat/Brat/Lexer/Token.hs index d5f8842b..149bde62 100644 --- a/brat/Brat/Lexer/Token.hs +++ b/brat/Brat/Lexer/Token.hs @@ -1,4 +1,4 @@ -module Brat.Lexer.Token (Tok(..), Token(..), Keyword(..)) where +module Brat.Lexer.Token (Tok(..), Token(..), Keyword(..), tokenLen) where import Brat.FC @@ -21,8 +21,8 @@ data Tok | RParen | LBrace | RBrace - | LBracket - | RBracket + | LSquare + | RSquare | Semicolon | Into | Comma @@ -43,6 +43,7 @@ data Tok | Dollar | Underscore | Pipe + | Bang | Cons | Snoc | ConcatEqEven @@ -66,8 +67,8 @@ instance Show Tok where show RParen = ")" show LBrace = "{" show RBrace = "}" - show LBracket = "[" - show RBracket = "]" + show LSquare = "[" + show RSquare = "]" show Semicolon = ";" show Into = "|>" show Comma = "," @@ -88,6 +89,7 @@ instance Show Tok where show Dollar = "$" show Underscore = "_" show Pipe = "|" + show Bang = "!" show Cons = ",-" show Snoc = "-," show ConcatEqEven = "=,=" @@ -102,7 +104,8 @@ instance Eq Token where (Token fc t) == (Token fc' t') = t == t' && fc == fc' instance Show Token where - show (Token _ t) = show t ++ " " + show (Token _ t) = show t + instance Ord Token where compare (Token (FC st nd) _) (Token (FC st' nd') _) = if st == st' then compare nd nd' @@ -128,6 +131,8 @@ instance Show Keyword where tokLen :: Tok -> Int tokLen = length . show +tokenLen = tokLen . _tok + instance VisualStream [Token] where tokensLength _ = sum . fmap (\(Token _ t) -> tokLen t) showTokens _ = concatMap show diff --git a/brat/Brat/Parser.hs b/brat/Brat/Parser.hs index e46428f8..7bd3546b 100644 --- a/brat/Brat/Parser.hs +++ b/brat/Brat/Parser.hs @@ -4,6 +4,7 @@ import Brat.Constructors.Patterns import Brat.Error import Brat.FC import Brat.Lexer (lex) +import Brat.Lexer.Bracketed (BToken(..), brackets) import Brat.Lexer.Token (Keyword(..), Token(..), Tok(..)) import qualified Brat.Lexer.Token as Lexer import Brat.QualName ( plain, QualName(..) ) @@ -16,17 +17,19 @@ import Brat.Syntax.Concrete import Brat.Syntax.Raw import Brat.Syntax.Simple import Brat.Elaborator +import Data.Bracket import Util ((**^)) import Control.Monad (void) import Control.Monad.State (State, evalState, runState, get, put) import Data.Bifunctor -import Data.List (intercalate) -import Data.List.HT (chop, viewR) -import Data.List.NonEmpty (toList, NonEmpty(..), nonEmpty) import Data.Foldable (msum) import Data.Functor (($>), (<&>)) -import Data.Maybe (fromJust, isJust, maybeToList, fromMaybe) +import Data.List (intercalate, uncons) +import Data.List.HT (chop, viewR) +import Data.List.NonEmpty (toList, NonEmpty(..), nonEmpty) +import qualified Data.List.NonEmpty as NE +import Data.Maybe (fromJust, isJust, fromMaybe, maybeToList) import Data.Set (empty) import Prelude hiding (lex, round) import Text.Megaparsec hiding (Pos, Token, State, empty, match, ParseError, parse) @@ -35,112 +38,101 @@ import qualified Text.Megaparsec as M (parse) newtype CustomError = Custom String deriving (Eq, Ord) -- the State is the (FC) Position of the last token *consumed* -type Parser a = ParsecT CustomError [Token] (State Pos) a +type Parser a = ParsecT CustomError [BToken] (State Pos) a -parse :: Parser a -> String -> [Token] -> Either (ParseErrorBundle [Token] CustomError) a +parse :: Parser a -> String -> [BToken] -> Either (ParseErrorBundle [BToken] CustomError) a parse p s tks = evalState (runParserT p s tks) (Pos 0 0) instance ShowErrorComponent CustomError where showErrorComponent (Custom s) = s - -withFC :: Parser a -> Parser (WC a) -withFC p = do - (Token (FC start _) _) <- nextToken - thing <- p - end <- get - pure (WC (FC start end) thing) - -nextToken :: Parser Token -nextToken = lookAhead $ token Just empty - -token0 :: (Tok -> Maybe a) -> Parser a -token0 f = do - (fc, r) <- token (\(Token fc t) -> (fc,) <$> f t) empty - -- token matched condition f - put (end fc) - pure r +matchFC :: Tok -> Parser (WC ()) +matchFC tok = label (show tok) $ matchTok f + where + f :: Tok -> Maybe () + f t | t == tok = Just () + | otherwise = Nothing match :: Tok -> Parser () -match tok = label (show tok) $ token0 $ \t -> if t == tok then Just () else Nothing +match = fmap unWC . matchFC -kmatch :: Keyword -> Parser () -kmatch = match . K - -matchString :: String -> Parser () -matchString s = ident $ \x -> if x == s then Just () else Nothing - -ident :: (String -> Maybe a) -> Parser a -ident f = label "identifier" $ token0 $ \case - Ident str -> f str +matchTok :: (Tok -> Maybe a) -> Parser (WC a) +matchTok f = token (matcher f) empty + where + matcher :: (Tok -> Maybe a) -> BToken -> Maybe (WC a) + matcher f (FlatTok (Token fc t)) = WC fc <$> f t + -- Returns the FC at the beginning of the token + matcher f (Bracketed _ Paren [t]) = matcher f t + matcher _ _ = Nothing + +kmatch :: Keyword -> Parser (WC ()) +kmatch = matchFC . K + +matchString :: String -> Parser (WC ()) +matchString s = label (show s) $ matchTok $ \case + Ident ident | ident == s -> Just () _ -> Nothing -hole :: Parser String -hole = label "hole" $ token0 $ \case +hole :: Parser (WC String) +hole = label "hole" $ matchTok $ \case Hole h -> Just h _ -> Nothing -simpleName :: Parser String -simpleName = token0 $ \case +simpleName :: Parser (WC String) +simpleName = matchTok $ \case Ident str -> Just str _ -> Nothing -qualName :: Parser QualName -qualName = ( "name") $ try qualifiedName <|> (PrefixName [] <$> simpleName) - where - qualifiedName :: Parser QualName - qualifiedName = ( "qualified name") . token0 $ \case - QualifiedId prefix str -> Just (PrefixName (toList prefix) str) - _ -> Nothing - - - -round :: Parser a -> Parser a -round p = label "(...)" $ match LParen *> p <* match RParen - -square :: Parser a -> Parser a -square p = label "[...]" $ match LBracket *> p <* match RBracket +qualName :: Parser (WC QualName) +qualName = label "qualified name" $ matchTok $ \case + QualifiedId prefix str -> Just (PrefixName (toList prefix) str) + Ident str -> Just (PrefixName [] str) + _ -> Nothing -curly :: Parser a -> Parser a -curly p = label "{...}" $ match LBrace *> p <* match RBrace +inBrackets :: BracketType -> Parser a -> Parser a +inBrackets b p = unWC <$> inBracketsFC b p -inLet :: Parser a -> Parser a -inLet p = label "let ... in" $ kmatch KLet *> p <* kmatch KIn +inBracketsFC :: BracketType -> Parser a -> Parser (WC a) +inBracketsFC b p = contents >>= \(outerFC, toks) -> either (customFailure . Custom . errorBundlePretty) (pure . WC outerFC) (parse (p <* eof) "" toks) + where + contents = flip token empty $ \case + Bracketed fc b' xs | b == b' -> Just (fc, xs) + _ -> Nothing -number :: Parser Int -number = label "nat" $ token0 $ \case +number :: Parser (WC Int) +number = label "nat" $ matchTok $ \case Number n -> Just n _ -> Nothing -float :: Parser Double -float = label "float" $ token0 $ \case +float :: Parser (WC Double) +float = label "float" $ matchTok $ \case FloatLit x -> Just x _ -> Nothing -comment :: Parser () -comment = label "Comment" $ token0 $ \case +comment :: Parser (WC ()) +comment = label "Comment" $ matchTok $ \case Comment _ -> Just () _ -> Nothing -string :: Parser String -string = token0 $ \case +string :: Parser (WC String) +string = matchTok $ \case Quoted txt -> Just txt _ -> Nothing -var :: Parser Flat -var = FVar <$> qualName +var :: Parser (WC Flat) +var = fmap FVar <$> qualName +port :: Parser (WC String) port = simpleName comma :: Parser (WC Flat -> WC Flat -> WC Flat) -comma = token0 $ \case +comma = fmap unWC . matchTok $ \case Comma -> Just $ \a b -> - let fc = FC (start (fcOf a)) (end (fcOf b)) - in WC fc (FJuxt a b) + WC (spanFCOf a b) (FJuxt a b) _ -> Nothing arith :: ArithOp -> Parser (WC Flat -> WC Flat -> WC Flat) -arith op = token0 $ \tok -> case (op, tok) of +arith op = fmap unWC . matchTok $ \tok -> case (op, tok) of (Add, Plus) -> Just make (Sub, Minus) -> Just make (Mul, Asterisk) -> Just make @@ -161,109 +153,149 @@ chainl1 px pf = px >>= rest Just (f, y) -> rest (f x y) Nothing -> pure x -abstractor :: Parser Abstractor +abstractor :: Parser (WC Abstractor) abstractor = do ps <- many (try portPull) - xs <- binding `chainl1` try binderComma - pure $ if null ps then xs else APull ps xs + abs <- try (inBrackets Paren binders) <|> binders + pure $ if null ps + then abs + else let fc = spanFCOf (head ps) abs in WC fc (APull (unWC <$> ps) (unWC abs)) where - binding :: Parser Abstractor - binding = try (APat <$> bigPat) <|> round abstractor - vecPat = square (binding `sepBy` match Comma) >>= list2Cons + -- Minus port pulling + binders = try (joinBinders <$> ((:|) <$> binding <*> many (match Comma *> binding))) + where + joinBinders xs = let (abs, startFC, endFC) = joinBindersAux xs in WC (spanFC startFC endFC) abs + + joinBindersAux (WC fc x :| []) = (x, fc, fc) + joinBindersAux (WC fc x :| (y:ys)) = let (abs, _, endFC) = joinBindersAux (y :| ys) in + (x :||: abs, fc, endFC) + + binding :: Parser (WC Abstractor) + binding = try (fmap APat <$> bigPat) <|> inBrackets Paren abstractor + + vecPat :: Parser (WC Pattern) + vecPat = do + WC fc elems <- inBracketsFC Square ((unWC <$> binding) `sepBy` match Comma) + WC fc <$> list2Cons elems list2Cons :: [Abstractor] -> Parser Pattern list2Cons [] = pure PNil list2Cons (APat x:xs) = PCons x <$> list2Cons xs list2Cons _ = customFailure (Custom "Internal error list2Cons") - portPull = simpleName <* match PortColon - - binderComma :: Parser (Abstractor -> Abstractor -> Abstractor) - binderComma = match Comma $> (:||:) + portPull = port <* match PortColon -- For simplicity, we can say for now that all of our infix vector patterns have -- the same precedence and associate to the right - bigPat :: Parser Pattern + bigPat :: Parser (WC Pattern) bigPat = do - lhs <- weePat + WC lfc lhs <- weePat rest <- optional $ PCons lhs <$ match Cons <|> PSnoc lhs <$ match Snoc <|> PConcatEqEven lhs <$ match ConcatEqEven - <|> PConcatEqOdd lhs <$ match ConcatEqOddL <*> weePat <* match ConcatEqOddR + <|> PConcatEqOdd lhs <$ match ConcatEqOddL <*> (unWC <$> weePat) <* match ConcatEqOddR <|> PRiffle lhs <$ match Riffle case rest of - Just f -> f <$> bigPat - Nothing -> pure lhs + Just f -> do + WC rfc rhs <- bigPat + pure $ WC (spanFC lfc rfc) (f rhs) + Nothing -> pure (WC lfc lhs) - weePat :: Parser Pattern + weePat :: Parser (WC Pattern) weePat = try vecPat - <|> (match Underscore $> DontCare) - <|> try (Lit <$> simpleTerm) + <|> (fmap (const DontCare) <$> matchFC Underscore) + <|> try (fmap Lit <$> simpleTerm) <|> try constructorsWithArgs <|> try nullaryConstructors - <|> (Bind <$> simpleName) - <|> round bigPat + <|> (fmap Bind <$> simpleName) + <|> inBrackets Paren bigPat where - constructor :: Parser Abstractor -> String -> Parser Pattern - constructor pabs c = do - matchString c - PCon (plain c) <$> pabs + nullaryConstructor c = do + WC fc () <- matchString c + pure $ WC fc (PCon (plain c) AEmpty) - nullaryConstructors = msum (try . constructor (pure AEmpty) <$> ["zero", "nil", "none", "true", "false"]) + nullaryConstructors = msum (try . nullaryConstructor <$> ["zero", "nil", "none", "true", "false"]) - constructorsWithArgs = msum (try . constructor (round abstractor) <$> ["succ", "doub", "cons", "some"]) + constructorWithArgs :: String -> Parser (WC Pattern) + constructorWithArgs c = do + str <- matchString c + abs <- inBracketsFC Paren (unWC <$> abstractor) + pure $ WC (spanFCOf str abs) (PCon (plain c) (unWC abs)) -simpleTerm :: Parser SimpleTerm -simpleTerm = - (Text <$> string "string") - <|> try (Float . negate <$> (match Minus *> float) "float") - <|> try (Float <$> float "float") - <|> (Num . negate <$> (match Minus *> number) "nat") - <|> (Num <$> number "nat") + constructorsWithArgs = msum (try . constructorWithArgs <$> ["succ", "doub", "cons", "some"]) -outputs :: Parser [RawIO] -outputs = rawIO (unWC <$> vtype) +simpleTerm :: Parser (WC SimpleTerm) +simpleTerm = + (fmap Text <$> string "string") + <|> try (maybeNegative Float float "float") + <|> maybeNegative Num number "nat" -typekind :: Parser TypeKind -typekind = try (match Hash $> Nat) <|> kindHelper Lexer.Dollar Syntax.Dollar <|> kindHelper Asterisk Star where - kindHelper tok c = do - match tok - margs <- optional (round row) - pure $ c (concat $ maybeToList margs) - - row = (`sepBy` match Comma) $ do - p <- port + maybeNegative :: Num a => (a -> SimpleTerm) -> Parser (WC a) + -> Parser (WC SimpleTerm) + maybeNegative f p = do + minusFC <- fmap fcOf <$> optional (matchFC Minus) + WC nFC n <- p + pure $ case minusFC of + Nothing -> WC nFC (f n) + Just minusFC -> WC (spanFC minusFC nFC) (f (negate n)) + +typekind :: Parser (WC TypeKind) +typekind = try (fmap (const Nat) <$> matchFC Hash) <|> kindHelper Lexer.Dollar Syntax.Dollar <|> kindHelper Asterisk Star + where + kindHelper tok con = do + WC conFC () <- matchFC tok + margs <- optional (inBracketsFC Paren row) + let (fc, args) = maybe + (conFC, []) + (\(WC argsFC args) -> (FC (start conFC) (end argsFC), args)) + margs + pure $ WC fc (con args) + + + row :: Parser [(PortName, TypeKind)] + row = (`sepBy` match Comma) $ do + p <- unWC <$> port match TypeColon - (p,) <$> typekind + (p,) . unWC <$> typekind vtype :: Parser (WC (Raw Chk Noun)) vtype = cnoun (expr' PApp) -- Parse a row of type and kind parameters -- N.B. kinds must be named -rawIO :: Parser ty -> Parser (TypeRow (KindOr ty)) -rawIO tyP = rowElem `sepBy` void (try comma) +-- TODO: Update definitions so we can retain the FC info, instead of forgetting it +rawIOFC :: Parser (TypeRow (WC (KindOr RawVType))) +rawIOFC = rowElem `sepBy` void (try comma) where - rowElem = try (round rowElem') <|> rowElem' + rowElem :: Parser (TypeRowElem (WC (KindOr RawVType))) + rowElem = try (inBrackets Paren rowElem') <|> rowElem' - rowElem' = try namedKind <|> try namedType <|> (Anon . Right <$> tyP) + rowElem' :: Parser (TypeRowElem (WC (KindOr RawVType))) + rowElem' = try namedKind <|> try namedType <|> ((\(WC tyFC ty) -> Anon (WC tyFC (Right ty))) <$> vtype) + namedType :: Parser (TypeRowElem (WC (KindOr RawVType))) namedType = do - p <- port + WC pFC p <- port match TypeColon - Named p . Right <$> tyP + WC tyFC ty <- vtype + pure (Named p (WC (spanFC pFC tyFC) (Right ty))) + namedKind :: Parser (TypeRowElem (WC (KindOr ty))) namedKind = do - p <- port + WC pFC p <- port match TypeColon - Named p . Left <$> typekind + WC kFC k <- typekind + pure (Named p (WC (spanFC pFC kFC) (Left k))) + +rawIO :: Parser [RawIO] +rawIO = fmap (fmap unWC) <$> rawIOFC rawIO' :: Parser ty -> Parser (TypeRow ty) rawIO' tyP = rowElem `sepBy` void (try comma) where - rowElem = try (round rowElem') <|> rowElem' + rowElem = try (inBrackets Paren rowElem') <|> rowElem' -- Look out if we can find ::. If not, backtrack and just do tyP. -- For example, if we get an invalid primitive type (e.g. `Int` in @@ -271,71 +303,92 @@ rawIO' tyP = rowElem `sepBy` void (try comma) -- error message from tyP instead of complaining about a missing :: -- (since the invalid type can be parsed as a port name) rowElem' = optional (try $ port <* match TypeColon) >>= \case - Just p -> Named p <$> tyP + Just (WC _ p) -> Named p <$> tyP Nothing -> Anon <$> tyP functionType :: Parser RawVType functionType = try ctype <|> kernel where ctype = do - ins <- round $ rawIO (unWC <$> vtype) + ins <- inBrackets Paren $ rawIO match Arrow - outs <- rawIO (unWC <$> vtype) + outs <- rawIO pure (RFn (ins :-> outs)) kernel = do - ins <- round $ rawIO' (unWC <$> vtype) + ins <- inBrackets Paren $ rawIO' (unWC <$> vtype) match Lolly isWeird <- isJust <$> optional (match Hash) outs <- rawIO' (unWC <$> vtype) pure (RKernel (if isWeird then PNone else PControllable) (ins :-> outs)) -vec :: Parser Flat -vec = (\(WC fc x) -> unWC $ vec2Cons (end fc) x) <$> withFC (square elems) +spanningFC :: TypeRow (WC ty) -> Parser (WC (TypeRow ty)) +spanningFC [] = customFailure (Custom "Internal: RawIO shouldn't be empty") +spanningFC [x] = pure (WC (fcOf $ forgetPortName x) [unWC <$> x]) +spanningFC (x:xs) = pure (WC (spanFC (fcOf $ forgetPortName x) (fcOf . forgetPortName $ last xs)) (fmap unWC <$> (x:xs))) + +rawIOWithSpanFC :: Parser (WC [RawIO]) +rawIOWithSpanFC = spanningFC =<< rawIOFC + +vec :: Parser (WC Flat) +vec = (\(WC fc x) -> WC fc (unWC (vec2Cons fc x))) <$> inBracketsFC Square elems where elems = (element `chainl1` try vecComma) <|> pure [] vecComma = match Comma $> (++) - element = (:[]) <$> withFC (expr' (succ PJuxtPull)) + + element :: Parser [WC Flat] + element = (:[]) <$> expr' (succ PJuxtPull) + mkNil fc = FCon (plain "nil") (WC fc FEmpty) - vec2Cons :: Pos -> [WC Flat] -> WC Flat - -- The nil element gets as FC the closing ']' of the [li,te,ral] - vec2Cons end [] = let fc = FC end{col=col end-1} end in WC fc (mkNil fc) + vec2Cons :: FC -> [WC Flat] -> WC Flat + -- The nil element gets the FC of the `[]` expression. + -- N.B. this is also true in non-nil lists: the `nil` terminator of the list + -- `[1,2,3]` gets the file context of `[1,2,3]` + vec2Cons outerFC [] = WC outerFC (mkNil outerFC) + vec2Cons outerFC [x] = WC (fcOf x) $ FCon (plain "cons") (WC (fcOf x) (FJuxt x (WC outerFC (mkNil outerFC)))) -- We give each cell of the list an FC which starts with the FC -- of its head element and ends at the end of the list (the closing ']') - vec2Cons end (x:xs) = let fc = FC (start $ fcOf x) end in - WC fc $ FCon (plain "cons") (WC fc (FJuxt x (vec2Cons end xs))) + vec2Cons outerFC (x:xs) = let endFC = fcOf (last xs) + fc = spanFC (fcOf x) endFC + in WC fc $ + FCon (plain "cons") (WC fc (FJuxt x (vec2Cons outerFC xs))) -cthunk :: Parser Flat +cthunk :: Parser (WC Flat) cthunk = try bratFn <|> try kernel <|> thunk where - bratFn = curly $ do - ss <- rawIO (unWC <$> vtype) + bratFn = inBracketsFC Brace $ do + ss <- rawIO match Arrow - ts <- rawIO (unWC <$> vtype) + ts <- rawIO pure $ FFn (ss :-> ts) - kernel = curly $ do + kernel = inBracketsFC Brace $ do ss <- rawIO' (unWC <$> vtype) match Lolly isWeird <- isJust <$> optional (match Hash) ts <- rawIO' (unWC <$> vtype) pure (FKernel (if isWeird then PNone else PControllable) (ss :-> ts)) + -- Explicit lambda or brace section - thunk = FThunk <$> withFC (curly braceSection) + thunk :: Parser (WC Flat) + thunk = do + WC bracesFC th <- inBracketsFC Brace braceSection + pure (WC bracesFC (FThunk th)) + braceSection :: Parser (WC Flat) braceSection = do - e <- withFC expr + e <- expr -- Replace underscores with invented variable names '1, '2, '3 ... -- which are illegal for the user to use as variables case runState (replaceU e) 0 of - (e', 0) -> pure (unWC e') + (e', 0) -> pure e' -- If we don't have a `=>` at the start of a kernel, it could (and should) -- be a verb, not the RHS of a no-arg lambda - (e', n) -> let abs = braceSectionAbstractor [0..n-1] in - pure $ FLambda ((WC (fcOf e) abs, e') :| []) -- TODO: Which FC to use for the abstracor? + (e', n) -> let abs = braceSectionAbstractor [0..n-1] + in pure $ WC (fcOf e) $ FLambda ((WC (fcOf e) abs, e') :| []) replaceU :: WC Flat -> State Int (WC Flat) replaceU (WC fc x) = WC fc <$> replaceU' x @@ -362,6 +415,25 @@ cthunk = try bratFn <|> try kernel <|> thunk (\x -> APat (Bind ('\'': show x))) <$> ns +-- Expressions that can occur inside juxtapositions and vectors (i.e. everything with a higher +-- precedence than juxtaposition). Precedence table (loosest to tightest binding): +atomExpr :: Parser (WC Flat) +atomExpr = simpleExpr <|> inBracketsFC Paren (unWC <$> expr) + where + simpleExpr :: Parser (WC Flat) + simpleExpr = fmap FHole <$> hole + <|> try (fmap FSimple <$> simpleTerm) + <|> try fanin + <|> try fanout + <|> vec + <|> cthunk + <|> fmap (const FPass) <$> matchFC DotDot + <|> var + <|> fmap (const FUnderscore) <$> matchFC Underscore + <|> fmap (const FIdentity) <$> matchFC Pipe + <|> fmap (const FHope) <$> matchFC Bang + + {- Infix operator precedence table (See Brat.Syntax.Common.Precedence) (loosest to tightest binding): => @@ -378,12 +450,12 @@ cthunk = try bratFn <|> try kernel <|> thunk -} expr = expr' minBound -expr' :: Precedence -> Parser Flat +expr' :: Precedence -> Parser (WC Flat) expr' p = choice $ (try . getParser <$> enumFrom p) ++ [atomExpr] where - getParser :: Precedence -> Parser Flat + getParser :: Precedence -> Parser (WC Flat) getParser = \case - PLetIn -> letin "let ... in" + PLetIn -> letIn "let ... in" PLambda -> lambda "lambda" PInto -> (emptyInto <|> into) "into" PComp -> composition "composition" @@ -397,129 +469,145 @@ expr' p = choice $ (try . getParser <$> enumFrom p) ++ [atomExpr] PApp -> application "application" -- Take the precedence level and return a parser for everything with a higher precedence - subExpr :: Precedence -> Parser Flat + subExpr :: Precedence -> Parser (WC Flat) subExpr PApp = atomExpr subExpr p = choice $ (try . getParser <$> enumFrom (succ p)) ++ [atomExpr] -- Top level parser, looks for vector constructors with `atomExpr'`s as their -- elements. - vectorBuild :: Parser Flat + vectorBuild :: Parser (WC Flat) vectorBuild = do - lhs <- withFC (subExpr PVecPat) + lhs <- subExpr PVecPat rest <- optional $ (CCons, [lhs]) <$ match Cons <|> (CSnoc, [lhs]) <$ match Snoc <|> (CConcatEqEven, [lhs]) <$ match ConcatEqEven - <|> (CConcatEqOdd,) . ([lhs] ++) . (:[]) <$ match ConcatEqOddL <*> withFC (subExpr (succ PVecPat)) <* match ConcatEqOddR - <|> (CRiffle, [lhs]) <$ match Riffle + <|> (CConcatEqOdd,) . ([lhs] ++) . (:[]) <$ match ConcatEqOddL <*> subExpr (succ PVecPat) <* match ConcatEqOddR + <|> (CRiffle, [lhs]) <$ matchFC Riffle case rest of Just (c, args) -> do - rhs <- withFC vectorBuild - pure (FCon c (mkJuxt (args ++ [rhs]))) - Nothing -> pure (unWC lhs) + rhs <- vectorBuild + let juxtElems = case args of + [] -> rhs :| [] + (a:as) -> a :| (as ++ [rhs]) + pure (WC (spanFCOf lhs rhs) (FCon c (mkJuxt juxtElems))) + Nothing -> pure lhs - ofExpr :: Parser Flat + ofExpr :: Parser (WC Flat) ofExpr = do - lhs <- withFC (subExpr POf) + lhs <- subExpr POf optional (kmatch KOf) >>= \case - Nothing -> pure (unWC lhs) - Just () -> FOf lhs <$> withFC ofExpr + Nothing -> pure lhs + Just _ -> do + rhs <- ofExpr + pure (WC (spanFCOf lhs rhs) (lhs `FOf` rhs)) - mkJuxt [x] = x - mkJuxt (x:xs) = let rest = mkJuxt xs in WC (FC (start (fcOf x)) (end (fcOf rest))) (FJuxt x rest) + mkJuxt :: NonEmpty (WC Flat) -> WC Flat + mkJuxt (x :| []) = x + mkJuxt (x :| (y:ys)) = let rest = mkJuxt (y:|ys) in WC (FC (start (fcOf x)) (end (fcOf rest))) (FJuxt x rest) - application = withFC atomExpr >>= applied + application :: Parser (WC Flat) + application = atomExpr >>= applied where - applied :: WC Flat -> Parser Flat + applied :: WC Flat -> Parser (WC Flat) applied f = do - first <- withFC (round $ expr <|> pure FEmpty) - let one = FApp f first - let combinedFC = FC (start (fcOf f)) (end (fcOf first)) - optional (applied $ WC combinedFC one) <&> fromMaybe one + first <- inBracketsFC Paren $ (unWC <$> expr) <|> pure FEmpty + let one = WC (spanFCOf f first) (FApp f first) + optional (applied one) <&> fromMaybe one + + binary :: [ArithOp] -> Precedence -> Parser (WC Flat) + binary ops lvl = subExpr lvl `chainl1` choice (try . arith <$> ops) - binary ops lvl = unWC <$> withFC (subExpr lvl) `chainl1` choice (try . arith <$> ops) addSub = binary [Add, Sub] PAddSub mulDiv = binary [Mul, Div] PMulDiv pow = binary [Pow] PPow - annotation = FAnnotation <$> withFC (subExpr PAnn) <* match TypeColon <*> rawIO (unWC <$> vtype) - - letin = do - (lhs,rhs) <- inLet $ do - abs <- withFC abstractor + annotation :: Parser (WC Flat) + annotation = do + tm <- subExpr PAnn + colon <- matchFC TypeColon + WC (spanFCOf tm colon) . FAnnotation tm <$> rawIO + + letIn :: Parser (WC Flat) + letIn = label "let ... in" $ do + let_ <- kmatch KLet + (lhs, rhs) <- letInBinding + kmatch KIn + body <- expr + pure (WC (spanFCOf let_ body) (FLetIn lhs rhs body)) + where + letInBinding = do + abs <- abstractor match Equal - thing <- withFC expr + thing <- expr pure (abs, thing) - body <- withFC expr - pure $ FLetIn lhs rhs body -- Sequence of `abstractor => expr` separated by `|` + lambda :: Parser (WC Flat) lambda = do firstClause <- lambdaClause otherClauses <- many (match Pipe >> lambdaClause) - pure (FLambda (firstClause :| otherClauses)) + let endPos = case otherClauses of + [] -> end (fcOf (snd firstClause)) + _ -> end (fcOf (snd (last otherClauses))) + let fc = FC (start (fcOf (fst firstClause))) endPos + pure (WC fc (FLambda (firstClause :| otherClauses))) -- A single `abstractor => expr` + lambdaClause :: Parser (WC Abstractor, WC Flat) lambdaClause = do - abs <- withFC (try abstractor <|> pure AEmpty) - match FatArrow - body <- withFC expr + mabs <- try (Right <$> abstractor) <|> pure (Left AEmpty) + WC arrowFC () <- matchFC FatArrow + let abs = either (WC arrowFC) id mabs + body <- expr pure (abs, body) + emptyInto :: Parser (WC Flat) emptyInto = do -- It's tricky to come up with an FC for empty syntax - WC lhs () <- withFC $ match Into - rhs <- withFC (subExpr (pred PInto)) - pure $ FInto (WC lhs FEmpty) rhs + WC lhs () <- matchFC Into + rhs <- subExpr (pred PInto) + pure $ WC (spanFC lhs (fcOf rhs)) $ FInto (WC lhs FEmpty) rhs - into = unWC <$> withFC (subExpr PInto) `chainl1` divider Into FInto + into :: Parser (WC Flat) + into = subExpr PInto `chainl1` divider Into FInto - composition = unWC <$> withFC (subExpr PComp) `chainl1` divider Semicolon FCompose + composition :: Parser (WC Flat) + composition = subExpr PComp `chainl1` divider Semicolon FCompose divider :: Tok -> (WC Flat -> WC Flat -> Flat) -> Parser (WC Flat -> WC Flat -> WC Flat) - divider tok f = token0 $ \case + divider tok f = fmap unWC . matchTok $ \case t | t == tok -> Just $ \a b -> - let fc = FC (start (fcOf a)) (end (fcOf b)) - in WC fc (f a b) + WC (spanFCOf a b) (f a b) _ -> Nothing - pullAndJuxt = do - ports <- many (try (port <* match PortColon)) + ports <- many (try portPull) + let firstPortFC = fcOf . fst <$> uncons ports case ports of [] -> juxtRhsWithPull - _ -> FPull ports <$> withFC juxtRhsWithPull + _ -> (\juxt@(WC juxtFC _) -> WC (maybe juxtFC (`spanFC` juxtFC) firstPortFC) (FPull (unWC <$> ports) juxt)) <$> juxtRhsWithPull where + portPull :: Parser (WC String) + portPull = do + WC portFC portName <- port + WC colonFC _ <- matchFC PortColon + pure (WC (spanFC portFC colonFC) portName) + -- Juxtaposition here includes port pulling, since they have the same precedence juxtRhsWithPull = do - expr <- withFC (subExpr PJuxtPull) - rest <- optional (match Comma *> withFC pullAndJuxt) + expr <- subExpr PJuxtPull + rest <- optional (match Comma *> pullAndJuxt) pure $ case rest of - Nothing -> unWC expr - Just rest -> FJuxt expr rest + Nothing -> expr + Just rest@(WC restFC _) -> WC (spanFC (fcOf expr) restFC) (FJuxt expr rest) - fanout = square (FFanOut <$ match Slash <* match Backslash) - fanin = square (FFanIn <$ match Backslash <* match Slash) +fanout = inBracketsFC Square (FFanOut <$ match Slash <* match Backslash) +fanin = inBracketsFC Square (FFanIn <$ match Backslash <* match Slash) - -- Expressions which don't contain juxtaposition or operators - atomExpr :: Parser Flat - atomExpr = simpleExpr <|> round expr - where - simpleExpr = FHole <$> hole - <|> try (FSimple <$> simpleTerm) - <|> try fanout - <|> try fanin - <|> vec - <|> cthunk - <|> try (match DotDot $> FPass) - <|> var - <|> match Underscore $> FUnderscore - <|> match Pipe $> FIdentity - - -cnoun :: Parser Flat -> Parser (WC (Raw 'Chk 'Noun)) +cnoun :: Parser (WC Flat) -> Parser (WC (Raw 'Chk 'Noun)) cnoun pe = do - e <- withFC pe + e <- pe case elaborate e of Left err -> fail (showError err) Right (SomeRaw r) -> case do @@ -532,17 +620,17 @@ cnoun pe = do decl :: Parser FDecl decl = do - (WC fc (nm, ty, body)) <- withFC (do - nm <- simpleName - ty <- try (functionType <&> \ty -> [Named "thunk" (Right ty)]) - <|> (match TypeColon >> outputs) + (fc, nm, ty, body) <- do + WC startFC nm <- simpleName + WC _ ty <- declSignature let allow_clauses = case ty of [Named _ (Right t)] -> is_fun_ty t [Anon (Right t)] -> is_fun_ty t _ -> False - body <- if allow_clauses then (FClauses <$> clauses nm) <|> (FNoLhs <$> nbody nm) - else FNoLhs <$> nbody nm - pure (nm, ty, body)) + WC endFC body <- if allow_clauses + then declClauses nm <|> declNounBody nm + else declNounBody nm + pure (spanFC startFC endFC, nm, ty, body) pure $ FuncDecl { fnName = nm , fnLoc = fc @@ -556,12 +644,20 @@ decl = do is_fun_ty (RKernel _ _) = True is_fun_ty _ = False - nbody :: String -> Parser (WC Flat) - nbody nm = do + declClauses :: String -> Parser (WC FBody) + declClauses nm = do + cs <- clauses nm + let startFC = fcOf . fst $ NE.head cs + let endFC = fcOf . snd $ NE.last cs + pure (WC (spanFC startFC endFC) (FClauses cs)) + + declNounBody :: String -> Parser (WC FBody) + declNounBody nm = do label (nm ++ "(...) = ...") $ matchString nm match Equal - withFC expr + body@(WC fc _) <- expr + pure (WC fc (FNoLhs body)) class FCStream a where getFC :: Int -> PosState a -> FC @@ -580,10 +676,18 @@ instance FCStream [Token] where [] -> spToFC pstateSourcePos (Token fc _):_ -> fc +instance FCStream [BToken] where + getFC o PosState{..} = case drop (o - pstateOffset) pstateInput of + [] -> spToFC pstateSourcePos + (Bracketed fc _ _):_ -> fc + (FlatTok (Token fc _)):_ -> fc + + parseFile :: String -> String -> Either SrcErr ([Import], FEnv) parseFile fname contents = addSrcContext fname contents $ do toks <- first (wrapParseErr LexErr) (M.parse lex fname contents) - first (wrapParseErr ParseErr) (parse pfile fname toks) + btoks <- brackets toks + first (wrapParseErr ParseErr) (parse pfile fname btoks) where wrapParseErr :: (VisualStream t, FCStream t, ShowErrorComponent e) => (ParseError -> ErrorMsg) -> ParseErrorBundle t e -> Error @@ -599,19 +703,20 @@ parseFile fname contents = addSrcContext fname contents $ do clauses :: String -> Parser (NonEmpty (WC Abstractor, WC Flat)) clauses declName = label "clauses" (fromJust . nonEmpty <$> some (try branch)) where + branch :: Parser (WC Abstractor, WC Flat) branch = do label (declName ++ "(...) = ...") $ matchString declName - lhs <- withFC $ round (abstractor "binder") + lhs <- inBrackets Paren (abstractor "binder") match Equal - rhs <- withFC expr + rhs <- expr pure (lhs,rhs) pimport :: Parser Import pimport = do o <- open kmatch KImport - x <- withFC qualName + x <- qualName a <- alias Import x (not o) a <$> selection where @@ -623,7 +728,7 @@ pimport = do alias :: Parser (Maybe (WC String)) alias = optional (matchString "as") >>= \case Nothing -> pure Nothing - Just _ -> Just <$> withFC (ident Just) + Just _ -> Just <$> simpleName selection :: Parser ImportSelection selection = optional (try $ matchString "hiding") >>= \case @@ -633,7 +738,7 @@ pimport = do Just ss -> pure (ImportPartial ss) list :: Parser [WC String] - list = round $ ((:[]) <$> withFC (ident Just)) `chainl1` try (match Comma $> (++)) + list = inBrackets Paren $ ((:[]) <$> simpleName) `chainl1` try (match Comma $> (++)) pstmt :: Parser FEnv pstmt = ((comment "comment") <&> \_ -> ([] , [])) @@ -642,16 +747,16 @@ pstmt = ((comment "comment") <&> \_ -> ([] , [])) <|> ((decl "declaration") <&> \x -> ([x], [])) where alias :: Parser RawAlias - alias = withFC aliasContents <&> - \(WC fc (name, args, ty)) -> TypeAlias fc name args ty + alias = aliasContents <&> + \(fc, name, args, ty) -> TypeAlias fc name args ty - aliasContents :: Parser (QualName, [(String, TypeKind)], RawVType) + aliasContents :: Parser (FC, QualName, [(String, TypeKind)], RawVType) aliasContents = do - match (K KType) - alias <- qualName - args <- option [] $ round (simpleName `sepBy` match Comma) + WC startFC () <- matchFC (K KType) + WC _ alias <- qualName + args <- option [] $ inBrackets Paren $ (unWC <$> simpleName) `sepBy` match Comma {- future stuff - args <- option [] $ round $ (`sepBy` (match Comma)) $ do + args <- option [] $ inBrackets Paren $ (`sepBy` (match Comma)) $ do port <- port match TypeColon (port,) <$> typekind @@ -663,21 +768,21 @@ pstmt = ((comment "comment") <&> \_ -> ([] , [])) -- users to specify the kinds of variables in type aliases, like: -- type X(a :: *, b :: #, c :: *(x :: *, y :: #)) = ... -- See KARL-325 - pure (alias, (,Star []) <$> args, unWC ty) + pure (spanFC startFC (fcOf ty), alias, (,Star []) <$> args, unWC ty) extDecl :: Parser FDecl - extDecl = do (WC fc (fnName, ty, symbol)) <- withFC $ do - match (K KExt) - symbol <- string - fnName <- simpleName - ty <- try nDecl <|> vDecl + extDecl = do (fc, fnName, ty, symbol) <- do + WC startFC () <- matchFC (K KExt) + symbol <- unWC <$> string + fnName <- unWC <$> simpleName + WC tyFC ty <- declSignature -- When external ops are used, we expect it to be in the form: -- extension.op for the hugr extension used and the op name let bits = chop (=='.') symbol (ext, op) <- case viewR bits of Just (ext, op) -> pure (intercalate "." ext, op) Nothing -> fail $ "Malformed op name: " ++ symbol - pure (fnName, ty, (ext, op)) + pure (spanFC startFC tyFC, fnName, ty, (ext, op)) pure FuncDecl { fnName = fnName , fnSig = ty @@ -685,9 +790,31 @@ pstmt = ((comment "comment") <&> \_ -> ([] , [])) , fnLoc = fc , fnLocality = Extern symbol } - where - nDecl = match TypeColon >> outputs - vDecl = (:[]) . Named "thunk" . Right <$> functionType + +declSignature :: Parser (WC [RawIO]) +declSignature = try nDecl <|> vDecl where + nDecl = match TypeColon >> rawIOWithSpanFC + vDecl = functionSignature <&> fmap (\ty -> [Named "thunk" (Right ty)]) + + functionSignature :: Parser (WC RawVType) + functionSignature = try (fmap RFn <$> ctype) <|> (fmap (RKernel _) <$> kernel) + where + ctype :: Parser (WC RawCType) + ctype = do + WC startFC ins <- inBracketsFC Paren rawIO + match Arrow + WC endFC outs <- rawIOWithSpanFC + pure (WC (spanFC startFC endFC) (ins :-> outs)) + + kernel :: Parser (WC RawKType) + kernel = do + WC startFC ins <- inBracketsFC Paren $ rawIO' (unWC <$> vtype) + match Lolly + WC endFC outs <- spanningFC =<< rawIO' vtype + pure (WC (spanFC startFC endFC) (ins :-> outs)) + + + pfile :: Parser ([Import], FEnv) pfile = do diff --git a/brat/Brat/Syntax/Common.hs b/brat/Brat/Syntax/Common.hs index dacd8d1b..f90dbed6 100644 --- a/brat/Brat/Syntax/Common.hs +++ b/brat/Brat/Syntax/Common.hs @@ -2,7 +2,6 @@ module Brat.Syntax.Common (PortName, Dir(..), Kind(..), Diry(..), - DIRY(..), Kindy(..), CType'(..), Import(..), @@ -110,8 +109,16 @@ instance Eq ty => Eq (TypeRowElem ty) where Named _ ty == Anon ty' = ty == ty' Anon ty == Anon ty' = ty == ty' -data TypeKind = TypeFor Mode [(PortName, TypeKind)] | Nat | Row - deriving (Eq, Show) +data TypeKind = TypeFor Mode [(PortName, TypeKind)] | Nat + deriving Eq + +instance Show TypeKind where + show (TypeFor m args) = let argsStr = if null args then "" else "(" ++ intercalate ", " (show <$> args) ++ ")" + kindStr = case m of + Brat -> "*" + Kernel -> "$" + in kindStr ++ argsStr + show Nat = "#" pattern Star, Dollar :: [(PortName, TypeKind)] -> TypeKind pattern Star ks = TypeFor Brat ks diff --git a/brat/Brat/Syntax/Concrete.hs b/brat/Brat/Syntax/Concrete.hs index 41294115..e09bd528 100644 --- a/brat/Brat/Syntax/Concrete.hs +++ b/brat/Brat/Syntax/Concrete.hs @@ -23,6 +23,7 @@ type FEnv = ([FDecl], [RawAlias]) data Flat = FVar QualName + | FHope | FApp (WC Flat) (WC Flat) | FJuxt (WC Flat) (WC Flat) | FThunk (WC Flat) diff --git a/brat/Brat/Syntax/Core.hs b/brat/Brat/Syntax/Core.hs index befcf270..f495c94a 100644 --- a/brat/Brat/Syntax/Core.hs +++ b/brat/Brat/Syntax/Core.hs @@ -50,6 +50,7 @@ data Term :: Dir -> Kind -> Type where Pull :: [PortName] -> WC (Term Chk k) -> Term Chk k Var :: QualName -> Term Syn Noun -- Look up in noun (value) env Identity :: Term Syn UVerb + Hope :: Term Chk Noun Arith :: ArithOp -> WC (Term Chk Noun) -> WC (Term Chk Noun) -> Term Chk Noun Of :: WC (Term Chk Noun) -> WC (Term d Noun) -> Term d Noun @@ -114,8 +115,10 @@ instance Show (Term d k) where ,"of" ,bracket POf e ] + show (Var x) = show x show Identity = "|" + show Hope = "!" -- Nested applications should be bracketed too, hence 4 instead of 3 show (fun :$: arg) = bracket PApp fun ++ ('(' : show arg ++ ")") show (tm ::: ty) = bracket PAnn tm ++ " :: " ++ show ty diff --git a/brat/Brat/Syntax/Raw.hs b/brat/Brat/Syntax/Raw.hs index afe9194b..c73ed412 100644 --- a/brat/Brat/Syntax/Raw.hs +++ b/brat/Brat/Syntax/Raw.hs @@ -72,6 +72,7 @@ data Raw :: Dir -> Kind -> Type where RPull :: [PortName] -> WC (Raw Chk k) -> Raw Chk k RVar :: QualName -> Raw Syn Noun RIdentity :: Raw Syn UVerb + RHope :: Raw Chk Noun RArith :: ArithOp -> WC (Raw Chk Noun) -> WC (Raw Chk Noun) -> Raw Chk Noun ROf :: WC (Raw Chk Noun) -> WC (Raw d Noun) -> Raw d Noun (:::::) :: WC (Raw Chk Noun) -> [RawIO] -> Raw Syn Noun @@ -103,6 +104,7 @@ instance Show (Raw d k) where = unwords ["let", show abs, "=", show xs, "in", show body] show (RNHole name) = '?':name show (RVHole name) = '?':name + show RHope = "!" show (RSimple tm) = show tm show RPass = show "pass" show REmpty = "()" @@ -202,6 +204,7 @@ instance (Kindable k) => Desugarable (Raw d k) where -- TODO: holes need to know their arity for type checking desugar' (RNHole strName) = NHole . (strName,) <$> freshM strName desugar' (RVHole strName) = VHole . (strName,) <$> freshM strName + desugar' RHope = pure Hope desugar' RPass = pure Pass desugar' (RSimple simp) = pure $ Simple simp desugar' REmpty = pure Empty diff --git a/brat/Brat/Syntax/Value.hs b/brat/Brat/Syntax/Value.hs index 02497b1e..a3a04939 100644 --- a/brat/Brat/Syntax/Value.hs +++ b/brat/Brat/Syntax/Value.hs @@ -34,7 +34,8 @@ import Hasochism import Data.List (intercalate, minimumBy) import Data.Ord (comparing) import Data.Kind (Type) -import Data.Type.Equality ((:~:)(..)) +import Data.Maybe (isJust) +import Data.Type.Equality ((:~:)(..), testEquality) newtype VDecl = VDecl (FuncDecl (Some (Ro Brat Z)) (FunBody Term Noun)) @@ -158,7 +159,31 @@ data Val :: N -> Type where VLam :: Val (S n) -> Val n -- Just body (binds DeBruijn index n) VFun :: MODEY m => Modey m -> FunTy m n -> Val n VApp :: VVar n -> Bwd (Val n) -> Val n - VSum :: MODEY m => Modey m -> [Some (Ro m n)] -> Val n -- (Hugr-like) Sum types + +-- Define a naive version of equality, which only says whether the data +-- structures are on-the-nose equal +instance Eq (Val n) where + VNum a == VNum b = a == b + (VCon c xs) == (VCon d ys) = c == d && xs == ys + (VLam x) == (VLam y) = x == y + (VFun m cty) == (VFun m' cty') = case testEquality m m' of + Just Refl -> cty == cty' + Nothing -> False + (VApp v zx) == (VApp w zy) = v == w && zx == zy + _ == _ = False + +instance MODEY m => Eq (FunTy m i) where + (FunTy ps ss ts) == (FunTy qs us vs) = ((eqProps (modey @m) ps qs) &&) $ case roEq (modey @m) ss us of + Just Refl -> isJust (roEq (modey @m) ts vs) + Nothing -> False + where + roEq :: forall m i j k. Modey m -> Ro m i j -> Ro m i k -> Maybe (j :~: k) + roEq _ R0 R0 = Just Refl + roEq my (RPr x ro) (RPr y rp) | x == y = roEq my ro rp + roEq Braty (REx x ro) (REx y rp) | x == y = case roEq Braty ro rp of + Just Refl -> Just Refl + Nothing -> Nothing + roEq _ _ _ = Nothing data SVar = SPar End | SLvl Int deriving (Show, Eq) @@ -228,12 +253,6 @@ instance Show (Val n) where show (VFun m cty) = "{ " ++ modily m (show cty) ++ " }" show (VApp v ctx) = "VApp " ++ show v ++ " " ++ show ctx show (VLam body) = "VLam " ++ show body - show (VSum my ros) = case my of - Braty -> "VSum (" ++ intercalate " + " (helper <$> ros) ++ ")" - Kerny -> "VSum (" ++ intercalate " + " (helper <$> ros) ++ ")" - where - helper :: MODEY m => Some (Ro m n) -> String - helper (Some ro) = show ro ---------------------------------- Patterns ----------------------------------- pattern TNat, TInt, TFloat, TBool, TText, TUnit, TNil :: Val n @@ -529,9 +548,6 @@ instance DeBruijn Val where = VFun Braty $ changeVar vc cty changeVar vc (VFun Kerny cty) = VFun Kerny $ changeVar vc cty - changeVar vc (VSum my ros) - = VSum my (f <$> ros) - where f (Some ro) = case varChangerThroughRo vc ro of Some (_ :* ro) -> Some ro varChangerThroughRo :: VarChanger src tgt -> Ro m src src' diff --git a/brat/Brat/Unelaborator.hs b/brat/Brat/Unelaborator.hs index fcea2943..f2ce7148 100644 --- a/brat/Brat/Unelaborator.hs +++ b/brat/Brat/Unelaborator.hs @@ -38,6 +38,7 @@ unelab _ _ (Con c args) = FCon c (unelab Chky Nouny <$> args) unelab _ _ (C (ss :-> ts)) = FFn (toRawRo ss :-> toRawRo ts) unelab _ _ (K ps cty) = FKernel ps $ fmap (\(p, ty) -> Named p (toRaw ty)) cty unelab _ _ Identity = FIdentity +unelab _ _ Hope = FHope unelab _ _ FanIn = FFanIn unelab _ _ FanOut = FFanOut @@ -67,6 +68,7 @@ toRaw (Con c args) = RCon c (toRaw <$> args) toRaw (C (ss :-> ts)) = RFn (toRawRo ss :-> toRawRo ts) toRaw (K ps cty) = RKernel ps $ (\(p, ty) -> Named p (toRaw ty)) <$> cty toRaw Identity = RIdentity +toRaw Hope = RHope toRaw FanIn = RFanIn toRaw FanOut = RFanOut diff --git a/brat/Data/Bracket.hs b/brat/Data/Bracket.hs new file mode 100644 index 00000000..2b1ce93f --- /dev/null +++ b/brat/Data/Bracket.hs @@ -0,0 +1,13 @@ +module Data.Bracket where + +data BracketType = Paren | Square | Brace deriving (Eq, Ord) + +showOpen :: BracketType -> String +showOpen Paren = "(" +showOpen Square = "[" +showOpen Brace = "{" + +showClose :: BracketType -> String +showClose Paren = ")" +showClose Square = "]" +showClose Brace = "}" diff --git a/brat/Data/Hugr.hs b/brat/Data/Hugr.hs index 4eba111e..f3bb8075 100644 --- a/brat/Data/Hugr.hs +++ b/brat/Data/Hugr.hs @@ -9,10 +9,27 @@ module Data.Hugr where import Data.Aeson import qualified Data.Aeson.KeyMap as KeyMap +import qualified Data.Set as S import Data.Text (Text, pack) import Brat.Syntax.Simple +-- We should be able to work out exact extension requirements for our functions, +-- but instead we'll overapproximate. +bratExts :: [ExtensionId] +bratExts = + ["prelude" + ,"arithmetic.int_ops" + ,"arithmetic.int_types" + ,"arithmetic.float_ops" + ,"arithmetic.float_types" + ,"collections" + ,"logic" + ,"tket2.quantum" + ,"BRAT" + ] + + ------------------------------------- TYPES ------------------------------------ ------------------------- (Depends on HugrValue and Hugr) -------------------- @@ -26,6 +43,8 @@ data SumType = SU UnitSum | SG GeneralSum newtype SumOfRows = SoR [[HugrType]] deriving Show +type ExtensionId = String + -- Convert from a hugr sum of tuples to a SumOfRows sumOfRows :: HugrType -> SumOfRows sumOfRows (HTSum (SG (GeneralSum rows))) = SoR rows @@ -64,7 +83,7 @@ instance ToJSON HugrType where toJSON (HTFunc sig) = object ["t" .= ("G" :: Text) ,"input" .= input (body sig) ,"output" .= output (body sig) - ,"extension_reqs" .= ([] :: [Text]) + ,"extension_reqs" .= extensions (body sig) ] toJSON ty = error $ "todo: json of " ++ show ty @@ -88,9 +107,6 @@ data CustomTypeArg = CustomTypeArg } deriving (Eq, Show) data CustomType deriving (Eq, Show) -data ExtensionId deriving (Eq, Show) -instance ToJSON ExtensionId where - toJSON = undefined data TypeBound = TBEq | TBCopy | TBAny deriving (Eq, Ord, Show) @@ -131,13 +147,14 @@ instance ToJSON TypeParam where data FunctionType = FunctionType { input :: [HugrType] , output :: [HugrType] + , extensions :: [ExtensionId] } deriving (Eq, Show) instance ToJSON FunctionType where - toJSON (FunctionType ins outs) = object ["input" .= ins - ,"output" .= outs - ,"extension_reqs" .= ([] :: [Text]) - ] + toJSON (FunctionType ins outs exts) = object ["input" .= ins + ,"output" .= outs + ,"extension_reqs" .= exts + ] data Array = Array { ty :: HugrType @@ -428,17 +445,18 @@ instance ToJSON node => ToJSON (CallOp node) where ,"instantiation" .= signature_ ] -intOp :: node -> String -> FunctionType -> [TypeArg] -> CustomOp node -intOp parent = CustomOp parent "arithmetic.int" +intOp :: node -> String -> [HugrType] -> [HugrType] -> [TypeArg] -> CustomOp node +intOp parent opName ins outs = CustomOp parent "arithmetic.int_ops" opName (FunctionType ins outs ["arithmetic.int_ops"]) binaryIntOp :: node -> String -> CustomOp node -binaryIntOp parent name = intOp parent name (FunctionType [hugrInt, hugrInt] [hugrInt]) [TANat intWidth] +binaryIntOp parent name + = intOp parent name [hugrInt, hugrInt] [hugrInt] [TANat intWidth] -floatOp :: node -> String -> FunctionType -> [TypeArg] -> CustomOp node -floatOp parent = CustomOp parent "arithmetic.float" +floatOp :: node -> String -> [HugrType] -> [HugrType] -> [TypeArg] -> CustomOp node +floatOp parent opName ins outs = CustomOp parent "arithmetic.float_ops" opName (FunctionType ins outs ["arithmetic.float_ops"]) binaryFloatOp :: node -> String -> CustomOp node -binaryFloatOp parent name = floatOp parent name (FunctionType [hugrFloat, hugrFloat] [hugrFloat]) [] +binaryFloatOp parent name = floatOp parent name [hugrFloat, hugrFloat] [hugrFloat] [] data CallIndirectOp node = CallIndirectOp { parent :: node @@ -455,7 +473,8 @@ instance ToJSON node => ToJSON (CallIndirectOp node) where ] holeOp :: node -> Int -> FunctionType -> CustomOp node -holeOp parent idx sig = CustomOp parent "BRAT" "Hole" sig [TANat idx] +holeOp parent idx sig = CustomOp parent "BRAT" "Hole" sig + [TANat idx, TAType (HTFunc (PolyFuncType [] sig))] -- TYPE ARGS: -- * A length-2 sequence comprising: @@ -472,15 +491,16 @@ substOp :: node -> {- innerSigs :: -}[FunctionType]{- length n -} -> CustomOp node substOp parent outerSig innerSigs - = CustomOp parent "Brat" "Substitute" sig args + = CustomOp parent "BRAT" "Substitute" sig [toArg outerSig, TASequence (toArg <$> innerSigs)] where - sig = FunctionType (toFunc <$> (outerSig : innerSigs)) [toFunc outerSig] - args = [funcToSeq outerSig, TASequence (funcToSeq <$> innerSigs)] + fnExts (FunctionType _ _ exts) = S.fromList exts + combinedExts = S.toList $ foldr S.union (fnExts outerSig) (fnExts <$> innerSigs) - funcToSeq (FunctionType ins outs) = TASequence [toSeq ins, toSeq outs] + sig = FunctionType (toFunc <$> (outerSig : innerSigs)) [toFunc outerSig] combinedExts + toArg = TAType . HTFunc . PolyFuncType [] toFunc :: FunctionType -> HugrType -toFunc = HTFunc . PolyFuncType [] +toFunc ty = HTFunc (PolyFuncType [] ty) toSeq :: [HugrType] -> TypeArg toSeq tys = TASequence (TAType <$> tys) @@ -489,9 +509,13 @@ partialOp :: node -- Parent -> FunctionType -- Signature of the function that is partially evaluated -> Int -- Number of arguments that are evaluated -> CustomOp node -partialOp parent funcSig numSupplied = CustomOp parent "Brat" "Partial" sig args +partialOp parent funcSig numSupplied = CustomOp parent "BRAT" "Partial" sig args where - sig = FunctionType (toFunc funcSig : partialInputs) [toFunc $ FunctionType otherInputs (output funcSig)] + sig :: FunctionType + sig = FunctionType + (toFunc funcSig : partialInputs) + [toFunc (FunctionType otherInputs (output funcSig) (extensions funcSig))] + ["BRAT"] args = [toSeq partialInputs, toSeq otherInputs, toSeq (output funcSig)] partialInputs = take numSupplied (input funcSig) diff --git a/brat/brat.cabal b/brat/brat.cabal index 494ed8f1..199105b8 100644 --- a/brat/brat.cabal +++ b/brat/brat.cabal @@ -56,7 +56,9 @@ common warning-flags library import: haskell, warning-flags default-language: GHC2021 - other-modules: Brat.Lexer.Flat, + other-modules: Data.Bracket, + Brat.Lexer.Bracketed, + Brat.Lexer.Flat, Brat.Lexer.Token exposed-modules: Brat.Checker.Quantity, @@ -66,6 +68,7 @@ library Brat.Checker.Helpers, Brat.Checker.Helpers.Nodes, Brat.Checker.Monad, + Brat.Checker.SolveHoles, Brat.Checker.SolvePatterns, Brat.Checker.Types, Brat.Compile.Hugr, diff --git a/brat/examples/infer.brat b/brat/examples/infer.brat new file mode 100644 index 00000000..e10ee44e --- /dev/null +++ b/brat/examples/infer.brat @@ -0,0 +1,8 @@ +map(X :: *, Y :: *, { X -> Y }, List(X)) -> List(Y) +map(_, _, _, []) = [] +map(_, _, f, x ,- xs) = f(x) ,- map(!, !, f, xs) + +-- TODO: Make BRAT solve for the # kinded args +mapVec(X :: *, Y :: *, { X -> Y }, n :: #, Vec(X, n)) -> Vec(Y, n) +mapVec(_, _, _, _, []) = [] +mapVec(_, _, f, succ(n), x ,- xs) = f(x) ,- mapVec(!, !, f, n, xs) diff --git a/brat/examples/karlheinz.brat b/brat/examples/karlheinz.brat index e43814c8..f4a9fbf1 100644 --- a/brat/examples/karlheinz.brat +++ b/brat/examples/karlheinz.brat @@ -89,7 +89,7 @@ answer = energy(results) evaluate(obs :: Observable ,q :: Quantity ,a :: Ansatz - ,rs :: List Real + ,rs :: List(Real) ) -> Real evaluate = ?eval diff --git a/brat/examples/let.brat b/brat/examples/let.brat index 69a6b400..b927a87c 100644 --- a/brat/examples/let.brat +++ b/brat/examples/let.brat @@ -32,3 +32,13 @@ nums' = let xs = map(inc, [0,2,3]) in xs nums'' :: List(Int) nums'' = let i2 = {inc; inc} in map(i2, xs) + +dyad :: Int, Bool +dyad = 42, true + +bind2 :: Bool +bind2 = let i, b = dyad in b + +-- It shouldn't matter if we put brackets in the binding sites +bind2' :: Bool +bind2' = let (i, b) = dyad in b diff --git a/brat/test/Test/Checking.hs b/brat/test/Test/Checking.hs index cdcc1302..b576421a 100644 --- a/brat/test/Test/Checking.hs +++ b/brat/test/Test/Checking.hs @@ -12,6 +12,7 @@ import Test.Tasty.HUnit import Test.Tasty.Silver expectedCheckingFails = map ("examples" ) ["nested-abstractors.brat" + ,"karlheinz.brat" ,"karlheinz_alias.brat" ,"hea.brat" ] diff --git a/brat/test/Test/Parsing.hs b/brat/test/Test/Parsing.hs index 02841f5a..efe6afc8 100644 --- a/brat/test/Test/Parsing.hs +++ b/brat/test/Test/Parsing.hs @@ -15,9 +15,7 @@ testParse file = testCase (show file) $ do Left err -> assertFailure (show err) Right _ -> return () -- OK -expectedParsingFails = map ("examples" ) [ - "karlheinz.brat", - "thin.brat"] +expectedParsingFails = ["examples" "thin.brat"] parseXF = expectFailForPaths expectedParsingFails testParse diff --git a/brat/test/golden/binding/cons.brat.golden b/brat/test/golden/binding/cons.brat.golden index 5742b0a3..150bf837 100644 --- a/brat/test/golden/binding/cons.brat.golden +++ b/brat/test/golden/binding/cons.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/binding/cons.brat on line 7: badUncons(cons(stuff)) = stuff - ^^^^^^^^^^^^^ + ^^^^^^^^^^^ Unification error: Pattern doesn't match expected length for constructor args diff --git a/brat/test/golden/error/fanin-diff-types.brat.golden b/brat/test/golden/error/fanin-diff-types.brat.golden index f97d0b45..1e4f402f 100644 --- a/brat/test/golden/error/fanin-diff-types.brat.golden +++ b/brat/test/golden/error/fanin-diff-types.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/fanin-diff-types.brat on line 2: f = { [\/] } - ^^^^^^^^ + ^^^^ Type mismatch when checking [\/] Expected: Qubit diff --git a/brat/test/golden/error/fanin-dynamic-length.brat.golden b/brat/test/golden/error/fanin-dynamic-length.brat.golden index 993c4357..0c4e0242 100644 --- a/brat/test/golden/error/fanin-dynamic-length.brat.golden +++ b/brat/test/golden/error/fanin-dynamic-length.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/fanin-dynamic-length.brat on line 2: f(n) = { [\/] } - ^^^^^^^^ + ^^^^ Type error: Can't fanout a Vec with non-constant length: VPar Ex checking_check_defs_1_f_f.box_2_lambda_fake_source 0 diff --git a/brat/test/golden/error/fanin-list.brat.golden b/brat/test/golden/error/fanin-list.brat.golden index 6a8a2989..cf210abf 100644 --- a/brat/test/golden/error/fanin-list.brat.golden +++ b/brat/test/golden/error/fanin-list.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/fanin-list.brat on line 2: f = { [\/] } - ^^^^^^^^ + ^^^^ Type error: Fanin ([\/]) only applies to Vec diff --git a/brat/test/golden/error/fanin-not-enough-overs.brat.golden b/brat/test/golden/error/fanin-not-enough-overs.brat.golden index 82d88b1c..0fa9c943 100644 --- a/brat/test/golden/error/fanin-not-enough-overs.brat.golden +++ b/brat/test/golden/error/fanin-not-enough-overs.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/fanin-not-enough-overs.brat on line 2: f = { [\/] } - ^^^^^^^^ + ^^^^ Type error: Not enough inputs to make a vector of size 3 diff --git a/brat/test/golden/error/fanout-diff-types.brat.golden b/brat/test/golden/error/fanout-diff-types.brat.golden index 1f52499e..abde4a2f 100644 --- a/brat/test/golden/error/fanout-diff-types.brat.golden +++ b/brat/test/golden/error/fanout-diff-types.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/fanout-diff-types.brat on line 2: f = { [/\] } - ^^^^^^^^ + ^^^^ Type mismatch when checking [/\] Expected: (b1 :: Bit) diff --git a/brat/test/golden/error/fanout-dynamic-length.brat.golden b/brat/test/golden/error/fanout-dynamic-length.brat.golden index 2d79c6e4..4c87893c 100644 --- a/brat/test/golden/error/fanout-dynamic-length.brat.golden +++ b/brat/test/golden/error/fanout-dynamic-length.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/fanout-dynamic-length.brat on line 2: f(n) = { [/\] } - ^^^^^^^^ + ^^^^ Type error: Can't fanout a Vec with non-constant length: VPar Ex checking_check_defs_1_f_f.box_2_lambda_fake_source 0 diff --git a/brat/test/golden/error/fanout-list.brat.golden b/brat/test/golden/error/fanout-list.brat.golden index 8214f883..95895a9c 100644 --- a/brat/test/golden/error/fanout-list.brat.golden +++ b/brat/test/golden/error/fanout-list.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/fanout-list.brat on line 2: f = { [/\] } - ^^^^^^^^ + ^^^^ Type error: Fanout ([/\]) only applies to Vec diff --git a/brat/test/golden/error/fanout-too-many-overs.brat.golden b/brat/test/golden/error/fanout-too-many-overs.brat.golden index bcf8c6b5..465ea12b 100644 --- a/brat/test/golden/error/fanout-too-many-overs.brat.golden +++ b/brat/test/golden/error/fanout-too-many-overs.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/fanout-too-many-overs.brat on line 2: f = { [/\] } - ^^^^^^^^ + ^^^^ Type error: No unders but overs: (head :: Nat) for [/\] diff --git a/brat/test/golden/error/kbadvec4.brat b/brat/test/golden/error/kbadvec4.brat index 6079c080..7b1ebc8c 100644 --- a/brat/test/golden/error/kbadvec4.brat +++ b/brat/test/golden/error/kbadvec4.brat @@ -1,2 +1,2 @@ f :: { Vec(Bool, 3) -> Bool } -f = { [1,2] => true } +f = { [1, 2] => true } diff --git a/brat/test/golden/error/kbadvec4.brat.golden b/brat/test/golden/error/kbadvec4.brat.golden index 9b17e59f..8a03a084 100644 --- a/brat/test/golden/error/kbadvec4.brat.golden +++ b/brat/test/golden/error/kbadvec4.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/kbadvec4.brat on line 2: -f = { [1,2] => true } - ^^^^^ +f = { [1, 2] => true } + ^^^^^^ Type error: Expected something of type `Bool` but got `1` diff --git a/brat/test/golden/error/noovers.brat.golden b/brat/test/golden/error/noovers.brat.golden index 8acaa5c5..369fcaaf 100644 --- a/brat/test/golden/error/noovers.brat.golden +++ b/brat/test/golden/error/noovers.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/noovers.brat on line 2: f(a, b) = [] - ^^^^^^ + ^^^^ Nothing to bind to: b diff --git a/brat/test/golden/error/remaining_hopes.brat b/brat/test/golden/error/remaining_hopes.brat new file mode 100644 index 00000000..164b8190 --- /dev/null +++ b/brat/test/golden/error/remaining_hopes.brat @@ -0,0 +1,5 @@ +f(n :: #) -> Nat +f(n) = n + +g :: Nat +g = f(!) diff --git a/brat/test/golden/error/remaining_hopes.brat.golden b/brat/test/golden/error/remaining_hopes.brat.golden new file mode 100644 index 00000000..80d15436 --- /dev/null +++ b/brat/test/golden/error/remaining_hopes.brat.golden @@ -0,0 +1,8 @@ +Error in test/golden/error/remaining_hopes.brat on line 5: +g = f(!) + ^^^ + + Expected to work out values for these holes: + In checking_check_defs_1_g_1_Eval 0 + + diff --git a/brat/test/golden/error/toplevel-leftovers3.brat.golden b/brat/test/golden/error/toplevel-leftovers3.brat.golden index 23727f22..84290ebb 100644 --- a/brat/test/golden/error/toplevel-leftovers3.brat.golden +++ b/brat/test/golden/error/toplevel-leftovers3.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/toplevel-leftovers3.brat on line 2: f(x) = x - ^^^ + ^ Type error: Inputs (b1 :: Bool) weren't used diff --git a/brat/test/golden/error/unmatched_bracket.brat.golden b/brat/test/golden/error/unmatched_bracket.brat.golden index f61aa5d2..ebc76650 100644 --- a/brat/test/golden/error/unmatched_bracket.brat.golden +++ b/brat/test/golden/error/unmatched_bracket.brat.golden @@ -1,8 +1,6 @@ Error in test/golden/error/unmatched_bracket.brat on line 1: f(n, Vec([], n) -> Vec([], n) -- First bracket never closed - ^^ - - Parse error unexpected -> -expecting (...) or ) + ^ + File ended before this ( was closed diff --git a/brat/test/golden/error/vec_length.brat b/brat/test/golden/error/vec_length.brat new file mode 100644 index 00000000..09652ecc --- /dev/null +++ b/brat/test/golden/error/vec_length.brat @@ -0,0 +1,2 @@ +f(X :: *, n :: #, Vec(X, 1 + n)) -> Vec(X, n) +f(_, _, xs) = xs diff --git a/brat/test/golden/error/vec_length.brat.golden b/brat/test/golden/error/vec_length.brat.golden new file mode 100644 index 00000000..6fda6b03 --- /dev/null +++ b/brat/test/golden/error/vec_length.brat.golden @@ -0,0 +1,9 @@ +Error in test/golden/error/vec_length.brat on line 2: +f(_, _, xs) = xs + ^^ + + Type mismatch when checking xs +Expected: (a1 :: Vec(VApp VPar Ex checking_check_defs_1_f_f.box_2_lambda_fake_source 0 B0, VPar Ex checking_check_defs_1_f_f.box_2_lambda_fake_source 1)) +But got: (xs :: Vec(VApp VPar Ex checking_check_defs_1_f_f.box_2_lambda_fake_source 0 B0, 1 + VPar Ex checking_check_defs_1_f_f.box_2_lambda_fake_source 1)) + + diff --git a/brat/test/golden/error/vecpat.brat.golden b/brat/test/golden/error/vecpat.brat.golden index 34f98e75..335d7fd1 100644 --- a/brat/test/golden/error/vecpat.brat.golden +++ b/brat/test/golden/error/vecpat.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/vecpat.brat on line 3: fst3(nil) = none - ^^^^^ + ^^^ Unification error: Couldn't force 3 to be 0 diff --git a/brat/test/golden/error/vecpat2.brat.golden b/brat/test/golden/error/vecpat2.brat.golden index 9009111f..ecb3ff7f 100644 --- a/brat/test/golden/error/vecpat2.brat.golden +++ b/brat/test/golden/error/vecpat2.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/vecpat2.brat on line 3: fst3(some(x)) = none - ^^^^^^^^^ + ^^^^^^^ "some" is not a valid constructor for type Vec diff --git a/brat/test/golden/error/vecpat3.brat.golden b/brat/test/golden/error/vecpat3.brat.golden index 3eedcbce..e1ba8ab9 100644 --- a/brat/test/golden/error/vecpat3.brat.golden +++ b/brat/test/golden/error/vecpat3.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/error/vecpat3.brat on line 3: fst3([a,b]) = none - ^^^^^^^ + ^^^^^ Unification error: Couldn't force 1 to be 0 diff --git a/brat/test/golden/kernel/copy.brat.golden b/brat/test/golden/kernel/copy.brat.golden index 1938ca9a..9657dbbe 100644 --- a/brat/test/golden/kernel/copy.brat.golden +++ b/brat/test/golden/kernel/copy.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/kernel/copy.brat on line 2: copy = { q => q, q } - ^^^^^^^^^^^^^ + ^^^^^^^^^ Type error: q has already been used diff --git a/brat/test/golden/kernel/delete.brat.golden b/brat/test/golden/kernel/delete.brat.golden index 3734c119..2aea3dc3 100644 --- a/brat/test/golden/kernel/delete.brat.golden +++ b/brat/test/golden/kernel/delete.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/kernel/delete.brat on line 2: deleteFst = { q0, q1 => q1 } - ^^^^^^^^^^^^^^^^ + ^^^^^^^^^^^^ Type error: Variable(s) q0 haven't been used diff --git a/brat/test/golden/kernel/deleteFst.brat.golden b/brat/test/golden/kernel/deleteFst.brat.golden index 050cdbab..2477d48f 100644 --- a/brat/test/golden/kernel/deleteFst.brat.golden +++ b/brat/test/golden/kernel/deleteFst.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/kernel/deleteFst.brat on line 2: deleteFst = { q0, q1 => q1 } - ^^^^^^^^^^^^^^^^ + ^^^^^^^^^^^^ Type error: Variable(s) q0 haven't been used diff --git a/brat/test/golden/kernel/deleteSnd.brat.golden b/brat/test/golden/kernel/deleteSnd.brat.golden index b4297b53..ea70a16f 100644 --- a/brat/test/golden/kernel/deleteSnd.brat.golden +++ b/brat/test/golden/kernel/deleteSnd.brat.golden @@ -1,6 +1,6 @@ Error in test/golden/kernel/deleteSnd.brat on line 2: deleteSnd = { q0, q1 => q0 } - ^^^^^^^^^^^^^^^^ + ^^^^^^^^^^^^ Type error: Variable(s) q1 haven't been used diff --git a/hugr_extension/Cargo.toml b/hugr_extension/Cargo.toml index 7b6cc2ea..02266229 100644 --- a/hugr_extension/Cargo.toml +++ b/hugr_extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "brat_extension" -version = "0.2.0" +version = "0.4.0" edition = "2021" [lib] @@ -9,7 +9,7 @@ bench = false path = "src/lib.rs" [dependencies] -hugr = "0.6.0" +hugr = "0.9.0" serde = "1.0" serde_json = "1.0.97" diff --git a/hugr_extension/src/ctor.rs b/hugr_extension/src/ctor.rs index 401bdb53..c6c615c0 100644 --- a/hugr_extension/src/ctor.rs +++ b/hugr_extension/src/ctor.rs @@ -2,9 +2,7 @@ use enum_iterator::Sequence; use hugr::{ ops::NamedOp, std_extensions::{arithmetic::int_types, collections}, - types::{ - type_param::TypeParam, CustomType, FunctionType, PolyFuncType, Type, TypeArg, TypeBound, - }, + types::{type_param::TypeParam, CustomType, PolyFuncType, Signature, Type, TypeArg, TypeBound}, }; use smol_str::{format_smolstr, SmolStr}; use std::str::FromStr; @@ -82,8 +80,8 @@ impl Ctor for BratCtor { impl Ctor for NatCtor { fn signature(self) -> PolyFuncType { match self { - NatCtor::zero => FunctionType::new(vec![], vec![nat_type()]).into(), - NatCtor::succ => FunctionType::new(vec![nat_type()], vec![nat_type()]).into(), + NatCtor::zero => Signature::new(vec![], vec![nat_type()]).into(), + NatCtor::succ => Signature::new(vec![nat_type()], vec![nat_type()]).into(), } } } @@ -94,11 +92,11 @@ impl Ctor for VecCtor { let ta = Type::new_var_use(0, TypeBound::Any); match self { VecCtor::nil => { - PolyFuncType::new(vec![tp], FunctionType::new(vec![], vec![vec_type(&ta)])) + PolyFuncType::new(vec![tp], Signature::new(vec![], vec![vec_type(&ta)])) } VecCtor::cons => PolyFuncType::new( vec![tp], - FunctionType::new(vec![ta.clone(), vec_type(&ta)], vec![vec_type(&ta)]), + Signature::new(vec![ta.clone(), vec_type(&ta)], vec![vec_type(&ta)]), ), } } diff --git a/hugr_extension/src/defs.rs b/hugr_extension/src/defs.rs index 82702980..c35c73be 100644 --- a/hugr_extension/src/defs.rs +++ b/hugr_extension/src/defs.rs @@ -6,11 +6,13 @@ use hugr::{ extension::{ prelude::USIZE_T, simple_op::{MakeOpDef, OpLoadError}, - OpDef, SignatureError, SignatureFromArgs, SignatureFunc, + ExtensionId, OpDef, SignatureError, SignatureFromArgs, SignatureFunc, }, ops::NamedOp, std_extensions::collections::list_type, - types::{type_param::TypeParam, FunctionType, PolyFuncType, Type, TypeArg, TypeBound}, + types::{ + type_param::TypeParam, FuncValueType, PolyFuncTypeRV, Type, TypeArg, TypeBound, TypeEnum, + }, }; use lazy_static::lazy_static; @@ -68,7 +70,7 @@ impl FromStr for BratOpDef { impl MakeOpDef for BratOpDef { fn from_def(op_def: &OpDef) -> Result { - hugr::extension::simple_op::try_from_name(op_def.name()) + hugr::extension::simple_op::try_from_name(op_def.name(), &super::EXTENSION_ID) } fn signature(&self) -> SignatureFunc { @@ -83,14 +85,17 @@ impl MakeOpDef for BratOpDef { let sig = ctor.signature(); let input = sig.body().output(); // Ctor output is input for the test let output = Type::new_sum(vec![input.clone(), sig.body().input().clone()]); - PolyFuncType::new(sig.params(), FunctionType::new(input.clone(), vec![output])) - .into() + PolyFuncTypeRV::new( + sig.params(), + FuncValueType::new(input.clone(), vec![output]), + ) + .into() } - Replicate => PolyFuncType::new( + Replicate => PolyFuncTypeRV::new( [TypeParam::Type { b: TypeBound::Copyable, }], - FunctionType::new( + FuncValueType::new( vec![USIZE_T, Type::new_var_use(0, TypeBound::Copyable)], vec![list_type(Type::new_var_use(0, TypeBound::Copyable))], ), @@ -98,17 +103,24 @@ impl MakeOpDef for BratOpDef { .into(), } } + + fn extension(&self) -> ExtensionId { + super::EXTENSION_ID.clone() + } } /// Binary compute_signature function for the `Hole` op struct HoleSigFun(); impl SignatureFromArgs for HoleSigFun { - fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { + fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { // The Hole op expects a nat identifier and two type sequences specifiying // the signature of the hole match arg_values { - [TypeArg::BoundedNat { n: _ }, input, output] => { - Ok(FunctionType::new(row_from_arg(input)?, row_from_arg(output)?).into()) + [TypeArg::BoundedNat { n: _ }, TypeArg::Type { ty: fun_ty }] => { + let TypeEnum::Function(sig) = fun_ty.as_type_enum().clone() else { + return Err(SignatureError::InvalidTypeArgs); + }; + Ok(PolyFuncTypeRV::new([], *sig)) } _ => Err(SignatureError::InvalidTypeArgs), } @@ -116,8 +128,8 @@ impl SignatureFromArgs for HoleSigFun { fn static_params(&self) -> &[TypeParam] { lazy_static! { - static ref PARAMS: [TypeParam; 3] = - [TypeParam::max_nat(), list_of_type(), list_of_type()]; + static ref PARAMS: [TypeParam; 2] = + [TypeParam::max_nat(), TypeParam::Type { b: TypeBound::Any }]; } PARAMS.as_slice() } @@ -126,16 +138,18 @@ impl SignatureFromArgs for HoleSigFun { /// Binary compute_signature function for the `Substitute` op struct SubstituteSigFun(); impl SignatureFromArgs for SubstituteSigFun { - fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { + fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { // The Substitute op expects a function signature and a list of hole signatures match arg_values { - [fun_sig, TypeArg::Sequence { elems: hole_sigs }] => { - let fun_ty = Type::new_function(sig_from_arg(fun_sig)?); - let mut inputs = vec![fun_ty.clone()]; + [TypeArg::Type { ty: outer_fun_ty }, TypeArg::Sequence { elems: hole_sigs }] => { + let mut inputs = vec![outer_fun_ty.clone()]; for sig in hole_sigs { - inputs.push(Type::new_function(sig_from_arg(sig)?)) + let TypeArg::Type { ty: inner_fun_ty } = sig else { + return Err(SignatureError::InvalidTypeArgs); + }; + inputs.push(inner_fun_ty.clone()) } - Ok(FunctionType::new(inputs, vec![fun_ty]).into()) + Ok(FuncValueType::new(inputs, vec![outer_fun_ty.clone()]).into()) } _ => Err(SignatureError::InvalidTypeArgs), } @@ -144,9 +158,11 @@ impl SignatureFromArgs for SubstituteSigFun { fn static_params(&self) -> &[TypeParam] { lazy_static! { static ref PARAMS: [TypeParam; 2] = [ - tuple_of_list_of_type(), + // The signature of outer functions + TypeParam::Type { b: TypeBound::Any }, + // A list of signatures for the inner functions which fill in holes TypeParam::List { - param: Box::new(tuple_of_list_of_type()) + param: Box::new(TypeParam::Type { b: TypeBound::Any }), }, ]; } @@ -157,7 +173,7 @@ impl SignatureFromArgs for SubstituteSigFun { /// Binary compute_signature function for the `Partial` op struct PartialSigFun(); impl SignatureFromArgs for PartialSigFun { - fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { + fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { // The Partial op expects a type sequence specifying the supplied partial inputs, a type // sequence specifiying the remaining inputs and a type sequence for the function outputs. match arg_values { @@ -166,13 +182,13 @@ impl SignatureFromArgs for PartialSigFun { let other_inputs = row_from_arg(other_inputs)?; let outputs = row_from_arg(outputs)?; let res_func = - Type::new_function(FunctionType::new(other_inputs.clone(), outputs.clone())); - let mut inputs = vec![Type::new_function(FunctionType::new( + Type::new_function(FuncValueType::new(other_inputs.clone(), outputs.clone())); + let mut inputs = vec![Type::new_function(FuncValueType::new( [partial_inputs.clone(), other_inputs].concat(), outputs, ))]; inputs.extend(partial_inputs); - Ok(FunctionType::new(inputs, vec![res_func]).into()) + Ok(FuncValueType::new(inputs, vec![res_func]).into()) } _ => Err(SignatureError::InvalidTypeArgs), } @@ -189,11 +205,11 @@ impl SignatureFromArgs for PartialSigFun { /// Binary compute_signature function for the `Panic` op struct PanicSigFun(); impl SignatureFromArgs for PanicSigFun { - fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { + fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { // The Panic op expects two type sequences specifiying the signature of the op match arg_values { [input, output] => { - Ok(FunctionType::new(row_from_arg(input)?, row_from_arg(output)?).into()) + Ok(FuncValueType::new(row_from_arg(input)?, row_from_arg(output)?).into()) } _ => Err(SignatureError::InvalidTypeArgs), } @@ -227,24 +243,8 @@ fn row_from_arg(arg: &TypeArg) -> Result, SignatureError> { } } -fn sig_from_arg(arg: &TypeArg) -> Result { - match arg { - TypeArg::Sequence { elems } if elems.len() == 2 => Ok(FunctionType::new( - row_from_arg(&elems[0])?, - row_from_arg(&elems[1])?, - )), - _ => Err(SignatureError::InvalidTypeArgs), - } -} - fn list_of_type() -> TypeParam { TypeParam::List { param: Box::new(TypeParam::Type { b: TypeBound::Any }), } } - -fn tuple_of_list_of_type() -> TypeParam { - TypeParam::Tuple { - params: vec![list_of_type(), list_of_type()], - } -} diff --git a/hugr_extension/src/ops.rs b/hugr_extension/src/ops.rs index b66565f7..f9bfea8b 100644 --- a/hugr_extension/src/ops.rs +++ b/hugr_extension/src/ops.rs @@ -4,7 +4,7 @@ use hugr::{ SignatureError, }, ops::{custom::ExtensionOp, NamedOp, OpTrait}, - types::{FunctionType, TypeArg, TypeEnum, TypeRow}, + types::{Signature, TypeArg, TypeEnum, TypeRow}, }; use smol_str::{format_smolstr, SmolStr}; @@ -16,18 +16,18 @@ use crate::{ctor::BratCtor, defs::BratOpDef}; pub enum BratOp { Hole { idx: u64, - sig: FunctionType, + sig: Signature, }, Substitute { - func_sig: FunctionType, - hole_sigs: Vec, + func_sig: Signature, + hole_sigs: Vec, }, Partial { inputs: TypeRow, - output_sig: FunctionType, + output_sig: Signature, }, Panic { - sig: FunctionType, + sig: Signature, }, Ctor { ctor: BratCtor, @@ -78,9 +78,20 @@ impl MakeExtensionOp for BratOp { _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); + let closed_sig = Signature::try_from(*func_sig.clone()) + .map_err(|_| SignatureError::InvalidTypeArgs)?; + + let closed_hole_sigs: Result, SignatureError> = hole_sigs? + .iter() + .map(|a| { + Signature::try_from(a.clone()) + .map_err(|_| SignatureError::InvalidTypeArgs) + }) + .collect(); + Ok(BratOp::Substitute { - func_sig: *func_sig.clone(), - hole_sigs: hole_sigs?, + func_sig: closed_sig, + hole_sigs: closed_hole_sigs?, }) } _ => Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs)), @@ -92,7 +103,8 @@ impl MakeExtensionOp for BratOp { }; Ok(BratOp::Partial { inputs: partial_inputs.to_vec().into(), - output_sig: *output_sig.clone(), + output_sig: Signature::try_from(*output_sig.clone()) + .expect("Invalid type arg to Partial"), }) } _ => Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs)), @@ -138,10 +150,12 @@ impl MakeExtensionOp for BratOp { ], BratOp::Partial { inputs, output_sig } => vec![ arg_from_row(inputs), - arg_from_row(output_sig.input()), - arg_from_row(output_sig.output()), + arg_from_row(output_sig.input().into()), + arg_from_row(output_sig.output().into()), ], - BratOp::Panic { sig } => vec![arg_from_row(sig.input()), arg_from_row(sig.output())], + BratOp::Panic { sig } => { + vec![arg_from_row(sig.input().into()), arg_from_row(sig.output())] + } BratOp::Ctor { args, .. } => args.clone(), BratOp::PrimCtorTest { args, .. } => args.clone(), BratOp::Replicate(arg) => vec![arg.clone()], diff --git a/hugr_validator/Cargo.toml b/hugr_validator/Cargo.toml index 9cc80749..f3cfa792 100644 --- a/hugr_validator/Cargo.toml +++ b/hugr_validator/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "hugr_validator" -version = "0.2.0" +version = "0.4.0" edition = "2021" [dependencies] -hugr = "0.6.0" +hugr = "0.9.0" serde_json = "*" -brat_extension = { path = "../hugr_extension" } \ No newline at end of file +brat_extension = { path = "../hugr_extension" }