Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Applicative based API #522

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions vector-bench-papi/benchmarks/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import Bench.Vector.Algo.Spectral (spectral)
import Bench.Vector.Algo.Tridiag (tridiag)
import Bench.Vector.Algo.FindIndexR (findIndexR, findIndexR_naive, findIndexR_manual)
import Bench.Vector.Algo.NextPermutation (generatePermTests)
import Bench.Vector.Algo.Applicative ( generateState, generateStateUnfold, generateIO, generateIOPrim
, lensSum, lensMap, baselineSum, baselineMap)

import Bench.Vector.TestData.ParenTree (parenTree)
import Bench.Vector.TestData.Graph (randomGraph)
Expand Down Expand Up @@ -68,4 +70,14 @@ main = do
, bench "minimumOn" $ whnf (U.minimumOn (\x -> x*x*x)) as
, bench "maximumOn" $ whnf (U.maximumOn (\x -> x*x*x)) as
, bgroup "(next|prev)Permutation" $ map (\(name, act) -> bench name $ whnfIO act) permTests
, bgroup "Applicative"
[ bench "generateState" $ whnf generateState useSize
, bench "generateStateUnfold" $ whnf generateStateUnfold useSize
, bench "generateIO" $ whnfIO (generateIO useSize)
, bench "generateIOPrim" $ whnfIO (generateIOPrim useSize)
, bench "sum[lens]" $ whnf lensSum as
, bench "sum[base]" $ whnf baselineSum as
, bench "map[lens]" $ whnf lensMap as
, bench "map[base]" $ whnf baselineMap as
]
]
101 changes: 101 additions & 0 deletions vector/benchlib/Bench/Vector/Algo/Applicative.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- This module provides benchmarks for functions which use API based
-- on applicative. We use @generateA@ based benchmark for state and IO
-- and also benchmark folds and mapping using lens since it's one of
-- important consumers of this API.
module Bench.Vector.Algo.Applicative
( -- * Standard benchmarks
generateState
, generateStateUnfold
, generateIO
, generateIOPrim
-- * Lens benchmarks
, lensSum
, baselineSum
, lensMap
, baselineMap
) where

import Control.Applicative
import Data.Coerce
import Data.Functor.Identity
import Data.Int
import Data.Monoid
import Data.Word
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as MVG
import qualified Data.Vector.Unboxed as VU
import System.Random.Stateful
import System.Mem (getAllocationCounter)

-- | Benchmark which is running in state monad.
generateState :: Int -> VU.Vector Word64
generateState n
= runStateGen_ (mkStdGen 42)
$ \g -> VG.generateA n (\_ -> uniformM g)

-- | Benchmark which is running in state monad.
generateStateUnfold :: Int -> VU.Vector Word64
generateStateUnfold n = VU.unfoldrExactN n genWord64 (mkStdGen 42)

-- | Benchmark for running @generateA@ in IO monad.
generateIO :: Int -> IO (VU.Vector Int64)
generateIO n = VG.generateA n (\_ -> getAllocationCounter)

-- | Baseline for 'generateIO' it uses primitive operations
generateIOPrim :: Int -> IO (VU.Vector Int64)
generateIOPrim n = VG.unsafeFreeze =<< MVG.replicateM n getAllocationCounter

-- | Sum using lens
lensSum :: VU.Vector Double -> Double
{-# NOINLINE lensSum #-}
lensSum = foldlOf' VG.traverse (+) 0

-- | Baseline for sum.
baselineSum :: VU.Vector Double -> Double
{-# NOINLINE baselineSum #-}
baselineSum = VU.sum

-- | Mapping over vector elements using
lensMap :: VU.Vector Double -> VU.Vector Double
{-# NOINLINE lensMap #-}
lensMap = over VG.traverse (*2)

-- | Baseline for map
baselineMap :: VU.Vector Double -> VU.Vector Double
{-# NOINLINE baselineMap #-}
baselineMap = VU.map (*2)

----------------------------------------------------------------
-- Bits and pieces of lens
--
-- We don't want to depend on lens so we just copy relevant
-- parts. After all we don't need much
----------------------------------------------------------------

type ASetter s t a b = (a -> Identity b) -> s -> Identity t
type Getting r s a = (a -> Const r a) -> s -> Const r s

foldlOf' :: Getting (Endo (Endo r)) s a -> (r -> a -> r) -> r -> s -> r
foldlOf' l f z0 = \xs ->
let f' x (Endo k) = Endo $ \z -> k $! f z x
in foldrOf l f' (Endo id) xs `appEndo` z0
{-# INLINE foldlOf' #-}

foldrOf :: Getting (Endo r) s a -> (a -> r -> r) -> r -> s -> r
foldrOf l f z = flip appEndo z . foldMapOf l (Endo #. f)
{-# INLINE foldrOf #-}

foldMapOf :: Getting r s a -> (a -> r) -> s -> r
foldMapOf = coerce
{-# INLINE foldMapOf #-}

( #. ) :: Coercible c b => (b -> c) -> (a -> b) -> (a -> c)
( #. ) _ = coerce (\x -> x :: b) :: forall a b. Coercible b a => a -> b
{-# INLINE (#.) #-}

over :: ASetter s t a b -> (a -> b) -> s -> t
over = coerce
{-# INLINE over #-}
13 changes: 13 additions & 0 deletions vector/benchmarks/Main.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE BangPatterns #-}
module Main where


import Bench.Vector.Algo.MutableSet (mutableSet)
import Bench.Vector.Algo.ListRank (listRank)
import Bench.Vector.Algo.Rootfix (rootfix)
Expand All @@ -12,6 +13,8 @@ import Bench.Vector.Algo.Spectral (spectral)
import Bench.Vector.Algo.Tridiag (tridiag)
import Bench.Vector.Algo.FindIndexR (findIndexR, findIndexR_naive, findIndexR_manual)
import Bench.Vector.Algo.NextPermutation (generatePermTests)
import Bench.Vector.Algo.Applicative ( generateState, generateStateUnfold, generateIO, generateIOPrim
, lensSum, lensMap, baselineSum, baselineMap)

import Bench.Vector.TestData.ParenTree (parenTree)
import Bench.Vector.TestData.Graph (randomGraph)
Expand Down Expand Up @@ -69,4 +72,14 @@ main = do
, bench "minimumOn" $ whnf (U.minimumOn (\x -> x*x*x)) as
, bench "maximumOn" $ whnf (U.maximumOn (\x -> x*x*x)) as
, bgroup "(next|prev)Permutation" $ map (\(name, act) -> bench name $ whnfIO act) permTests
, bgroup "Applicative"
[ bench "generateState" $ whnf generateState useSize
, bench "generateStateUnfold" $ whnf generateStateUnfold useSize
, bench "generateIO" $ whnfIO (generateIO useSize)
, bench "generateIOPrim" $ whnfIO (generateIOPrim useSize)
, bench "sum[lens]" $ whnf lensSum as
, bench "sum[base]" $ whnf baselineSum as
, bench "map[lens]" $ whnf lensMap as
, bench "map[base]" $ whnf baselineMap as
]
]
60 changes: 56 additions & 4 deletions vector/src/Data/Vector/Generic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ module Data.Vector.Generic (
scanr, scanr', scanr1, scanr1',
iscanr, iscanr',

-- * Applicative API
replicateA, generateA, traverse, itraverse,

-- * Conversions

-- ** Lists
Expand Down Expand Up @@ -197,10 +200,10 @@ import Data.Vector.Internal.Check
import Control.Monad.ST ( ST, runST )
import Control.Monad.Primitive
import Prelude
( Eq, Ord, Num, Enum, Monoid, Monad, Read, Show, Bool, Ordering(..), Int, Maybe(..), Either, IO, ShowS, ReadS, String
( Eq(..), Ord(..), Num, Enum, Monoid, Applicative(..), Monad, Read, Show, Bool, Ordering(..)
, Int, Maybe(..), Either, IO, ShowS, ReadS, String
, compare, mempty, mappend, return, fmap, otherwise, id, flip, seq, error, undefined, uncurry, shows, fst, snd, min, max, not
, (>>=), (+), (-), (*), (<), (==), (.), ($), (=<<), (>>), (<$>) )

, (>>=), (+), (-), (*), (.), ($), (=<<), (>>), (<$>))
import qualified Text.Read as Read
import qualified Data.List.NonEmpty as NonEmpty

Expand All @@ -210,7 +213,7 @@ import Data.Typeable ( Typeable, gcast1 )

import Data.Data ( Data, DataType, Constr, Fixity(Prefix),
mkDataType, mkConstr, constrIndex, mkNoRepType )
import qualified Data.Traversable as T (Traversable(mapM))
import qualified Data.Traversable as T (Traversable(mapM,traverse))

-- Length information
-- ------------------
Expand Down Expand Up @@ -2641,6 +2644,55 @@ clone v = v `seq` New.create (
unsafeCopy mv v
return mv)

-- Applicatives
-- ------------



newtype STA v a = STA {
_runSTA :: forall s. Mutable v s a -> ST s (v a)
}

runSTA :: Vector v a => Int -> STA v a -> v a
runSTA !sz = \(STA fun) -> runST $ fun =<< M.unsafeNew sz
{-# INLINE runSTA #-}




-- | Execute the applicative action the given number of times and store the
-- results in a vector.
replicateA :: (Vector v a, Applicative f) => Int -> f a -> f (v a)
{-# INLINE replicateA #-}
replicateA n f = generateA n (\_ -> f)


-- | Construct a vector of the given length by applying the applicative
-- action to each index.
generateA :: (Vector v a, Applicative f) => Int -> (Int -> f a) -> f (v a)
{-# INLINE generateA #-}
generateA 0 _ = pure empty
generateA n f = runSTA n <$> go 0
where
go !i | i >= n = pure $ STA unsafeFreeze
| otherwise = (\a (STA m) -> STA $ \mv -> M.unsafeWrite mv i a >> m mv)
<$> f i
<*> go (i + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably better to use liftA2 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably. But I'm not sure that STA will survive maybe some New-like variant will perform better


-- | Apply the applicative action to all elements of the vector, yielding a
-- vector of results.
traverse :: (Vector v a, Vector v b, Applicative f)
=> (a -> f b) -> v a -> f (v b)
{-# INLINE traverse #-}
traverse f v = generateA (length v) $ \i -> f (unsafeIndex v i)

-- | Apply the applicative action to every element of a vector and its
-- index, yielding a vector of results.
itraverse :: (Vector v a, Vector v b, Applicative f)
=> (Int -> a -> f b) -> v a -> f (v b)
{-# INLINE itraverse #-}
itraverse f v = generateA (length v) $ \i -> f i (unsafeIndex v i)

-- Comparisons
-- -----------

Expand Down
1 change: 1 addition & 0 deletions vector/vector.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ library benchmarks-O2
Bench.Vector.Algo.Quickhull
Bench.Vector.Algo.Spectral
Bench.Vector.Algo.Tridiag
Bench.Vector.Algo.Applicative
Bench.Vector.Algo.FindIndexR
Bench.Vector.Algo.NextPermutation
Bench.Vector.TestData.ParenTree
Expand Down
Loading