Skip to content

Commit

Permalink
Add generate, modifyM & folds for mutable vectors (#338)
Browse files Browse the repository at this point in the history
Add generate, modifyM & folds for mutable vectors:

* Added `generate`, `generateM` for mutable vectors
* Added `modifyM` and `unsafeModifyM` for mutable vectors
* Add all variants of folds for mutable vectors:

 *  `mapM_`, `imapM_`, `forM_`, `iforM_`
 * `foldl`, `foldl'`, `foldM`, `foldM'`,
 * `foldr, `foldr'`, `foldrM`, `foldrM'`,
 * `ifoldl`, `ifoldl'`, `ifoldM`, `ifoldM'`,
 * `ifoldr`, `ifoldr'`, `ifoldrM, `ifoldrM'`

* Add tests for all new functions

* Update changelog
  • Loading branch information
Shimuuar authored and lehins committed Apr 1, 2021
1 parent 25a09a7 commit 5c4fcde
Show file tree
Hide file tree
Showing 8 changed files with 946 additions and 24 deletions.
184 changes: 180 additions & 4 deletions Data/Vector/Generic/Mutable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ module Data.Vector.Generic.Mutable (
-- * Construction

-- ** Initialisation
new, unsafeNew, replicate, replicateM, clone,
new, unsafeNew, replicate, replicateM, generate, generateM, clone,

-- ** Growing
grow, unsafeGrow,
Expand All @@ -40,8 +40,15 @@ module Data.Vector.Generic.Mutable (
clear,

-- * Accessing individual elements
read, write, modify, swap, exchange,
unsafeRead, unsafeWrite, unsafeModify, unsafeSwap, unsafeExchange,
read, write, modify, modifyM, swap, exchange,
unsafeRead, unsafeWrite, unsafeModify, unsafeModifyM, unsafeSwap, unsafeExchange,

-- * Folds
mapM_, imapM_, forM_, iforM_,
foldl, foldl', foldM, foldM',
foldr, foldr', foldrM, foldrM',
ifoldl, ifoldl', ifoldM, ifoldM',
ifoldr, ifoldr', ifoldrM, ifoldrM',

-- * Modifying vectors
nextPermutation,
Expand Down Expand Up @@ -74,7 +81,7 @@ import Data.Vector.Fusion.Util ( delay_inline )
import Control.Monad.Primitive ( PrimMonad, PrimState )

import Prelude hiding ( length, null, replicate, reverse, map, read,
take, drop, splitAt, init, tail )
take, drop, splitAt, init, tail, mapM_, foldr, foldl )

#include "vector.h"

Expand Down Expand Up @@ -616,6 +623,26 @@ replicateM :: (PrimMonad m, MVector v a) => Int -> m a -> m (v (PrimState m) a)
{-# INLINE replicateM #-}
replicateM n m = munstream (MBundle.replicateM n m)

-- | /O(n)/ Create a mutable vector of the given length (0 if the length is negative)
-- and fill it with the results of applying the function to each index.
generate :: (PrimMonad m, MVector v a) => Int -> (Int -> a) -> m (v (PrimState m) a)
{-# INLINE generate #-}
generate n f = stToPrim $ generateM n (return . f)

-- | /O(n)/ Create a mutable vector of the given length (0 if the length is
-- negative) and fill it with the results of applying the monadic function to each
-- index. Iteration starts at index 0.
generateM :: (PrimMonad m, MVector v a) => Int -> (Int -> m a) -> m (v (PrimState m) a)
{-# INLINE generateM #-}
generateM n f
| n <= 0 = new 0
| otherwise = do
vec <- new n
let loop i | i >= n = return vec
| otherwise = do unsafeWrite vec i =<< f i
loop (i + 1)
loop 0

-- | Create a copy of a mutable vector.
clone :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m (v (PrimState m) a)
{-# INLINE clone #-}
Expand Down Expand Up @@ -755,6 +782,12 @@ modify :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (a -> a) -> Int ->
modify v f i = BOUNDS_CHECK(checkIndex) "modify" i (length v)
$ unsafeModify v f i

-- | Modify the element at the given position using a monadic function.
modifyM :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (a -> m a) -> Int -> m ()
{-# INLINE modifyM #-}
modifyM v f i = BOUNDS_CHECK(checkIndex) "modifyM" i (length v)
$ unsafeModifyM v f i

-- | Swap the elements at the given positions.
swap :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> Int -> m ()
{-# INLINE swap #-}
Expand Down Expand Up @@ -788,6 +821,13 @@ unsafeModify v f i = UNSAFE_CHECK(checkIndex) "unsafeModify" i (length v)
$ basicUnsafeRead v i >>= \x ->
basicUnsafeWrite v i (f x)

-- | Modify the element at the given position using a monadic
-- function. No bounds checks are performed.
unsafeModifyM :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (a -> m a) -> Int -> m ()
{-# INLINE unsafeModifyM #-}
unsafeModifyM v f i = UNSAFE_CHECK(checkIndex) "unsafeModifyM" i (length v)
$ stToPrim . basicUnsafeWrite v i =<< f =<< stToPrim (basicUnsafeRead v i)

-- | Swap the elements at the given positions. No bounds checks are performed.
unsafeSwap :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> Int -> m ()
Expand All @@ -811,6 +851,142 @@ unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
unsafeWrite v i x
return y

-- Folds
-- -----

forI_ :: (Monad m, MVector v a) => v (PrimState m) a -> (Int -> m b) -> m ()
{-# INLINE forI_ #-}
forI_ v f = loop 0
where
loop i | i >= n = return ()
| otherwise = f i >> loop (i + 1)
n = length v

-- | /O(n)/ Apply the monadic action to every element of the vector, discarding the results.
mapM_ :: (PrimMonad m, MVector v a) => (a -> m b) -> v (PrimState m) a -> m ()
{-# INLINE mapM_ #-}
mapM_ f v = forI_ v $ \i -> f =<< unsafeRead v i

-- | /O(n)/ Apply the monadic action to every element of the vector and its index, discarding the results.
imapM_ :: (PrimMonad m, MVector v a) => (Int -> a -> m b) -> v (PrimState m) a -> m ()
{-# INLINE imapM_ #-}
imapM_ f v = forI_ v $ \i -> f i =<< unsafeRead v i

-- | /O(n)/ Apply the monadic action to every element of the vector,
-- discarding the results. It's same as the @flip mapM_@.
forM_ :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (a -> m b) -> m ()
{-# INLINE forM_ #-}
forM_ = flip mapM_

-- | /O(n)/ Apply the monadic action to every element of the vector
-- and its index, discarding the results. It's same as the @flip imapM_@.
iforM_ :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (Int -> a -> m b) -> m ()
{-# INLINE iforM_ #-}
iforM_ = flip imapM_

-- | /O(n)/ Pure left fold.
foldl :: (PrimMonad m, MVector v a) => (b -> a -> b) -> b -> v (PrimState m) a -> m b
{-# INLINE foldl #-}
foldl f = ifoldl (\b _ -> f b)

-- | /O(n)/ Pure left fold with strict accumulator.
foldl' :: (PrimMonad m, MVector v a) => (b -> a -> b) -> b -> v (PrimState m) a -> m b
{-# INLINE foldl' #-}
foldl' f = ifoldl' (\b _ -> f b)

-- | /O(n)/ Pure left fold (function applied to each element and its index).
ifoldl :: (PrimMonad m, MVector v a) => (b -> Int -> a -> b) -> b -> v (PrimState m) a -> m b
{-# INLINE ifoldl #-}
ifoldl f b0 v = stToPrim $ ifoldM (\b i a -> return $ f b i a) b0 v

-- | /O(n)/ Pure left fold with strict accumulator (function applied to each element and its index).
ifoldl' :: (PrimMonad m, MVector v a) => (b -> Int -> a -> b) -> b -> v (PrimState m) a -> m b
{-# INLINE ifoldl' #-}
ifoldl' f b0 v = stToPrim $ ifoldM' (\b i a -> return $ f b i a) b0 v

-- | /O(n)/ Pure right fold.
foldr :: (PrimMonad m, MVector v a) => (a -> b -> b) -> b -> v (PrimState m) a -> m b
{-# INLINE foldr #-}
foldr f = ifoldr (const f)

-- | /O(n)/ Pure right fold with strict accumulator.
foldr' :: (PrimMonad m, MVector v a) => (a -> b -> b) -> b -> v (PrimState m) a -> m b
{-# INLINE foldr' #-}
foldr' f = ifoldr' (const f)

-- | /O(n)/ Pure right fold (function applied to each element and its index).
ifoldr :: (PrimMonad m, MVector v a) => (Int -> a -> b -> b) -> b -> v (PrimState m) a -> m b
{-# INLINE ifoldr #-}
ifoldr f b0 v = stToPrim $ ifoldrM (\i a b -> return $ f i a b) b0 v

-- | /O(n)/ Pure right fold with strict accumulator (function applied
-- to each element and its index).
ifoldr' :: (PrimMonad m, MVector v a) => (Int -> a -> b -> b) -> b -> v (PrimState m) a -> m b
{-# INLINE ifoldr' #-}
ifoldr' f b0 v = stToPrim $ ifoldrM' (\i a b -> return $ f i a b) b0 v

-- | /O(n)/ Monadic fold.
foldM :: (PrimMonad m, MVector v a) => (b -> a -> m b) -> b -> v (PrimState m) a -> m b
{-# INLINE foldM #-}
foldM f = ifoldM (\x _ -> f x)

-- | /O(n)/ Monadic fold with strict accumulator.
foldM' :: (PrimMonad m, MVector v a) => (b -> a -> m b) -> b -> v (PrimState m) a -> m b
{-# INLINE foldM' #-}
foldM' f = ifoldM' (\x _ -> f x)

-- | /O(n)/ Monadic fold (action applied to each element and its index).
ifoldM :: (PrimMonad m, MVector v a) => (b -> Int -> a -> m b) -> b -> v (PrimState m) a -> m b
{-# INLINE ifoldM #-}
ifoldM f b0 v = loop 0 b0
where
loop i b | i >= n = return b
| otherwise = do a <- unsafeRead v i
loop (i + 1) =<< f b i a
n = length v

-- | /O(n)/ Monadic fold with strict accumulator (action applied to each element and its index).
ifoldM' :: (PrimMonad m, MVector v a) => (b -> Int -> a -> m b) -> b -> v (PrimState m) a -> m b
{-# INLINE ifoldM' #-}
ifoldM' f b0 v = loop 0 b0
where
loop i !b | i >= n = return b
| otherwise = do a <- unsafeRead v i
loop (i + 1) =<< f b i a
n = length v

-- | /O(n)/ Monadic right fold.
foldrM :: (PrimMonad m, MVector v a) => (a -> b -> m b) -> b -> v (PrimState m) a -> m b
{-# INLINE foldrM #-}
foldrM f = ifoldrM (const f)

-- | /O(n)/ Monadic right fold with strict accumulator.
foldrM' :: (PrimMonad m, MVector v a) => (a -> b -> m b) -> b -> v (PrimState m) a -> m b
{-# INLINE foldrM' #-}
foldrM' f = ifoldrM' (const f)

-- | /O(n)/ Monadic right fold (action applied to each element and its index).
ifoldrM :: (PrimMonad m, MVector v a) => (Int -> a -> b -> m b) -> b -> v (PrimState m) a -> m b
{-# INLINE ifoldrM #-}
ifoldrM f b0 v = loop (n-1) b0
where
loop i b | i < 0 = return b
| otherwise = do a <- unsafeRead v i
loop (i - 1) =<< f i a b
n = length v

-- | /O(n)/ Monadic right fold with strict accumulator (action applied
-- to each element and its index).
ifoldrM' :: (PrimMonad m, MVector v a) => (Int -> a -> b -> m b) -> b -> v (PrimState m) a -> m b
{-# INLINE ifoldrM' #-}
ifoldrM' f b0 v = loop (n-1) b0
where
loop i !b | i < 0 = return b
| otherwise = do a <- unsafeRead v i
loop (i - 1) =<< f i a b
n = length v


-- Filling and copying
-- -------------------

Expand Down
Loading

0 comments on commit 5c4fcde

Please sign in to comment.