From 34fd8090195cdfa287551844ec83ded896f62e7f Mon Sep 17 00:00:00 2001 From: martyall Date: Mon, 29 Jan 2024 22:54:15 -0800 Subject: [PATCH 1/4] zip/unzip --- src/Snarkl/Language/Vector.hs | 45 +++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/Snarkl/Language/Vector.hs b/src/Snarkl/Language/Vector.hs index 151f8c3..7a370f7 100644 --- a/src/Snarkl/Language/Vector.hs +++ b/src/Snarkl/Language/Vector.hs @@ -17,6 +17,8 @@ module Snarkl.Language.Vector transpose, all, any, + unzip, + zip, ) where @@ -27,7 +29,7 @@ import Data.Typeable (Proxy (Proxy), Typeable) import Snarkl.AST.TExpr (TExp (TEApp)) import Snarkl.Language.Prelude ( Comp, - Ty (TArr, TBool, TFun, TUnit, TVec), + Ty (TArr, TBool, TFun, TProd, TUnit, TVec), apply, arr, arr2, @@ -35,8 +37,11 @@ import Snarkl.Language.Prelude false, forall, forall2, + fst_pair, lambda, + pair, return, + snd_pair, true, unsafe_cast, (&&), @@ -44,7 +49,7 @@ import Snarkl.Language.Prelude (||), ) import qualified Snarkl.Language.Prelude as Snarkl -import Prelude hiding (all, any, concat, foldl, map, return, traverse, (&&), (*), (>>=), (||)) +import Prelude hiding (all, any, concat, foldl, map, return, traverse, unzip, zip, (&&), (*), (>>=), (||)) import qualified Prelude as P type Vector = 'TVec @@ -239,3 +244,39 @@ any as = do lambda $ \x -> return $ acc || x foldl f false as + +unzip :: + forall (n :: Nat) a b k. + (SNatI n) => + (Typeable a) => + (Typeable b) => + (Typeable n) => + TExp (Vector n ('TProd a b)) k -> + Comp ('TProd (Vector n a) (Vector n b)) k +unzip ps = do + as <- vec @n + bs <- vec @n + _ <- forall (universe @n) $ \i -> do + p <- get (ps, i) + a <- fst_pair p + b <- snd_pair p + _ <- set (as, i) a + set (bs, i) b + pair as bs + +zip :: + forall (n :: Nat) a b k. + (SNatI n) => + (Typeable a) => + (Typeable b) => + TExp (Vector n a) k -> + TExp (Vector n b) k -> + Comp (Vector n ('TProd a b)) k +zip as bs = do + ps <- vec @n + _ <- forall (universe @n) $ \i -> do + a <- get (as, i) + b <- get (bs, i) + p <- pair a b + set (ps, i) p + return ps \ No newline at end of file From 04fbcfb8c07d97e155860ae6d98dea976bdcc7f3 Mon Sep 17 00:00:00 2001 From: martyall Date: Mon, 29 Jan 2024 23:23:40 -0800 Subject: [PATCH 2/4] zippable vec instance --- src/Snarkl/Language/Prelude.hs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Snarkl/Language/Prelude.hs b/src/Snarkl/Language/Prelude.hs index a6288ae..a3c2b7f 100644 --- a/src/Snarkl/Language/Prelude.hs +++ b/src/Snarkl/Language/Prelude.hs @@ -599,6 +599,9 @@ instance return $ TEApp e2 x zip_vals b y1 y2 +instance Zippable ('TVec n ty) k where + zip_vals _ x _ = return x + ---------------------------------------------------- -- -- Recursive Types From af375c7c123f062fd955ad7c8d25cf8e98cd6f71 Mon Sep 17 00:00:00 2001 From: martyall Date: Mon, 29 Jan 2024 23:29:50 -0800 Subject: [PATCH 3/4] derive instance --- src/Snarkl/Language/Prelude.hs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/Snarkl/Language/Prelude.hs b/src/Snarkl/Language/Prelude.hs index a3c2b7f..3d6614a 100644 --- a/src/Snarkl/Language/Prelude.hs +++ b/src/Snarkl/Language/Prelude.hs @@ -404,6 +404,11 @@ instance (Typeable ty, Derive ty k) => Derive ('TArr ty) k where _ <- set (a, 0) v return a +instance (Typeable ty, Derive ty k) => Derive ('TVec n ty) k where + derive n = do + a :: TExp ('TArr ty) k <- derive n + return $ unsafe_cast a + instance ( Typeable ty1, Derive ty1 k, From 57be9177576b5c6dbf810ca197895f0cd06f02f2 Mon Sep 17 00:00:00 2001 From: martyall Date: Tue, 30 Jan 2024 21:12:19 -0800 Subject: [PATCH 4/4] add tabulate, use in tutorial --- src/Snarkl/Language/Vector.hs | 32 +++++++++++++++++++++++--------- tutorial/sudoku/Sudoku.md | 7 ++----- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/Snarkl/Language/Vector.hs b/src/Snarkl/Language/Vector.hs index 7a370f7..9243651 100644 --- a/src/Snarkl/Language/Vector.hs +++ b/src/Snarkl/Language/Vector.hs @@ -19,6 +19,7 @@ module Snarkl.Language.Vector any, unzip, zip, + tabulate, ) where @@ -99,7 +100,7 @@ map :: Comp (Vector n b) k map f a = do b <- vec - _ <- forall (universe @n) $ \i -> do + _ <- forall universe $ \i -> do ai <- get (a, i) bi <- apply f ai set (b, i) bi @@ -115,7 +116,7 @@ foldl :: TExp (Vector n a) k -> Comp b k foldl f b0 as = do - go (universe @n) b0 + go universe b0 where go ns acc = case ns of [] -> return acc @@ -134,7 +135,7 @@ traverse :: Comp (Vector n b) k traverse f as = do bs <- vec - _ <- forall (universe @n) $ \i -> do + _ <- forall universe $ \i -> do ai <- get (as, i) bi <- f ai set (bs, i) bi @@ -150,7 +151,7 @@ traverseWithIndex :: Comp (Vector n b) k traverseWithIndex f as = do bs <- vec - _ <- forall (universe @n) $ \i -> do + _ <- forall universe $ \i -> do ai <- get (as, i) bi <- f i ai set (bs, i) bi @@ -164,7 +165,7 @@ traverse_ :: TExp (Vector n a) k -> Comp 'TUnit k traverse_ f as = do - forall (universe @n) $ \i -> do + forall universe $ \i -> do ai <- get (as, i) f ai @@ -256,7 +257,7 @@ unzip :: unzip ps = do as <- vec @n bs <- vec @n - _ <- forall (universe @n) $ \i -> do + _ <- forall universe $ \i -> do p <- get (ps, i) a <- fst_pair p b <- snd_pair p @@ -273,10 +274,23 @@ zip :: TExp (Vector n b) k -> Comp (Vector n ('TProd a b)) k zip as bs = do - ps <- vec @n - _ <- forall (universe @n) $ \i -> do + ps <- vec + _ <- forall universe $ \i -> do a <- get (as, i) b <- get (bs, i) p <- pair a b set (ps, i) p - return ps \ No newline at end of file + return ps + +tabulate :: + forall (n :: Nat) a k. + (SNatI n) => + (Typeable a) => + (Fin n -> Comp a k) -> + Comp (Vector n a) k +tabulate f = do + as <- vec + _ <- forall universe $ \i -> do + a <- f i + set (as, i) a + return as \ No newline at end of file diff --git a/tutorial/sudoku/Sudoku.md b/tutorial/sudoku/Sudoku.md index 1eb48ed..4b7a026 100644 --- a/tutorial/sudoku/Sudoku.md +++ b/tutorial/sudoku/Sudoku.md @@ -58,11 +58,8 @@ type SudokuSet = Vector Nat9 'TField -- | Smart constructor to build the set [1..9] mkSudokuSet :: (GaloisField k) => Comp SudokuSet k -mkSudokuSet = do - ss <- Vec.vec - forall (universe @Nat9) $ \i -> - Vec.set (ss, i) (fromField $ 1 P.+ fromIntegral i) - return ss +mkSudokuSet = Vec.tabulate $ \i -> + return $ fromField (fromIntegral i P.+ 1) -- | Check that a number belongs to the valid range of numbers, -- e.g. [1 .. 9]