Skip to content

Commit

Permalink
✨ Add the apply function for uninterpreted functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lsrcz committed Jan 6, 2024
1 parent 623b574 commit 05f225b
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Exported some previously hidden API (`BVSignConversion`, `runFreshTFromIndex`) that we found useful or forgot to export. ([#138](https://github.com/lsrcz/grisette/pull/138), [#139](https://github.com/lsrcz/grisette/pull/139))
- Provided `mrgRunFreshT` to run `FreshT` with merging. ([#140](https://github.com/lsrcz/grisette/pull/140))
- Added `Grisette.Data.Class.SignConversion.SignConversion` for types from `Data.Int` and `Data.Word`. ([#142](https://github.com/lsrcz/grisette/pull/142))
- Added shift functions by symbolic shift amounts. ([#151](https://github.com/lsrcz/grisette/pull/151))
- Added `apply` for uninterpreted functions. ([#155](https://github.com/lsrcz/grisette/pull/155))

### Removed

Expand Down
3 changes: 2 additions & 1 deletion src/Grisette/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ module Grisette.Core
SafeDivision (..),
SafeLinearArith (..),
Function (..),
Apply (..),

-- ** Unsolvable types

Expand Down Expand Up @@ -1083,7 +1084,7 @@ import Grisette.Core.Data.Class.EvaluateSym
import Grisette.Core.Data.Class.ExtractSymbolics
( ExtractSymbolics (..),
)
import Grisette.Core.Data.Class.Function (Function (..))
import Grisette.Core.Data.Class.Function (Apply (..), Function (..))
import Grisette.Core.Data.Class.GPretty (GPretty (..))
import Grisette.Core.Data.Class.GenSym
( EnumGenBound (..),
Expand Down
29 changes: 29 additions & 0 deletions src/Grisette/Core/Data/Class/Function.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,21 @@
module Grisette.Core.Data.Class.Function
( -- * Function operations
Function (..),
Apply (..),
)
where

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim
-- >>> :set -XDataKinds
-- >>> :set -XBinaryLiterals
-- >>> :set -XFlexibleContexts
-- >>> :set -XFlexibleInstances
-- >>> :set -XFunctionalDependencies
-- >>> :set -XOverloadedStrings
-- >>> :set -XTypeOperators

-- | Abstraction for function-like types.
class Function f where
-- | Argument type
Expand All @@ -41,3 +53,20 @@ instance Function (a -> b) where
type Arg (a -> b) = a
type Ret (a -> b) = b
f # a = f a

-- | Applying an uninterpreted function.
--
-- >>> let f = "f" :: SymInteger =~> SymInteger =~> SymInteger
-- >>> apply f "a" "b"
-- (apply (apply f a) b)
--
-- Note that for implementation reasons, you can also use `apply` function on
-- a non-function symbolic value. In this case, the function is treated as an
-- `id` function.
class Apply uf where
type FunType uf
apply :: uf -> FunType uf

instance (Apply b) => Apply (a -> b) where
type FunType (a -> b) = a -> FunType b
apply f a = apply (f a)
26 changes: 25 additions & 1 deletion src/Grisette/IR/SymPrim/Data/SymPrim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ import Grisette.Core.Data.Class.BitVector
( BV (bvConcat, bvExt, bvSelect, bvSext, bvZext),
SizedBV (sizedBVConcat, sizedBVExt, sizedBVSelect, sizedBVSext, sizedBVZext),
)
import Grisette.Core.Data.Class.Function (Function (Arg, Ret, (#)))
import Grisette.Core.Data.Class.Function (Apply (FunType, apply), Function (Arg, Ret, (#)))
import Grisette.Core.Data.Class.ModelOps
( ModelOps (emptyModel, insertValue),
ModelRep (buildModel),
Expand Down Expand Up @@ -469,6 +469,10 @@ instance (SupportedPrim ca, SupportedPrim cb, LinkedRep ca sa, LinkedRep cb sb)
type Ret (sa =~> sb) = sb
(SymTabularFun f) # t = wrapTerm $ pevalTabularFunApplyTerm f (underlyingTerm t)

instance (LinkedRep ca sa, LinkedRep ct st, Apply st) => Apply (sa =~> st) where
type FunType (sa =~> st) = sa -> FunType st
apply uf a = apply (uf # a)

-- |
-- Symbolic general function type.
--
Expand Down Expand Up @@ -510,6 +514,10 @@ instance (SupportedPrim ca, SupportedPrim cb, LinkedRep ca sa, LinkedRep cb sb)
type Ret (sa -~> sb) = sb
(SymGeneralFun f) # t = wrapTerm $ pevalGeneralFunApplyTerm f (underlyingTerm t)

instance (LinkedRep ca sa, LinkedRep ct st, Apply st) => Apply (sa -~> st) where
type FunType (sa -~> st) = sa -> FunType st
apply uf a = apply (uf # a)

-- | Construction of general symbolic functions.
--
-- >>> f = "a" --> "a" + 1 :: Integer --> Integer
Expand Down Expand Up @@ -537,6 +545,22 @@ instance Hashable ARG where

-- Aggregate instances

instance Apply SymBool where
type FunType SymBool = SymBool
apply = id

instance Apply SymInteger where
type FunType SymInteger = SymInteger
apply = id

instance (KnownNat n, 1 <= n) => Apply (SymIntN n) where
type FunType (SymIntN n) = SymIntN n
apply = id

instance (KnownNat n, 1 <= n) => Apply (SymWordN n) where
type FunType (SymWordN n) = SymWordN n
apply = id

#define SOLVABLE_SIMPLE(contype, symtype) \
instance Solvable contype symtype where \
con = symtype . conTerm; \
Expand Down
29 changes: 25 additions & 4 deletions test/Grisette/IR/SymPrim/Data/SymPrimTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import Grisette.Core.Data.Class.EvaluateSym
import Grisette.Core.Data.Class.ExtractSymbolics
( ExtractSymbolics (extractSymbolics),
)
import Grisette.Core.Data.Class.Function (Function ((#)))
import Grisette.Core.Data.Class.Function (Apply (apply), Function ((#)))
import Grisette.Core.Data.Class.GenSym
( genSym,
genSymSimple,
Expand Down Expand Up @@ -1028,10 +1028,23 @@ symPrimTests =
],
testGroup
"TabularFun"
[ testCase "apply" $
[ testCase "#" $
(ssym "a" :: SymInteger =~> SymInteger)
# ssym "b"
@=? SymInteger (pevalTabularFunApplyTerm (ssymTerm "a" :: Term (Integer =-> Integer)) (ssymTerm "b"))
@=? SymInteger (pevalTabularFunApplyTerm (ssymTerm "a" :: Term (Integer =-> Integer)) (ssymTerm "b")),
testCase "apply" $
apply
(ssym "f" :: SymInteger =~> SymInteger =~> SymInteger)
(ssym "a")
(ssym "b")
@=? SymInteger
( pevalTabularFunApplyTerm
( pevalTabularFunApplyTerm
(ssymTerm "f" :: Term (Integer =-> Integer =-> Integer))
(ssymTerm "a")
)
(ssymTerm "b")
)
],
testGroup
"GeneralFun"
Expand All @@ -1045,7 +1058,15 @@ symPrimTests =
False
(buildModel ("a" := (1 :: Integer), "b" := (2 :: Integer), "c" := (3 :: Integer)))
(con ("a" --> con ("b" --> "a" + "b" + "c")) :: SymInteger -~> SymInteger -~> SymInteger)
@=? con ("a" --> con ("b" --> "a" + "b" + 3) :: Integer --> Integer --> Integer)
@=? con ("a" --> con ("b" --> "a" + "b" + 3) :: Integer --> Integer --> Integer),
testCase "#" $ do
let f :: SymInteger -~> SymInteger -~> SymInteger =
con ("a" --> con ("b" --> "a" + "b"))
f # ssym "x" @=? con ("b" --> "x" + "b"),
testCase "apply" $ do
let f :: SymInteger -~> SymInteger -~> SymInteger =
con ("a" --> con ("b" --> "a" + "b"))
apply f "x" "y" @=? "x" + "y"
],
testGroup
"Symbolic size"
Expand Down

0 comments on commit 05f225b

Please sign in to comment.