diff --git a/.gitignore b/.gitignore index 2dc9bad21..eec9590ea 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ /docs/_build *.hi *.o + +hie.yaml diff --git a/accelerate.cabal b/accelerate.cabal index 0b95607e4..2e64e1e1f 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -402,6 +402,7 @@ library Data.Array.Accelerate.Classes.RealFloat Data.Array.Accelerate.Classes.RealFrac Data.Array.Accelerate.Classes.ToFloating + Data.Array.Accelerate.Classes.Vector Data.Array.Accelerate.Debug.Internal.Clock Data.Array.Accelerate.Debug.Internal.Flags Data.Array.Accelerate.Debug.Internal.Graph diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index ff1729f27..e2543c6ae 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -310,6 +310,9 @@ module Data.Array.Accelerate ( -- ** SIMD vectors Vec, VecElt, + Vectoring(..), + vecOfList, + listOfVec, -- ** Type classes -- *** Basic type classes @@ -317,6 +320,7 @@ module Data.Array.Accelerate ( Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_, Enum, succ, pred, Bounded, minBound, maxBound, + -- Functor(..), (<$>), ($>), void, -- Monad(..), @@ -445,6 +449,7 @@ import Data.Array.Accelerate.Classes.Rational import Data.Array.Accelerate.Classes.RealFloat import Data.Array.Accelerate.Classes.RealFrac import Data.Array.Accelerate.Classes.ToFloating +import Data.Array.Accelerate.Classes.Vector import Data.Array.Accelerate.Data.Either import Data.Array.Accelerate.Data.Maybe import Data.Array.Accelerate.Language diff --git a/src/Data/Array/Accelerate/AST.hs b/src/Data/Array/Accelerate/AST.hs index c84f5723f..d3a26353e 100644 --- a/src/Data/Array/Accelerate/AST.hs +++ b/src/Data/Array/Accelerate/AST.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -7,6 +8,7 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_HADDOCK hide #-} @@ -559,6 +561,21 @@ data OpenExp env aenv t where -> OpenExp env aenv (Vec n s) -> OpenExp env aenv tup + VecIndex :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> OpenExp env aenv (Vec n s) + -> OpenExp env aenv i + -> OpenExp env aenv s + + VecWrite :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> OpenExp env aenv (Vec n s) + -> OpenExp env aenv i + -> OpenExp env aenv s + -> OpenExp env aenv (Vec n s) + -- Array indices & shapes IndexSlice :: SliceIndex slix sl co sh -> OpenExp env aenv slix @@ -655,7 +672,6 @@ data PrimConst ty where -- constant from Floating PrimPi :: FloatingType a -> PrimConst a - -- |Primitive scalar operations -- data PrimFun sig where @@ -814,6 +830,8 @@ expType = \case Nil -> TupRunit VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR VecUnpack vecR _ -> vecRtuple vecR + VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s + VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT IndexSlice si _ _ -> shapeType $ sliceShapeR si IndexFull si _ _ -> shapeType $ sliceDomainR si ToIndex{} -> TupRsingle scalarTypeInt @@ -825,7 +843,7 @@ expType = \case While _ (Lam lhs _) _ -> lhsToTupR lhs While{} -> error "What's the matter, you're running in the shadows" Const tR _ -> TupRsingle tR - PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c + PrimConst c -> TupRsingle $ primConstType c PrimApp f _ -> snd $ primFunType f Index (Var repr _) _ -> arrayRtype repr LinearIndex (Var repr _) _ -> arrayRtype repr @@ -834,17 +852,17 @@ expType = \case Undef tR -> TupRsingle tR Coerce _ tR _ -> TupRsingle tR -primConstType :: PrimConst a -> SingleType a +primConstType :: PrimConst a -> ScalarType a primConstType = \case PrimMinBound t -> bounded t PrimMaxBound t -> bounded t PrimPi t -> floating t where - bounded :: BoundedType a -> SingleType a - bounded (IntegralBoundedType t) = NumSingleType $ IntegralNumType t + bounded :: BoundedType a -> ScalarType a + bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t - floating :: FloatingType t -> SingleType t - floating = NumSingleType . FloatingNumType + floating :: FloatingType t -> ScalarType t + floating = SingleScalarType . NumSingleType . FloatingNumType primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b) primFunType = \case @@ -1073,6 +1091,8 @@ rnfOpenExp topExp = Nil -> () VecPack vecr e -> rnfVecR vecr `seq` rnfE e VecUnpack vecr e -> rnfVecR vecr `seq` rnfE e + VecIndex vt it v i -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i + VecWrite vt it v i e -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i `seq` rnfE e IndexSlice slice slix sh -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix @@ -1100,9 +1120,9 @@ rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf = rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b rnfPrimConst :: PrimConst c -> () -rnfPrimConst (PrimMinBound t) = rnfBoundedType t -rnfPrimConst (PrimMaxBound t) = rnfBoundedType t -rnfPrimConst (PrimPi t) = rnfFloatingType t +rnfPrimConst (PrimMinBound t) = rnfBoundedType t +rnfPrimConst (PrimMaxBound t) = rnfBoundedType t +rnfPrimConst (PrimPi t) = rnfFloatingType t rnfPrimFun :: PrimFun f -> () rnfPrimFun (PrimAdd t) = rnfNumType t @@ -1293,6 +1313,8 @@ liftOpenExp pexp = Nil -> [|| Nil ||] VecPack vecr e -> [|| VecPack $$(liftVecR vecr) $$(liftE e) ||] VecUnpack vecr e -> [|| VecUnpack $$(liftVecR vecr) $$(liftE e) ||] + VecIndex vt it v i -> [|| VecIndex $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) ||] + VecWrite vt it v i e -> [|| VecWrite $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) $$(liftE e) ||] IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||] IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||] ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||] @@ -1326,9 +1348,9 @@ liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||] liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||] liftPrimConst :: PrimConst c -> CodeQ (PrimConst c) -liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] -liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] -liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] +liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||] +liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||] +liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||] liftPrimFun :: PrimFun f -> CodeQ (PrimFun f) liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||] @@ -1440,6 +1462,8 @@ formatExpOp = later $ \case Nil{} -> "Nil" VecPack{} -> "VecPack" VecUnpack{} -> "VecUnpack" + VecIndex{} -> "VecIndex" + VecWrite{} -> "VecWrite" IndexSlice{} -> "IndexSlice" IndexFull{} -> "IndexFull" ToIndex{} -> "ToIndex" diff --git a/src/Data/Array/Accelerate/Analysis/Hash.hs b/src/Data/Array/Accelerate/Analysis/Hash.hs index 75625b9ec..964a5f11a 100644 --- a/src/Data/Array/Accelerate/Analysis/Hash.hs +++ b/src/Data/Array/Accelerate/Analysis/Hash.hs @@ -320,6 +320,8 @@ encodeOpenExp exp = Pair e1 e2 -> intHost $(hashQ "Pair") <> travE e1 <> travE e2 VecPack _ e -> intHost $(hashQ "VecPack") <> travE e VecUnpack _ e -> intHost $(hashQ "VecUnpack") <> travE e + VecIndex _ _ v i -> intHost $(hashQ "VecIndex") <> travE v <> travE i + VecWrite _ _ v i e -> intHost $(hashQ "VecWrite") <> travE v <> travE i <> travE e Const tp c -> intHost $(hashQ "Const") <> encodeScalarConst tp c Undef tp -> intHost $(hashQ "Undef") <> encodeScalarType tp IndexSlice spec ix sh -> intHost $(hashQ "IndexSlice") <> travE ix <> travE sh <> encodeSliceIndex spec diff --git a/src/Data/Array/Accelerate/Classes/Enum.hs b/src/Data/Array/Accelerate/Classes/Enum.hs index 84b344273..10e946ee5 100644 --- a/src/Data/Array/Accelerate/Classes/Enum.hs +++ b/src/Data/Array/Accelerate/Classes/Enum.hs @@ -187,8 +187,7 @@ defaultFromEnum = preludeError "fromEnum" preludeError :: String -> a preludeError x = error - $ unlines [ printf "Prelude.%s is not supported for Accelerate types" x - , "" + $ unlines [ printf "Prelude.%s is not supported for Accelerate types" x , "" , "These Prelude.Enum instances are present only to fulfil superclass" , "constraints for subsequent classes in the standard Haskell numeric hierarchy." ] diff --git a/src/Data/Array/Accelerate/Classes/Vector.hs b/src/Data/Array/Accelerate/Classes/Vector.hs new file mode 100644 index 000000000..21c7a7be2 --- /dev/null +++ b/src/Data/Array/Accelerate/Classes/Vector.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MonoLocalBinds #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE GADTs #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} +-- | +-- Module : Data.Array.Accelerate.Classes.Vector +-- Copyright : [2016..2020] The Accelerate Team +-- License : BSD3 +-- +-- Maintainer : Trevor L. McDonell +-- Stability : experimental +-- Portability : non-portable (GHC extensions) +-- +module Data.Array.Accelerate.Classes.Vector where + +import GHC.TypeLits +import Data.Array.Accelerate.Sugar.Vec +import Data.Array.Accelerate.Smart +import Data.Primitive.Vec + + + +instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where + type IndexType (Exp (Vec n a)) = Exp Int + vecIndex = mkVectorIndex + vecWrite = mkVectorWrite + vecEmpty = undef + + diff --git a/src/Data/Array/Accelerate/Interpreter.hs b/src/Data/Array/Accelerate/Interpreter.hs index 5b8e6401a..aee68443f 100644 --- a/src/Data/Array/Accelerate/Interpreter.hs +++ b/src/Data/Array/Accelerate/Interpreter.hs @@ -69,6 +69,7 @@ import qualified Data.Array.Accelerate.Sugar.Array as Sugar import qualified Data.Array.Accelerate.Sugar.Elt as Sugar import qualified Data.Array.Accelerate.Trafo.Delayed as AST +import GHC.TypeLits import Control.DeepSeq import Control.Exception import Control.Monad @@ -1168,6 +1169,12 @@ evalLOr (x, y) = fromBool (toBool x || toBool y) evalLNot :: PrimBool -> PrimBool evalLNot = fromBool . not . toBool +evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a +evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i) + +evalVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, (i, a)) -> Vec n a +evalVectorWrite (VectorType n _) ti (v, (i, a)) | IntegralDict <- integralDict ti = vecWrite v (fromIntegral i) a + evalFromIntegral :: IntegralType a -> NumType b -> a -> b evalFromIntegral ta (IntegralNumType tb) | IntegralDict <- integralDict ta @@ -1213,6 +1220,9 @@ evalMaxBound (IntegralBoundedType ty) evalPi :: FloatingType a -> a evalPi ty | FloatingDict <- floatingDict ty = pi +evalVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> Vec n a +evalVectorCreate (VectorType n _) = vecEmpty + evalSin :: FloatingType a -> (a -> a) evalSin ty | FloatingDict <- floatingDict ty = sin diff --git a/src/Data/Array/Accelerate/Representation/Vec.hs b/src/Data/Array/Accelerate/Representation/Vec.hs index 35eac3b6c..bd37c7f18 100644 --- a/src/Data/Array/Accelerate/Representation/Vec.hs +++ b/src/Data/Array/Accelerate/Representation/Vec.hs @@ -41,6 +41,7 @@ data VecR (n :: Nat) single tuple where VecRnil :: SingleType s -> VecR 0 s () VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s) + vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s) vecRvector = uncurry VectorType . go where @@ -48,6 +49,9 @@ vecRvector = uncurry VectorType . go go (VecRnil tp) = (0, tp) go (VecRsucc vec) | (n, tp) <- go vec = (n + 1, tp) +vecRSingle :: KnownNat n => VecR n s tuple -> SingleType s +vecRSingle vecr = let (VectorType _ s) = vecRvector vecr in s + vecRtuple :: VecR n s tuple -> TypeR tuple vecRtuple = snd . go where diff --git a/src/Data/Array/Accelerate/Smart.hs b/src/Data/Array/Accelerate/Smart.hs index 8fa577f41..ccb38e7ab 100644 --- a/src/Data/Array/Accelerate/Smart.hs +++ b/src/Data/Array/Accelerate/Smart.hs @@ -1,5 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE CPP #-} + {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -12,6 +12,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE PolyKinds #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Smart @@ -71,6 +72,10 @@ module Data.Array.Accelerate.Smart ( -- ** Smart constructors for type coercion functions mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..), + -- ** Smart constructors for vector operations + mkVectorIndex, + mkVectorWrite, + -- ** Auxiliary functions ($$), ($$$), ($$$$), ($$$$$), ApplyAcc(..), @@ -83,6 +88,7 @@ module Data.Array.Accelerate.Smart ( ) where +import Data.Proxy import Data.Array.Accelerate.AST.Idx import Data.Array.Accelerate.Error import Data.Array.Accelerate.Representation.Array @@ -95,6 +101,7 @@ import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Representation.Vec import Data.Array.Accelerate.Sugar.Array ( Arrays ) import Data.Array.Accelerate.Sugar.Elt +import Data.Array.Accelerate.Sugar.Vec import Data.Array.Accelerate.Sugar.Foreign import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) ) import Data.Array.Accelerate.Type @@ -520,6 +527,21 @@ data PreSmartExp acc exp t where -> exp (Vec n s) -> PreSmartExp acc exp tup + VecIndex :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> exp (Vec n s) + -> exp i + -> PreSmartExp acc exp s + + VecWrite :: (KnownNat n, v ~ Vec n s) + => VectorType v + -> IntegralType i + -> exp (Vec n s) + -> exp i + -> exp s + -> PreSmartExp acc exp (Vec n s) + ToIndex :: ShapeR sh -> exp sh -> exp sh @@ -853,13 +875,15 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where Prj _ _ -> error "I never joke about my work" VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR VecUnpack vecR _ -> vecRtuple vecR + VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s + VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT ToIndex _ _ _ -> TupRsingle scalarTypeInt FromIndex shr _ _ -> shapeType shr Case _ ((_,c):_) -> typeR c Case{} -> internalError "encountered empty case" Cond _ e _ -> typeR e While t _ _ _ -> t - PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c + PrimConst c -> TupRsingle $ primConstType c PrimApp f _ -> snd $ primFunType f Index tp _ _ -> tp LinearIndex tp _ _ -> tp @@ -1172,6 +1196,16 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil where x = SmartExp $ Prj PairIdxLeft a + +inferNat :: forall n. KnownNat n => Int +inferNat = fromInteger $ natVal (Proxy @n) + +mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a +mkVectorIndex (Exp v) (Exp i) = mkExp $ VecIndex (VectorType (inferNat @n) singleType) integralType v i + +mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a) +mkVectorWrite (Exp v) (Exp i) (Exp el) = mkExp $ VecWrite (VectorType (inferNat @n) singleType) integralType v i el + -- Numeric conversions mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b @@ -1259,6 +1293,9 @@ mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b) +mkPrimTernary :: (Elt a, Elt b, Elt c, Elt d) => PrimFun ((EltR a, (EltR b, EltR c)) -> EltR d) -> Exp a -> Exp b -> Exp c -> Exp d +mkPrimTernary prim (Exp a) (Exp b) (Exp c) = mkExp $ PrimApp prim (SmartExp $ Pair a (SmartExp (Pair b c))) + mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary diff --git a/src/Data/Array/Accelerate/Trafo/Algebra.hs b/src/Data/Array/Accelerate/Trafo/Algebra.hs index 9cfea36ae..807ffe474 100644 --- a/src/Data/Array/Accelerate/Trafo/Algebra.hs +++ b/src/Data/Array/Accelerate/Trafo/Algebra.hs @@ -33,12 +33,14 @@ import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Pretty.Print ( primOperator, isInfix, opName ) import Data.Array.Accelerate.Trafo.Environment import Data.Array.Accelerate.Type +import Data.Array.Accelerate.Classes.Vector import qualified Data.Array.Accelerate.Debug.Internal.Stats as Stats import Data.Bits import Data.Monoid import Data.Text ( Text ) +import Data.Primitive.Vec import Data.Text.Prettyprint.Doc import Data.Text.Prettyprint.Doc.Render.Text import GHC.Float ( float2Double, double2Float ) diff --git a/src/Data/Array/Accelerate/Trafo/Sharing.hs b/src/Data/Array/Accelerate/Trafo/Sharing.hs index 67ead04f0..9a740cb06 100644 --- a/src/Data/Array/Accelerate/Trafo/Sharing.hs +++ b/src/Data/Array/Accelerate/Trafo/Sharing.hs @@ -764,6 +764,8 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp Pair e1 e2 -> AST.Pair (cvt e1) (cvt e2) VecPack vec e -> AST.VecPack vec (cvt e) VecUnpack vec e -> AST.VecUnpack vec (cvt e) + VecIndex vt it v i -> AST.VecIndex vt it (cvt v) (cvt i) + VecWrite vt it v i e -> AST.VecWrite vt it (cvt v) (cvt i) (cvt e) ToIndex shr sh ix -> AST.ToIndex shr (cvt sh) (cvt ix) FromIndex shr sh e -> AST.FromIndex shr (cvt sh) (cvt e) Case e rhs -> cvtCase (cvt e) (over (mapped . _2) cvt rhs) @@ -1841,37 +1843,39 @@ makeOccMapSharingExp config accOccMap expOccMap = travE return (UnscopedExp [] (ExpSharing (StableNameHeight sn height) exp), height) reconstruct $ case pexp of - Tag tp i -> return (Tag tp i, 0) -- height is 0! - Const tp c -> return (Const tp c, 1) - Undef tp -> return (Undef tp, 1) - Nil -> return (Nil, 1) - Pair e1 e2 -> travE2 Pair e1 e2 - Prj i e -> travE1 (Prj i) e - VecPack vec e -> travE1 (VecPack vec) e - VecUnpack vec e -> travE1 (VecUnpack vec) e - ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix - FromIndex shr sh e -> travE2 (FromIndex shr) sh e - Match t e -> travE1 (Match t) e - Case e rhs -> do - (e', h1) <- travE lvl e - (rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ] - return (Case e' rhs', h1 `max` maximum h2 + 1) - Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 - While t p iter init -> do - (p' , h1) <- traverseFun1 lvl t p - (iter', h2) <- traverseFun1 lvl t iter - (init', h3) <- travE lvl init - return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) - PrimConst c -> return (PrimConst c, 1) - PrimApp p e -> travE1 (PrimApp p) e - Index tp a e -> travAE (Index tp) a e - LinearIndex tp a i -> travAE (LinearIndex tp) a i - Shape shr a -> travA (Shape shr) a - ShapeSize shr e -> travE1 (ShapeSize shr) e - Foreign tp ff f e -> do - (e', h) <- travE lvl e - return (Foreign tp ff f e', h+1) - Coerce t1 t2 e -> travE1 (Coerce t1 t2) e + Tag tp i -> return (Tag tp i, 0) -- height is 0! + Const tp c -> return (Const tp c, 1) + Undef tp -> return (Undef tp, 1) + Nil -> return (Nil, 1) + Pair e1 e2 -> travE2 Pair e1 e2 + Prj i e -> travE1 (Prj i) e + VecPack vec e -> travE1 (VecPack vec) e + VecUnpack vec e -> travE1 (VecUnpack vec) e + VecIndex vt ti v i -> travE2 (VecIndex vt ti) v i + VecWrite vt ti v i e -> travE3 (VecWrite vt ti) v i e + ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix + FromIndex shr sh e -> travE2 (FromIndex shr) sh e + Match t e -> travE1 (Match t) e + Case e rhs -> do + (e', h1) <- travE lvl e + (rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ] + return (Case e' rhs', h1 `max` maximum h2 + 1) + Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 + While t p iter init -> do + (p' , h1) <- traverseFun1 lvl t p + (iter', h2) <- traverseFun1 lvl t iter + (init', h3) <- travE lvl init + return (While t p' iter' init', h1 `max` h2 `max` h3 + 1) + PrimConst c -> return (PrimConst c, 1) + PrimApp p e -> travE1 (PrimApp p) e + Index tp a e -> travAE (Index tp) a e + LinearIndex tp a i -> travAE (LinearIndex tp) a i + Shape shr a -> travA (Shape shr) a + ShapeSize shr e -> travE1 (ShapeSize shr) e + Foreign tp ff f e -> do + (e', h) <- travE lvl e + return (Foreign tp ff f e', h+1) + Coerce t1 t2 e -> travE1 (Coerce t1 t2) e where traverseAcc :: HasCallStack => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int) @@ -2755,6 +2759,8 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp Prj i e -> travE1 (Prj i) e VecPack vec e -> travE1 (VecPack vec) e VecUnpack vec e -> travE1 (VecUnpack vec) e + VecIndex vt it v i -> travE2 (VecIndex vt it) v i + VecWrite vt it v i e -> travE3 (VecWrite vt it) v i e ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix FromIndex shr sh e -> travE2 (FromIndex shr) sh e Match t e -> travE1 (Match t) e diff --git a/src/Data/Array/Accelerate/Trafo/Shrink.hs b/src/Data/Array/Accelerate/Trafo/Shrink.hs index 574747865..636043113 100644 --- a/src/Data/Array/Accelerate/Trafo/Shrink.hs +++ b/src/Data/Array/Accelerate/Trafo/Shrink.hs @@ -293,6 +293,8 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE Pair x y -> Pair <$> shrinkE x <*> shrinkE y VecPack vec e -> VecPack vec <$> shrinkE e VecUnpack vec e -> VecUnpack vec <$> shrinkE e + VecIndex vt it v i -> VecIndex vt it <$> shrinkE v <*> shrinkE i + VecWrite vt it v i e -> VecWrite vt it <$> shrinkE v <*> shrinkE i <*> shrinkE e IndexSlice x ix sh -> IndexSlice x <$> shrinkE ix <*> shrinkE sh IndexFull x ix sl -> IndexFull x <$> shrinkE ix <*> shrinkE sl ToIndex shr sh ix -> ToIndex shr <$> shrinkE sh <*> shrinkE ix @@ -494,6 +496,8 @@ usesOfExp range = countE Pair e1 e2 -> countE e1 <> countE e2 VecPack _ e -> countE e VecUnpack _ e -> countE e + VecIndex _ _ v i -> countE v <> countE i + VecWrite _ _ v i e -> countE v <> countE i <> countE e IndexSlice _ ix sh -> countE ix <> countE sh IndexFull _ ix sl -> countE ix <> countE sl FromIndex _ sh i -> countE sh <> countE i @@ -581,6 +585,8 @@ usesOfPreAcc withShape countAcc idx = count Pair x y -> countE x + countE y VecPack _ e -> countE e VecUnpack _ e -> countE e + VecIndex _ _ v i -> countE v + countE i + VecWrite _ _ v i e -> countE v + countE i + countE e IndexSlice _ ix sh -> countE ix + countE sh IndexFull _ ix sl -> countE ix + countE sl ToIndex _ sh ix -> countE sh + countE ix diff --git a/src/Data/Array/Accelerate/Trafo/Simplify.hs b/src/Data/Array/Accelerate/Trafo/Simplify.hs index 71be5aad3..6fe611f7a 100644 --- a/src/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/Simplify.hs @@ -226,6 +226,8 @@ simplifyOpenExp env = first getAny . cvtE Pair e1 e2 -> Pair <$> cvtE e1 <*> cvtE e2 VecPack vec e -> VecPack vec <$> cvtE e VecUnpack vec e -> VecUnpack vec <$> cvtE e + VecIndex vt it v i -> VecIndex vt it <$> cvtE v <*> cvtE i + VecWrite vt it v i e -> VecWrite vt it <$> cvtE v <*> cvtE i <*> cvtE e IndexSlice x ix sh -> IndexSlice x <$> cvtE ix <*> cvtE sh IndexFull x ix sl -> IndexFull x <$> cvtE ix <*> cvtE sl ToIndex shr sh ix -> toIndex shr (cvtE sh) (cvtE ix) @@ -548,6 +550,8 @@ summariseOpenExp = (terms +~ 1) . goE Pair e1 e2 -> travE e1 +++ travE e2 & terms +~ 1 VecPack _ e -> travE e VecUnpack _ e -> travE e + VecIndex _ _ v i -> travE v +++ travE i + VecWrite _ _ v i e -> travE v +++ travE i +++ travE e IndexSlice _ slix sh -> travE slix +++ travE sh & terms +~ 1 -- +1 for sliceIndex IndexFull _ slix sl -> travE slix +++ travE sl & terms +~ 1 -- +1 for sliceIndex ToIndex _ sh ix -> travE sh +++ travE ix diff --git a/src/Data/Array/Accelerate/Trafo/Substitution.hs b/src/Data/Array/Accelerate/Trafo/Substitution.hs index e1aa1176b..7debd6d07 100644 --- a/src/Data/Array/Accelerate/Trafo/Substitution.hs +++ b/src/Data/Array/Accelerate/Trafo/Substitution.hs @@ -149,29 +149,31 @@ inlineVars lhsBound expr bound substitute k1 k2 vars topExp = case topExp of Let lhs e1 e2 | Exists lhs' <- rebuildLHS lhs - -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2 - Evar (Var t ix) -> Evar . Var t <$> k1 ix - Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 - Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 - Nil -> Just Nil - VecPack vec e1 -> VecPack vec <$> travE e1 - VecUnpack vec e1 -> VecUnpack vec <$> travE e1 - IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 - IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 - ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 - FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 - Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def - Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 - While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 - Const t c -> Just $ Const t c - PrimConst c -> Just $ PrimConst c - PrimApp p e1 -> PrimApp p <$> travE e1 - Index a e1 -> Index a <$> travE e1 - LinearIndex a e1 -> LinearIndex a <$> travE e1 - Shape a -> Just $ Shape a - ShapeSize shr e1 -> ShapeSize shr <$> travE e1 - Undef t -> Just $ Undef t - Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 + -> Let lhs' <$> travE e1 <*> substitute (strengthenAfter lhs lhs' k1) (weakenWithLHS lhs' .> k2) (weakenWithLHS lhs `weakenVars` vars) e2 + Evar (Var t ix) -> Evar . Var t <$> k1 ix + Foreign tp asm f e1 -> Foreign tp asm f <$> travE e1 + Pair e1 e2 -> Pair <$> travE e1 <*> travE e2 + Nil -> Just Nil + VecPack vec e1 -> VecPack vec <$> travE e1 + VecUnpack vec e1 -> VecUnpack vec <$> travE e1 + VecIndex vt it v i -> VecIndex vt it <$> travE v <*> travE i + VecWrite vt it v i e -> VecWrite vt it <$> travE v <*> travE i <*> travE e + IndexSlice si e1 e2 -> IndexSlice si <$> travE e1 <*> travE e2 + IndexFull si e1 e2 -> IndexFull si <$> travE e1 <*> travE e2 + ToIndex shr e1 e2 -> ToIndex shr <$> travE e1 <*> travE e2 + FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2 + Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def + Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3 + While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1 + Const t c -> Just $ Const t c + PrimConst c -> Just $ PrimConst c + PrimApp p e1 -> PrimApp p <$> travE e1 + Index a e1 -> Index a <$> travE e1 + LinearIndex a e1 -> LinearIndex a <$> travE e1 + Shape a -> Just $ Shape a + ShapeSize shr e1 -> ShapeSize shr <$> travE e1 + Undef t -> Just $ Undef t + Coerce t1 t2 e1 -> Coerce t1 t2 <$> travE e1 where travE :: OpenExp env1 aenv s -> Maybe (OpenExp env2 aenv s) @@ -546,31 +548,33 @@ rebuildOpenExp -> f (OpenExp env' aenv' t) rebuildOpenExp v av@(ReindexAvar reindex) exp = case exp of - Const t c -> pure $ Const t c - PrimConst c -> pure $ PrimConst c - Undef t -> pure $ Undef t - Evar var -> expOut <$> v var + Const t c -> pure $ Const t c + PrimConst c -> pure $ PrimConst c + Undef t -> pure $ Undef t + Evar var -> expOut <$> v var Let lhs a b | Exists lhs' <- rebuildLHS lhs - -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b - Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 - Nil -> pure Nil - VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e - VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e - IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh - IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl - ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix - FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix - Case e rhs def -> Case <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def - Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e - While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x - PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x - Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh - LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i - Shape a -> Shape <$> reindex a - ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh - Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e - Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e + -> Let lhs' <$> rebuildOpenExp v av a <*> rebuildOpenExp (shiftE' lhs lhs' v) av b + Pair e1 e2 -> Pair <$> rebuildOpenExp v av e1 <*> rebuildOpenExp v av e2 + Nil -> pure Nil + VecPack vec e -> VecPack vec <$> rebuildOpenExp v av e + VecUnpack vec e -> VecUnpack vec <$> rebuildOpenExp v av e + VecIndex vt it v' i -> VecIndex vt it <$> rebuildOpenExp v av v' <*> rebuildOpenExp v av i + VecWrite vt it v' i e -> VecWrite vt it <$> rebuildOpenExp v av v' <*> rebuildOpenExp v av i <*> rebuildOpenExp v av e + IndexSlice x ix sh -> IndexSlice x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sh + IndexFull x ix sl -> IndexFull x <$> rebuildOpenExp v av ix <*> rebuildOpenExp v av sl + ToIndex shr sh ix -> ToIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v av sh <*> rebuildOpenExp v av ix + Case e rhs def -> Case <$> rebuildOpenExp v av e <*> sequenceA [ (t,) <$> rebuildOpenExp v av c | (t,c) <- rhs ] <*> rebuildMaybeExp v av def + Cond p t e -> Cond <$> rebuildOpenExp v av p <*> rebuildOpenExp v av t <*> rebuildOpenExp v av e + While p f x -> While <$> rebuildFun v av p <*> rebuildFun v av f <*> rebuildOpenExp v av x + PrimApp f x -> PrimApp f <$> rebuildOpenExp v av x + Index a sh -> Index <$> reindex a <*> rebuildOpenExp v av sh + LinearIndex a i -> LinearIndex <$> reindex a <*> rebuildOpenExp v av i + Shape a -> Shape <$> reindex a + ShapeSize shr sh -> ShapeSize shr <$> rebuildOpenExp v av sh + Foreign tp ff f e -> Foreign tp ff f <$> rebuildOpenExp v av e + Coerce t1 t2 e -> Coerce t1 t2 <$> rebuildOpenExp v av e {-# INLINEABLE rebuildFun #-} rebuildFun diff --git a/src/Data/Primitive/Vec.hs b/src/Data/Primitive/Vec.hs index 34a77635b..a50f643c2 100644 --- a/src/Data/Primitive/Vec.hs +++ b/src/Data/Primitive/Vec.hs @@ -1,15 +1,21 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE UnboxedTuples #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TupleSections #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Primitive.Vec @@ -32,11 +38,16 @@ module Data.Primitive.Vec ( Vec16, pattern Vec16, listOfVec, + vecOfList, liftVec, + Vectoring(..) ) where +import Data.Kind +import Data.Proxy import Control.Monad.ST +import Control.Monad.Reader import Data.Primitive.ByteArray import Data.Primitive.Types import Data.Text.Prettyprint.Doc @@ -83,6 +94,36 @@ import GHC.Word -- data Vec (n :: Nat) a = Vec ByteArray# +class Vectoring vector a | vector -> a where + type IndexType vector :: Data.Kind.Type + vecIndex :: vector -> IndexType vector -> a + vecWrite :: vector -> IndexType vector -> a -> vector + vecEmpty :: vector + +instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where + type IndexType (Vec n a) = Int + vecIndex (Vec ba#) i@(I# iu#) = let + n :: Int + n = fromIntegral $ natVal $ Proxy @n + in if i >= 0 && i < n then indexByteArray# ba# iu# else error ("index " <> show i <> " out of range in Vec of size " <> show n) + vecWrite vec@(Vec ba#) i@(I# iu#) v = runST $ do + let n :: Int + n = fromIntegral $ natVal $ Proxy @n + mba <- unsafeThawByteArray (ByteArray ba#) + writeByteArray mba i v + ByteArray nba# <- unsafeFreezeByteArray mba + return $! Vec nba# + vecEmpty = mkVec + + +mkVec :: forall n a. (KnownNat n, Prim a) => Vec n a +mkVec = runST $ do + let n :: Int = fromIntegral $ natVal $ Proxy @n + mba <- newByteArray (n * sizeOf (undefined :: a)) + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + + type role Vec nominal representational instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where @@ -93,6 +134,14 @@ instance (Show a, Prim a, KnownNat n) => Show (Vec n a) where . group . encloseSep (flatAlt "< " "<") (flatAlt " >" ">") ", " . map viaShow +vecOfList :: forall n a. (KnownNat n, Prim a) => [a] -> Vec n a +vecOfList vs = runST $ do + let n :: Int = fromIntegral $ natVal $ Proxy @n + mba <- newByteArray (n * sizeOf (undefined :: a)) + zipWithM_ (writeByteArray mba) [0..n-1] vs + ByteArray ba# <- unsafeFreezeByteArray mba + return $! Vec ba# + listOfVec :: forall a n. (Prim a, KnownNat n) => Vec n a -> [a] listOfVec (Vec ba#) = go 0# where @@ -259,6 +308,7 @@ packVec16 a b c d e f g h i j k l m n o p = runST $ do ByteArray ba# <- unsafeFreezeByteArray mba return $! Vec ba# + -- O(n) at runtime to copy from the Addr# to the ByteArray#. We should be able -- to do this without copying, but I don't think the definition of ByteArray# is -- exported (or it is deeply magical).