Skip to content

Commit

Permalink
Optimize Generic.Mutable.nextPermutation
Browse files Browse the repository at this point in the history
This implements some optimization of `nextPermutation` from
`Data.Vector.Generic.Mutable`. The main content of this
re-implementation is the following two points:

1. Wrapping the whole implementation in `stToPrim`. This
   allows the compiler to optimize the code better.
2. When finding the rightmost increasing pair v[k]<v[k+1], we now
   search from the right, instead of from the left. This allows us to
   abort the search as soon as we find such a pair, giving
   average-case constant performance, instead of best-case linear
   in the previous implementation.
  • Loading branch information
gksato committed Jul 21, 2024
1 parent 4ac750f commit 272fa34
Showing 1 changed file with 49 additions and 20 deletions.
69 changes: 49 additions & 20 deletions vector/src/Data/Vector/Generic/Mutable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ import Prelude
( Ord, Monad, Bool(..), Int, Maybe(..), Either(..)
, return, otherwise, flip, const, seq, min, max, not, pure
, (>>=), (+), (-), (<), (<=), (>=), (==), (/=), (.), ($), (=<<), (>>), (<$>) )
import Data.Bits ( Bits(shiftR) )

#include "vector.h"

Expand Down Expand Up @@ -1213,6 +1214,16 @@ partitionWithUnknown f s
-- Modifying vectors
-- -----------------


-- | Compute the (lexicographically) next permutation of the given vector in-place.
-- Returns False when the input is the last permutation; in this case the vector
-- will not get updated, as opposed to the behavior of the C++ function
-- @std::next_permutation@.
nextPermutation :: (PrimMonad m, Ord e, MVector v e) => v (PrimState m) e -> m Bool
{-# INLINE nextPermutation #-}
nextPermutation = nextPermutationByLt (<)


{-
http://en.wikipedia.org/wiki/Permutation#Algorithms_to_generate_permutations
Expand All @@ -1224,32 +1235,50 @@ a given permutation. It changes the given permutation in-place.
2. Find the largest index l greater than k such that a[k] < a[l].
3. Swap the value of a[k] with that of a[l].
4. Reverse the sequence from a[k + 1] up to and including the final element a[n]
The algorithm has been updated to look up the k in Step 1 beginning from the
last of the vector; which renders the algorithm to achieve the average time
complexity of O(1) each call. The worst case time complexity is still O(n).
The orginal implementation, which scanned the vector from the left, had the
time complexity of O(n) on the best case.
-}

-- | Compute the (lexicographically) next permutation of the given vector in-place.
-- Here, the first argument should be a less-than comparison function.
-- Returns False when the input is the last permutation; in this case the vector
-- will not get updated, as opposed to the behavior of the C++ function
-- @std::next_permutation@.
nextPermutation :: (PrimMonad m,Ord e,MVector v e) => v (PrimState m) e -> m Bool
nextPermutation v
| dim < 2 = return False
| otherwise = do
val <- unsafeRead v 0
(k,l) <- loop val (-1) 0 val 1
if k < 0
then return False
else unsafeSwap v k l >>
reverse (unsafeSlice (k+1) (dim-k-1) v) >>
return True
where loop !kval !k !l !prev !i
| i == dim = return (k,l)
| otherwise = do
cur <- unsafeRead v i
-- TODO: make tuple unboxed
let (kval',k') = if prev < cur then (prev,i-1) else (kval,k)
l' = if kval' < cur then i else l
loop kval' k' l' cur (i+1)
dim = length v
nextPermutationByLt :: (PrimMonad m, MVector v e) => (e -> e -> Bool) -> v (PrimState m) e -> m Bool
nextPermutationByLt lt v
| dim < 2 = return False
| otherwise = stToPrim $ do
!vlast <- unsafeRead v (dim - 1)
decrLoop (dim - 2) vlast
where
dim = length v
-- find the largest index k such that a[k] < a[k + 1], and then pass to the rest.
decrLoop !i !vi1 | i >= 0 = do
!vi <- unsafeRead v i
if vi `lt` vi1 then swapLoop i vi (i+1) vi1 dim else decrLoop (i-1) vi
decrLoop _ !_ = return False
-- find the largest index l greater than k such that a[k] < a[l], and do the rest.
swapLoop !k !vk = go
where
-- binary search.
go !l !vl !r | r - l <= 1 = do
-- Done; do the rest of the algorithm.
unsafeWrite v k vl
unsafeWrite v l vk
reverse $ unsafeSlice (k + 1) (dim - k - 1) v
return True
go !l !vl !r = do
!vmid <- unsafeRead v mid
if vk `lt` vmid
then go mid vmid r
else go l vl mid
where
!mid = l + (r - l) `shiftR` 1


-- $setup
-- >>> import Prelude ((*))

0 comments on commit 272fa34

Please sign in to comment.