Skip to content

Commit

Permalink
add SingleNodeEntry datatype to represent explicitly single-sided nodes
Browse files Browse the repository at this point in the history
This clears up some of the ambiguity in the SyncPoint data which otherwise
had lots of implicit assumptions about the run-time contents of the NodeEntry
for each sync point

A SingleNodeEntry is functionally equivalent to a NodeEntry that's
annotated with the fact that it's single-sided
  • Loading branch information
danmatichuk committed May 15, 2024
1 parent ac18ddd commit 483b7df
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 63 deletions.
10 changes: 10 additions & 0 deletions src/Data/Parameterized/SetF.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ module Data.Parameterized.SetF
, union
, unions
, null
, toSet
, ppSetF
) where

Expand All @@ -55,6 +56,7 @@ import Prettyprinter ( (<+>) )

import Data.Set (Set)
import qualified Data.Set as S
import Unsafe.Coerce (unsafeCoerce)

newtype AsOrd f tp where
AsOrd :: { unAsOrd :: f tp } -> AsOrd f tp
Expand Down Expand Up @@ -130,6 +132,14 @@ null ::
SetF f tp -> Bool
null (SetF es) = S.null es

-- | Convert a 'SetF' to a 'Set', under the assumption
-- that the 'OrdF' and 'Ord' instances are consistent.
-- This uses coercion rather than re-building the set,
-- which is sound given the above assumption.
toSet ::
(OrdF f, Ord (f tp)) => SetF f tp -> Set (f tp)
toSet (SetF s) = unsafeCoerce s

ppSetF ::
(f tp -> PP.Doc a) ->
SetF f tp ->
Expand Down
18 changes: 18 additions & 0 deletions src/Pate/PatchPair.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ module Pate.PatchPair (
, ppPatchPair'
, forBins
, update
, insertWith
, forBinsC
, catBins
, get
Expand Down Expand Up @@ -332,6 +333,23 @@ update src f = do
(PatchPair _ b, PatchPairOriginal a) -> return $ PatchPair a b
(PatchPair a _, PatchPairPatched b) -> return $ PatchPair a b

-- | Add a value to a 'PatchPair', combining it with an existing entry if
-- present using the given function (i.e. similar to Map.insertWith)
insertWith ::
PB.WhichBinaryRepr bin ->
f bin ->
(f bin -> f bin -> f bin) ->
PatchPair f ->
PatchPair f
insertWith bin v f = \case
PatchPair vO vP | PB.OriginalRepr <- bin -> PatchPair (f v vO) vP
PatchPair vO vP | PB.PatchedRepr <- bin -> PatchPair vO (f v vP)
PatchPairSingle bin' v' -> case (bin, bin') of
(PB.OriginalRepr, PB.OriginalRepr) -> PatchPairSingle bin (f v v')
(PB.PatchedRepr, PB.PatchedRepr) -> PatchPairSingle bin (f v v')
(PB.PatchedRepr, PB.OriginalRepr) -> PatchPair v' v
( PB.OriginalRepr, PB.PatchedRepr) -> PatchPair v v'

-- | Specialization of 'PatchPair' to types which are not indexed on 'PB.WhichBinary'
type PatchPairC tp = PatchPair (Const tp)

Expand Down
58 changes: 23 additions & 35 deletions src/Pate/Verification/PairGraph.hs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ import qualified Pate.Location as PL
import qualified Pate.Address as PAd
import Data.Foldable (find)
import Data.Parameterized.PairF
import Data.Parameterized.SetF (SetF)
import qualified Data.Parameterized.SetF as SetF

-- | Gas is used to ensure that our fixpoint computation terminates
-- in a reasonable amount of time. Gas is expended each time
Expand Down Expand Up @@ -476,7 +478,7 @@ data SyncPoint arch =
-- The "merge" nodes are the cross product of these, since we
-- must assume that any combination of exits is possible
-- during the single-sided analysis
{ syncStartNodes :: PPa.PatchPairC (Set (GraphNode arch))
{ syncStartNodes :: PPa.PatchPair (SetF (SingleNodeEntry arch))
, syncCutAddresses :: PPa.PatchPairC (PAd.ConcreteAddress arch)
}

Expand Down Expand Up @@ -921,10 +923,10 @@ getSyncPoints ::
forall sym arch bin.
PBi.WhichBinaryRepr bin ->
GraphNode arch ->
PairGraphM sym arch (Set (GraphNode arch))
PairGraphM sym arch (Set (SingleNodeEntry arch bin))
getSyncPoints bin nd = do
(SyncPoint syncPair _) <- lookupPairGraph @sym pairGraphSyncPoints nd
PPa.getC bin syncPair
SetF.toSet <$> PPa.get bin syncPair

getSyncAddress ::
forall sym arch bin.
Expand All @@ -950,40 +952,31 @@ updateSyncPoint nd f = do
getCombinedSyncPoints ::
forall sym arch.
GraphNode arch ->
PairGraphM sym arch ([((GraphNode arch, GraphNode arch), GraphNode arch)], SyncPoint arch)
PairGraphM sym arch ([((SingleNodeEntry arch PBi.Original, SingleNodeEntry arch PBi.Patched), GraphNode arch)], SyncPoint arch)
getCombinedSyncPoints ndDiv = do
sync@(SyncPoint syncSet _) <- lookupPairGraph @sym pairGraphSyncPoints ndDiv
case syncSet of
PPa.PatchPairC ndsO ndsP -> do
all_pairs <- forM (Set.toList $ Set.cartesianProduct ndsO ndsP) $ \(ndO, ndP) -> do
PPa.PatchPair ndsO ndsP -> do
all_pairs <- forM (Set.toList $ Set.cartesianProduct (SetF.toSet ndsO) (SetF.toSet ndsP)) $ \(ndO, ndP) -> do
combined <- pgMaybe "failed to combine nodes" $ combineNodes ndO ndP
return $ ((ndO, ndP), combined)
return (all_pairs, sync)
_ -> return ([], sync)

-- | Compute a merged node for two diverging nodes
-- FIXME: do we need to support mismatched node kinds here?
combineNodes :: GraphNode arch -> GraphNode arch -> Maybe (GraphNode arch)
combineNodes :: SingleNodeEntry arch bin -> SingleNodeEntry arch (PBi.OtherBinary bin) -> Maybe (GraphNode arch)
combineNodes node1 node2 = do
(nodeO, nodeP) <- case PPa.get PBi.OriginalRepr (graphNodeBlocks node1) of
Just{} -> return (node1, node2)
Nothing -> return (node2, node1)
let ndPair = PPa.mkPair (singleEntryBin node1) node1 node2
nodeO <- PPa.get PBi.OriginalRepr ndPair
nodeP <- PPa.get PBi.PatchedRepr ndPair
-- it only makes sense to combine nodes that share a divergence point,
-- where that divergence point will be used as the calling context for the
-- merged point
divergeO <- divergePoint $ nodeContext nodeO
divergeP <- divergePoint $ nodeContext nodeP
let divergeO = singleNodeDivergence nodeO
let divergeP = singleNodeDivergence nodeP
guard $ divergeO == divergeP
case (nodeO, nodeP) of
(GraphNode nodeO', GraphNode nodeP') -> do
blocksO <- PPa.get PBi.OriginalRepr (nodeBlocks nodeO')
blocksP <- PPa.get PBi.PatchedRepr (nodeBlocks nodeP')
return $ GraphNode $ mkMergedNodeEntry divergeO blocksO blocksP
(ReturnNode nodeO', ReturnNode nodeP') -> do
fnsO <- PPa.get PBi.OriginalRepr (nodeFuns nodeO')
fnsP <- PPa.get PBi.PatchedRepr (nodeFuns nodeP')
return $ ReturnNode $ mkMergedNodeReturn divergeO fnsO fnsP
_ -> Nothing
return $ GraphNode $ mkMergedNodeEntry divergeO (singleNodeBlock nodeO) (singleNodeBlock nodeP)

singleNodeRepr :: GraphNode arch -> Maybe (Some (PBi.WhichBinaryRepr))
singleNodeRepr nd = case graphNodeBlocks nd of
Expand All @@ -993,17 +986,12 @@ singleNodeRepr nd = case graphNodeBlocks nd of
addSyncNode ::
forall sym arch.
GraphNode arch {- ^ The divergent node -} ->
GraphNode arch {- ^ the sync node -} ->
NodeEntry arch {- ^ the sync node -} ->
PairGraphM sym arch ()
addSyncNode ndDiv ndSync = do
Pair bin _ <- PPa.asSingleton (graphNodeBlocks ndSync)
let ndSync' = PPa.PatchPairSingle bin (Const ndSync)
Some nd <- asSingleNodeEntry ndSync
(SyncPoint sp addrs) <- lookupPairGraph @sym pairGraphSyncPoints ndDiv
sp' <- PPa.update sp $ \bin' -> do
s <- PPa.getC bin' ndSync'
case PPa.getC bin' sp of
Nothing -> return $ (Const (Set.singleton s))
Just s' -> return $ (Const (Set.insert s s'))
let sp' = PPa.insertWith (singleEntryBin nd) (SetF.singleton nd) SetF.union sp
modify $ \pg -> pg { pairGraphSyncPoints = Map.insert ndDiv (SyncPoint sp' addrs) (pairGraphSyncPoints pg) }

tryPG :: PairGraphM sym arch a -> PairGraphM sym arch (Maybe a)
Expand All @@ -1022,7 +1010,7 @@ setSyncAddress ndDiv bin syncAddr = do
addrs' <- PPa.update addrs $ \bin' -> PPa.get bin' syncAddr'
modify $ \pg -> pg{ pairGraphSyncPoints = Map.insert ndDiv (SyncPoint sp addrs') (pairGraphSyncPoints pg) }
Nothing -> do
let sp = PPa.mkPair bin (Const Set.empty) (Const Set.empty)
let sp = PPa.mkPair bin SetF.empty SetF.empty
modify $ \pg -> pg{pairGraphSyncPoints = Map.insert ndDiv (SyncPoint sp syncAddr') (pairGraphSyncPoints pg) }

-- | Add a node back to the worklist to be re-analyzed if there is
Expand Down Expand Up @@ -1178,15 +1166,15 @@ checkForNodeSync ::
[(PPa.PatchPair (PB.BlockTarget arch))] ->
PairGraphM sym arch Bool
checkForNodeSync ne targets_pairs = fmap (fromMaybe False) $ tryPG $ do
Just (Some bin) <- return $ isSingleNodeEntry ne

Just dp <- return $ getDivergePoint (GraphNode ne)
Some sne <- asSingleNodeEntry ne
let bin = singleEntryBin sne
let dp = singleNodeDivergence sne
syncPoints <- getSyncPoints bin dp
syncAddr <- getSyncAddress bin dp
thisAddr <- fmap PB.concreteAddress $ PPa.get bin (nodeBlocks ne)
-- if this node is already defined as sync point then we don't
-- have to check anything else
if | Set.member (GraphNode ne) syncPoints -> return True
if | Set.member sne syncPoints -> return True
-- similarly if this is exactly the sync address then we should
-- stop the single-sided analysis
| thisAddr == syncAddr -> return True
Expand Down
115 changes: 105 additions & 10 deletions src/Pate/Verification/PairGraph/Node.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ module Pate.Verification.PairGraph.Node (
, pattern GraphNodeEntry
, pattern GraphNodeReturn
, nodeContext
, graphNodeContext
, divergePoint
, graphNodeBlocks
, mkNodeEntry
Expand All @@ -49,6 +50,13 @@ module Pate.Verification.PairGraph.Node (
, eqUptoDivergePoint
, mkMergedNodeEntry
, mkMergedNodeReturn
, SingleNodeEntry
, singleEntryBin
, asSingleNodeEntry
, singleToNodeEntry
, combineSingleEntries
, singleNodeBlock
, singleNodeDivergence
) where

import Prettyprinter ( Pretty(..), sep, (<+>), Doc )
Expand All @@ -64,6 +72,9 @@ import qualified Pate.Binary as PB
import qualified Prettyprinter as PP
import Data.Parameterized (Some(..), Pair (..))
import qualified What4.JSON as W4S
import Control.Monad (guard)
import Data.Parameterized.Classes
import Pate.Panic

-- | Nodes in the program graph consist either of a pair of
-- program points (GraphNode), or a synthetic node representing
Expand Down Expand Up @@ -91,28 +102,48 @@ instance PA.ValidArch arch => W4S.W4Serializable sym (GraphNode arch) where
instance PA.ValidArch arch => W4S.W4Serializable sym (NodeEntry arch) where
w4Serialize r = return $ JSON.toJSON r

-- A frozen binary
data NodeEntry arch =
NodeEntry { graphNodeContext :: CallingContext arch, nodeBlocks :: PB.BlockPair arch }
data NodeContent arch e =
NodeContent { nodeContentCtx :: CallingContext arch, nodeContent :: e }
deriving (Eq, Ord)

data NodeReturn arch =
NodeReturn { returnNodeContext :: CallingContext arch, nodeFuns :: PB.FunPair arch }
deriving (Eq, Ord)
type NodeEntry arch = NodeContent arch (PB.BlockPair arch)

pattern NodeEntry :: CallingContext arch -> PB.BlockPair arch -> NodeEntry arch
pattern NodeEntry ctx bp = NodeContent ctx bp
{-# COMPLETE NodeEntry #-}


nodeBlocks :: NodeEntry arch -> PB.BlockPair arch
nodeBlocks = nodeContent

graphNodeContext :: NodeEntry arch -> CallingContext arch
graphNodeContext = nodeContentCtx

type NodeReturn arch = NodeContent arch (PB.FunPair arch)

nodeFuns :: NodeReturn arch -> PB.FunPair arch
nodeFuns = nodeContent

returnNodeContext :: NodeReturn arch -> CallingContext arch
returnNodeContext = nodeContentCtx

pattern NodeReturn :: CallingContext arch -> PB.FunPair arch -> NodeReturn arch
pattern NodeReturn ctx bp = NodeContent ctx bp
{-# COMPLETE NodeReturn #-}

graphNodeBlocks :: GraphNode arch -> PB.BlockPair arch
graphNodeBlocks (GraphNode ne) = nodeBlocks ne
graphNodeBlocks (ReturnNode ret) = TF.fmapF PB.functionEntryToConcreteBlock (nodeFuns ret)

nodeContext :: GraphNode arch -> CallingContext arch
nodeContext (GraphNode nd) = graphNodeContext nd
nodeContext (ReturnNode ret) = returnNodeContext ret
nodeContext (GraphNode nd) = nodeContentCtx nd
nodeContext (ReturnNode ret) = nodeContentCtx ret

pattern GraphNodeEntry :: PB.BlockPair arch -> GraphNode arch
pattern GraphNodeEntry blks <- (GraphNode (NodeEntry _ blks))
pattern GraphNodeEntry blks <- (GraphNode (NodeContent _ blks))

pattern GraphNodeReturn :: PB.FunPair arch -> GraphNode arch
pattern GraphNodeReturn blks <- (ReturnNode (NodeReturn _ blks))
pattern GraphNodeReturn blks <- (ReturnNode (NodeContent _ blks))

{-# COMPLETE GraphNodeEntry, GraphNodeReturn #-}

Expand Down Expand Up @@ -191,6 +222,7 @@ mkMergedNodeReturn nd fnO fnP = NodeReturn (CallingContext cctx (Just nd)) (PPa.
where
CallingContext cctx _ = nodeContext nd


-- | Project the given 'NodeReturn' into a singleton node for the given binary
toSingleReturn :: PPa.PatchPairM m => PB.WhichBinaryRepr bin -> GraphNode arch -> NodeReturn arch -> m (NodeReturn arch)
toSingleReturn bin divergedAt (NodeReturn ctx fPair) = do
Expand Down Expand Up @@ -352,3 +384,66 @@ instance forall sym arch. PA.ValidArch arch => IsTraceNode '(sym, arch) "entryno
prettyNode () = pretty
nodeTags = mkTags @'(sym,arch) @"entrynode" [Simplified, Summary]
jsonNode _ = nodeToJSON @'(sym,arch) @"entrynode"

-- | Equivalent to a 'NodeEntry' but necessarily a single-sided node.
-- Converting a 'SingleNodeEntry' to a 'NodeEntry' is always defined,
-- while converting a 'NodeEntry' to a 'SingleNodeEntry' is partial.
data SingleNodeEntry arch bin =
SingleNodeEntry
{ singleEntryBin :: PB.WhichBinaryRepr bin
, _singleEntry :: NodeContent arch (PB.ConcreteBlock arch bin)
}

instance TestEquality (SingleNodeEntry arch) where
testEquality se1 se2 | EQF <- compareF se1 se2 = Just Refl
testEquality _ _ = Nothing

instance Eq (SingleNodeEntry arch bin) where
se1 == se2 = compare se1 se2 == EQ

instance Ord (SingleNodeEntry arch bin) where
compare (SingleNodeEntry _ se1) (SingleNodeEntry _ se2) = compare se1 se2

instance OrdF (SingleNodeEntry arch) where
compareF (SingleNodeEntry bin1 se1) (SingleNodeEntry bin2 se2) =
lexCompareF bin1 bin2 $ fromOrdering (compare se1 se2)

instance PA.ValidArch arch => Show (SingleNodeEntry arch bin) where
show e = show (singleToNodeEntry e)

asSingleNodeEntry :: PPa.PatchPairM m => NodeEntry arch -> m (Some (SingleNodeEntry arch))
asSingleNodeEntry (NodeEntry cctx bPair) = do
Pair bin blk <- PPa.asSingleton bPair
case divergePoint cctx of
Just{} -> return $ Some (SingleNodeEntry bin (NodeContent cctx blk))
Nothing -> PPa.throwPairErr

singleToNodeEntry :: SingleNodeEntry arch bin -> NodeEntry arch
singleToNodeEntry (SingleNodeEntry bin (NodeContent cctx v)) =
NodeEntry cctx (PPa.PatchPairSingle bin v)

singleNodeBlock :: SingleNodeEntry arch bin -> PB.ConcreteBlock arch bin
singleNodeBlock (SingleNodeEntry _ (NodeContent _ blk)) = blk

singleNodeDivergence :: SingleNodeEntry arch bin -> GraphNode arch
singleNodeDivergence (SingleNodeEntry _ (NodeContent cctx _)) = case divergePoint cctx of
Just dp -> dp
Nothing -> panic Verifier "singleNodeDivergence" ["Unexpected missing divergence point"]

-- | Create a combined two-sided 'NodeEntry' based on
-- a pair of single-sided entries. The given entries
-- must have the same diverge point (returns 'Nothing' otherwise),
-- and the calling context of the resulting node is inherited from
-- that point (i.e. any additional context accumulated during
-- the either single-sided analysis is discarded)
combineSingleEntries ::
SingleNodeEntry arch PB.Original ->
SingleNodeEntry arch PB.Patched ->
Maybe (NodeEntry arch)
combineSingleEntries (SingleNodeEntry _ eO) (SingleNodeEntry _ eP) = do
GraphNode divergeO <- divergePoint $ nodeContentCtx eO
GraphNode divergeP <- divergePoint $ nodeContentCtx eP
guard $ divergeO == divergeP
let blksO = nodeContent eO
let blksP = nodeContent eP
return $ mkNodeEntry divergeO (PPa.PatchPair blksO blksP)
Loading

0 comments on commit 483b7df

Please sign in to comment.