diff --git a/pate.cabal b/pate.cabal index d2423fae..120c48f1 100644 --- a/pate.cabal +++ b/pate.cabal @@ -168,6 +168,8 @@ library What4.UninterpFns, What4.JSON, What4.SymSequence, + What4.Simplify, + What4.Simplify.Bitvector, Data.Parameterized.SetF, Data.Parameterized.SetCtx, Data.Parameterized.PairF, diff --git a/src/Pate/Verification/ConditionalEquiv.hs b/src/Pate/Verification/ConditionalEquiv.hs index 49523f9f..fffc0120 100644 --- a/src/Pate/Verification/ConditionalEquiv.hs +++ b/src/Pate/Verification/ConditionalEquiv.hs @@ -256,7 +256,7 @@ checkAndMinimizeEqCondition cond goal = fnTrace "checkAndMinimizeEqCondition" $ goalTimeout <- CMR.asks (PC.cfgGoalTimeout . envConfig) -- this check is not strictly necessary, as the incremental checks should guarantee it -- for the moment it's a sanity check on this process as well as the final simplifications - cond' <- PSi.simplifyPred_deep cond >>= \cond' -> (liftIO $ WEH.stripAnnotations sym cond') + cond' <- PSi.applySimpStrategy PSi.deepPredicateSimplifier cond result <- withSatAssumption (PAS.fromPred cond') $ do isPredTrue' goalTimeout goal case result of diff --git a/src/Pate/Verification/Simplify.hs b/src/Pate/Verification/Simplify.hs index 2ca21fcd..03a74d19 100644 --- a/src/Pate/Verification/Simplify.hs +++ b/src/Pate/Verification/Simplify.hs @@ -15,9 +15,12 @@ module Pate.Verification.Simplify ( , simplifyBVOps_trace , Simplifier , applySimplifier - , runSimplifier - , getSimplifier + , WEH.runSimplifier + , mkSimplifier , deepPredicateSimplifier + , prettySimplifier + , coreStrategy + , applySimpStrategy ) where import Control.Monad (foldM) @@ -38,6 +41,7 @@ import qualified Pate.ExprMappable as PEM import qualified Pate.Equivalence.Error as PEE import Pate.Monad import qualified What4.ExprHelpers as WEH +import What4.ExprHelpers (Simplifier, SimpStrategy) import Pate.TraceTree import qualified Data.Set as Set import Pate.AssumptionSet @@ -121,7 +125,7 @@ simplifyWithSolver a = withValid $ withSym $ \sym -> do IO.withRunInIO $ \runInIO -> PEM.mapExpr sym (\e -> runInIO (doSimp e)) a tracedSimpCheck :: forall sym arch. WEH.SimpCheck sym (EquivM_ sym arch) -tracedSimpCheck = WEH.SimpCheck $ \e_orig e_simp -> withValid $ withSym $ \sym -> do +tracedSimpCheck = WEH.SimpCheck (emitTrace @"debug") $ \ce_trace e_orig e_simp -> withValid $ withSym $ \sym -> do not_valid <- liftIO $ (W4.isEq sym e_orig e_simp >>= W4.notPred sym) goalSat "tracedSimpCheck" not_valid $ \case W4R.Unsat{} -> return e_simp @@ -137,6 +141,9 @@ tracedSimpCheck = WEH.SimpCheck $ \e_orig e_simp -> withValid $ withSym $ \sym - emitTraceLabel @"expr" "simplified" (Some e_simp) e_orig_conc <- concretizeWithModel fn e_orig e_simp_conc <- concretizeWithModel fn e_simp + do + SymGroundEvalFn fn' <- return fn + ce_trace fn' vars <- fmap Set.toList $ liftIO $ WEH.boundVars e_orig binds <- foldM (\asms (Some var) -> do conc <- concretizeWithModel fn (W4.varExpr sym var) @@ -161,81 +168,85 @@ getSimpCheck = do -- Additionally, simplify array lookups across unrelated updates. simplifyPred_deep :: forall sym arch. - W4.Pred sym -> - EquivM sym arch (W4.Pred sym) -simplifyPred_deep p = withSym $ \sym -> do + SimpStrategy sym (EquivM_ sym arch) +simplifyPred_deep = WEH.SimpStrategy $ \_ simp_check -> withValid $ withSym $ \sym -> do heuristicTimeout <- CMR.asks (PC.cfgHeuristicTimeout . envConfig) cache <- W4B.newIdxCache + fn_cache <- W4B.newIdxCache let checkPred :: W4.Pred sym -> EquivM sym arch Bool - checkPred p' = fmap getConst $ W4B.idxCacheEval cache p' $ - Const <$> isPredTrue' heuristicTimeout p' - -- remove redundant atoms - p1 <- WEH.minimalPredAtoms sym (\x -> checkPred x) p - -- resolve array lookups across unrelated updates - p2 <- WEH.resolveConcreteLookups sym (\p' -> return $ W4.asConstantPred p') p1 - -- additional bitvector simplifications - p3 <- liftIO $ WEH.simplifyBVOps sym p2 - -- drop any muxes across equality tests - p4 <- liftIO $ WEH.expandMuxEquality sym p3 - -- remove redundant conjuncts - p_final <- WEH.simplifyConjuncts sym (\x -> checkPred x) p4 - -- TODO: redundant sanity check that simplification hasn't clobbered anything - validSimpl <- liftIO $ W4.isEq sym p p_final - goal <- liftIO $ W4.notPred sym validSimpl - r <- checkSatisfiableWithModel heuristicTimeout "SimplifierConsistent" goal $ \sr -> - case sr of - W4R.Unsat _ -> return p_final - W4R.Sat _ -> do - traceM "ERROR: simplifyPred_deep: simplifier broken" - traceM "Original:" - traceM (show (W4.printSymExpr p)) - traceM "Simplified:" - traceM (show (W4.printSymExpr p_final)) - return p - W4R.Unknown -> do - traceM ("WARNING: simplifyPred_deep: simplifier timeout") - return p - case r of - Left exn -> do - traceM ("ERROR: simplifyPred_deep: exception " ++ show exn) - return p - Right r' -> return r' - - - - - -data Simplifier sym arch = Simplifier { runSimplifier :: forall tp. W4.SymExpr sym tp -> EquivM_ sym arch (W4.SymExpr sym tp) } + checkPred p' = fmap getConst $ W4B.idxCacheEval cache p' $ do + p'' <- WEH.unfoldDefinedFns sym (Just fn_cache) p' + Const <$> isPredTrue' heuristicTimeout p'' + return $ WEH.Simplifier $ \p -> case W4.exprType p of + W4.BaseBoolRepr -> do + -- remove redundant atoms + p1 <- WEH.minimalPredAtoms sym (\x -> checkPred x) p + -- resolve array lookups across unrelated updates + p2 <- WEH.resolveConcreteLookups sym (\p' -> return $ W4.asConstantPred p') p1 + -- additional bitvector simplifications + p3 <- liftIO $ WEH.simplifyBVOps sym p2 + -- drop any muxes across equality tests + p4 <- liftIO $ WEH.expandMuxEquality sym p3 + -- remove redundant conjuncts + p_final <- WEH.simplifyConjuncts sym (\x -> checkPred x) p4 + WEH.runSimpCheck simp_check p p_final + _ -> return p applySimplifier :: PEM.ExprMappable sym v => - Simplifier sym arch -> + Simplifier sym (EquivM_ sym arch) -> v -> EquivM sym arch v applySimplifier simplifier v = withSym $ \sym -> do shouldCheck <- CMR.asks (PC.cfgCheckSimplifier . envConfig) case shouldCheck of - True -> withTracing @"debug_tree" "Simplifier" $ PEM.mapExpr sym (runSimplifier simplifier) v - False -> withNoTracing $ PEM.mapExpr sym (runSimplifier simplifier) v + True -> withTracing @"debug_tree" "Simplifier" $ PEM.mapExpr sym (WEH.runSimplifier simplifier) v + False -> withNoTracing $ PEM.mapExpr sym (WEH.runSimplifier simplifier) v -deepPredicateSimplifier :: forall sym arch. EquivM sym arch (Simplifier sym arch) -deepPredicateSimplifier = withSym $ \sym -> do - Simplifier f <- getSimplifier - return $ Simplifier $ \e0 -> do - e1 <- liftIO $ WEH.stripAnnotations sym e0 - e2 <- f e1 - e4 <- case W4.exprType e0 of - W4.BaseBoolRepr -> simplifyPred_deep e2 - _ -> return e2 - applyCurrentAsms e4 +applySimpStrategy :: + PEM.ExprMappable sym v => + SimpStrategy sym (EquivM_ sym arch) -> + v -> + EquivM sym arch v +applySimpStrategy strat v = do + simp <- mkSimplifier strat + applySimplifier simp v -getSimplifier :: forall sym arch. EquivM sym arch (Simplifier sym arch) -getSimplifier = withSym $ \sym -> do - heuristicTimeout <- CMR.asks (PC.cfgHeuristicTimeout . envConfig) +mkSimplifier :: SimpStrategy sym (EquivM_ sym arch) -> EquivM sym arch (Simplifier sym (EquivM_ sym arch)) +mkSimplifier strat = withSym $ \sym -> do + check <- getSimpCheck + WEH.mkSimplifier sym check strat + +deepPredicateSimplifier :: SimpStrategy sym (EquivM_ sym arch) +deepPredicateSimplifier = WEH.joinStrategy $ withValid $ do + let + stripAnnStrat = WEH.mkSimpleStrategy $ \sym e -> liftIO $ WEH.stripAnnotations sym e + applyAsmsStrat = WEH.mkSimpleStrategy $ \_ -> applyCurrentAsmsExpr + return $ stripAnnStrat <> coreStrategy <> simplifyPred_deep <> applyAsmsStrat + + +-- | Simplifier that should only be used to display terms. +-- Interleaved with the deep predicate simplifier in order to +-- drop any redundant terms that are introduced. +prettySimplifier :: forall sym arch. SimpStrategy sym (EquivM_ sym arch) +prettySimplifier = deepPredicateSimplifier <> base <> deepPredicateSimplifier <> base + where + base :: SimpStrategy sym (EquivM_ sym arch) + base = WEH.joinStrategy $ withValid $ + return $ WEH.bvPrettySimplify <> WEH.memReadPrettySimplify <> WEH.collapseBVOps + +-- TODO: the "core" simplification strategy that stitches together the main strategies +-- from 'What4.ExprHelpers'. These are implemented in "old" style (i.e. as expression +-- transformers instead of 'SimpStrategy's.) and so we lift them into a 'SimpStrategy' here. +-- In general these should individually be implemented as strategies so that +-- this glue code is just trivially combining them. + +coreStrategy :: forall sym arch. SimpStrategy sym (EquivM_ sym arch) +coreStrategy = WEH.joinStrategy $ withValid $ return $ WEH.SimpStrategy $ \sym check -> do + ecache <- W4B.newIdxCache conccache <- W4B.newIdxCache - ecache <- W4B.newIdxCache - + heuristicTimeout <- CMR.asks (PC.cfgHeuristicTimeout . envConfig) let concPred :: W4.Pred sym -> EquivM_ sym arch (Maybe Bool) concPred p | Just b <- W4.asConstantPred p = return $ Just b @@ -251,12 +262,12 @@ getSimplifier = withSym $ \sym -> do shouldCheck <- CMR.asks (PC.cfgCheckSimplifier . envConfig) case shouldCheck of True -> withTracing @"debug_tree" "Simplifier Check" $ do - e'' <- WEH.runSimpCheck tracedSimpCheck e e' + e'' <- WEH.runSimpCheck check e e' case W4.testEquality e' e'' of Just W4.Refl -> return e'' Nothing -> do -- re-run the simplifier with tracing enabled - _ <- simp tracedSimpCheck e + _ <- simp check e return e False -> return e' @@ -272,7 +283,7 @@ getSimplifier = withSym $ \sym -> do e3 <- liftIO $ WEH.fixMux sym e2 emitIfChanged "fixMux" e2 e3 return e3 - return $ Simplifier $ \v -> PEM.mapExpr sym simp_wrapped v + return $ WEH.Simplifier simp_wrapped emitIfChanged :: ExprLabel -> diff --git a/src/Pate/Verification/StrongestPosts.hs b/src/Pate/Verification/StrongestPosts.hs index 5c5fa3b7..a6f5df72 100644 --- a/src/Pate/Verification/StrongestPosts.hs +++ b/src/Pate/Verification/StrongestPosts.hs @@ -828,13 +828,13 @@ showFinalResult pg = withTracing @"final_result" () $ withSym $ \sym -> do s <- withFreshScope (graphNodeBlocks nd) $ \scope -> do (_,cond) <- IO.liftIO $ PS.bindSpec sym (PS.scopeVarsPair scope) cond_spec (tr, _) <- withGraphNode scope nd pg $ \bundle d -> do - simplifier <- PSi.deepPredicateSimplifier - cond_simplified <- PSi.applySimplifier simplifier cond + cond_simplified <- PSi.applySimpStrategy PSi.deepPredicateSimplifier cond eqCond_pred <- PEC.toPred sym cond_simplified (mtraceT, mtraceF) <- getTracesForPred scope bundle d eqCond_pred case (mtraceT, mtraceF) of - (Just traceT, Just traceF) -> - return $ (Just (FinalEquivCond eqCond_pred traceT traceF), pg) + (Just traceT, Just traceF) -> do + cond_pretty <- PSi.applySimpStrategy PSi.prettySimplifier eqCond_pred + return $ (Just (FinalEquivCond cond_pretty traceT traceF), pg) _ -> return (Nothing, pg) return $ (Const (fmap (nd,) tr)) return $ PS.viewSpec s (\_ -> getConst) @@ -1185,7 +1185,8 @@ getTracesForPred :: EquivM sym arch (Maybe (CE.TraceEvents sym arch), Maybe (CE.TraceEvents sym arch)) getTracesForPred scope bundle dom p = withSym $ \sym -> do not_p <- liftIO $ W4.notPred sym p - withTracing @"expr" (Some p) $ do + p_pretty <- PSi.applySimpStrategy PSi.prettySimplifier p + withTracing @"expr" (Some p_pretty) $ do mtraceT <- withTracing @"message" "With condition assumed" $ withSatAssumption (PAS.fromPred p) $ do traceT <- getSomeGroundTrace scope bundle dom Nothing @@ -1291,8 +1292,7 @@ withSimBundle :: EquivM sym arch a withSimBundle pg vars node f = do bundle0 <- mkSimBundle pg node vars - simplifier <- PSi.getSimplifier - bundle1 <- PSi.applySimplifier simplifier bundle0 + bundle1 <- PSi.applySimpStrategy PSi.coreStrategy bundle0 bundle <- applyCurrentAsms bundle1 emitTrace @"bundle" (Some bundle) f bundle @@ -1638,7 +1638,7 @@ doCheckObservables scope ne bundle preD pg = case PS.simOut bundle of case mcondK of Just (condK, p) -> do let do_propagate = shouldPropagate (getPropagationKind pg nd condK) - simplifier <- PSi.deepPredicateSimplifier + simplifier <- PSi.mkSimplifier PSi.deepPredicateSimplifier eqSeq_simp <- PSi.applySimplifier simplifier eqSeq withPG pg $ do liftEqM_ $ addToEquivCondition scope nd condK eqSeq_simp diff --git a/src/Pate/Verification/Widening.hs b/src/Pate/Verification/Widening.hs index 53ef22dc..d55d3998 100644 --- a/src/Pate/Verification/Widening.hs +++ b/src/Pate/Verification/Widening.hs @@ -442,10 +442,9 @@ addRefinementChoice nd gr0 = withTracing @"message" "Modify Proof Node" $ do emitTrace @"message" (conditionName condK ++ " Discharged") return Nothing False -> do - simplifier <- PSi.deepPredicateSimplifier curAsm <- currentAsm emitTrace @"assumption" curAsm - eqCond_pred_simp <- PSi.applySimplifier simplifier eqCond_pred + eqCond_pred_simp <- PSi.applySimpStrategy PSi.deepPredicateSimplifier eqCond_pred emitTraceLabel @"expr" (ExprLabel $ "Simplified " ++ conditionName condK) (Some eqCond_pred_simp) return $ Just eqCond_pred_simp case meqCond_pred' of diff --git a/src/What4/ExprHelpers.hs b/src/What4/ExprHelpers.hs index 28430f73..81b73dc7 100644 --- a/src/What4/ExprHelpers.hs +++ b/src/What4/ExprHelpers.hs @@ -64,12 +64,8 @@ module What4.ExprHelpers ( , simplifyBVOps' , simplifyConjuncts , boundVars - , setProgramLoc , idxCacheEvalWriter , Tagged - , SimpCheck(..) - , noSimpCheck - , unliftSimpCheck , assumePositiveInt , integerToNat , asConstantOffset @@ -78,23 +74,37 @@ module What4.ExprHelpers ( , assertPositiveNat , printAtoms , iteToImp + , unfoldDefinedFns + -- re-exports from What4.Simplify + , setProgramLoc + , Simplifier(..) + , mkSimplifier + , SimpStrategy(..) + , joinStrategy + , mkSimpleStrategy + , SimpCheck(..) + , wrapSimpSolverCheck + , runSimpCheck + , noSimpCheck + -- re-exports from What4.Simplify.Bitvector + , W4SBV.bvPrettySimplify + , W4SBV.memReadPrettySimplify + , W4SBV.collapseBVOps ) where import GHC.TypeNats import Unsafe.Coerce ( unsafeCoerce ) -- for mulMono axiom -import Control.Lens ( (.~), (&), (^.) ) import Control.Applicative import Control.Monad (foldM) -import Control.Monad.Except import Control.Monad.IO.Class (MonadIO, liftIO) import qualified Control.Monad.IO.Class as IO import qualified Control.Monad.IO.Unlift as IO import Control.Monad.ST ( RealWorld, stToIO ) import qualified Control.Monad.Writer as CMW -import qualified Control.Monad.State as CMS -import Control.Monad.Trans (lift) -import Control.Monad.Trans.Maybe (MaybeT, runMaybeT) + +import Control.Monad.Trans (lift, MonadTrans(..)) +import Control.Monad.Trans.Maybe (MaybeT(..), runMaybeT) import qualified System.IO as IO import qualified Prettyprinter as PP @@ -121,7 +131,6 @@ import qualified Lang.Crucible.CFG.Core as CC import qualified Lang.Crucible.LLVM.MemModel as CLM import qualified What4.Expr.Builder as W4B -import qualified What4.ProgramLoc as W4PL import qualified What4.Expr.ArrayUpdateMap as AUM import qualified What4.Expr.GroundEval as W4G import qualified What4.Expr.WeightedSum as WSum @@ -131,10 +140,13 @@ import qualified What4.SemiRing as SR import qualified What4.Expr.BoolMap as BM import qualified What4.Symbol as WS import qualified What4.Utils.AbstractDomains as W4AD +import What4.Simplify import Data.Parameterized.SetF (SetF) import qualified Data.Parameterized.SetF as SetF import Data.Maybe (fromMaybe) +import qualified What4.Simplify.Bitvector as W4SBV +import What4.Simplify.Bitvector (asSimpleSum) -- | Sets the abstract domain of the given integer to assume -- that it is positive. @@ -1413,81 +1425,27 @@ simplifyBVOps' sym simp_check outer = do inIO (go outer) - -asSimpleSum :: - forall sym sr w. - sym -> - W4.NatRepr w -> - WSum.WeightedSum (W4.SymExpr sym) (SR.SemiRingBV sr w) -> - Maybe ([W4.SymBV sym w], BVS.BV w) -asSimpleSum _ _ ws = do - terms <- WSum.evalM - (\x y -> return $ x ++ y) - (\c e -> case c == one of {True -> return [e]; False -> fail ""}) - (\c -> case c == zero of { True -> return []; False -> fail ""}) - (ws & WSum.sumOffset .~ zero) - return $ (terms, ws ^. WSum.sumOffset ) - where - one :: BVS.BV w - one = SR.one (WSum.sumRepr ws) - - zero :: BVS.BV w - zero = SR.zero (WSum.sumRepr ws) - - - --- | An action for validating a simplification step. --- After a step is taken, this function is given the original expression as the --- first argument and the simplified expression as the second argument. --- This action should check that the original and simplified expressions are equal, --- and return the simplified expression if they are, or the original expression if they are not, --- optionally raising any exceptions or warnings in the given monad. -newtype SimpCheck sym m = SimpCheck - { runSimpCheck :: forall tp. W4.SymExpr sym tp -> W4.SymExpr sym tp -> m (W4.SymExpr sym tp) } - - -noSimpCheck :: Applicative m => SimpCheck sym m -noSimpCheck = SimpCheck (\_ e_patched -> pure e_patched) - -unliftSimpCheck :: IO.MonadUnliftIO m => SimpCheck sym m -> m (SimpCheck sym IO) -unliftSimpCheck simp_check = IO.withRunInIO $ \inIO -> do - return $ SimpCheck (\e1 e2 -> inIO (runSimpCheck simp_check e1 e2)) - -simplifyApp :: +unfoldDefinedFns :: forall sym m t solver fs tp. - IO.MonadIO m => sym ~ (W4B.ExprBuilder t solver fs) => + IO.MonadIO m => sym -> - W4B.IdxCache t (W4B.Expr t) -> - SimpCheck sym m {- ^ double-check simplification step -} -> - (forall tp'. W4B.App (W4B.Expr t) tp' -> m (Maybe (W4.SymExpr sym tp'))) {- ^ app simplification -} -> + Maybe (W4B.IdxCache t (W4B.Expr t)) -> W4.SymExpr sym tp -> m (W4.SymExpr sym tp) -simplifyApp sym cache simp_check simp_app outer = do +unfoldDefinedFns sym mcache e_outer = do + cache <- fromMaybe W4B.newIdxCache (fmap return mcache) let - else_ :: forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp') - else_ e = do - e' <- case e of - W4B.AppExpr a0 -> do - a0' <- W4B.traverseApp go (W4B.appExprApp a0) - if (W4B.appExprApp a0) == a0' then return e - else (liftIO $ W4B.sbMakeExpr sym a0') >>= go - W4B.NonceAppExpr a0 -> do - a0' <- TFC.traverseFC go (W4B.nonceExprApp a0) - if (W4B.nonceExprApp a0) == a0' then return e - else (liftIO $ W4B.sbNonceExpr sym a0') >>= go - _ -> return e - runSimpCheck simp_check e e' - - go :: forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp') - go e = W4B.idxCacheEval cache e $ do - setProgramLoc sym e - case e of - W4B.AppExpr a0 -> simp_app (W4B.appExprApp a0) >>= \case - Just e' -> runSimpCheck simp_check e e' - Nothing -> else_ e - _ -> else_ e - go outer + go :: forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp') + go = simplifyGenApp sym cache noSimpCheck (\_ -> return Nothing) goNonceApp + + goNonceApp :: forall tp'. W4B.NonceApp t (W4B.Expr t) tp' -> m (Maybe (W4.SymExpr sym tp')) + goNonceApp (W4B.FnApp fn args) | W4B.DefinedFnInfo vars body _ <- W4B.symFnInfo fn = do + fn' <- IO.liftIO $ W4.definedFn sym W4.emptySymbol vars body W4.AlwaysUnfold + e <- IO.liftIO $ W4.applySymFn sym fn' args + Just <$> go e + goNonceApp _ = return Nothing + go e_outer -- (if x then y else z) ==> (x -> y) AND (NOT(x) -> z) -- Truth table: @@ -1634,16 +1592,7 @@ simplifyConjuncts sym provable p_outer = do go ConjunctFoldLeft (W4.truePred sym) p -setProgramLoc :: - forall m sym t solver fs tp. - IO.MonadIO m => - sym ~ (W4B.ExprBuilder t solver fs) => - sym -> - W4.SymExpr sym tp -> - m () -setProgramLoc sym e = case W4PL.plSourceLoc (W4B.exprLoc e) of - W4PL.InternalPos -> return () - _ -> liftIO $ W4.setCurrentProgramLoc sym (W4B.exprLoc e) + data Tagged w f tp where Tagged :: w -> f tp -> Tagged w f tp @@ -1661,4 +1610,4 @@ idxCacheEvalWriter cache e f = do (result, w) <- CMW.listen $ f return $ Tagged w result CMW.tell w - return result + return result \ No newline at end of file diff --git a/src/What4/Simplify.hs b/src/What4/Simplify.hs new file mode 100644 index 00000000..1cb7c97c --- /dev/null +++ b/src/What4/Simplify.hs @@ -0,0 +1,357 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE LambdaCase #-} + +module What4.Simplify ( + SimpCheck(..) + , noSimpCheck + , unliftSimpCheck + , runSimpCheck + , runSimpCheckTrace + , wrapSimpSolverCheck + , SimpT + , simpLog + , simpCheck + , simpCheckTraced + , runSimpT + , simpMaybe + , withSym + , appAlts + , asApp + , tryAll + , Simplifier(..) + , SimpStrategy(..) + , mkSimpleStrategy + , joinStrategy + , mkSimplifier + , simplifyGenApp + , simplifyApp + , genAppSimplifier + , liftSimpTGenStrategy + , liftSimpTStrategy + , setProgramLoc +) where + +import Control.Applicative +import Control.Monad +import qualified Control.Monad.IO.Class as IO +import Control.Monad.IO.Class (liftIO) +import qualified Control.Monad.IO.Unlift as IO +import qualified Control.Monad.Trans.Reader as RWS hiding (ask, local) +import qualified Control.Monad.Reader as RWS +import Control.Monad.Trans.Maybe (MaybeT(..)) +import Control.Monad.Trans (MonadTrans(..)) + +import Data.Parameterized.Classes +import qualified Data.Parameterized.TraversableFC as TFC + +import qualified What4.Expr.Builder as W4B +import qualified What4.Interface as W4 +import qualified What4.Expr.GroundEval as W4G +import qualified What4.ProgramLoc as W4PL + + +-- | An action for validating a simplification step. +-- After a step is taken, this function is given the original expression as the +-- first argument and the simplified expression as the second argument. +-- This action should check that the original and simplified expressions are equal, +-- and return the simplified expression if they are, or the original expression if they are not, +-- optionally raising any exceptions or warnings in the given monad. +data SimpCheck sym m = SimpCheck + { simpCheckLog :: String -> m () + , runSimpCheck_ :: forall tp. + (forall t fs solver. + sym ~ W4B.ExprBuilder t solver fs => W4G.GroundEvalFn t -> m ()) -> + W4.SymExpr sym tp -> + W4.SymExpr sym tp -> + m (W4.SymExpr sym tp) + } + +instance Monad m => Monoid (SimpCheck sym m) where + mempty = noSimpCheck + +instance Monad m => Semigroup (SimpCheck sym m) where + (SimpCheck l1 f1) <> (SimpCheck l2 f2) = SimpCheck (\msg -> l1 msg >> l2 msg) + (\cet e_orig e_simp -> f1 cet e_orig e_simp >>= \e_simp' -> f2 cet e_orig e_simp') + +noSimpCheck :: Applicative m => SimpCheck sym m +noSimpCheck = SimpCheck (\_ -> pure ()) (\_ _ -> pure) + +unliftSimpCheck :: IO.MonadUnliftIO m => SimpCheck sym m -> m (SimpCheck sym IO) +unliftSimpCheck simp_check = IO.withRunInIO $ \inIO -> do + return $ SimpCheck (\msg -> inIO (simpCheckLog simp_check msg)) (\ce e1 e2 -> inIO (runSimpCheck_ simp_check ((\x -> IO.liftIO $ ce x)) e1 e2)) + + +runSimpCheck :: Monad m => SimpCheck sym m -> W4.SymExpr sym tp -> W4.SymExpr sym tp -> m (W4.SymExpr sym tp) +runSimpCheck simp_check = runSimpCheck_ simp_check (\_ -> pure ()) + +runSimpCheckTrace :: + Monad m => + SimpCheck sym m -> + (forall t fs solver. + sym ~ W4B.ExprBuilder t solver fs => W4G.GroundEvalFn t -> m ()) -> + W4.SymExpr sym tp -> W4.SymExpr sym tp -> m (W4.SymExpr sym tp) +runSimpCheckTrace simp_check f = runSimpCheck_ simp_check f + + +-- | Add a pre-processing step before sending to the solver. +-- This step is assumed to produce an equivalent term, but its +-- result is discarded in the final output. +wrapSimpSolverCheck :: + Monad m => + W4.IsSymExprBuilder sym => + (forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp')) -> + SimpCheck sym m -> + SimpCheck sym m +wrapSimpSolverCheck f (SimpCheck l r) = SimpCheck l $ \tr e_orig e_simp -> do + e_orig' <- f e_orig + e_simp' <- f e_simp + e_simp'' <- r tr e_orig' e_simp' + case testEquality e_simp'' e_simp' of + Just Refl -> return e_simp + Nothing -> return e_orig + +data Sym sym where + Sym :: (W4B.ExprBuilder t solver fs) -> Sym (W4B.ExprBuilder t solver fs) + +newtype SimpT sym m a = SimpT { _unSimpT :: MaybeT (RWS.ReaderT (Sym sym, SimpCheck sym m) m) a } + deriving (Functor, Applicative, Alternative, MonadPlus, Monad, RWS.MonadReader (Sym sym, SimpCheck sym m), IO.MonadIO) + +instance Monad m => MonadFail (SimpT sym m) where + fail msg = do + simpLog ("Failed: " ++ msg) + SimpT $ fail msg + +instance MonadTrans (SimpT sym) where + lift f = SimpT $ lift $ lift f + +runSimpT :: + Monad m => + sym ~ (W4B.ExprBuilder t solver fs) => + sym -> + SimpCheck sym m -> + SimpT sym m a -> m (Maybe a) +runSimpT sym simp_check (SimpT f) = RWS.runReaderT (runMaybeT f) (Sym sym, simp_check) + +simpMaybe :: Monad m => Maybe a -> SimpT sym m a +simpMaybe (Just a) = return a +simpMaybe Nothing = fail "" + +withSym :: Monad m => (forall t solver fs. sym ~ (W4B.ExprBuilder t solver fs) => sym -> SimpT sym m a) -> SimpT sym m a +withSym f = do + (Sym sym, _) <- RWS.ask + f sym + +appAlts :: W4B.App (W4.SymExpr sym) tp -> [W4B.App (W4.SymExpr sym) tp] +appAlts app = [app] ++ case app of + W4B.BaseEq r e1 e2 -> [W4B.BaseEq r e2 e1] + _ -> [] + +asApp :: Monad m => W4.SymExpr sym tp -> SimpT sym m (W4B.App (W4.SymExpr sym) tp) +asApp e = withSym $ \_ -> simpMaybe $ W4B.asApp e + + +tryAll :: Alternative m => [a] -> (a -> m b) -> m b +tryAll (a : as) f = f a <|> tryAll as f +tryAll [] _f = empty + +simpLog :: Monad m => String -> SimpT sym m () +simpLog msg = do + (_, simp_check) <- RWS.ask + lift $ simpCheckLog simp_check msg + +simpCheck :: Monad m => W4.SymExpr sym tp -> W4.SymExpr sym tp -> SimpT sym m (W4.SymExpr sym tp) +simpCheck orig_expr simp_expr = do + (_, simp_check) <- RWS.ask + lift $ runSimpCheck simp_check orig_expr simp_expr + +simpCheckTraced :: Monad m => + W4.SymExpr sym tp -> W4.SymExpr sym tp -> + (forall t fs solver. + sym ~ W4B.ExprBuilder t solver fs => W4G.GroundEvalFn t -> m ()) -> + SimpT sym m (W4.SymExpr sym tp) +simpCheckTraced orig_expr simp_expr tr = do + (_, simp_check) <- RWS.ask + lift $ runSimpCheckTrace simp_check tr orig_expr simp_expr + +-- | A thin wrapper around a monadic expression ('W4.SymExpr') transformer. +data Simplifier sym m = + Simplifier { runSimplifier :: forall tp. W4.SymExpr sym tp -> m (W4.SymExpr sym tp) } + +instance Monad m => Monoid (Simplifier sym m) where + mempty = Simplifier pure + +instance Monad m => Semigroup (Simplifier sym m) where + (Simplifier f1) <> (Simplifier f2) = Simplifier $ \e -> f1 e >>= f2 + +-- | A 'SimpStrategy' is a function that produces a 'Simplifier' in the given +-- monad 'm'. This allows the strategy to first perform any required initialization +-- (e.g. creating fresh term caches) before it is applied. Subsequent applications +-- of the resulting 'Simplifier' will then re-use the initialized data (e.g. using +-- cached results). +-- Importantly, in composite strategies all initialization occurs before any +-- simplification. +data SimpStrategy sym m where + SimpStrategy :: + { getStrategy :: + sym -> + SimpCheck sym m -> + m (Simplifier sym m) + } -> SimpStrategy sym m + +instance Monad m => Monoid (SimpStrategy sym m) where + mempty = SimpStrategy (\_ _ -> return mempty) + +instance Monad m => Semigroup (SimpStrategy sym m) where + (SimpStrategy f1) <> (SimpStrategy f2) = SimpStrategy $ \sym check -> do + simp_f1 <- f1 sym check + simp_f2 <- f2 sym check + return $ simp_f1 <> simp_f2 + +mkSimpleStrategy :: + forall sym m. + Monad m => + (forall tp. sym -> W4.SymExpr sym tp -> m (W4.SymExpr sym tp)) -> SimpStrategy sym m +mkSimpleStrategy f = SimpStrategy $ \sym _ -> return $ Simplifier $ \e -> f sym e + +joinStrategy :: + Monad m => + m (SimpStrategy sym m) -> + SimpStrategy sym m +joinStrategy f = SimpStrategy $ \sym check -> do + strat <- f + getStrategy strat sym check + +mkSimplifier :: + Monad m => + sym -> + SimpCheck sym m -> + SimpStrategy sym m -> + m (Simplifier sym m) +mkSimplifier sym simp_check strat = do + Simplifier strat' <- getStrategy strat sym simp_check + return $ Simplifier strat' + +setProgramLoc :: + forall m sym t solver fs tp. + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + sym -> + W4.SymExpr sym tp -> + m () +setProgramLoc sym e = case W4PL.plSourceLoc (W4B.exprLoc e) of + W4PL.InternalPos -> return () + _ -> liftIO $ W4.setCurrentProgramLoc sym (W4B.exprLoc e) + +-- Create a 'Simplifier' that recurses into the sub-term structure of +-- expressions using the given 'W4B.App' and 'W4B.NonceApp' transformers. +-- For each subterm 'e', if the corresponding 'SimpT' operation succeeds +-- the result is used (i.e. replaced in the term) without further simplification. Otherwise, the simplification +-- traverses further into the sub-terms of 'e'. +-- The given operations are therefore responsible for handling any recursive +-- application of this simplification. +-- See 'liftSimpTGenStrategy', where each transformer is passed a recursor function. +genAppSimplifier :: + forall sym m t solver fs. + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + sym -> + W4B.IdxCache t (W4B.Expr t) -> + SimpCheck sym m -> + (forall tp'. W4B.App (W4B.Expr t) tp' -> SimpT sym m (W4.SymExpr sym tp')) {- ^ app simplification -} -> + (forall tp'. W4B.NonceApp t (W4B.Expr t) tp' -> SimpT sym m (W4.SymExpr sym tp')) {- ^ nonce app simplification -} -> + Simplifier sym m +genAppSimplifier sym cache simp_check simp_app simp_nonce_app = + let + else_ :: forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp') + else_ e = do + e' <- case e of + W4B.AppExpr a0 -> do + a0' <- W4B.traverseApp go (W4B.appExprApp a0) + if (W4B.appExprApp a0) == a0' then return e + else (liftIO $ W4B.sbMakeExpr sym a0') >>= go + W4B.NonceAppExpr a0 -> do + a0' <- TFC.traverseFC go (W4B.nonceExprApp a0) + if (W4B.nonceExprApp a0) == a0' then return e + else (liftIO $ W4B.sbNonceExpr sym a0') >>= go + _ -> return e + runSimpCheck simp_check e e' + + go :: forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp') + go e = W4B.idxCacheEval cache e $ do + setProgramLoc sym e + case e of + W4B.AppExpr a0 -> runSimpT sym simp_check (simp_app (W4B.appExprApp a0)) >>= \case + Just e' -> runSimpCheck simp_check e e' + Nothing -> else_ e + W4B.NonceAppExpr a0 -> runSimpT sym simp_check (simp_nonce_app (W4B.nonceExprApp a0)) >>= \case + Just e' -> runSimpCheck simp_check e e' + Nothing -> else_ e + _ -> else_ e + in Simplifier go + +simplifyGenApp :: + forall sym m t solver fs tp. + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + sym -> + W4B.IdxCache t (W4B.Expr t) -> + SimpCheck sym m {- ^ double-check simplification step -} -> + (forall tp'. W4B.App (W4B.Expr t) tp' -> m (Maybe (W4.SymExpr sym tp'))) {- ^ app simplification -} -> + (forall tp'. W4B.NonceApp t (W4B.Expr t) tp' -> m (Maybe (W4.SymExpr sym tp'))) {- ^ nonce app simplification -} -> + W4.SymExpr sym tp -> + m (W4.SymExpr sym tp) +simplifyGenApp sym cache check f_app f_nonce_app e = do + let s = genAppSimplifier sym cache check (\app -> simpMaybe =<< lift (f_app app)) (\nonce_app -> simpMaybe =<< lift (f_nonce_app nonce_app)) + runSimplifier s e + +simplifyApp :: + forall sym m t solver fs tp. + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + sym -> + W4B.IdxCache t (W4B.Expr t) -> + SimpCheck sym m {- ^ double-check simplification step -} -> + (forall tp'. W4B.App (W4B.Expr t) tp' -> m (Maybe (W4.SymExpr sym tp'))) {- ^ app simplification -} -> + W4.SymExpr sym tp -> + m (W4.SymExpr sym tp) +simplifyApp sym cache simp_check simp_app e = simplifyGenApp sym cache simp_check simp_app (\_ -> return Nothing) e + + +-- | Lift a pair of 'W4B.App' and 'W4B.NonceApp' transformers into a 'SimpStrategy' by recursively applying them +-- to the sub-terms of an expression. For each subterm 'e', if the corresponding 'SimpT' operation succeeds +-- the result is used without further simplification. Otherwise, the simplification +-- traverses further into the sub-terms of 'e'. +-- The first argument to the given function is the recursive application of this +-- strategy, which can be used to selectively simplify sub-terms during transformation. +liftSimpTGenStrategy :: + forall m sym t solver fs. + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + (forall tp''. (forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp')) -> + W4B.App (W4.SymExpr sym) tp'' -> SimpT sym m (W4.SymExpr sym tp'')) -> + (forall tp''. (forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp')) -> + W4B.NonceApp t (W4B.Expr t) tp'' -> SimpT sym m (W4.SymExpr sym tp'')) -> + SimpStrategy sym m +liftSimpTGenStrategy f_app f_nonce_app = SimpStrategy $ \sym check -> do + cache <- W4B.newIdxCache + let + go :: Simplifier sym m + go = genAppSimplifier sym cache check (\app -> f_app (\e' -> runSimplifier go e') app) (\nonce_app -> f_nonce_app (\e' -> runSimplifier go e') nonce_app) + return go + +-- | Specialized form of 'liftSimpTGenStrategy' that only takes an 'W4B.App' transformer. +liftSimpTStrategy :: + forall m sym t solver fs. + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + (forall tp''. (forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp')) -> + W4B.App (W4.SymExpr sym) tp'' -> SimpT sym m (W4.SymExpr sym tp'')) -> + SimpStrategy sym m +liftSimpTStrategy f_app = liftSimpTGenStrategy f_app (\_ _ -> fail "") \ No newline at end of file diff --git a/src/What4/Simplify/Bitvector.hs b/src/What4/Simplify/Bitvector.hs new file mode 100644 index 00000000..0f63da95 --- /dev/null +++ b/src/What4/Simplify/Bitvector.hs @@ -0,0 +1,681 @@ +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE AllowAmbiguousTypes #-} + +module What4.Simplify.Bitvector ( + memReadPrettySimplify + , bvPrettySimplify + , collapseBVOps + , asSimpleSum +) where + +import GHC.TypeNats +import Control.Lens ( (.~), (&), (^.) ) + +import Control.Monad +import Control.Applicative +import qualified Control.Monad.IO.Class as IO +import Control.Monad.Trans + +import Data.List ( permutations ) + +import Data.Parameterized.Classes +import Data.Parameterized.Some +import qualified Data.Parameterized.Vector as V + +import What4.Simplify +import qualified What4.Interface as W4 +import qualified What4.Expr.Builder as W4B +import qualified Data.Parameterized.Context as Ctx +import qualified Data.Parameterized.TraversableFC as TFC +import qualified What4.Expr.BoolMap as BM +import Data.List.NonEmpty +import qualified What4.Concrete as W4C +import qualified Data.BitVector.Sized as BVS +import qualified What4.SemiRing as SR +import qualified What4.Expr.WeightedSum as WSum + + +-- ====================================== +-- Detecting memory-like array accesses and wrapping them in defined +-- functions. Specifically we look for a concatenation of bitvectors +-- from an array where the indexes (addresses) are sequential. +-- i.e. concat (select arr 0) (select arr 1) --> readLE2 0 + +data BVConcatResult sym w_outer w_inner where + BVConcatResult :: forall n sym w_inner. (1 <= n, 1 <= w_inner) => + W4.NatRepr n -> W4.NatRepr w_inner -> V.Vector n (W4.SymBV sym w_inner) -> BVConcatResult sym (n * w_inner) w_inner + +bvConcatResultWidth :: BVConcatResult sym w_outer w_inner -> W4.NatRepr w_inner +bvConcatResultWidth = \case + BVConcatResult _ w _ -> w + +combineBVConcatResults :: + BVConcatResult sym w_outer1 w_inner -> + BVConcatResult sym w_outer2 w_inner -> + BVConcatResult sym (w_outer1 + w_outer2) w_inner +combineBVConcatResults (BVConcatResult lhs_n_inner lhs_w_inner lhs_bvs) (BVConcatResult rhs_n_inner _ rhs_bvs) + | Refl <- W4.addMulDistribRight lhs_n_inner rhs_n_inner lhs_w_inner + , W4.LeqProof <- W4.leqAddPos lhs_n_inner rhs_n_inner + = BVConcatResult (W4.addNat lhs_n_inner rhs_n_inner) lhs_w_inner (V.append lhs_bvs rhs_bvs) + + +-- Could be implemented as a pure/total function, since +-- there is always a base case of no concatenations +getBVConcats :: + Monad m => + W4.SymBV sym w -> + SimpT sym m (Some (BVConcatResult sym w)) +getBVConcats e = withSym $ \_ -> do + W4B.BVConcat _ lhs rhs <- asApp e + getBVConcatsApp lhs rhs + <|> do + W4.BaseBVRepr w <- return $ W4.exprType e + return $ Some $ BVConcatResult (W4.knownNat @1) w (V.singleton e) + +getBVConcatsApp :: + Monad m => + W4.SymBV sym w1 -> + W4.SymBV sym w2 -> + SimpT sym m (Some (BVConcatResult sym (w1 + w2))) +getBVConcatsApp lhs rhs = do + Some lhs_result <- getBVConcats lhs + Some rhs_result <- getBVConcats rhs + lhs_inner_w <- return $ bvConcatResultWidth lhs_result + rhs_inner_w <- return $ bvConcatResultWidth rhs_result + Refl <- simpMaybe $ testEquality lhs_inner_w rhs_inner_w + return $ Some $ combineBVConcatResults lhs_result rhs_result + +asSequential :: + forall m sym w n. + IO.MonadIO m => + Bool -> + W4.BoundVar sym (W4.BaseBVType w) -> + V.Vector n (W4.SymBV sym w) -> + SimpT sym m (V.Vector n (W4.SymBV sym w)) +asSequential be var v_outer = go 0 v_outer + where + go :: forall n_. Integer -> V.Vector n_ (W4.SymBV sym w) -> SimpT sym m (V.Vector n_ (W4.SymBV sym w)) + go offset v = withSym $ \sym -> do + + let (x1, rest) = next_bv v + W4.BaseBVRepr w <- return $ W4.exprType x1 + offset_bv <- IO.liftIO $ W4.bvLit sym w (BVS.mkBV w offset) + var_as_offset <- IO.liftIO $ W4.bvAdd sym (W4.varExpr sym var) offset_bv + x1_as_offset <- IO.liftIO $ W4.bvAdd sym first_bv offset_bv + check <- IO.liftIO $ W4.isEq sym x1_as_offset x1 + case W4.asConstantPred check of + Just True -> case rest of + Left Refl -> return $ V.singleton var_as_offset + Right v' -> do + v_result <- go (offset + 1) v' + W4.LeqProof <- return $ V.nonEmpty v + Refl <- return $ W4.minusPlusCancel (V.length v) (W4.knownNat @1) + return $ mk_v var_as_offset v_result + _ -> fail $ "not sequential:" ++ show x1 ++ "vs. " ++ show x1_as_offset + + mk_v :: tp -> V.Vector n_ tp -> V.Vector (n_+1) tp + mk_v x v_ = case be of + True -> V.cons x v_ + False -> V.snoc v_ x + + next_bv :: forall n_ tp. V.Vector n_ tp -> (tp, Either (n_ :~: 1) (V.Vector (n_-1) tp)) + next_bv v_ = case be of + True -> V.uncons v_ + False -> V.unsnoc v_ + + first_bv = fst (next_bv v_outer) + +concatBVs :: + W4.IsExprBuilder sym => + sym -> + V.Vector n (W4.SymBV sym w) -> + IO (W4.SymBV sym (n*w)) +concatBVs sym v = do + let (x1, rest) = V.uncons v + W4.BaseBVRepr w <- return $ W4.exprType x1 + case rest of + Left Refl -> return x1 + Right v' -> do + W4.LeqProof <- return $ V.nonEmpty v + W4.LeqProof <- return $ V.nonEmpty v' + bv' <- concatBVs sym v' + W4.LeqProof <- return $ W4.leqMulPos (V.length v') w + Refl <- return $ W4.lemmaMul w (V.length v) + W4.bvConcat sym x1 bv' + +memReadPrettySimplifyApp :: + forall sym m tp. + IO.MonadIO m => + (forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp')) -> + W4B.App (W4.SymExpr sym) tp -> + SimpT sym m (W4.SymExpr sym tp) +memReadPrettySimplifyApp rec app = withSym $ \sym -> do + W4B.BVConcat _ lhs rhs <- return app + Some (BVConcatResult _ _ inner_bvs) <- getBVConcatsApp lhs rhs + let (fst_bv, _) = V.uncons inner_bvs + W4B.SelectArray _ arr (Ctx.Empty Ctx.:> fst_addr) <- asApp fst_bv + W4.BaseBVRepr (addr_w :: W4.NatRepr addr_w) <- return $ W4.exprType fst_addr + (addrs :: V.Vector n (W4.SymBV sym addr_w)) <- forM inner_bvs $ \inner_bv -> do + W4B.SelectArray _ arr' (Ctx.Empty Ctx.:> inner_addr) <- asApp inner_bv + Refl <- simpMaybe $ testEquality arr arr' + return inner_addr + tryAll [True,False] $ \b -> do + addr_var <- IO.liftIO $ W4.freshBoundVar sym W4.emptySymbol (W4.BaseBVRepr addr_w) + addrs_seq <- asSequential b addr_var addrs + let index_addr = if b then fst (V.uncons addrs) else fst (V.unsnoc addrs) + arr_var <- IO.liftIO $ W4.freshBoundVar sym W4.emptySymbol (W4.exprType arr) + vals <- IO.liftIO $ mapM (\addr_ -> W4.arrayLookup sym (W4.varExpr sym arr_var) (Ctx.empty Ctx.:> addr_)) addrs_seq + new_val <- IO.liftIO $ concatBVs sym vals + let nm = (if b then "readBE" else "readLE") ++ show (W4.natValue (V.length addrs)) + fn <- IO.liftIO $ W4.definedFn sym (W4.safeSymbol nm) (Ctx.Empty Ctx.:> arr_var Ctx.:> addr_var) new_val W4.NeverUnfold + arr' <- lift $ rec arr + index_addr' <- lift $ rec index_addr + IO.liftIO $ W4.applySymFn sym fn (Ctx.empty Ctx.:> arr' Ctx.:> index_addr') + +memReadPrettySimplify :: + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + SimpStrategy sym m +memReadPrettySimplify = liftSimpTStrategy memReadPrettySimplifyApp + + +-- ====================================== +-- Bitvector simplification strategies + +-- | Simplification rules that are for display purposes only, +-- as they can make terms more difficult for the solver to handle. +-- TODO: if we add implicit function unfolding to all solver calls then +-- we can safely apply these unconditionally +-- TODO: these were lifted directly from term forms that appear in the ARM semantics when comparing values, +-- but can likely be generalized. +bvPrettySimplifyApp :: + forall sym m tp. + IO.MonadIO m => + (forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp')) -> + W4B.App (W4.SymExpr sym) tp -> + SimpT sym m (W4.SymExpr sym tp) +bvPrettySimplifyApp rec app = withSym $ \sym -> tryAll (appAlts @sym app) $ \app' -> do + W4B.BaseEq _ (isUInt @sym 0 -> True) (W4B.asApp -> Just (W4B.BVSelect idx n bv_)) <- return app' + W4.BaseBVRepr bv_w <- return $ W4.exprType bv_ + Refl <- simpMaybe $ testEquality (W4.knownNat @1) n + let bv_dec = W4.decNat bv_w + Refl <- simpMaybe $ testEquality bv_dec idx + let bv_min_i = BVS.asSigned bv_w $ BVS.minSigned bv_w + bv_min_sym <- IO.liftIO $ W4.intLit sym bv_min_i + bv <- lift $ rec bv_ + ss@[_,_] <- asSimpleSumM bv + tryAll (permutations ss) $ \ss' -> do + [bv_base, (W4.asConcrete -> Just (W4C.ConcreteBV _ bv_c))] <- return ss' + let bv_c_i = BVS.asSigned bv_w bv_c + guard $ bv_c_i < 0 + let le = GE True (CompatibleBVToInt2 bv_w) + bv_c_i_sym <- IO.liftIO $ W4.intLit sym (-bv_c_i) + bounded <- IO.liftIO $ mvBVOpFn sym le bv_base bv_c_i_sym + let sub = BVSub True (CompatibleBVToInt2 bv_w) + new_sum <- IO.liftIO $ mvBVOpFn sym sub bv_base bv_c_i_sym + let lt = COrd LT True CompatibleInts + overflowed <- IO.liftIO $ mvBVOpFn sym lt new_sum bv_min_sym + IO.liftIO $ W4.orPred sym bounded overflowed + <|> do + simpLog $ "check app: " ++ show app' + W4B.BaseEq _ bv_sum1_ (W4B.asApp -> Just (W4B.BVSext sext_w bv_sum2)) <- return app' + + simpLog $ "as_eq: " ++ show bv_sum1_ ++ " AND " ++ show bv_sum2 + W4.BaseBVRepr bv_sum2_w <- return $ W4.exprType bv_sum2 + Refl <- simpMaybe $ testEquality (W4.knownNat @65) sext_w + + bv_sum2_sext <- transformSum bv_sum2 sext_w (\bv_ -> IO.liftIO $ W4.bvSext sym sext_w bv_) + sums_eq <- IO.liftIO $ W4.isEq sym bv_sum1_ bv_sum2_sext + simpLog $ "sums_eq: " ++ show sums_eq + Just True <- return $ W4.asConstantPred sums_eq + bv_sum1 <- lift $ rec bv_sum1_ + + let bv_min_i = BVS.asSigned bv_sum2_w $ BVS.minSigned bv_sum2_w + let bv_max_i = BVS.asSigned bv_sum2_w $ BVS.maxSigned bv_sum2_w + + ss@[_,_] <- asSimpleSumM bv_sum1 + simpLog $ "simple sum:" ++ show ss + tryAll (permutations ss) $ \ss' -> do + [(W4B.asApp -> Just (W4B.BVSext _ bv_s1)), (W4.asConcrete -> Just (W4C.ConcreteBV _ bv_c))] <- return ss' + let bv_c_i = BVS.asSigned sext_w bv_c + + bv_min_sym <- IO.liftIO $ W4.intLit sym bv_min_i + bv_max_sym <- IO.liftIO $ W4.intLit sym bv_max_i + e <- case bv_c_i < 0 of + True -> do + let sub = BVSub True (CompatibleBVToInt2 (W4.bvWidth bv_s1)) + bv_c_i_sym <- IO.liftIO $ W4.intLit sym (-bv_c_i) + IO.liftIO $ mvBVOpFn sym sub bv_s1 bv_c_i_sym + False -> do + let add = BVAdd True (CompatibleBVToInt2 (W4.bvWidth bv_s1)) + bv_c_i_sym <- IO.liftIO $ W4.intLit sym bv_c_i + IO.liftIO $ mvBVOpFn sym add bv_s1 bv_c_i_sym + + let le = LE True CompatibleInts + let ge = GE True CompatibleInts + upper_bound <- IO.liftIO $ mvBVOpFn sym le e bv_max_sym + lower_bound <- IO.liftIO $ mvBVOpFn sym ge e bv_min_sym + p <- IO.liftIO $ W4.andPred sym upper_bound lower_bound + return p + +bvPrettySimplify :: + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + SimpStrategy sym m +bvPrettySimplify = liftSimpTStrategy bvPrettySimplifyApp + +isUInt :: W4.IsExpr (W4.SymExpr sym) => Integer -> W4.SymExpr sym tp -> Bool +isUInt i e = case W4.asConcrete e of + Just (W4C.ConcreteBV _w bv_c) | BVS.asUnsigned bv_c == i -> True + Just (W4C.ConcreteInteger int_c) | int_c == i -> True + _ -> False + +transformSum :: IO.MonadIO m => + 1 <= w' => + (W4.SymBV sym w) -> + W4.NatRepr w' -> + (W4.SymBV sym w -> SimpT sym m (W4.SymBV sym w')) -> + SimpT sym m (W4.SymBV sym w') +transformSum bv w' f = withSym $ \sym -> do + W4B.SemiRingSum s <- asApp bv + SR.SemiRingBVRepr baserepr w <- return $ WSum.sumRepr s + let repr = SR.SemiRingBVRepr baserepr w' + let f_lit c = do + c' <- IO.liftIO $ W4.bvLit sym w c + Just (W4C.ConcreteBV _ c'') <- W4.asConcrete <$> f c' + return c'' + s' <- WSum.transformSum repr f_lit f s + IO.liftIO $ WSum.evalM (W4.bvAdd sym) (\c x -> W4.bvMul sym x =<< W4.bvLit sym w' c) (W4.bvLit sym w') s' + +asSimpleSum :: + forall sym sr w. + sym -> + W4.NatRepr w -> + WSum.WeightedSum (W4.SymExpr sym) (SR.SemiRingBV sr w) -> + Maybe ([W4.SymBV sym w], BVS.BV w) +asSimpleSum _ _ ws = do + terms <- WSum.evalM + (\x y -> return $ x ++ y) + (\c e -> case c == one of {True -> return [e]; False -> fail ""}) + (\c -> case c == zero of { True -> return []; False -> fail ""}) + (ws & WSum.sumOffset .~ zero) + return $ (terms, ws ^. WSum.sumOffset ) + where + one :: BVS.BV w + one = SR.one (WSum.sumRepr ws) + + zero :: BVS.BV w + zero = SR.zero (WSum.sumRepr ws) + +asSimpleSumM :: IO.MonadIO m => W4.SymBV sym w -> SimpT sym m [W4.SymBV sym w] +asSimpleSumM bv = withSym $ \sym -> do + W4B.SemiRingSum s <- asApp bv + W4.BaseBVRepr w <- return $ W4.exprType bv + SR.SemiRingBVRepr SR.BVArithRepr _ <- return $ WSum.sumRepr s + (bvs, c) <- simpMaybe $ asSimpleSum sym w s + const_expr <- IO.liftIO $ W4.bvLit sym w c + return $ const_expr:bvs + + +-- ====================================== +-- Wrappers around bitvector operations to help with pretty-printing. +-- Specifically they hide type conversions and add explicit representations +-- for comparisons/operations that are otherwise simplified into a normal +-- form by the What4 expression builder. +-- ====================================== + +data CompatibleTypes tp1 tp2 tp3 where + CompatibleInts :: CompatibleTypes W4.BaseIntegerType W4.BaseIntegerType W4.BaseIntegerType + CompatibleBVs :: 1 <= w => W4.NatRepr w -> CompatibleTypes (W4.BaseBVType w) (W4.BaseBVType w) (W4.BaseBVType w) + CompatibleBVsExt1 :: (1 <= w1, w1+1 <= w2, 1 <= w2) => W4.NatRepr w1 -> W4.NatRepr w2 -> CompatibleTypes (W4.BaseBVType w1) (W4.BaseBVType w2) (W4.BaseBVType w2) + CompatibleBVsExt2 :: (1 <= w2, w2+1 <= w1, 1 <= w1) => W4.NatRepr w1 -> W4.NatRepr w2 -> CompatibleTypes (W4.BaseBVType w1) (W4.BaseBVType w2) (W4.BaseBVType w1) + CompatibleBVToInt1 :: 1 <= w => W4.NatRepr w -> CompatibleTypes W4.BaseIntegerType (W4.BaseBVType w) W4.BaseIntegerType + CompatibleBVToInt2 :: 1 <= w => W4.NatRepr w -> CompatibleTypes (W4.BaseBVType w) W4.BaseIntegerType W4.BaseIntegerType + +compatibleTypeRepr :: CompatibleTypes tp1 tp2 tp3 -> W4.BaseTypeRepr tp3 +compatibleTypeRepr = \case + CompatibleInts -> W4.BaseIntegerRepr + CompatibleBVs w -> W4.BaseBVRepr w + CompatibleBVsExt1 _ w2 -> W4.BaseBVRepr w2 + CompatibleBVsExt2 w1 _ -> W4.BaseBVRepr w1 + CompatibleBVToInt1{} -> W4.BaseIntegerRepr + CompatibleBVToInt2{} -> W4.BaseIntegerRepr + +instance Show (CompatibleTypes tp1 tp2 tp3) where + show ct = case ct of + CompatibleInts -> go showInt showInt + CompatibleBVs w -> go (showBV w) (showBV w) + CompatibleBVsExt1 w1 w2 -> go (showBV w1) (showBV w2) + CompatibleBVsExt2 w1 w2 -> go (showBV w1) (showBV w2) + CompatibleBVToInt1 w -> go showInt (showBV w) + CompatibleBVToInt2 w -> go (showBV w) showInt + where + showInt = showTp W4.BaseIntegerRepr + showBV w = showTp (W4.BaseBVRepr w) + + showTp :: forall tp'. W4.BaseTypeRepr tp' -> String + showTp = \case + W4.BaseIntegerRepr -> "int" + W4.BaseBVRepr w -> "bv" ++ show (W4.natValue w) + tp -> show tp + + parens :: String -> String -> String + parens s1 s2 = "(" ++ s1 ++ "," ++ s2 ++ ")" + + go :: String -> String -> String + go s1 s2 = parens s1 s2 ++ "→" ++ showTp (compatibleTypeRepr ct) + +compatibleTypes :: + W4.BaseTypeRepr t1 -> + W4.BaseTypeRepr t2 -> + Maybe (Some (CompatibleTypes t1 t2)) +compatibleTypes t1 t2 = case (t1, t2) of + (W4.BaseIntegerRepr, W4.BaseIntegerRepr) -> Just $ Some CompatibleInts + (W4.BaseBVRepr w1, W4.BaseBVRepr w2) -> Just $ case W4.testNatCases w1 w2 of + W4.NatCaseLT W4.LeqProof -> Some $ CompatibleBVsExt1 w1 w2 + W4.NatCaseGT W4.LeqProof -> Some $ CompatibleBVsExt2 w1 w2 + W4.NatCaseEQ -> Some $ CompatibleBVs w1 + (W4.BaseIntegerRepr, W4.BaseBVRepr w) -> Just $ Some $ CompatibleBVToInt1 w + (W4.BaseBVRepr w, W4.BaseIntegerRepr) -> Just $ Some $ CompatibleBVToInt2 w + _ -> Nothing + +-- Turn two incompatible operands into the same type using +-- conversion operations +mkOperands :: + forall sym tp1 tp2 tp3. + W4.IsExprBuilder sym => + sym -> + Bool {- signed comparison -} -> + CompatibleTypes tp1 tp2 tp3 {- proof that the types are compatible for comparison -} -> + W4.SymExpr sym tp1 -> + W4.SymExpr sym tp2 -> + IO (W4.SymExpr sym tp3, W4.SymExpr sym tp3) +mkOperands sym b ct e1 e2 = case ct of + CompatibleBVsExt1 _ w2 | b -> do + e1' <- W4.bvSext sym w2 e1 + return $ (e1', e2) + CompatibleBVsExt1 _ w2 | False <- b -> do + e1' <- W4.bvZext sym w2 e1 + return $ (e1', e2) + CompatibleBVsExt2 w1 _ | b -> do + e2' <- W4.bvSext sym w1 e2 + return $ (e1, e2') + CompatibleBVsExt2 w1 _ | False <- b -> do + e2' <- W4.bvZext sym w1 e2 + return $ (e1, e2') + CompatibleBVToInt1{} -> case b of + True -> do + e2' <- W4.sbvToInteger sym e2 + return $ (e1, e2') + False -> do + e2' <- W4.bvToInteger sym e2 + return $ (e1, e2') + CompatibleBVToInt2{} -> case b of + True -> do + e1' <- W4.sbvToInteger sym e1 + return $ (e1', e2) + False -> do + e1' <- W4.bvToInteger sym e1 + return $ (e1', e2) + CompatibleInts -> return (e1,e2) + CompatibleBVs{} -> return (e1,e2) + +data BVOp tp1 tp2 tp3 where + COrd :: forall tp1 tp2 tp3. Ordering -> Bool -> CompatibleTypes tp1 tp2 tp3 -> BVOp tp1 tp2 W4.BaseBoolType + NEQ :: Bool -> CompatibleTypes tp1 tp2 tp3 -> BVOp tp1 tp2 W4.BaseBoolType + LE :: Bool -> CompatibleTypes tp1 tp2 tp3 -> BVOp tp1 tp2 W4.BaseBoolType + GE :: Bool -> CompatibleTypes tp1 tp2 tp3 -> BVOp tp1 tp2 W4.BaseBoolType + BVAdd :: Bool -> CompatibleTypes tp1 tp2 tp3 -> BVOp tp1 tp2 tp3 + BVSub :: Bool -> CompatibleTypes tp1 tp2 tp3 -> BVOp tp1 tp2 tp3 + BVMul :: Bool -> CompatibleTypes tp1 tp2 tp3 -> BVOp tp1 tp2 tp3 + + +instance Show (BVOp tp1 tp2 tp3) where + show bvop | Some ct <- getCompatibleTypes bvop = simpleShowBVOp bvop ++ show ct + +parseBVOp :: + String -> + W4.BaseTypeRepr tp1 -> + W4.BaseTypeRepr tp2 -> + W4.BaseTypeRepr tp3 -> + Maybe (BVOp tp1 tp2 tp3) +parseBVOp nm tp1 tp2 tp3 = case compatibleTypes tp1 tp2 of + Just (Some ct) -> + case tp3 of + W4.BaseBoolRepr -> case nm of + "LTs" -> Just $ COrd LT True ct + "GTs" -> Just $ COrd GT True ct + "EQs" -> Just $ COrd EQ True ct + "LTu" -> Just $ COrd LT False ct + "GTu" -> Just $ COrd GT False ct + "EQu" -> Just $ COrd EQ False ct + "LEs" -> Just $ LE True ct + "GEs" -> Just $ GE True ct + "LEu" -> Just $ LE False ct + "GEu" -> Just $ GE False ct + "NEQu" -> Just $ NEQ False ct + "NEQs" -> Just $ NEQ True ct + _ -> Nothing + W4.BaseBVRepr{} | Just Refl <- testEquality (compatibleTypeRepr ct) tp3 -> case nm of + "ADDs" -> Just $ BVAdd True ct + "ADDu" -> Just $ BVAdd False ct + "MULs" -> Just $ BVMul True ct + "MULu" -> Just $ BVMul False ct + "SUBs" -> Just $ BVSub True ct + "SUBu" -> Just $ BVSub False ct + _ -> Nothing + _ -> Nothing + _ -> Nothing + +isSignedOp :: BVOp tp1 tp2 tp3 -> Bool +isSignedOp = \case + COrd LT b _ -> b + COrd GT b _ -> b + COrd EQ b _ -> b + NEQ b _ -> b + LE b _ -> b + GE b _ -> b + BVAdd b _ -> b + BVSub b _ -> b + BVMul b _ -> b + +simpleShowBVOp :: BVOp tp1 tp2 tp3 -> String +simpleShowBVOp bvop = case bvop of + COrd LT _ _ -> "LT" ++ suf + COrd EQ _ _ -> "EQ" ++ suf + COrd GT _ _ -> "GT" ++ suf + NEQ _ _ -> "NEQ" ++ suf + LE _ _ -> "LE" ++ suf + GE _ _ -> "GE" ++ suf + BVAdd _ _ -> "ADD" ++ suf + BVSub _ _ -> "SUB" ++ suf + BVMul _ _ -> "MUL" ++ suf + where + suf :: String + suf = case (isSignedOp bvop) of + True -> "s" + False -> "u" + +getCompatibleTypes :: BVOp tp1 tp2 tp3 -> Some (CompatibleTypes tp1 tp2) +getCompatibleTypes = \case + COrd _ _ ct -> Some ct + NEQ _ ct -> Some ct + LE _ ct -> Some ct + GE _ ct -> Some ct + BVAdd _ ct -> Some ct + BVSub _ ct -> Some ct + BVMul _ ct -> Some ct + +notBVOp :: BVOp tp1 tp2 W4.BaseBoolType -> BVOp tp1 tp2 W4.BaseBoolType +notBVOp = \case + COrd LT b ct -> GE b ct + COrd GT b ct -> LE b ct + COrd EQ b ct -> NEQ b ct + NEQ b ct -> COrd EQ b ct + LE b ct -> COrd GT b ct + GE b ct -> COrd LT b ct + -- arithmetic ops can't occur with bool result type + BVAdd _ ct -> case ct of + BVSub _ ct -> case ct of + BVMul _ ct -> case ct of + +appBVOp :: + W4.IsExprBuilder sym => + sym -> + BVOp tp1 tp2 tp3 -> + CompatibleTypes tp1 tp2 tpOp -> + (sym -> W4.SymExpr sym tpOp -> W4.SymExpr sym tpOp -> IO (W4.SymExpr sym tp3)) -> + W4.SymExpr sym tp1 -> + W4.SymExpr sym tp2 -> + IO (W4.SymExpr sym tp3) +appBVOp sym bvop ct f e1 e2 = do + (e1', e2') <- mkOperands sym (isSignedOp bvop) ct e1 e2 + f sym e1' e2' + +mkBVOp :: + forall sym tp1 tp2 tp3. + W4.IsExprBuilder sym => + sym -> + BVOp tp1 tp2 tp3 -> + W4.SymExpr sym tp1 -> + W4.SymExpr sym tp2 -> + IO (W4.SymExpr sym tp3) +mkBVOp sym bvop e1 e2 = do + case bvop of + COrd EQ _ ct -> appBVOp sym bvop ct W4.isEq e1 e2 + NEQ _ ct -> appBVOp sym bvop ct (\sym' e1' e2' -> W4.isEq sym' e1' e2' >>= W4.notPred sym') e1 e2 + BVAdd _ ct -> case compatibleTypeRepr ct of + W4.BaseBVRepr{} -> appBVOp sym bvop ct W4.bvAdd e1 e2 + W4.BaseIntegerRepr{} -> appBVOp sym bvop ct W4.intAdd e1 e2 + _ -> case ct of + + BVSub _ ct -> case compatibleTypeRepr ct of + W4.BaseBVRepr{} -> appBVOp sym bvop ct W4.bvSub e1 e2 + W4.BaseIntegerRepr{} -> appBVOp sym bvop ct W4.intSub e1 e2 + _ -> case ct of + + BVMul _ ct -> case compatibleTypeRepr ct of + W4.BaseBVRepr{} -> appBVOp sym bvop ct W4.bvMul e1 e2 + W4.BaseIntegerRepr{} -> appBVOp sym bvop ct W4.intMul e1 e2 + _ -> case ct of + + COrd ord s ct -> case compatibleTypeRepr ct of + W4.BaseBVRepr{} -> case (ord, s) of + (LT, True) -> appBVOp sym bvop ct W4.bvSlt e1 e2 + (LT, False) -> appBVOp sym bvop ct W4.bvUlt e1 e2 + (GT, True) -> appBVOp sym bvop ct W4.bvSgt e1 e2 + (GT, False) -> appBVOp sym bvop ct W4.bvUgt e1 e2 + W4.BaseIntegerRepr{} -> case ord of + LT -> appBVOp sym bvop ct W4.intLt e1 e2 + GT -> appBVOp sym bvop ct (\sym' e1' e2' -> W4.intLt sym' e2' e1') e1 e2 + _ -> case ct of + + GE s ct -> case compatibleTypeRepr ct of + W4.BaseBVRepr{} | s -> appBVOp sym bvop ct W4.bvSge e1 e2 + W4.BaseBVRepr{} | False <- s -> appBVOp sym bvop ct W4.bvUge e1 e2 + W4.BaseIntegerRepr{} -> appBVOp sym bvop ct (\sym' e1' e2' -> W4.intLe sym' e2' e1') e1 e2 + _ -> case ct of + + LE s ct -> case compatibleTypeRepr ct of + W4.BaseBVRepr{} | s -> appBVOp sym bvop ct W4.bvSle e1 e2 + W4.BaseBVRepr{} | False <- s -> appBVOp sym bvop ct W4.bvUle e1 e2 + W4.BaseIntegerRepr{} -> appBVOp sym bvop ct W4.intLe e1 e2 + _ -> case ct of + +-- | Wrap an operation in an applied defined function. +wrapFn :: + forall sym args tp. + W4.IsSymExprBuilder sym => + Ctx.CurryAssignmentClass args => + sym -> + String -> + Ctx.Assignment (W4.SymExpr sym) args -> + (Ctx.CurryAssignment args (W4.SymExpr sym) (IO (W4.SymExpr sym tp))) -> + IO (W4.SymExpr sym tp) +wrapFn sym nm args f = do + let tps = TFC.fmapFC W4.exprType args + fn <- W4.inlineDefineFun sym (W4.safeSymbol nm) tps W4.NeverUnfold f + W4.applySymFn sym fn args + +mvBVOpFn :: + forall sym tp1 tp2 tp3. + W4.IsSymExprBuilder sym => + sym -> + BVOp tp1 tp2 tp3 -> + W4.SymExpr sym tp1 -> + W4.SymExpr sym tp2 -> + IO (W4.SymExpr sym tp3) +mvBVOpFn sym bvop e1 e2 = do + wrapFn sym (simpleShowBVOp bvop) (Ctx.empty Ctx.:> e1 Ctx.:> e2) (mkBVOp sym bvop) + +data SomeAppBVOp sym tp where + SomeAppBVOp :: forall sym tp1 tp2 tp3. BVOp tp1 tp2 tp3 -> W4.SymExpr sym tp1 -> W4.SymExpr sym tp2 -> SomeAppBVOp sym tp3 + +asSomeAppBVOp :: + Monad m => + W4.SymExpr sym tp -> + SimpT sym m (SomeAppBVOp sym tp) +asSomeAppBVOp e = withSym $ \_ -> do + W4B.FnApp fn (Ctx.Empty Ctx.:> arg1 Ctx.:> arg2) <- simpMaybe $ W4B.asNonceApp e + nm <- return $ show (W4B.symFnName fn) + case parseBVOp nm (W4.exprType arg1) (W4.exprType arg2) (W4.exprType e) of + Just bvop -> return $ (SomeAppBVOp bvop arg1 arg2) + Nothing -> fail $ "Failed to parse " ++ show nm ++ ": " ++ show e + +toSimpleConj :: + forall sym m t solver fs. + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + sym -> + BM.BoolMap (W4B.Expr t) -> + m [W4.Pred sym] +toSimpleConj sym bm = case BM.viewBoolMap bm of + BM.BoolMapTerms (t1:|ts) -> mapM go (t1:ts) + BM.BoolMapUnit -> return [W4.truePred sym] + BM.BoolMapDualUnit -> return [W4.falsePred sym] + where + go :: (W4.Pred sym, BM.Polarity) -> m (W4.Pred sym) + go (p, BM.Positive) = return p + go (p, BM.Negative) = IO.liftIO $ W4.notPred sym p + +collapseAppBVOps :: + forall sym m tp. + IO.MonadIO m => + (forall tp'. W4.SymExpr sym tp' -> m (W4.SymExpr sym tp')) -> + W4B.App (W4.SymExpr sym) tp -> + SimpT sym m (W4.SymExpr sym tp) +collapseAppBVOps rec app = case app of + W4B.NotPred e -> go e + W4B.ConjPred bm -> withSym $ \sym -> do + ps <- toSimpleConj sym bm + (p:ps') <- lift $ mapM rec ps + IO.liftIO $ foldM (W4.andPred sym) p ps' + _ -> withSym $ \_ -> fail $ "not negated predicate" ++ show app + where + go :: W4.Pred sym -> SimpT sym m (W4.Pred sym) + go e = withSym $ \sym -> do + e' <- lift $ rec e + SomeAppBVOp bvop e1 e2 <- asSomeAppBVOp e' + let bvop' = notBVOp bvop + IO.liftIO $ mvBVOpFn sym bvop' e1 e2 + +collapseBVOps :: + IO.MonadIO m => + sym ~ (W4B.ExprBuilder t solver fs) => + SimpStrategy sym m +collapseBVOps = liftSimpTStrategy collapseAppBVOps + + +