Skip to content
This repository was archived by the owner on Jun 4, 2024. It is now read-only.

Stalk updates #15

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/Snarkl/Language/Prelude.hs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ instance (Typeable ty, Derive ty k) => Derive ('TArr ty) k where
_ <- set (a, 0) v
return a

instance (Typeable ty, Derive ty k) => Derive ('TVec n ty) k where
derive n = do
a :: TExp ('TArr ty) k <- derive n
return $ unsafe_cast a

instance
( Typeable ty1,
Derive ty1 k,
Expand Down Expand Up @@ -599,6 +604,9 @@ instance
return $ TEApp e2 x
zip_vals b y1 y2

instance Zippable ('TVec n ty) k where
zip_vals _ x _ = return x

----------------------------------------------------
--
-- Recursive Types
Expand Down
69 changes: 62 additions & 7 deletions src/Snarkl/Language/Vector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ module Snarkl.Language.Vector
transpose,
all,
any,
unzip,
zip,
tabulate,
)
where

Expand All @@ -27,24 +30,27 @@ import Data.Typeable (Proxy (Proxy), Typeable)
import Snarkl.AST.TExpr (TExp (TEApp))
import Snarkl.Language.Prelude
( Comp,
Ty (TArr, TBool, TFun, TUnit, TVec),
Ty (TArr, TBool, TFun, TProd, TUnit, TVec),
apply,
arr,
arr2,
dec,
false,
forall,
forall2,
fst_pair,
lambda,
pair,
return,
snd_pair,
true,
unsafe_cast,
(&&),
(>>=),
(||),
)
import qualified Snarkl.Language.Prelude as Snarkl
import Prelude hiding (all, any, concat, foldl, map, return, traverse, (&&), (*), (>>=), (||))
import Prelude hiding (all, any, concat, foldl, map, return, traverse, unzip, zip, (&&), (*), (>>=), (||))
import qualified Prelude as P

type Vector = 'TVec
Expand Down Expand Up @@ -94,7 +100,7 @@ map ::
Comp (Vector n b) k
map f a = do
b <- vec
_ <- forall (universe @n) $ \i -> do
_ <- forall universe $ \i -> do
ai <- get (a, i)
bi <- apply f ai
set (b, i) bi
Expand All @@ -110,7 +116,7 @@ foldl ::
TExp (Vector n a) k ->
Comp b k
foldl f b0 as = do
go (universe @n) b0
go universe b0
where
go ns acc = case ns of
[] -> return acc
Expand All @@ -129,7 +135,7 @@ traverse ::
Comp (Vector n b) k
traverse f as = do
bs <- vec
_ <- forall (universe @n) $ \i -> do
_ <- forall universe $ \i -> do
ai <- get (as, i)
bi <- f ai
set (bs, i) bi
Expand All @@ -145,7 +151,7 @@ traverseWithIndex ::
Comp (Vector n b) k
traverseWithIndex f as = do
bs <- vec
_ <- forall (universe @n) $ \i -> do
_ <- forall universe $ \i -> do
ai <- get (as, i)
bi <- f i ai
set (bs, i) bi
Expand All @@ -159,7 +165,7 @@ traverse_ ::
TExp (Vector n a) k ->
Comp 'TUnit k
traverse_ f as = do
forall (universe @n) $ \i -> do
forall universe $ \i -> do
ai <- get (as, i)
f ai

Expand Down Expand Up @@ -239,3 +245,52 @@ any as = do
lambda $ \x ->
return $ acc || x
foldl f false as

unzip ::
forall (n :: Nat) a b k.
(SNatI n) =>
(Typeable a) =>
(Typeable b) =>
(Typeable n) =>
TExp (Vector n ('TProd a b)) k ->
Comp ('TProd (Vector n a) (Vector n b)) k
unzip ps = do
as <- vec @n
bs <- vec @n
_ <- forall universe $ \i -> do
p <- get (ps, i)
a <- fst_pair p
b <- snd_pair p
_ <- set (as, i) a
set (bs, i) b
pair as bs

zip ::
forall (n :: Nat) a b k.
(SNatI n) =>
(Typeable a) =>
(Typeable b) =>
TExp (Vector n a) k ->
TExp (Vector n b) k ->
Comp (Vector n ('TProd a b)) k
zip as bs = do
ps <- vec
_ <- forall universe $ \i -> do
a <- get (as, i)
b <- get (bs, i)
p <- pair a b
set (ps, i) p
return ps

tabulate ::
forall (n :: Nat) a k.
(SNatI n) =>
(Typeable a) =>
(Fin n -> Comp a k) ->
Comp (Vector n a) k
tabulate f = do
as <- vec
_ <- forall universe $ \i -> do
a <- f i
set (as, i) a
return as
7 changes: 2 additions & 5 deletions tutorial/sudoku/Sudoku.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,8 @@ type SudokuSet = Vector Nat9 'TField

-- | Smart constructor to build the set [1..9]
mkSudokuSet :: (GaloisField k) => Comp SudokuSet k
mkSudokuSet = do
ss <- Vec.vec
forall (universe @Nat9) $ \i ->
Vec.set (ss, i) (fromField $ 1 P.+ fromIntegral i)
return ss
mkSudokuSet = Vec.tabulate $ \i ->
return $ fromField (fromIntegral i P.+ 1)

-- | Check that a number belongs to the valid range of numbers,
-- e.g. [1 .. 9]
Expand Down
Loading