diff --git a/what4/CHANGES.md b/what4/CHANGES.md index 938335de..7dd6ef46 100644 --- a/what4/CHANGES.md +++ b/what4/CHANGES.md @@ -1,3 +1,10 @@ +# next + +* The `BoolMap` parameter of `ConjPred` is now a `ConjMap`. This is a `newtype` + wrapper around `BoolMap` that makes clear that the `BoolMap` in question + represents a conjunction (as `BoolMap`s may also represent disjunctions). + See the Haddocks on `ConjMap` for more details. + # 1.6.2 (Sep 2024) * Allow building with GHC 9.10. diff --git a/what4/src/What4/Expr/App.hs b/what4/src/What4/Expr/App.hs index f4943f33..511de035 100644 --- a/what4/src/What4/Expr/App.hs +++ b/what4/src/What4/Expr/App.hs @@ -80,7 +80,7 @@ import What4.ProgramLoc import qualified What4.SemiRing as SR import qualified What4.SpecialFunctions as SFn import qualified What4.Expr.ArrayUpdateMap as AUM -import What4.Expr.BoolMap (BoolMap, Polarity(..), BoolMapView(..), Wrap(..)) +import What4.Expr.BoolMap (BoolMap, Polarity(..), Wrap(..)) import qualified What4.Expr.BoolMap as BM import What4.Expr.MATLAB import What4.Expr.WeightedSum (WeightedSum, SemiRingProduct) @@ -191,10 +191,10 @@ data App (e :: BaseType -> Type) (tp :: BaseType) where -- Invariant: The argument to a NotPred must not be another NotPred. NotPred :: !(e BaseBoolType) -> App e BaseBoolType - -- Invariant: The BoolMap must contain at least two elements. No - -- element may be a NotPred; negated elements must be represented - -- with Negative element polarity. - ConjPred :: !(BoolMap e) -> App e BaseBoolType + -- Invariant: The 'BM.ConjMap' must contain at least two elements. No element + -- may be a NotPred; negated elements must be represented with Negative + -- element polarity. See also 'isNormal' in @test/Bool.hs@. + ConjPred :: !(BM.ConjMap e) -> App e BaseBoolType ------------------------------------------------------------------------ -- Semiring operations @@ -814,6 +814,9 @@ traverseApp = , ( ConType [t|BoolMap|] `TypeApp` AnyType , [| BM.traverseVars |] ) + , ( ConType [t|BM.ConjMap|] `TypeApp` AnyType + , [| \f cm -> BM.ConjMap <$> BM.traverseVars f (BM.getConjMap cm) |] + ) , ( ConType [t|Ctx.Assignment|] `TypeApp` AnyType `TypeApp` AnyType , [| traverseFC |] ) @@ -1158,20 +1161,20 @@ asWeightedSum sr x asConjunction :: Expr t BaseBoolType -> [(Expr t BaseBoolType, Polarity)] asConjunction (BoolExpr True _) = [] asConjunction (asApp -> Just (ConjPred xs)) = - case BM.viewBoolMap xs of - BoolMapUnit -> [] - BoolMapDualUnit -> [(BoolExpr False initializationLoc, Positive)] - BoolMapTerms (tm:|tms) -> tm:tms + case BM.viewConjMap xs of + BM.ConjTrue -> [] + BM.ConjFalse -> [(BoolExpr False initializationLoc, Positive)] + BM.Conjuncts (tm:|tms) -> tm:tms asConjunction x = [(x,Positive)] asDisjunction :: Expr t BaseBoolType -> [(Expr t BaseBoolType, Polarity)] asDisjunction (BoolExpr False _) = [] asDisjunction (asApp -> Just (NotPred (asApp -> Just (ConjPred xs)))) = - case BM.viewBoolMap xs of - BoolMapUnit -> [] - BoolMapDualUnit -> [(BoolExpr True initializationLoc, Positive)] - BoolMapTerms (tm:|tms) -> map (over _2 BM.negatePolarity) (tm:tms) + case BM.viewConjMap xs of + BM.ConjTrue -> [] + BM.ConjFalse -> [(BoolExpr True initializationLoc, Positive)] + BM.Conjuncts (tm:|tms) -> map (over _2 BM.negatePolarity) (tm:tms) asDisjunction x = [(x,Positive)] asPosAtom :: Expr t BaseBoolType -> (Expr t BaseBoolType, Polarity) @@ -2086,11 +2089,11 @@ reduceApp sym unary a0 = do BaseEq _ x y -> isEq sym x y NotPred x -> notPred sym x - ConjPred bm -> - case BM.viewBoolMap bm of - BoolMapDualUnit -> return $ falsePred sym - BoolMapUnit -> return $ truePred sym - BoolMapTerms tms -> + ConjPred cm -> + case BM.viewConjMap cm of + BM.ConjFalse -> return $ falsePred sym + BM.ConjTrue -> return $ truePred sym + BM.Conjuncts tms -> do let pol (p, Positive) = return p pol (p, Negative) = notPred sym p x:|xs <- mapM pol tms @@ -2337,14 +2340,14 @@ ppApp' a0 = do NotPred x -> ppSExpr "not" [x] - ConjPred xs -> + ConjPred cm -> let pol (x,Positive) = exprPrettyArg x pol (x,Negative) = PrettyFunc "not" [ exprPrettyArg x ] in - case BM.viewBoolMap xs of - BoolMapUnit -> prettyApp "true" [] - BoolMapDualUnit -> prettyApp "false" [] - BoolMapTerms tms -> prettyApp "and" (map pol (toList tms)) + case BM.viewConjMap cm of + BM.ConjTrue -> prettyApp "true" [] + BM.ConjFalse-> prettyApp "false" [] + BM.Conjuncts tms -> prettyApp "and" (map pol (toList tms)) RealIsInteger x -> ppSExpr "isInteger" [x] BVTestBit i x -> prettyApp "testBit" [exprPrettyArg x, showPrettyArg i] diff --git a/what4/src/What4/Expr/BoolMap.hs b/what4/src/What4/Expr/BoolMap.hs index d5ff7e4f..8e8f9767 100644 --- a/what4/src/What4/Expr/BoolMap.hs +++ b/what4/src/What4/Expr/BoolMap.hs @@ -11,8 +11,12 @@ laws like commutativity, associativity and resolution. -} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + module What4.Expr.BoolMap ( BoolMap , var @@ -26,18 +30,30 @@ module What4.Expr.BoolMap , isNull , BoolMapView(..) , viewBoolMap + , foldMapVars , traverseVars , reversePolarities , removeVar , Wrap(..) + -- * 'ConjMap' + , ConjMap(..) + , ConjMapView + , pattern ConjTrue + , pattern ConjFalse + , pattern Conjuncts + , viewConjMap + , addConjunct + , evalConj ) where import Control.Lens (_1, over) +import Data.Coerce (coerce) import Data.Hashable import qualified Data.List as List (foldl') import Data.List.NonEmpty (NonEmpty(..)) import Data.Kind (Type) import Data.Parameterized.Classes +import Data.Parameterized.TraversableF import What4.BaseTypes import qualified What4.Utils.AnnotatedMap as AM @@ -66,18 +82,20 @@ instance OrdF f => Ord (Wrap f x) where instance (HashableF f, TestEquality f) => Hashable (Wrap f x) where hashWithSalt s (Wrap a) = hashWithSaltF s a --- | This data structure keeps track of a collection of expressions --- together with their polarities. Such a collection might represent --- either a conjunction or a disjunction of expressions. The --- implementation uses a map from expression values to their --- polarities, and thus automatically implements the associative, --- commutative and idempotency laws common to both conjunctions and --- disjunctions. Moreover, if the same expression occurs in the --- collection with opposite polarities, the entire collection --- collapses via a resolution step to an \"inconsistent\" map. For --- conjunctions this corresponds to a contradiction and --- represents false; for disjunction, this corresponds to the law of --- the excluded middle and represents true. +-- | A representation of a conjunction or a disjunction. +-- +-- This data structure keeps track of a collection of expressions together +-- with their polarities. The implementation uses a map from expression +-- values to their polarities, and thus automatically implements the +-- associative, commutative and idempotency laws common to both conjunctions +-- and disjunctions. Moreover, if the same expression occurs in the +-- collection with opposite polarities, the entire collection collapses +-- via a resolution step to an \"inconsistent\" map. For conjunctions this +-- corresponds to a contradiction and represents false; for disjunction, this +-- corresponds to the law of the excluded middle and represents true. +-- +-- The annotation on the 'AM.AnnotatedMap' is an incremental hash ('IncrHash') +-- of the map, used to support a fast 'Hashable' instance. data BoolMap (f :: BaseType -> Type) = InconsistentMap @@ -88,6 +106,16 @@ instance OrdF f => Eq (BoolMap f) where BoolMap m1 == BoolMap m2 = AM.eqBy (==) m1 m2 _ == _ = False +instance OrdF f => Semigroup (BoolMap f) where + (<>) = combine + +-- | Specialized version of 'foldMapVars' +instance FoldableF BoolMap where + foldMapF f = foldMapVars f + +foldMapVars :: Monoid m => (f BaseBoolType -> m) -> BoolMap f -> m +foldMapVars _ InconsistentMap = mempty +foldMapVars f (BoolMap am) = foldMap (f . unWrap . fst) (AM.toList am) -- | Traverse the expressions in a bool map, and rebuild the map. traverseVars :: (Applicative m, HashableF g, OrdF g) => @@ -107,7 +135,10 @@ instance (OrdF f, HashableF f) => Hashable (BoolMap f) where Nothing -> hashWithSalt s (1::Int) Just h -> hashWithSalt (hashWithSalt s (1::Int)) h --- | Represents the state of a bool map +-- | Represents the state of a 'BoolMap' (either a conjunction or disjunction). +-- +-- If you know you are dealing with a 'BoolMap' that represents a conjunction, +-- consider using 'ConjMap' and 'viewConjMap' for the sake of clarity. data BoolMapView f = BoolMapUnit -- ^ A bool map with no expressions, represents the unit of the corresponding operation @@ -179,3 +210,66 @@ reversePolarities (BoolMap m) = BoolMap $! fmap negatePolarity m removeVar :: OrdF f => BoolMap f -> f BaseBoolType -> BoolMap f removeVar InconsistentMap _ = InconsistentMap removeVar (BoolMap m) x = BoolMap (AM.delete (Wrap x) m) + +-------------------------------------------------------------------------------- +-- ConjMap + +-- | A 'BoolMap' representing a conjunction. +newtype ConjMap f = ConjMap { getConjMap :: BoolMap f } + deriving (Eq, FoldableF, Hashable, Semigroup) + +-- | Represents the state of a 'ConjMap'. See 'viewConjMap'. +-- +-- Like 'BoolMapView', but with more specific patterns for readability. +newtype ConjMapView f = ConjMapView (BoolMapView f) + +pattern ConjTrue :: ConjMapView f +pattern ConjTrue = ConjMapView BoolMapUnit + +pattern ConjFalse :: ConjMapView f +pattern ConjFalse = ConjMapView BoolMapDualUnit + +pattern Conjuncts :: NonEmpty (f BaseBoolType, Polarity) -> ConjMapView f +pattern Conjuncts ts = ConjMapView (BoolMapTerms ts) + +{-# COMPLETE ConjTrue, ConjFalse, Conjuncts #-} + +-- | Deconstruct the given 'ConjMap' for later processing +viewConjMap :: forall f. ConjMap f -> ConjMapView f +viewConjMap = + -- The explicit type annotations on `coerce` are likely necessary because of + -- https://gitlab.haskell.org/ghc/ghc/-/issues/21003 + coerce @(BoolMap f -> BoolMapView f) @(ConjMap f -> ConjMapView f) viewBoolMap +{-# INLINE viewConjMap #-} + +-- | Add a conjunct to a 'ConjMap'. +-- +-- Wrapper around 'addVar'. +addConjunct :: + forall f. + (HashableF f, OrdF f) => + f BaseBoolType -> + Polarity -> + ConjMap f -> + ConjMap f +addConjunct = + -- The explicit type annotations on `coerce` are likely necessary because of + -- https://gitlab.haskell.org/ghc/ghc/-/issues/21003 + coerce + @(f BaseBoolType -> Polarity -> BoolMap f -> BoolMap f) + @(f BaseBoolType -> Polarity -> ConjMap f -> ConjMap f) + addVar +{-# INLINE addConjunct #-} + +-- | Given the means to evaluate the conjuncts of a 'ConjMap' to a concrete +-- 'Bool', evaluate the whole conjunction to a 'Bool'. +evalConj :: Applicative m => (f BaseBoolType -> m Bool) -> ConjMap f -> m Bool +evalConj f cm = + let pol (x, Positive) = f x + pol (x, Negative) = not <$> f x + in + case viewConjMap cm of + ConjTrue -> pure True + ConjFalse -> pure False + Conjuncts (t:|ts) -> + List.foldl' (&&) <$> pol t <*> traverse pol ts diff --git a/what4/src/What4/Expr/Builder.hs b/what4/src/What4/Expr/Builder.hs index dc39a467..d27a2829 100644 --- a/what4/src/What4/Expr/Builder.hs +++ b/what4/src/What4/Expr/Builder.hs @@ -19,6 +19,20 @@ threads may still clobber state others have set (e.g., the current program location) so the potential for truly multithreaded use is somewhat limited. Consider the @exprBuilderFreshConfig@ or @exprBuilderSplitConfig@ operations if this is a concern. + +-- * Boolean expressions + +'ExprBuilder' tries to rewrite expressions in order to keep them as simple +and concrete as possible. In particular, here are a few considerations for +boolean-typed expressions: + +* Disjunctions are implicitly represented as negated conjunctions +* Conjunctions are represented via 'BM.ConjMap' (see docs on that type) +* @xor@ is represented as the negation of equality + +Boolean expressions are expected to be somewhat normalized at all times. +For example, there should never be a double negation (nested 'NotPred'). +See @isNormal@ in @test/Bool.hs@ for the exact expectations. -} {-# LANGUAGE CPP #-} {-# LANGUAGE BangPatterns #-} @@ -233,7 +247,7 @@ import What4.Symbol import What4.Expr.Allocator import What4.Expr.App import qualified What4.Expr.ArrayUpdateMap as AUM -import What4.Expr.BoolMap (BoolMap, Polarity(..), BoolMapView(..)) +import What4.Expr.BoolMap (BoolMap, Polarity(..)) import qualified What4.Expr.BoolMap as BM import What4.Expr.MATLAB import What4.Expr.WeightedSum (WeightedSum, SemiRingProduct) @@ -1548,16 +1562,16 @@ realSum sym s = semiRingSum sym s bvSum :: ExprBuilder t st fs -> WeightedSum (Expr t) (SR.SemiRingBV flv w) -> IO (BVExpr t w) bvSum sym s = semiRingSum sym s -conjPred :: ExprBuilder t st fs -> BoolMap (Expr t) -> IO (BoolExpr t) -conjPred sym bm = - case BM.viewBoolMap bm of - BoolMapUnit -> return $ truePred sym - BoolMapDualUnit -> return $ falsePred sym - BoolMapTerms ((x,p):|[]) -> +conjPred :: ExprBuilder t st fs -> BM.ConjMap (Expr t) -> IO (BoolExpr t) +conjPred sym cm = + case BM.viewConjMap cm of + BM.ConjTrue -> return $ truePred sym + BM.ConjFalse -> return $ falsePred sym + BM.Conjuncts ((x,p):|[]) -> case p of Positive -> return x Negative -> notPred sym x - _ -> sbMakeExpr sym $ ConjPred bm + _ -> sbMakeExpr sym $ ConjPred cm bvUnary :: (1 <= w) => ExprBuilder t st fs -> UnaryBV (BoolExpr t) w -> IO (BVExpr t w) bvUnary sym u @@ -1985,7 +1999,7 @@ tryAndAbsorption :: BoolExpr t -> Bool tryAndAbsorption (asApp -> Just (NotPred (asApp -> Just (ConjPred as)))) (asConjunction -> bs) - = checkAbsorption (BM.reversePolarities as) bs + = checkAbsorption (BM.reversePolarities (BM.getConjMap as)) bs tryAndAbsorption _ _ = False @@ -1997,7 +2011,7 @@ tryOrAbsorption :: BoolExpr t -> Bool tryOrAbsorption (asApp -> Just (ConjPred as)) (asDisjunction -> bs) - = checkAbsorption as bs + = checkAbsorption (BM.getConjMap as) bs tryOrAbsorption _ _ = False @@ -2048,6 +2062,8 @@ instance IsExprBuilder (ExprBuilder t st fs) where ---------------------------------------------------------------------- -- Bool operations. + -- + -- See Boolean expressions in the module-level docs for some discussion. truePred = sbTrue falsePred = sbFalse @@ -2095,7 +2111,7 @@ instance IsExprBuilder (ExprBuilder t st fs) where go a b | Just (ConjPred as) <- asApp a , Just (ConjPred bs) <- asApp b - = conjPred sym $ BM.combine as bs + = conjPred sym $ as <> bs | tryAndAbsorption a b = return b @@ -2104,13 +2120,13 @@ instance IsExprBuilder (ExprBuilder t st fs) where = return a | Just (ConjPred as) <- asApp a - = conjPred sym $ uncurry BM.addVar (asPosAtom b) as + = conjPred sym $ uncurry BM.addConjunct (asPosAtom b) as | Just (ConjPred bs) <- asApp b - = conjPred sym $ uncurry BM.addVar (asPosAtom a) bs + = conjPred sym $ uncurry BM.addConjunct (asPosAtom a) bs | otherwise - = conjPred sym $ BM.fromVars [asPosAtom a, asPosAtom b] + = conjPred sym $ BM.ConjMap (BM.fromVars [asPosAtom a, asPosAtom b]) orPred sym x y = case (asConstantPred x, asConstantPred y) of @@ -2125,7 +2141,7 @@ instance IsExprBuilder (ExprBuilder t st fs) where go a b | Just (NotPred (asApp -> Just (ConjPred as))) <- asApp a , Just (NotPred (asApp -> Just (ConjPred bs))) <- asApp b - = notPred sym =<< conjPred sym (BM.combine as bs) + = notPred sym =<< conjPred sym (as <> bs) | tryOrAbsorption a b = return b @@ -2134,13 +2150,13 @@ instance IsExprBuilder (ExprBuilder t st fs) where = return a | Just (NotPred (asApp -> Just (ConjPred as))) <- asApp a - = notPred sym =<< conjPred sym (uncurry BM.addVar (asNegAtom b) as) + = notPred sym =<< conjPred sym (uncurry BM.addConjunct (asNegAtom b) as) | Just (NotPred (asApp -> Just (ConjPred bs))) <- asApp b - = notPred sym =<< conjPred sym (uncurry BM.addVar (asNegAtom a) bs) + = notPred sym =<< conjPred sym (uncurry BM.addConjunct (asNegAtom a) bs) | otherwise - = notPred sym =<< conjPred sym (BM.fromVars [asNegAtom a, asNegAtom b]) + = notPred sym =<< conjPred sym (BM.ConjMap (BM.fromVars [asNegAtom a, asNegAtom b])) itePred sb c x y -- ite c c y = c || y diff --git a/what4/src/What4/Expr/GroundEval.hs b/what4/src/What4/Expr/GroundEval.hs index 4179cd77..827f9248 100644 --- a/what4/src/What4/Expr/GroundEval.hs +++ b/what4/src/What4/Expr/GroundEval.hs @@ -44,7 +44,6 @@ import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Maybe import qualified Data.BitVector.Sized as BV -import Data.List.NonEmpty (NonEmpty(..)) import Data.Foldable import qualified Data.Map.Strict as Map import Data.Maybe ( fromMaybe ) @@ -320,15 +319,7 @@ evalGroundApp f a0 = do if xv then f y else f z NotPred x -> not <$> f x - ConjPred xs -> - let pol (x,Positive) = f x - pol (x,Negative) = not <$> f x - in - case BM.viewBoolMap xs of - BM.BoolMapUnit -> return True - BM.BoolMapDualUnit -> return False - BM.BoolMapTerms (t:|ts) -> - foldl' (&&) <$> pol t <*> mapM pol ts + ConjPred cm -> BM.evalConj f cm RealIsInteger x -> (\xv -> denominator xv == 1) <$> f x BVTestBit i x -> diff --git a/what4/src/What4/Expr/VarIdentification.hs b/what4/src/What4/Expr/VarIdentification.hs index 279ac13e..b42c8870 100644 --- a/what4/src/What4/Expr/VarIdentification.hs +++ b/what4/src/What4/Expr/VarIdentification.hs @@ -306,14 +306,14 @@ recurseAssertedAppExprVars scope p e = go e go (asApp -> Just (NotPred x)) = recordAssertionVars scope (BM.negatePolarity p) x - go (asApp -> Just (ConjPred xs)) = + go (asApp -> Just (ConjPred cm)) = let pol (x,BM.Positive) = recordAssertionVars scope p x pol (x,BM.Negative) = recordAssertionVars scope (BM.negatePolarity p) x in - case BM.viewBoolMap xs of - BM.BoolMapUnit -> return () - BM.BoolMapDualUnit -> return () - BM.BoolMapTerms (t:|ts) -> mapM_ pol (t:ts) + case BM.viewConjMap cm of + BM.ConjTrue -> return () + BM.ConjFalse -> return () + BM.Conjuncts (t:|ts) -> mapM_ pol (t:ts) go (asApp -> Just (BaseIte BaseBoolRepr _ c x y)) = do recordExprVars scope c diff --git a/what4/src/What4/Protocol/SMTWriter.hs b/what4/src/What4/Protocol/SMTWriter.hs index a97e9344..4dfc7213 100644 --- a/what4/src/What4/Protocol/SMTWriter.hs +++ b/what4/src/What4/Protocol/SMTWriter.hs @@ -2134,14 +2134,14 @@ appSMTExpr ae = do let pol (x,Positive) = mkBaseExpr x pol (x,Negative) = notExpr <$> mkBaseExpr x in - case BM.viewBoolMap xs of - BM.BoolMapUnit -> + case BM.viewConjMap xs of + BM.ConjTrue -> return $ SMTExpr BoolTypeMap $ boolExpr True - BM.BoolMapDualUnit -> + BM.ConjFalse -> return $ SMTExpr BoolTypeMap $ boolExpr False - BM.BoolMapTerms (t:|[]) -> + BM.Conjuncts (t:|[]) -> SMTExpr BoolTypeMap <$> pol t - BM.BoolMapTerms (t:|ts) -> + BM.Conjuncts (t:|ts) -> do cnj <- andAll <$> mapM pol (t:ts) freshBoundTerm BoolTypeMap cnj diff --git a/what4/src/What4/Protocol/VerilogWriter/Backend.hs b/what4/src/What4/Protocol/VerilogWriter/Backend.hs index c10b6ed9..da85cc11 100644 --- a/what4/src/What4/Protocol/VerilogWriter/Backend.hs +++ b/what4/src/What4/Protocol/VerilogWriter/Backend.hs @@ -156,7 +156,7 @@ appVerilogExpr app = e' <- exprToVerilogExpr e unop Not e' --DisjPred es -> boolMapToExpr False True Or es - ConjPred es -> boolMapToExpr True False And es + ConjPred es -> boolMapToExpr True False And (BMap.getConjMap es) -- Semiring operations -- We only support bitvector semiring operations diff --git a/what4/src/What4/Serialize/Printer.hs b/what4/src/What4/Serialize/Printer.hs index 26fa5707..1d34f5ee 100644 --- a/what4/src/What4/Serialize/Printer.hs +++ b/what4/src/What4/Serialize/Printer.hs @@ -479,7 +479,7 @@ convertAppExpr' = go . W4.appExprApp go (W4.NotPred e) = do s <- goE e return $ S.L [ident "notp", s] - go (W4.ConjPred bm) = convertBoolMap "andp" True bm + go (W4.ConjPred cm) = convertBoolMap "andp" True (BooM.getConjMap cm) go (W4.BVSlt e1 e2) = do s1 <- goE e1 s2 <- goE e2 diff --git a/what4/test/Bool.hs b/what4/test/Bool.hs new file mode 100644 index 00000000..0a9f181f --- /dev/null +++ b/what4/test/Bool.hs @@ -0,0 +1,295 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ViewPatterns #-} + +module Bool where + +import Control.Monad (unless, when) +import Control.Monad.IO.Class (MonadIO, liftIO) +import qualified Control.Monad.State.Strict as State +import Control.Monad.Trans (lift) +import Data.Coerce (coerce) +import Data.Either (isRight) +import Data.Foldable (traverse_) +import qualified Data.Map as Map +import qualified Data.Parameterized.Map as MapF +import Data.Parameterized.Nonce (newIONonceGenerator) +import Data.Parameterized.Some (Some(Some)) +import Hedgehog (GenT) +import qualified Hedgehog.Gen as Gen +import qualified Hedgehog.Internal.Gen as HG +import qualified Hedgehog.Internal.Property as HG +import qualified Test.Tasty.Hedgehog as THG +import qualified Test.Tasty as T +import qualified What4.Expr.BoolMap as BM +import What4.Expr.Builder +import What4.Expr (EmptyExprBuilderState(EmptyExprBuilderState)) +import What4.Interface + +-- | A tree of API calls to 'IsExprBuilder' methods. +-- +-- Instances may be \"interpreted\" into 'IsExprBuilder' calls via 'toSymExpr'. +-- Data flows from children to parents. +-- +-- Given a means to evaluate variables to 'Bool's, these expressions can +-- also be evaluated directly (via 'eval'), in order to compare the result to +-- 'asConstantPred'. +data BExpr var + = -- 0-ary + -- | 'falsePred', 'truePred' + Lit !Bool + -- | 'freshConstant' + | Var !var + -- unary + -- | 'notPred' + | Not !(BExpr var) + -- binary + -- | 'andPred' + | And !(BExpr var) !(BExpr var) + -- | 'eqPred' + | Eq !(BExpr var) !(BExpr var) + -- | 'orPred' + | Or !(BExpr var) !(BExpr var) + -- | 'xorPred' + | Xor !(BExpr var) !(BExpr var) + -- tertiary + -- | 'itePred' + | Ite !(BExpr var) !(BExpr var) !(BExpr var) + deriving Show + +genBExpr :: HG.MonadGen m => m var -> m (BExpr var) +genBExpr var = + Gen.recursive + Gen.choice + [ -- 0-ary + Lit <$> Gen.bool + , Var <$> var + ] + [ -- unary + Not <$> genBExpr var + -- binary + , And <$> genBExpr var <*> genBExpr var + -- TODO: Generate Eq, Xor. + -- + -- This would require updating 'isNormal' to take these into account. + -- + -- , Eq <$> genBExpr var <*> genBExpr var + , Or <$> genBExpr var <*> genBExpr var + -- , Xor <$> genBExpr var <*> genBExpr var + , Ite <$> genBExpr var <*> genBExpr var <*> genBExpr var + ] + +newtype Valuation t + = Valuation { getValuation :: Map.Map (ExprBoundVar t BaseBoolType) Bool } + deriving Show + +getValue :: ExprBoundVar t BaseBoolType -> Valuation t -> Bool +getValue v vs = + case Map.lookup v (getValuation vs) of + Nothing -> error "getValue: bad variable" + Just b -> b + +genFreshVar :: + (HG.MonadGen m, MonadIO m) => + ExprBuilder t st fs -> + State.StateT (Valuation t) m (ExprBoundVar t BaseBoolType) +genFreshVar sym = do + v <- lift (liftIO (freshConstant sym (safeSymbol "b") BaseBoolRepr)) + case v of + BoundVarExpr v' -> do + b <- Gen.bool + State.modify (coerce (Map.insert v' b)) + pure v' + _ -> error "Not a bound variable?" + +-- | Generate a new variable ('genFreshVar') or reuse an existing one +genVar :: + (HG.MonadGen m, MonadIO m) => + ExprBuilder t st fs -> + State.StateT (Valuation t) m (ExprBoundVar t BaseBoolType) +genVar sym = do + b <- Gen.bool + if b + then genFreshVar sym + else do + vs <- State.gets (Map.toList . getValuation) + case vs of + [] -> genFreshVar sym + _ -> Gen.choice (map (pure . fst) vs) + +doGenExpr :: + ExprBuilder t st fs -> + GenT IO (BExpr (ExprBoundVar t BaseBoolType), Valuation t) +doGenExpr sym = + let vars0 = Valuation Map.empty in + State.runStateT (genBExpr @(State.StateT _ (GenT IO)) (genVar @(GenT IO) sym)) vars0 + +toSymExpr :: + IsExprBuilder sym => + sym -> + -- | How to handle variables + (var -> IO (SymExpr sym BaseBoolType)) -> + BExpr var -> + IO (SymExpr sym BaseBoolType) +toSymExpr sym var = go + where + go = + \case + Lit True -> pure (truePred sym) + Lit False -> pure (falsePred sym) + Var v -> var v + Not e -> notPred sym =<< go e + And l r -> do + l' <- go l + r' <- go r + andPred sym l' r' + Eq l r -> do + l' <- go l + r' <- go r + eqPred sym l' r' + Or l r -> do + l' <- go l + r' <- go r + orPred sym l' r' + Xor l r -> do + l' <- go l + r' <- go r + xorPred sym l' r' + Ite c l r -> do + c' <- go c + l' <- go l + r' <- go r + itePred sym c' l' r' + +-- | For use with 'toSymExpr', to leave variables uninterpreted +uninterpVar :: ExprBoundVar t BaseBoolType -> Expr t BaseBoolType +uninterpVar = BoundVarExpr + +eval :: Applicative f => (var -> f Bool) -> BExpr var -> f Bool +eval var = go + where + ite c l r = if c then l else r + go = + \case + Lit True -> pure True + Lit False -> pure False + Var v -> var v + Not e -> not <$> go e + And l r -> (&&) <$> go l <*> go r + Eq l r -> (==) <$> go l <*> go r + Or l r -> (||) <$> go l <*> go r + Xor l r -> (/=) <$> go l <*> go r + Ite c l r -> ite <$> go c <*> go l <*> go r + +-- | For use with 'eval', to interpret variables +getVar :: ExprBoundVar t BaseBoolType -> State.State (Valuation t) Bool +getVar v = State.gets (getValue v) + +isNot :: Expr t BaseBoolType -> Bool +isNot e = + case e of + AppExpr ae -> + case appExprApp ae of + NotPred {} -> True + _ -> False + _ -> False + +isNormalIte :: + ExprBuilder t st fs -> + Expr t BaseBoolType -> + Expr t BaseBoolType -> + Expr t BaseBoolType -> + Either String () +isNormalIte sym c l r = do + isNormal sym c + isNormal sym l + isNormal sym r + when (isNot c) (Left "negated ite condition") + when (c == l) (Left "ite cond == LHS") + when (c == r) (Left "ite cond == RHS") + when (c == truePred sym) (Left "ite cond == true") + when (c == falsePred sym) (Left "ite cond == false") + +isNormalConjunct :: + ExprBuilder t st fs -> + Expr t BaseBoolType -> + BM.Polarity -> + Either String () +isNormalConjunct sym expr pol = + case expr of + BoolExpr {} -> Left "boolean literal inside conjunction" + BoundVarExpr {} -> Right () + AppExpr ae -> + case appExprApp ae of + NotPred {} -> Left "not should be expressed via polarity" + -- This must be an OR, if it is an AND it should be combined with + -- its parent + ConjPred cm' -> do + when (pol == BM.Positive) (Left "and inside and") + -- Note that it is possible to have ORs inside ORs, e.g., if the outer + -- OR used to be an AND but was negated. + isNormalMap sym cm' + BaseIte BaseBoolRepr _sz c l r -> isNormalIte sym c l r + _ -> Left "non-normal app in conjunct" + _ -> Left "non-normal expr in conjunct" + +isNormalMap :: ExprBuilder t st fs -> BM.ConjMap (Expr t) -> Either String () +isNormalMap sym cm = + case BM.viewConjMap cm of + BM.ConjTrue -> Left "empty conjunction map" + BM.ConjFalse -> Left "inconsistent conjunction map" + BM.Conjuncts conjs -> traverse_ (uncurry (isNormalConjunct sym)) conjs + +-- | Is this boolean expression sufficiently normalized? +isNormal :: ExprBuilder t st fs -> Expr t BaseBoolType -> Either String () +isNormal sym = + \case + BoolExpr {} -> Right () + BoundVarExpr {} -> Right () + AppExpr ae -> + case appExprApp ae of + NotPred (asApp -> Just NotPred {}) -> Left "double negation" + NotPred e -> isNormal sym e + ConjPred cm -> isNormalMap sym cm + BaseIte BaseBoolRepr _sz c l r -> isNormalIte sym c l r + _ -> Left "non-normal app" + _ -> Left "non-normal expr" + +boolTests :: T.TestTree +boolTests = + T.testGroup + "boolean normalization tests" + [ -- Test that the rewrite rules rewrite expressions into a sufficiently + -- normal form (defined by 'isNormal'). + THG.testProperty "boolean rewrites normalize" $ + HG.property $ do + Some ng <- liftIO newIONonceGenerator + sym <- liftIO (newExprBuilder FloatIEEERepr EmptyExprBuilderState ng) + (e, _vars) <- HG.forAllT (doGenExpr sym) + e' <- liftIO (toSymExpr sym (pure . uninterpVar) e) + let ok = isNormal sym e' + unless (isRight ok) $ + liftIO (putStrLn ("Not normalized:\n" ++ show (printSymExpr e'))) + ok HG.=== Right () + , THG.testProperty "boolean rewrites preserve semantics" $ + HG.property $ do + Some ng <- liftIO newIONonceGenerator + sym <- liftIO (newExprBuilder FloatIEEERepr EmptyExprBuilderState ng) + (e, vars) <- HG.forAllT (doGenExpr sym) + -- Concretely evaluate the `BExpr` to get the expected semantics. + let expected = State.evalState (eval getVar e) vars + -- Generate a `Expr` with uninterpreted variables. It is important to + -- not interpret the variables into `truePred` and `falsePred` here, + -- to avoid only hitting the `asConstantPred` cases in the rewrites. + e' <- liftIO (toSymExpr sym (pure . uninterpVar) e) + -- Finally, substitute values in for the variables, simplifying the + -- `Expr` along the way until we get a concrete boolean. + let vs = Map.toList (getValuation vars) + let substs = foldr (\(v, b) -> MapF.insert v (if b then truePred sym else falsePred sym)) MapF.empty vs + e'' <- liftIO (substituteBoundVars sym substs e') + -- Check that the `BExpr` and `Expr` agreed on the semantics. + case asConstantPred e'' of + Just actual -> actual HG.=== expected + Nothing -> HG.failure + ] diff --git a/what4/test/BoolNormalization.hs b/what4/test/BoolNormalization.hs new file mode 100644 index 00000000..76d29699 --- /dev/null +++ b/what4/test/BoolNormalization.hs @@ -0,0 +1,62 @@ +-- See what percentage of randomly-generated boolean expressions can be +-- completely simplified away. Higher is better. This is one mechanism for +-- evaluating rewrite rules. + +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} + +module Main (main) where + +import Control.Monad (foldM) +import qualified Control.Monad.State.Strict as State +import Data.Parameterized.Nonce (newIONonceGenerator) +import Data.Parameterized.Some (Some(Some)) +import Data.Parameterized.TraversableFC (traverseFC_) +import qualified Hedgehog.Internal.Gen as HG +import qualified Hedgehog.Internal.Tree as HG +import qualified Hedgehog as HG +import What4.Expr.Builder +import What4.Expr (EmptyExprBuilderState(EmptyExprBuilderState)) +import What4.Interface + +import Bool + +-- | Get the size of an expression. Lower is better. +sz :: Expr t tp -> Int +sz = + \case + SemiRingLiteral {} -> 1 + BoolExpr {} -> 1 + FloatExpr {} -> 1 + StringExpr {} -> 1 + AppExpr ae -> + State.execState (traverseFC_ (\e -> State.modify (+ sz e)) (appExprApp ae)) 1 + NonceAppExpr nae -> + State.execState (traverseFC_ (\e -> State.modify (+ sz e)) (nonceExprApp nae)) 1 + BoundVarExpr {} -> 1 + +main :: IO () +main = do + Some ng <- newIONonceGenerator + sym <- newExprBuilder FloatIEEERepr EmptyExprBuilderState ng + let eliminated i = do + x <- HG.runTreeT (HG.evalGenT (HG.Size 100) (HG.Seed i 1) (doGenExpr sym)) + case HG.nodeValue x of + Nothing -> error "whoops" + Just (bExpr, _vars) -> do + e <- toSymExpr sym (pure . uninterpVar) bExpr + -- Audit the quality of the generated expressions: + -- putStrLn "--------------------------------------" + -- putStrLn (show bExpr) + -- putStrLn "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" + -- putStrLn (show (printSymExpr e)) + -- putStrLn "______________________________________" + -- putStrLn (show (sz e)) + case asConstantPred e of + Just {} -> pure (1, sz e) + Nothing -> pure (0, sz e) + let total = 20000 + let count (accFull, accSize) (full, size) = (accFull + full, accSize + size) + (full, size) <- foldM (\acc seed -> count acc <$> eliminated seed) (0 :: Int, 0 :: Int) [0..total] + putStrLn ("Fully eliminated " ++ show full ++ "/" ++ show total) + putStrLn ("Total size: " ++ show size) diff --git a/what4/test/ExprsTest.hs b/what4/test/ExprsTest.hs index d4b4d81f..3f2258f7 100644 --- a/what4/test/ExprsTest.hs +++ b/what4/test/ExprsTest.hs @@ -33,6 +33,8 @@ import What4.Concrete import What4.Expr import What4.Interface +import Bool (boolTests) + type IteExprBuilder t fs = ExprBuilder t EmptyExprBuilderState fs withTestSolver :: (forall t. IteExprBuilder t (Flags FloatIEEE) -> IO a) -> IO a @@ -385,4 +387,5 @@ main = defaultMain $ testGroup "What4 Expressions" return (asConcrete s) (fromConcreteString <$> s) === Just "" , testInjectiveConversions + , boolTests ] diff --git a/what4/what4.cabal b/what4/what4.cabal index 84a62292..e4c6ceb3 100644 --- a/what4/what4.cabal +++ b/what4/what4.cabal @@ -363,9 +363,20 @@ test-suite exprs_tests main-is: ExprsTest.hs other-modules: + Bool GenWhat4Expr build-depends: bv-sized + , containers + , mtl + +executable bool-normalization + import: bldflags, testdefs-hedgehog, testdefs-hunit + main-is: BoolNormalization.hs + other-modules: Bool + build-depends: containers + , mtl + , transformers test-suite iteexprs_tests