Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector indexing and insertion operations #509

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
/docs/_build
*.hi
*.o

hie.yaml
1 change: 1 addition & 0 deletions accelerate.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ library
Data.Array.Accelerate.Classes.RealFloat
Data.Array.Accelerate.Classes.RealFrac
Data.Array.Accelerate.Classes.ToFloating
Data.Array.Accelerate.Classes.Vector
Data.Array.Accelerate.Debug.Internal.Clock
Data.Array.Accelerate.Debug.Internal.Flags
Data.Array.Accelerate.Debug.Internal.Graph
Expand Down
5 changes: 5 additions & 0 deletions src/Data/Array/Accelerate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,17 @@ module Data.Array.Accelerate (

-- ** SIMD vectors
Vec, VecElt,
Vectoring(..),
vecOfList,
listOfVec,

-- ** Type classes
-- *** Basic type classes
Eq(..),
Ord(..), Ordering(..), pattern LT_, pattern EQ_, pattern GT_,
Enum, succ, pred,
Bounded, minBound, maxBound,

-- Functor(..), (<$>), ($>), void,
-- Monad(..),

Expand Down Expand Up @@ -445,6 +449,7 @@ import Data.Array.Accelerate.Classes.Rational
import Data.Array.Accelerate.Classes.RealFloat
import Data.Array.Accelerate.Classes.RealFrac
import Data.Array.Accelerate.Classes.ToFloating
import Data.Array.Accelerate.Classes.Vector
import Data.Array.Accelerate.Data.Either
import Data.Array.Accelerate.Data.Maybe
import Data.Array.Accelerate.Language
Expand Down
50 changes: 37 additions & 13 deletions src/Data/Array/Accelerate/AST.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
Expand All @@ -7,6 +8,7 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
Expand Down Expand Up @@ -559,6 +561,21 @@ data OpenExp env aenv t where
-> OpenExp env aenv (Vec n s)
-> OpenExp env aenv tup

VecIndex :: (KnownNat n, v ~ Vec n s)
=> VectorType v
-> IntegralType i
-> OpenExp env aenv (Vec n s)
-> OpenExp env aenv i
-> OpenExp env aenv s

VecWrite :: (KnownNat n, v ~ Vec n s)
=> VectorType v
-> IntegralType i
-> OpenExp env aenv (Vec n s)
-> OpenExp env aenv i
-> OpenExp env aenv s
-> OpenExp env aenv (Vec n s)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove VecPac and VecUnpack? Those just get turned your new index and write instructions. We might as well also add VecShuffle, corresponding to the shufflevector instruction.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the more complicated things about VecPack and Unpack is the need to derive the tuple type. I got rather stuck trying to make changes to that whole pipeline, could you assist me on this live sometime?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point. Okay let's just shelve that for now.

I've been thinking how this vector support can be restructured after you mentioned you couldn't have a Vec Node (or any non-primitive thing), which is sort of a weird limitation for our language...

Copy link
Author

@HugoPeters1024 HugoPeters1024 Dec 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two blocking constraints are Prim Node (which can be implemented I think). However EltR Node ~ Node is the culprit. Perhaps some unsafe coerceon from and to a flattened byte vector?

-- Array indices & shapes
IndexSlice :: SliceIndex slix sl co sh
-> OpenExp env aenv slix
Expand Down Expand Up @@ -655,7 +672,6 @@ data PrimConst ty where
-- constant from Floating
PrimPi :: FloatingType a -> PrimConst a


-- |Primitive scalar operations
--
data PrimFun sig where
Expand Down Expand Up @@ -814,6 +830,8 @@ expType = \case
Nil -> TupRunit
VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR
VecUnpack vecR _ -> vecRtuple vecR
VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s
VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT
IndexSlice si _ _ -> shapeType $ sliceShapeR si
IndexFull si _ _ -> shapeType $ sliceDomainR si
ToIndex{} -> TupRsingle scalarTypeInt
Expand All @@ -825,7 +843,7 @@ expType = \case
While _ (Lam lhs _) _ -> lhsToTupR lhs
While{} -> error "What's the matter, you're running in the shadows"
Const tR _ -> TupRsingle tR
PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c
PrimConst c -> TupRsingle $ primConstType c
PrimApp f _ -> snd $ primFunType f
Index (Var repr _) _ -> arrayRtype repr
LinearIndex (Var repr _) _ -> arrayRtype repr
Expand All @@ -834,17 +852,17 @@ expType = \case
Undef tR -> TupRsingle tR
Coerce _ tR _ -> TupRsingle tR

primConstType :: PrimConst a -> SingleType a
primConstType :: PrimConst a -> ScalarType a
primConstType = \case
PrimMinBound t -> bounded t
PrimMaxBound t -> bounded t
PrimPi t -> floating t
where
bounded :: BoundedType a -> SingleType a
bounded (IntegralBoundedType t) = NumSingleType $ IntegralNumType t
bounded :: BoundedType a -> ScalarType a
bounded (IntegralBoundedType t) = SingleScalarType $ NumSingleType $ IntegralNumType t

floating :: FloatingType t -> SingleType t
floating = NumSingleType . FloatingNumType
floating :: FloatingType t -> ScalarType t
floating = SingleScalarType . NumSingleType . FloatingNumType

primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b)
primFunType = \case
Expand Down Expand Up @@ -1073,6 +1091,8 @@ rnfOpenExp topExp =
Nil -> ()
VecPack vecr e -> rnfVecR vecr `seq` rnfE e
VecUnpack vecr e -> rnfVecR vecr `seq` rnfE e
VecIndex vt it v i -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i
VecWrite vt it v i e -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i `seq` rnfE e
IndexSlice slice slix sh -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh
IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl
ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix
Expand Down Expand Up @@ -1100,9 +1120,9 @@ rnfConst (TupRsingle t) !_ = rnfScalarType t -- scalars should have (nf =
rnfConst (TupRpair ta tb) (a,b) = rnfConst ta a `seq` rnfConst tb b

rnfPrimConst :: PrimConst c -> ()
rnfPrimConst (PrimMinBound t) = rnfBoundedType t
rnfPrimConst (PrimMaxBound t) = rnfBoundedType t
rnfPrimConst (PrimPi t) = rnfFloatingType t
rnfPrimConst (PrimMinBound t) = rnfBoundedType t
rnfPrimConst (PrimMaxBound t) = rnfBoundedType t
rnfPrimConst (PrimPi t) = rnfFloatingType t

rnfPrimFun :: PrimFun f -> ()
rnfPrimFun (PrimAdd t) = rnfNumType t
Expand Down Expand Up @@ -1293,6 +1313,8 @@ liftOpenExp pexp =
Nil -> [|| Nil ||]
VecPack vecr e -> [|| VecPack $$(liftVecR vecr) $$(liftE e) ||]
VecUnpack vecr e -> [|| VecUnpack $$(liftVecR vecr) $$(liftE e) ||]
VecIndex vt it v i -> [|| VecIndex $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) ||]
VecWrite vt it v i e -> [|| VecWrite $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) $$(liftE e) ||]
IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||]
IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||]
ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||]
Expand Down Expand Up @@ -1326,9 +1348,9 @@ liftBoundary (ArrayR _ tp) (Constant v) = [|| Constant $$(liftElt tp v) ||]
liftBoundary _ (Function f) = [|| Function $$(liftOpenFun f) ||]

liftPrimConst :: PrimConst c -> CodeQ (PrimConst c)
liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||]
liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||]
liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||]
liftPrimConst (PrimMinBound t) = [|| PrimMinBound $$(liftBoundedType t) ||]
liftPrimConst (PrimMaxBound t) = [|| PrimMaxBound $$(liftBoundedType t) ||]
liftPrimConst (PrimPi t) = [|| PrimPi $$(liftFloatingType t) ||]

liftPrimFun :: PrimFun f -> CodeQ (PrimFun f)
liftPrimFun (PrimAdd t) = [|| PrimAdd $$(liftNumType t) ||]
Expand Down Expand Up @@ -1440,6 +1462,8 @@ formatExpOp = later $ \case
Nil{} -> "Nil"
VecPack{} -> "VecPack"
VecUnpack{} -> "VecUnpack"
VecIndex{} -> "VecIndex"
VecWrite{} -> "VecWrite"
IndexSlice{} -> "IndexSlice"
IndexFull{} -> "IndexFull"
ToIndex{} -> "ToIndex"
Expand Down
2 changes: 2 additions & 0 deletions src/Data/Array/Accelerate/Analysis/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ encodeOpenExp exp =
Pair e1 e2 -> intHost $(hashQ "Pair") <> travE e1 <> travE e2
VecPack _ e -> intHost $(hashQ "VecPack") <> travE e
VecUnpack _ e -> intHost $(hashQ "VecUnpack") <> travE e
VecIndex _ _ v i -> intHost $(hashQ "VecIndex") <> travE v <> travE i
VecWrite _ _ v i e -> intHost $(hashQ "VecWrite") <> travE v <> travE i <> travE e
Const tp c -> intHost $(hashQ "Const") <> encodeScalarConst tp c
Undef tp -> intHost $(hashQ "Undef") <> encodeScalarType tp
IndexSlice spec ix sh -> intHost $(hashQ "IndexSlice") <> travE ix <> travE sh <> encodeSliceIndex spec
Expand Down
3 changes: 1 addition & 2 deletions src/Data/Array/Accelerate/Classes/Enum.hs
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ defaultFromEnum = preludeError "fromEnum"
preludeError :: String -> a
preludeError x
= error
$ unlines [ printf "Prelude.%s is not supported for Accelerate types" x
, ""
$ unlines [ printf "Prelude.%s is not supported for Accelerate types" x , ""
, "These Prelude.Enum instances are present only to fulfil superclass"
, "constraints for subsequent classes in the standard Haskell numeric hierarchy."
]
Expand Down
36 changes: 36 additions & 0 deletions src/Data/Array/Accelerate/Classes/Vector.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GADTs #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module : Data.Array.Accelerate.Classes.Vector
-- Copyright : [2016..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <[email protected]>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
module Data.Array.Accelerate.Classes.Vector where

import GHC.TypeLits
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Smart
import Data.Primitive.Vec



instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where
type IndexType (Exp (Vec n a)) = Exp Int
vecIndex = mkVectorIndex
vecWrite = mkVectorWrite
vecEmpty = undef


10 changes: 10 additions & 0 deletions src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import qualified Data.Array.Accelerate.Sugar.Array as Sugar
import qualified Data.Array.Accelerate.Sugar.Elt as Sugar
import qualified Data.Array.Accelerate.Trafo.Delayed as AST

import GHC.TypeLits
import Control.DeepSeq
import Control.Exception
import Control.Monad
Expand Down Expand Up @@ -1168,6 +1169,12 @@ evalLOr (x, y) = fromBool (toBool x || toBool y)
evalLNot :: PrimBool -> PrimBool
evalLNot = fromBool . not . toBool

evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a
evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i)

evalVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, (i, a)) -> Vec n a
evalVectorWrite (VectorType n _) ti (v, (i, a)) | IntegralDict <- integralDict ti = vecWrite v (fromIntegral i) a

evalFromIntegral :: IntegralType a -> NumType b -> a -> b
evalFromIntegral ta (IntegralNumType tb)
| IntegralDict <- integralDict ta
Expand Down Expand Up @@ -1213,6 +1220,9 @@ evalMaxBound (IntegralBoundedType ty)
evalPi :: FloatingType a -> a
evalPi ty | FloatingDict <- floatingDict ty = pi

evalVectorCreate :: (KnownNat n, Prim a) => VectorType (Vec n a) -> Vec n a
evalVectorCreate (VectorType n _) = vecEmpty

evalSin :: FloatingType a -> (a -> a)
evalSin ty | FloatingDict <- floatingDict ty = sin

Expand Down
4 changes: 4 additions & 0 deletions src/Data/Array/Accelerate/Representation/Vec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ data VecR (n :: Nat) single tuple where
VecRnil :: SingleType s -> VecR 0 s ()
VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s)


vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s)
vecRvector = uncurry VectorType . go
where
go :: VecR n s tuple -> (Int, SingleType s)
go (VecRnil tp) = (0, tp)
go (VecRsucc vec) | (n, tp) <- go vec = (n + 1, tp)

vecRSingle :: KnownNat n => VecR n s tuple -> SingleType s
vecRSingle vecr = let (VectorType _ s) = vecRvector vecr in s

vecRtuple :: VecR n s tuple -> TypeR tuple
vecRtuple = snd . go
where
Expand Down
41 changes: 39 additions & 2 deletions src/Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
Expand All @@ -12,6 +12,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.Smart
Expand Down Expand Up @@ -71,6 +72,10 @@ module Data.Array.Accelerate.Smart (
-- ** Smart constructors for type coercion functions
mkFromIntegral, mkToFloating, mkBitcast, mkCoerce, Coerce(..),

-- ** Smart constructors for vector operations
mkVectorIndex,
mkVectorWrite,

-- ** Auxiliary functions
($$), ($$$), ($$$$), ($$$$$),
ApplyAcc(..),
Expand All @@ -83,6 +88,7 @@ module Data.Array.Accelerate.Smart (
) where


import Data.Proxy
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
Expand All @@ -95,6 +101,7 @@ import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Sugar.Array ( Arrays )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Sugar.Shape ( (:.)(..) )
import Data.Array.Accelerate.Type
Expand Down Expand Up @@ -520,6 +527,21 @@ data PreSmartExp acc exp t where
-> exp (Vec n s)
-> PreSmartExp acc exp tup

VecIndex :: (KnownNat n, v ~ Vec n s)
=> VectorType v
-> IntegralType i
-> exp (Vec n s)
-> exp i
-> PreSmartExp acc exp s

VecWrite :: (KnownNat n, v ~ Vec n s)
=> VectorType v
-> IntegralType i
-> exp (Vec n s)
-> exp i
-> exp s
-> PreSmartExp acc exp (Vec n s)

ToIndex :: ShapeR sh
-> exp sh
-> exp sh
Expand Down Expand Up @@ -853,13 +875,15 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where
Prj _ _ -> error "I never joke about my work"
VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR
VecUnpack vecR _ -> vecRtuple vecR
VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s
VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT
ToIndex _ _ _ -> TupRsingle scalarTypeInt
FromIndex shr _ _ -> shapeType shr
Case _ ((_,c):_) -> typeR c
Case{} -> internalError "encountered empty case"
Cond _ e _ -> typeR e
While t _ _ _ -> t
PrimConst c -> TupRsingle $ SingleScalarType $ primConstType c
PrimConst c -> TupRsingle $ primConstType c
PrimApp f _ -> snd $ primFunType f
Index tp _ _ -> tp
LinearIndex tp _ _ -> tp
Expand Down Expand Up @@ -1172,6 +1196,16 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil
where
x = SmartExp $ Prj PairIdxLeft a


inferNat :: forall n. KnownNat n => Int
inferNat = fromInteger $ natVal (Proxy @n)

mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a
mkVectorIndex (Exp v) (Exp i) = mkExp $ VecIndex (VectorType (inferNat @n) singleType) integralType v i

mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a)
mkVectorWrite (Exp v) (Exp i) (Exp el) = mkExp $ VecWrite (VectorType (inferNat @n) singleType) integralType v i el

-- Numeric conversions

mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b
Expand Down Expand Up @@ -1259,6 +1293,9 @@ mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a
mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c
mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b)

mkPrimTernary :: (Elt a, Elt b, Elt c, Elt d) => PrimFun ((EltR a, (EltR b, EltR c)) -> EltR d) -> Exp a -> Exp b -> Exp c -> Exp d
mkPrimTernary prim (Exp a) (Exp b) (Exp c) = mkExp $ PrimApp prim (SmartExp $ Pair a (SmartExp (Pair b c)))

mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool
mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary

Expand Down
Loading