Skip to content

Commit

Permalink
agent: fail if non-unique connection IDs are passed to sendMessages (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
epoberezkin authored May 23, 2024
1 parent 984394d commit 6309f92
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 17 deletions.
30 changes: 21 additions & 9 deletions src/Simplex/Messaging/Agent.hs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text (Text)
import qualified Data.Text as T
import Data.Time.Clock
Expand Down Expand Up @@ -354,12 +356,12 @@ sendMessage c = withAgentEnv c .:: sendMessage' c
type MsgReq = (ConnId, PQEncryption, MsgFlags, MsgBody)

-- | Send multiple messages to different connections (SEND command)
sendMessages :: AgentClient -> [MsgReq] -> IO [Either AgentErrorType (AgentMsgId, PQEncryption)]
sendMessages c = withAgentEnv' c . sendMessages' c
sendMessages :: AgentClient -> [MsgReq] -> AE [Either AgentErrorType (AgentMsgId, PQEncryption)]
sendMessages c = withAgentEnv c . sendMessages' c
{-# INLINE sendMessages #-}

sendMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> IO (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB c = withAgentEnv' c . sendMessagesB' c
sendMessagesB :: Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AE (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB c = withAgentEnv c . sendMessagesB' c
{-# INLINE sendMessagesB #-}

ackMessage :: AgentClient -> ConnId -> AgentMsgId -> Maybe MsgReceiptInfo -> AE ()
Expand Down Expand Up @@ -1033,16 +1035,27 @@ getNotificationMessage' c nonce encNtfInfo = do

-- | Send message to the connection (SEND command) in Reader monad
sendMessage' :: AgentClient -> ConnId -> PQEncryption -> MsgFlags -> MsgBody -> AM (AgentMsgId, PQEncryption)
sendMessage' c connId pqEnc msgFlags msg = ExceptT $ runIdentity <$> sendMessagesB' c (Identity (Right (connId, pqEnc, msgFlags, msg)))
sendMessage' c connId pqEnc msgFlags msg = ExceptT $ runIdentity <$> sendMessagesB_ c (Identity (Right (connId, pqEnc, msgFlags, msg))) (S.singleton connId)
{-# INLINE sendMessage' #-}

-- | Send multiple messages to different connections (SEND command) in Reader monad
sendMessages' :: AgentClient -> [MsgReq] -> AM' [Either AgentErrorType (AgentMsgId, PQEncryption)]
sendMessages' :: AgentClient -> [MsgReq] -> AM [Either AgentErrorType (AgentMsgId, PQEncryption)]
sendMessages' c = sendMessagesB' c . map Right
{-# INLINE sendMessages' #-}

sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do
sendMessagesB' :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> AM (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB' c reqs = do
connIds <- liftEither $ foldl' addConnId (Right S.empty) reqs
lift $ sendMessagesB_ c reqs connIds
where
addConnId s@(Right s') (Right (connId, _, _, _))
| B.null connId = s
| connId `S.notMember` s' = Right $ S.insert connId s'
| otherwise = Left $ INTERNAL "sendMessages: duplicate connection ID"
addConnId s _ = s

sendMessagesB_ :: forall t. Traversable t => AgentClient -> t (Either AgentErrorType MsgReq) -> Set ConnId -> AM' (t (Either AgentErrorType (AgentMsgId, PQEncryption)))
sendMessagesB_ c reqs connIds = withConnLocks c connIds "sendMessages" $ do
reqs' <- withStoreBatch c (\db -> fmap (bindRight $ \req@(connId, _, _, _) -> bimap storeError (req,) <$> getConn db connId) reqs)
let (toEnable, reqs'') = mapAccumL prepareConn [] reqs'
void $ withStoreBatch' c $ \db -> map (\connId -> setConnPQSupport db connId PQSupportOn) toEnable
Expand All @@ -1064,7 +1077,6 @@ sendMessagesB' c reqs = withConnLocks c connIds "sendMessages" $ do
let cData' = cData {pqSupport = PQSupportOn} :: ConnData
in (connId : acc, Right (cData', sqs, Just pqEnc, msgFlags, A_MSG msg))
| otherwise = (acc, Right (cData, sqs, Just pqEnc, msgFlags, A_MSG msg))
connIds = map (\(connId, _, _, _) -> connId) $ rights $ toList reqs

-- / async command processing v v v

Expand Down
6 changes: 3 additions & 3 deletions src/Simplex/Messaging/Agent/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -826,15 +826,15 @@ withInvLock' :: AgentClient -> ByteString -> String -> AM' a -> AM' a
withInvLock' AgentClient {invLocks} = withLockMap invLocks
{-# INLINE withInvLock' #-}

withConnLocks :: AgentClient -> [ConnId] -> String -> AM' a -> AM' a
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks . filter (not . B.null)
withConnLocks :: AgentClient -> Set ConnId -> String -> AM' a -> AM' a
withConnLocks AgentClient {connLocks} = withLocksMap_ connLocks
{-# INLINE withConnLocks #-}

withLockMap :: (Ord k, MonadUnliftIO m) => TMap k Lock -> k -> String -> m a -> m a
withLockMap = withGetLock . getMapLock
{-# INLINE withLockMap #-}

withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> [k] -> String -> m a -> m a
withLocksMap_ :: (Ord k, MonadUnliftIO m) => TMap k Lock -> Set k -> String -> m a -> m a
withLocksMap_ = withGetLocks . getMapLock
{-# INLINE withLocksMap_ #-}

Expand Down
10 changes: 5 additions & 5 deletions src/Simplex/Messaging/Agent/Lock.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import Control.Monad (void)
import Control.Monad.Except (ExceptT (..), runExceptT)
import Control.Monad.IO.Unlift
import Data.Functor (($>))
import Data.Set (Set)
import qualified Data.Set as S
import UnliftIO.Async (forConcurrently)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
Expand Down Expand Up @@ -39,13 +41,11 @@ withGetLock getLock key name a =
(atomically . takeTMVar)
(const a)

withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> [k] -> String -> m a -> m a
withGetLocks :: MonadUnliftIO m => (k -> STM Lock) -> Set k -> String -> m a -> m a
withGetLocks getLock keys name = E.bracket holdLocks releaseLocks . const
where
holdLocks = forConcurrently keys $ \key -> atomically $ getPutLock getLock key name
-- only this withGetLocks would be holding the locks,
-- so it's safe to combine all lock releases into one transaction
releaseLocks = atomically . mapM_ takeTMVar
holdLocks = forConcurrently (S.toList keys) $ \key -> atomically $ getPutLock getLock key name
releaseLocks = mapM_ (atomically . takeTMVar)

-- getLock and putTMVar can be in one transaction on the assumption that getLock doesn't write in case the lock already exists,
-- and in case it is created and added to some shared resource (we use TMap) it also helps avoid contention for the newly created lock.
Expand Down

0 comments on commit 6309f92

Please sign in to comment.