Skip to content

Commit

Permalink
what4: Add a ConjMap newtype around BoolMap for readability
Browse files Browse the repository at this point in the history
Matches on `ConjMapView` encode the semantic meaning of the
constructors.
  • Loading branch information
langston-barrett committed Feb 1, 2025
1 parent 746ab5c commit e4a00f0
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 58 deletions.
41 changes: 22 additions & 19 deletions what4/src/What4/Expr/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ data App (e :: BaseType -> Type) (tp :: BaseType) where
-- Invariant: The BoolMap must contain at least two elements. No
-- element may be a NotPred; negated elements must be represented
-- with Negative element polarity.
ConjPred :: !(BoolMap e) -> App e BaseBoolType
ConjPred :: !(BM.ConjMap e) -> App e BaseBoolType

------------------------------------------------------------------------
-- Semiring operations
Expand Down Expand Up @@ -814,6 +814,9 @@ traverseApp =
, ( ConType [t|BoolMap|] `TypeApp` AnyType
, [| BM.traverseVars |]
)
, ( ConType [t|BM.ConjMap|] `TypeApp` AnyType
, [| \f cm -> BM.ConjMap <$> BM.traverseVars f (BM.getConjMap cm) |]
)
, ( ConType [t|Ctx.Assignment|] `TypeApp` AnyType `TypeApp` AnyType
, [| traverseFC |]
)
Expand Down Expand Up @@ -1158,20 +1161,20 @@ asWeightedSum sr x
asConjunction :: Expr t BaseBoolType -> [(Expr t BaseBoolType, Polarity)]
asConjunction (BoolExpr True _) = []
asConjunction (asApp -> Just (ConjPred xs)) =
case BM.viewBoolMap xs of
BoolMapUnit -> []
BoolMapDualUnit -> [(BoolExpr False initializationLoc, Positive)]
BoolMapTerms (tm:|tms) -> tm:tms
case BM.viewConjMap xs of
BM.ConjTrue -> []
BM.ConjFalse -> [(BoolExpr False initializationLoc, Positive)]
BM.Conjuncts (tm:|tms) -> tm:tms
asConjunction x = [(x,Positive)]


asDisjunction :: Expr t BaseBoolType -> [(Expr t BaseBoolType, Polarity)]
asDisjunction (BoolExpr False _) = []
asDisjunction (asApp -> Just (NotPred (asApp -> Just (ConjPred xs)))) =
case BM.viewBoolMap xs of
BoolMapUnit -> []
BoolMapDualUnit -> [(BoolExpr True initializationLoc, Positive)]
BoolMapTerms (tm:|tms) -> map (over _2 BM.negatePolarity) (tm:tms)
case BM.viewConjMap xs of
BM.ConjTrue -> []
BM.ConjFalse -> [(BoolExpr True initializationLoc, Positive)]
BM.Conjuncts (tm:|tms) -> map (over _2 BM.negatePolarity) (tm:tms)
asDisjunction x = [(x,Positive)]

asPosAtom :: Expr t BaseBoolType -> (Expr t BaseBoolType, Polarity)
Expand Down Expand Up @@ -2086,11 +2089,11 @@ reduceApp sym unary a0 = do
BaseEq _ x y -> isEq sym x y

NotPred x -> notPred sym x
ConjPred bm ->
case BM.viewBoolMap bm of
BoolMapDualUnit -> return $ falsePred sym
BoolMapUnit -> return $ truePred sym
BoolMapTerms tms ->
ConjPred cm ->
case BM.viewConjMap cm of
BM.ConjFalse -> return $ falsePred sym
BM.ConjTrue -> return $ truePred sym
BM.Conjuncts tms ->
do let pol (p, Positive) = return p
pol (p, Negative) = notPred sym p
x:|xs <- mapM pol tms
Expand Down Expand Up @@ -2337,14 +2340,14 @@ ppApp' a0 = do

NotPred x -> ppSExpr "not" [x]

ConjPred xs ->
ConjPred cm ->
let pol (x,Positive) = exprPrettyArg x
pol (x,Negative) = PrettyFunc "not" [ exprPrettyArg x ]
in
case BM.viewBoolMap xs of
BoolMapUnit -> prettyApp "true" []
BoolMapDualUnit -> prettyApp "false" []
BoolMapTerms tms -> prettyApp "and" (map pol (toList tms))
case BM.viewConjMap cm of
BM.ConjTrue -> prettyApp "true" []
BM.ConjFalse-> prettyApp "false" []
BM.Conjuncts tms -> prettyApp "and" (map pol (toList tms))

RealIsInteger x -> ppSExpr "isInteger" [x]
BVTestBit i x -> prettyApp "testBit" [exprPrettyArg x, showPrettyArg i]
Expand Down
55 changes: 54 additions & 1 deletion what4/src/What4/Expr/BoolMap.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ laws like commutativity, associativity and resolution.
-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ViewPatterns #-}

module What4.Expr.BoolMap
( BoolMap
, var
Expand All @@ -30,6 +32,12 @@ module What4.Expr.BoolMap
, reversePolarities
, removeVar
, Wrap(..)
-- * 'ConjMap'
, ConjMap(..)
, ConjMapView(..)
, viewConjMap
, addConjunct
, evalConj
) where

import Control.Lens (_1, over)
Expand All @@ -42,6 +50,7 @@ import Data.Parameterized.Classes
import What4.BaseTypes
import qualified What4.Utils.AnnotatedMap as AM
import What4.Utils.IncrHash
import Data.Coerce (coerce)

-- | Describes the occurrence of a variable or expression, whether it is
-- negated or not.
Expand Down Expand Up @@ -91,6 +100,8 @@ instance OrdF f => Eq (BoolMap f) where
BoolMap m1 == BoolMap m2 = AM.eqBy (==) m1 m2
_ == _ = False

instance OrdF f => Semigroup (BoolMap f) where
(<>) = combine

-- | Traverse the expressions in a bool map, and rebuild the map.
traverseVars :: (Applicative m, HashableF g, OrdF g) =>
Expand Down Expand Up @@ -182,3 +193,45 @@ reversePolarities (BoolMap m) = BoolMap $! fmap negatePolarity m
removeVar :: OrdF f => BoolMap f -> f BaseBoolType -> BoolMap f
removeVar InconsistentMap _ = InconsistentMap
removeVar (BoolMap m) x = BoolMap (AM.delete (Wrap x) m)

--------------------------------------------------------------------------------
-- ConjMap

newtype ConjMap f = ConjMap { getConjMap :: BoolMap f }
deriving (Eq, Hashable, Semigroup)

-- | Represents the state of a 'ConjMap'. See 'viewConjMap'.
--
-- Like 'BoolMapView', but with more specific names for readability.
data ConjMapView f
= ConjFalse
-- ^ A 'ConjMap' with no expressions
| ConjTrue
-- ^ An inconsistent 'ConjMap'
| Conjuncts (NonEmpty (f BaseBoolType, Polarity))
-- ^ The terms appearing in the 'ConjMap', of which there is at least one

conjMapView :: BoolMapView f -> ConjMapView f
conjMapView BoolMapUnit = ConjTrue
conjMapView BoolMapDualUnit = ConjFalse
conjMapView (BoolMapTerms ts) = Conjuncts ts

-- | Deconstruct the given 'ConjMap' for later processing
viewConjMap :: ConjMap f -> ConjMapView f
viewConjMap = conjMapView . viewBoolMap . getConjMap
{-# INLINE viewConjMap #-}

addConjunct :: (HashableF f, OrdF f) => f BaseBoolType -> Polarity -> ConjMap f -> ConjMap f
addConjunct t p = coerce (addVar t p)
{-# INLINE addConjunct #-}

evalConj :: Applicative m => (f BaseBoolType -> m Bool) -> ConjMap f -> m Bool
evalConj f cm =
let pol (x, Positive) = f x
pol (x, Negative) = not <$> f x
in
case viewConjMap cm of
ConjTrue -> pure True
ConjFalse -> pure False
Conjuncts (t:|ts) ->
List.foldl' (&&) <$> pol t <*> traverse pol ts
34 changes: 17 additions & 17 deletions what4/src/What4/Expr/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1548,16 +1548,16 @@ realSum sym s = semiRingSum sym s
bvSum :: ExprBuilder t st fs -> WeightedSum (Expr t) (SR.SemiRingBV flv w) -> IO (BVExpr t w)
bvSum sym s = semiRingSum sym s

conjPred :: ExprBuilder t st fs -> BoolMap (Expr t) -> IO (BoolExpr t)
conjPred sym bm =
case BM.viewBoolMap bm of
BoolMapUnit -> return $ truePred sym
BoolMapDualUnit -> return $ falsePred sym
BoolMapTerms ((x,p):|[]) ->
conjPred :: ExprBuilder t st fs -> BM.ConjMap (Expr t) -> IO (BoolExpr t)
conjPred sym cm =
case BM.viewConjMap cm of
BM.ConjTrue -> return $ truePred sym
BM.ConjFalse -> return $ falsePred sym
BM.Conjuncts ((x,p):|[]) ->
case p of
Positive -> return x
Negative -> notPred sym x
_ -> sbMakeExpr sym $ ConjPred bm
_ -> sbMakeExpr sym $ ConjPred cm

bvUnary :: (1 <= w) => ExprBuilder t st fs -> UnaryBV (BoolExpr t) w -> IO (BVExpr t w)
bvUnary sym u
Expand Down Expand Up @@ -1985,7 +1985,7 @@ tryAndAbsorption ::
BoolExpr t ->
Bool
tryAndAbsorption (asApp -> Just (NotPred (asApp -> Just (ConjPred as)))) (asConjunction -> bs)
= checkAbsorption (BM.reversePolarities as) bs
= checkAbsorption (BM.reversePolarities (BM.getConjMap as)) bs
tryAndAbsorption _ _ = False


Expand All @@ -1997,7 +1997,7 @@ tryOrAbsorption ::
BoolExpr t ->
Bool
tryOrAbsorption (asApp -> Just (ConjPred as)) (asDisjunction -> bs)
= checkAbsorption as bs
= checkAbsorption (BM.getConjMap as) bs
tryOrAbsorption _ _ = False


Expand Down Expand Up @@ -2095,7 +2095,7 @@ instance IsExprBuilder (ExprBuilder t st fs) where
go a b
| Just (ConjPred as) <- asApp a
, Just (ConjPred bs) <- asApp b
= conjPred sym $ BM.combine as bs
= conjPred sym $ as <> bs

| tryAndAbsorption a b
= return b
Expand All @@ -2104,13 +2104,13 @@ instance IsExprBuilder (ExprBuilder t st fs) where
= return a

| Just (ConjPred as) <- asApp a
= conjPred sym $ uncurry BM.addVar (asPosAtom b) as
= conjPred sym $ uncurry BM.addConjunct (asPosAtom b) as

| Just (ConjPred bs) <- asApp b
= conjPred sym $ uncurry BM.addVar (asPosAtom a) bs
= conjPred sym $ uncurry BM.addConjunct (asPosAtom a) bs

| otherwise
= conjPred sym $ BM.fromVars [asPosAtom a, asPosAtom b]
= conjPred sym $ BM.ConjMap (BM.fromVars [asPosAtom a, asPosAtom b])

orPred sym x y =
case (asConstantPred x, asConstantPred y) of
Expand All @@ -2125,7 +2125,7 @@ instance IsExprBuilder (ExprBuilder t st fs) where
go a b
| Just (NotPred (asApp -> Just (ConjPred as))) <- asApp a
, Just (NotPred (asApp -> Just (ConjPred bs))) <- asApp b
= notPred sym =<< conjPred sym (BM.combine as bs)
= notPred sym =<< conjPred sym (as <> bs)

| tryOrAbsorption a b
= return b
Expand All @@ -2134,13 +2134,13 @@ instance IsExprBuilder (ExprBuilder t st fs) where
= return a

| Just (NotPred (asApp -> Just (ConjPred as))) <- asApp a
= notPred sym =<< conjPred sym (uncurry BM.addVar (asNegAtom b) as)
= notPred sym =<< conjPred sym (uncurry BM.addConjunct (asNegAtom b) as)

| Just (NotPred (asApp -> Just (ConjPred bs))) <- asApp b
= notPred sym =<< conjPred sym (uncurry BM.addVar (asNegAtom a) bs)
= notPred sym =<< conjPred sym (uncurry BM.addConjunct (asNegAtom a) bs)

| otherwise
= notPred sym =<< conjPred sym (BM.fromVars [asNegAtom a, asNegAtom b])
= notPred sym =<< conjPred sym (BM.ConjMap (BM.fromVars [asNegAtom a, asNegAtom b]))

itePred sb c x y
-- ite c c y = c || y
Expand Down
10 changes: 1 addition & 9 deletions what4/src/What4/Expr/GroundEval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,7 @@ evalGroundApp f a0 = do
if xv then f y else f z

NotPred x -> not <$> f x
ConjPred xs ->
let pol (x,Positive) = f x
pol (x,Negative) = not <$> f x
in
case BM.viewBoolMap xs of
BM.BoolMapUnit -> return True
BM.BoolMapDualUnit -> return False
BM.BoolMapTerms (t:|ts) ->
foldl' (&&) <$> pol t <*> mapM pol ts
ConjPred cm -> BM.evalConj f cm

RealIsInteger x -> (\xv -> denominator xv == 1) <$> f x
BVTestBit i x ->
Expand Down
10 changes: 5 additions & 5 deletions what4/src/What4/Expr/VarIdentification.hs
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,14 @@ recurseAssertedAppExprVars scope p e = go e
go (asApp -> Just (NotPred x)) =
recordAssertionVars scope (BM.negatePolarity p) x

go (asApp -> Just (ConjPred xs)) =
go (asApp -> Just (ConjPred cm)) =
let pol (x,BM.Positive) = recordAssertionVars scope p x
pol (x,BM.Negative) = recordAssertionVars scope (BM.negatePolarity p) x
in
case BM.viewBoolMap xs of
BM.BoolMapUnit -> return ()
BM.BoolMapDualUnit -> return ()
BM.BoolMapTerms (t:|ts) -> mapM_ pol (t:ts)
case BM.viewConjMap cm of
BM.ConjTrue -> return ()
BM.ConjFalse -> return ()
BM.Conjuncts (t:|ts) -> mapM_ pol (t:ts)

go (asApp -> Just (BaseIte BaseBoolRepr _ c x y)) =
do recordExprVars scope c
Expand Down
10 changes: 5 additions & 5 deletions what4/src/What4/Protocol/SMTWriter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2134,14 +2134,14 @@ appSMTExpr ae = do
let pol (x,Positive) = mkBaseExpr x
pol (x,Negative) = notExpr <$> mkBaseExpr x
in
case BM.viewBoolMap xs of
BM.BoolMapUnit ->
case BM.viewConjMap xs of
BM.ConjTrue ->
return $ SMTExpr BoolTypeMap $ boolExpr True
BM.BoolMapDualUnit ->
BM.ConjFalse ->
return $ SMTExpr BoolTypeMap $ boolExpr False
BM.BoolMapTerms (t:|[]) ->
BM.Conjuncts (t:|[]) ->
SMTExpr BoolTypeMap <$> pol t
BM.BoolMapTerms (t:|ts) ->
BM.Conjuncts (t:|ts) ->
do cnj <- andAll <$> mapM pol (t:ts)
freshBoundTerm BoolTypeMap cnj

Expand Down
2 changes: 1 addition & 1 deletion what4/src/What4/Protocol/VerilogWriter/Backend.hs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ appVerilogExpr app =
e' <- exprToVerilogExpr e
unop Not e'
--DisjPred es -> boolMapToExpr False True Or es
ConjPred es -> boolMapToExpr True False And es
ConjPred es -> boolMapToExpr True False And (BMap.getConjMap es)

-- Semiring operations
-- We only support bitvector semiring operations
Expand Down
2 changes: 1 addition & 1 deletion what4/src/What4/Serialize/Printer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ convertAppExpr' = go . W4.appExprApp
go (W4.NotPred e) = do
s <- goE e
return $ S.L [ident "notp", s]
go (W4.ConjPred bm) = convertBoolMap "andp" True bm
go (W4.ConjPred cm) = convertBoolMap "andp" True (BooM.getConjMap cm)
go (W4.BVSlt e1 e2) = do
s1 <- goE e1
s2 <- goE e2
Expand Down

0 comments on commit e4a00f0

Please sign in to comment.