Skip to content

Commit

Permalink
Implement multiple folds using scans
Browse files Browse the repository at this point in the history
  • Loading branch information
adithyaov committed Jan 28, 2025
1 parent 5537b6b commit 44616ee
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 230 deletions.
160 changes: 26 additions & 134 deletions core/src/Streamly/Internal/Data/Fold/Combinators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ module Streamly.Internal.Data.Fold.Combinators
, the
, mean
, rollingHash
, defaultSalt
, Scanl.defaultSalt
, rollingHashWithSalt
, rollingHashFirstN
-- , rollingHashLastN
Expand Down Expand Up @@ -67,6 +67,7 @@ module Streamly.Internal.Data.Fold.Combinators
-- usually a transformation of the current element rather than an
-- aggregation of all elements till now.
-- , nthLast -- using RingArray array
, rollingMap
, rollingMapM

-- *** Filters
Expand Down Expand Up @@ -242,12 +243,10 @@ import Streamly.Internal.Data.Unfold.Type (Unfold(..))
import qualified Prelude
import qualified Streamly.Internal.Data.MutArray.Type as MA
import qualified Streamly.Internal.Data.Array.Type as Array
import qualified Streamly.Internal.Data.Fold.Type as Fold
import qualified Streamly.Internal.Data.Pipe.Type as Pipe
import qualified Streamly.Internal.Data.RingArray as RingArray
import qualified Streamly.Internal.Data.Scanl.Combinators as Scanl
import qualified Streamly.Internal.Data.Scanl.Type as Scanl
import qualified Streamly.Internal.Data.Scanl.Window as Scanl
import qualified Streamly.Internal.Data.Stream.Type as StreamD

import Prelude hiding
Expand Down Expand Up @@ -500,17 +499,7 @@ pipe (Pipe consume produce pinitial) (Fold fstep finitial fextract ffinal) =
--
{-# INLINE_NORMAL deleteBy #-}
deleteBy :: Monad m => (a -> a -> Bool) -> a -> Fold m a (Maybe a)
deleteBy eq x0 = fmap extract $ foldl' step (Tuple' False Nothing)

where

step (Tuple' False _) x =
if eq x x0
then Tuple' True Nothing
else Tuple' False (Just x)
step (Tuple' True _) x = Tuple' True (Just x)

extract (Tuple' _ x) = x
deleteBy eq = fromScanl . Scanl.deleteBy eq

-- | Provide a sliding window of length 2 elements.
--
Expand Down Expand Up @@ -550,14 +539,7 @@ slide2 (Fold step1 initial1 extract1 final1) = Fold step initial extract final
--
{-# INLINE uniqBy #-}
uniqBy :: Monad m => (a -> a -> Bool) -> Fold m a (Maybe a)
uniqBy eq = rollingMap f

where

f pre curr =
case pre of
Nothing -> Just curr
Just x -> if x `eq` curr then Nothing else Just curr
uniqBy = fromScanl . Scanl.uniqBy

-- | See 'uniqBy'.
--
Expand All @@ -567,7 +549,7 @@ uniqBy eq = rollingMap f
--
{-# INLINE uniq #-}
uniq :: (Monad m, Eq a) => Fold m a (Maybe a)
uniq = uniqBy (==)
uniq = fromScanl Scanl.uniq

-- | Strip all leading and trailing occurrences of an element passing a
-- predicate and make all other consecutive occurrences uniq.
Expand Down Expand Up @@ -628,17 +610,7 @@ drainBy = drainMapM
--
{-# INLINE the #-}
the :: (Monad m, Eq a) => Fold m a (Maybe a)
the = foldt' step initial id

where

initial = Partial Nothing

step Nothing x = Partial (Just x)
step old@(Just x0) x =
if x0 == x
then Partial old
else Done Nothing
the = fromScanl Scanl.the

------------------------------------------------------------------------------
-- To Summary
Expand All @@ -657,7 +629,7 @@ the = foldt' step initial id
--
{-# INLINE sum #-}
sum :: (Monad m, Num a) => Fold m a a
sum = Fold.fromScanl $ Scanl.cumulativeScan Scanl.incrSum
sum = fromScanl Scanl.sum

-- | Determine the product of all elements of a stream of numbers. Returns
-- multiplicative identity (@1@) when the stream is empty. The fold terminates
Expand All @@ -669,14 +641,7 @@ sum = Fold.fromScanl $ Scanl.cumulativeScan Scanl.incrSum
--
{-# INLINE product #-}
product :: (Monad m, Num a, Eq a) => Fold m a a
product = foldt' step (Partial 1) id

where

step x a =
if a == 0
then Done 0
else Partial $ x * a
product = fromScanl Scanl.product

------------------------------------------------------------------------------
-- To Summary (Maybe)
Expand Down Expand Up @@ -761,17 +726,7 @@ range = fromScanl Scanl.range
--
{-# INLINE mean #-}
mean :: (Monad m, Fractional a) => Fold m a a
mean = fmap done $ foldl' step begin

where

begin = Tuple' 0 0

step (Tuple' x n) y =
let n1 = n + 1
in Tuple' (x + (y - x) / n1) n1

done (Tuple' x _) = x
mean = fromScanl Scanl.mean

-- | Compute a numerically stable (population) variance over all elements in
-- the input stream.
Expand Down Expand Up @@ -817,26 +772,15 @@ stdDev = sqrt <$> variance
--
{-# INLINE rollingHashWithSalt #-}
rollingHashWithSalt :: (Monad m, Enum a) => Int64 -> Fold m a Int64
rollingHashWithSalt = foldl' step

where

k = 2891336453 :: Int64

step cksum a = cksum * k + fromIntegral (fromEnum a)

-- | A default salt used in the implementation of 'rollingHash'.
{-# INLINE defaultSalt #-}
defaultSalt :: Int64
defaultSalt = -2578643520546668380
rollingHashWithSalt = fromScanl . Scanl.rollingHashWithSalt

-- | Compute an 'Int' sized polynomial rolling hash of a stream.
--
-- >>> rollingHash = Fold.rollingHashWithSalt Fold.defaultSalt
--
{-# INLINE rollingHash #-}
rollingHash :: (Monad m, Enum a) => Fold m a Int64
rollingHash = rollingHashWithSalt defaultSalt
rollingHash = fromScanl Scanl.rollingHash

-- | Compute an 'Int' sized polynomial rolling hash of the first n elements of
-- a stream.
Expand All @@ -846,7 +790,7 @@ rollingHash = rollingHashWithSalt defaultSalt
-- /Pre-release/
{-# INLINE rollingHashFirstN #-}
rollingHashFirstN :: (Monad m, Enum a) => Int -> Fold m a Int64
rollingHashFirstN n = take n rollingHash
rollingHashFirstN = fromScanl . Scanl.rollingHashFirstN

-- XXX Compare this with the implementation in Fold.Window, preferrably use the
-- latter if performance is good.
Expand All @@ -860,26 +804,14 @@ rollingHashFirstN n = take n rollingHash
--
{-# INLINE rollingMapM #-}
rollingMapM :: Monad m => (Maybe a -> a -> m b) -> Fold m a b
rollingMapM f = Fold step initial extract extract

where

-- XXX We need just a postscan. We do not need an initial result here.
-- Or we can supply a default initial result as an argument to rollingMapM.
initial = return $ Partial (Nothing, error "Empty stream")

step (prev, _) cur = do
x <- f prev cur
return $ Partial (Just cur, x)

extract = return . snd
rollingMapM = fromScanl . Scanl.rollingMapM

-- |
-- >>> rollingMap f = Fold.rollingMapM (\x y -> return $ f x y)
--
{-# INLINE rollingMap #-}
rollingMap :: Monad m => (Maybe a -> a -> b) -> Fold m a b
rollingMap f = rollingMapM (\x y -> return $ f x y)
rollingMap = fromScanl . Scanl.rollingMap

------------------------------------------------------------------------------
-- Monoidal left folds
Expand All @@ -898,7 +830,7 @@ rollingMap f = rollingMapM (\x y -> return $ f x y)
--
{-# INLINE sconcat #-}
sconcat :: (Monad m, Semigroup a) => a -> Fold m a a
sconcat = foldl' (<>)
sconcat = fromScanl . Scanl.sconcat

-- | Monoid concat. Fold an input stream consisting of monoidal elements using
-- 'mappend' and 'mempty'.
Expand All @@ -915,7 +847,7 @@ sconcat = foldl' (<>)
mconcat ::
( Monad m
, Monoid a) => Fold m a a
mconcat = sconcat mempty
mconcat = fromScanl Scanl.mconcat

-- |
-- Definition:
Expand All @@ -931,7 +863,7 @@ mconcat = sconcat mempty
--
{-# INLINE foldMap #-}
foldMap :: (Monad m, Monoid b) => (a -> b) -> Fold m a b
foldMap f = lmap f mconcat
foldMap = fromScanl . Scanl.foldMap

-- |
-- Definition:
Expand All @@ -947,13 +879,7 @@ foldMap f = lmap f mconcat
--
{-# INLINE foldMapM #-}
foldMapM :: (Monad m, Monoid b) => (a -> m b) -> Fold m a b
foldMapM act = foldlM' step (pure mempty)

where

step m a = do
m' <- act a
return $! mappend m m'
foldMapM = fromScanl . Scanl.foldMapM

------------------------------------------------------------------------------
-- Partial Folds
Expand All @@ -969,7 +895,7 @@ foldMapM act = foldlM' step (pure mempty)
-- /Pre-release/
{-# INLINE drainN #-}
drainN :: Monad m => Int -> Fold m a ()
drainN n = take n drain
drainN = fromScanl . Scanl.drainN

------------------------------------------------------------------------------
-- To Elements
Expand Down Expand Up @@ -1134,16 +1060,7 @@ findIndex predicate = foldt' step (Partial 0) (const Nothing)
--
{-# INLINE findIndices #-}
findIndices :: Monad m => (a -> Bool) -> Fold m a (Maybe Int)
findIndices predicate =
-- XXX implement by combining indexing and filtering scans
fmap (either (const Nothing) Just) $ foldl' step (Left (-1))

where

step i a =
if predicate a
then Right (either id id i + 1)
else Left (either id id i + 1)
findIndices = fromScanl . Scanl.findIndices

-- | Returns the index of the latest element if the element matches the given
-- value.
Expand All @@ -1154,7 +1071,7 @@ findIndices predicate =
--
{-# INLINE elemIndices #-}
elemIndices :: (Monad m, Eq a) => a -> Fold m a (Maybe Int)
elemIndices a = findIndices (== a)
elemIndices = fromScanl . Scanl.elemIndices

-- | Returns the first index where a given value is found in the stream.
--
Expand Down Expand Up @@ -2256,7 +2173,7 @@ chunksBetween _low _high _f1 _f2 = undefined
-- /Pre-release/
{-# INLINE toStream #-}
toStream :: (Monad m, Monad n) => Fold m a (Stream n a)
toStream = fmap StreamD.fromList toList
toStream = fromScanl Scanl.toStream

-- This is more efficient than 'toStream'. toStream is exactly the same as
-- reversing the stream after toStreamRev.
Expand All @@ -2274,7 +2191,7 @@ toStream = fmap StreamD.fromList toList
-- xn : ... : x2 : x1 : []
{-# INLINE toStreamRev #-}
toStreamRev :: (Monad m, Monad n) => Fold m a (Stream n a)
toStreamRev = fmap StreamD.fromList toListRev
toStreamRev = fromScanl Scanl.toStreamRev

-- XXX This does not fuse. It contains a recursive step function. We will need
-- a Skip input constructor in the fold type to make it fuse.
Expand Down Expand Up @@ -2316,32 +2233,7 @@ bottomBy :: (MonadIO m, Unbox a) =>
(a -> a -> Ordering)
-> Int
-> Fold m a (MutArray a)
bottomBy cmp n = Fold step initial extract extract

where

initial = do
arr <- MA.emptyOf' n
if n <= 0
then return $ Done arr
else return $ Partial (arr, 0)

step (arr, i) x =
if i < n
then do
arr' <- MA.snoc arr x
MA.bubble cmp arr'
return $ Partial (arr', i + 1)
else do
x1 <- MA.unsafeGetIndex (i - 1) arr
case x `cmp` x1 of
LT -> do
MA.unsafePutIndex (i - 1) arr x
MA.bubble cmp arr
return $ Partial (arr, i)
_ -> return $ Partial (arr, i)

extract = return . fst
bottomBy cmp = fromScanl . Scanl.bottomBy cmp

-- | Get the top @n@ elements using the supplied comparison function.
--
Expand Down Expand Up @@ -2377,7 +2269,7 @@ topBy cmp = bottomBy (flip cmp)
-- /Pre-release/
{-# INLINE top #-}
top :: (MonadIO m, Unbox a, Ord a) => Int -> Fold m a (MutArray a)
top = bottomBy $ flip compare
top = fromScanl . Scanl.top

-- | Fold the input stream to bottom n elements.
--
Expand All @@ -2392,7 +2284,7 @@ top = bottomBy $ flip compare
-- /Pre-release/
{-# INLINE bottom #-}
bottom :: (MonadIO m, Unbox a, Ord a) => Int -> Fold m a (MutArray a)
bottom = bottomBy compare
bottom = fromScanl . Scanl.bottom

------------------------------------------------------------------------------
-- Interspersed parsing
Expand Down
Loading

0 comments on commit 44616ee

Please sign in to comment.