diff --git a/simplexmq.cabal b/simplexmq.cabal index 97ada9f55..30b265592 100644 --- a/simplexmq.cabal +++ b/simplexmq.cabal @@ -67,6 +67,11 @@ flag client_library manual: True default: False +flag client_postgres + description: Build with PostgreSQL instead of SQLite. + manual: True + default: False + library exposed-modules: Simplex.FileTransfer.Agent @@ -90,47 +95,11 @@ library Simplex.Messaging.Agent.RetryInterval Simplex.Messaging.Agent.Stats Simplex.Messaging.Agent.Store - Simplex.Messaging.Agent.Store.SQLite - Simplex.Messaging.Agent.Store.SQLite.Common - Simplex.Messaging.Agent.Store.SQLite.DB - Simplex.Messaging.Agent.Store.SQLite.Migrations - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220101_initial - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220301_snd_queue_keys - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220322_notifications - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220608_v2 - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230217_server_key_hash - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230223_files - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230320_retry_state - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230401_snd_files - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230510_files_pending_replicas_indexes - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230516_encrypted_rcv_message_hashes - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230531_switch_status - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230615_ratchet_sync - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230701_delivery_receipts - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230720_delete_expired_messages - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230722_indexes - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230814_indexes - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230829_crypto_files - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231222_command_created_at - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240417_rcv_files_approved_relays - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240624_snd_secure - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240702_servers_stats - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240930_ntf_tokens_to_delete - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20241007_rcv_queues_last_broker_ts - Simplex.Messaging.Agent.Store.SQLite.Migrations.M20241224_ratchet_e2e_snd_params + Simplex.Messaging.Agent.Store.AgentStore + Simplex.Messaging.Agent.Store.Common + Simplex.Messaging.Agent.Store.DB + Simplex.Messaging.Agent.Store.Migrations + Simplex.Messaging.Agent.Store.Shared Simplex.Messaging.Agent.TRcvQueues Simplex.Messaging.Client Simplex.Messaging.Client.Agent @@ -175,6 +144,59 @@ library Simplex.RemoteControl.Discovery.Multicast Simplex.RemoteControl.Invitation Simplex.RemoteControl.Types + if flag(client_postgres) + exposed-modules: + Simplex.Messaging.Agent.Store.Postgres + Simplex.Messaging.Agent.Store.Postgres.Common + Simplex.Messaging.Agent.Store.Postgres.DB + Simplex.Messaging.Agent.Store.Postgres.Migrations + Simplex.Messaging.Agent.Store.Postgres.Migrations.M20241210_initial + if !flag(client_library) + exposed-modules: + Simplex.Messaging.Agent.Store.Postgres.Util + else + exposed-modules: + Simplex.Messaging.Agent.Store.SQLite + Simplex.Messaging.Agent.Store.SQLite.Common + Simplex.Messaging.Agent.Store.SQLite.DB + Simplex.Messaging.Agent.Store.SQLite.Migrations + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220101_initial + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220301_snd_queue_keys + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220322_notifications + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220608_v2 + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220811_onion_hosts + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220817_connection_ntfs + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220905_commands + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220915_connection_queues + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230110_users + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230117_fkey_indexes + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230120_delete_errors + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230217_server_key_hash + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230223_files + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230320_retry_state + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230401_snd_files + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230510_files_pending_replicas_indexes + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230516_encrypted_rcv_message_hashes + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230531_switch_status + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230615_ratchet_sync + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230701_delivery_receipts + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230720_delete_expired_messages + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230722_indexes + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230814_indexes + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20230829_crypto_files + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231222_command_created_at + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20231225_failed_work_items + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240121_message_delivery_indexes + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240124_file_redirect + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240223_connections_wait_delivery + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240225_ratchet_kem + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240417_rcv_files_approved_relays + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240624_snd_secure + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240702_servers_stats + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240930_ntf_tokens_to_delete + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20241007_rcv_queues_last_broker_ts + Simplex.Messaging.Agent.Store.SQLite.Migrations.M20241224_ratchet_e2e_snd_params if !flag(client_library) exposed-modules: Simplex.FileTransfer.Client.Main @@ -243,7 +265,6 @@ library , crypton-x509-validation ==1.6.* , cryptostore ==0.3.* , data-default ==0.7.* - , direct-sqlcipher ==2.3.* , directory ==1.3.* , filepath ==1.4.* , hourglass ==0.2.* @@ -263,7 +284,6 @@ library , random >=1.1 && <1.3 , simple-logger ==0.1.* , socks ==0.6.* - , sqlcipher-simple ==0.4.* , stm ==2.5.* , temporary ==1.3.* , time ==1.12.* @@ -282,6 +302,16 @@ library case-insensitive ==1.2.* , hashable ==1.4.* , websockets ==0.12.* + if flag(client_postgres) + build-depends: + postgresql-libpq >=0.10.0.0 + , postgresql-simple ==0.7.* + , raw-strings-qq ==1.1.* + cpp-options: -DdbPostgres + else + build-depends: + direct-sqlcipher ==2.3.* + , sqlcipher-simple ==0.4.* if impl(ghc >= 9.6.2) build-depends: bytestring ==0.11.* @@ -384,10 +414,7 @@ test-suite simplexmq-test AgentTests.EqInstances AgentTests.FunctionalAPITests AgentTests.MigrationTests - AgentTests.NotificationTests - AgentTests.SchemaDump AgentTests.ServerChoice - AgentTests.SQLiteTests CLITests CoreTests.BatchingTests CoreTests.CryptoFileTests @@ -401,6 +428,7 @@ test-suite simplexmq-test CoreTests.UtilTests CoreTests.VersionRangeTests FileDescriptionTests + Fixtures NtfClient NtfServerTests RemoteControl @@ -416,6 +444,11 @@ test-suite simplexmq-test Static Static.Embedded Paths_simplexmq + if !flag(client_postgres) + other-modules: + AgentTests.NotificationTests + AgentTests.SchemaDump + AgentTests.SQLiteTests hs-source-dirs: tests apps/smp-server/web @@ -456,7 +489,6 @@ test-suite simplexmq-test , silently ==1.2.* , simple-logger , simplexmq - , sqlcipher-simple , stm , text , time @@ -471,3 +503,12 @@ test-suite simplexmq-test , warp-tls , yaml default-language: Haskell2010 + if flag(client_postgres) + build-depends: + postgresql-libpq >=0.10.0.0 + , postgresql-simple ==0.7.* + , raw-strings-qq ==1.1.* + cpp-options: -DdbPostgres + else + build-depends: + sqlcipher-simple diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index 9506b465c..cfee308bc 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -65,8 +65,8 @@ import Simplex.Messaging.Agent.Env.SQLite import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Stats -import Simplex.Messaging.Agent.Store.SQLite -import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB +import Simplex.Messaging.Agent.Store.AgentStore +import qualified Simplex.Messaging.Agent.Store.DB as DB import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs) import qualified Simplex.Messaging.Crypto.File as CF diff --git a/src/Simplex/FileTransfer/Description.hs b/src/Simplex/FileTransfer/Description.hs index 865daf23d..0c7c42ab4 100644 --- a/src/Simplex/FileTransfer/Description.hs +++ b/src/Simplex/FileTransfer/Description.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DerivingStrategies #-} @@ -9,6 +10,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} @@ -66,17 +68,23 @@ import Data.Text (Text) import Data.Text.Encoding (encodeUtf8) import Data.Word (Word32) import qualified Data.Yaml as Y -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.FileTransfer.Chunks import Simplex.FileTransfer.Protocol import Simplex.Messaging.Agent.QueryString +import Simplex.Messaging.Agent.Store.DB (Binary (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (defaultJSON, parseAll) import Simplex.Messaging.Protocol (XFTPServer) import Simplex.Messaging.ServiceScheme (ServiceScheme (..)) import Simplex.Messaging.Util (bshow, safeDecodeUtf8, (<$?>)) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif data FileDescription (p :: FileParty) = FileDescription { party :: SFileParty p, @@ -113,6 +121,9 @@ fdSeparator = "################################\n" newtype FileDigest = FileDigest {unFileDigest :: ByteString} deriving (Eq, Show) + deriving newtype (FromField) + +instance ToField FileDigest where toField (FileDigest s) = toField $ Binary s instance StrEncoding FileDigest where strEncode (FileDigest fd) = strEncode fd @@ -126,10 +137,6 @@ instance ToJSON FileDigest where toJSON = strToJSON toEncoding = strToJEncoding -instance FromField FileDigest where fromField f = FileDigest <$> fromField f - -instance ToField FileDigest where toField (FileDigest s) = toField s - data FileChunk = FileChunk { chunkNo :: Int, chunkSize :: FileSize Word32, @@ -307,9 +314,9 @@ instance (Integral a, Show a) => StrEncoding (FileSize a) where instance (Integral a, Show a) => IsString (FileSize a) where fromString = either error id . strDecode . B.pack -instance FromField a => FromField (FileSize a) where fromField f = FileSize <$> fromField f +deriving newtype instance FromField a => FromField (FileSize a) -instance ToField a => ToField (FileSize a) where toField (FileSize s) = toField s +deriving newtype instance ToField a => ToField (FileSize a) groupReplicasByServer :: FileSize Word32 -> [FileChunk] -> [NonEmpty FileServerReplica] groupReplicasByServer defChunkSize = diff --git a/src/Simplex/FileTransfer/Types.hs b/src/Simplex/FileTransfer/Types.hs index 571cf3748..c18d31779 100644 --- a/src/Simplex/FileTransfer/Types.hs +++ b/src/Simplex/FileTransfer/Types.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} @@ -13,8 +14,6 @@ import Data.Int (Int64) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Data.Word (Word32) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.FileTransfer.Client (XFTPChunkSpec (..)) import Simplex.FileTransfer.Description import qualified Simplex.Messaging.Crypto as C @@ -24,6 +23,13 @@ import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol (XFTPServer) import System.FilePath (()) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif type RcvFileId = ByteString -- Agent entity ID diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 1f839b297..a7371b935 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -167,9 +167,11 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Stats import Simplex.Messaging.Agent.Store -import Simplex.Messaging.Agent.Store.SQLite -import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB -import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations +import Simplex.Messaging.Agent.Store.AgentStore +import Simplex.Messaging.Agent.Store.Common (DBStore) +import qualified Simplex.Messaging.Agent.Store.DB as DB +import qualified Simplex.Messaging.Agent.Store.Migrations as Migrations +import Simplex.Messaging.Agent.Store.Shared (UpMigration (..), upMigration) import Simplex.Messaging.Client (SMPClientError, ServerTransmission (..), ServerTransmissionBatch, temporaryClientError, unexpectedResponse) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile, CryptoFileArgs) @@ -200,11 +202,11 @@ import UnliftIO.STM type AE a = ExceptT AgentErrorType IO a -- | Creates an SMP agent client instance -getSMPAgentClient :: AgentConfig -> InitialAgentServers -> SQLiteStore -> Bool -> IO AgentClient +getSMPAgentClient :: AgentConfig -> InitialAgentServers -> DBStore -> Bool -> IO AgentClient getSMPAgentClient = getSMPAgentClient_ 1 {-# INLINE getSMPAgentClient #-} -getSMPAgentClient_ :: Int -> AgentConfig -> InitialAgentServers -> SQLiteStore -> Bool -> IO AgentClient +getSMPAgentClient_ :: Int -> AgentConfig -> InitialAgentServers -> DBStore -> Bool -> IO AgentClient getSMPAgentClient_ clientId cfg initServers@InitialAgentServers {smp, xftp} store backgroundMode = newSMPAgentEnv cfg store >>= runReaderT runAgent where @@ -277,7 +279,7 @@ disposeAgentClient c@AgentClient {acThread, agentEnv = Env {store}} = do t_ <- atomically (swapTVar acThread Nothing) $>>= (liftIO . deRefWeak) disconnectAgentClient c mapM_ killThread t_ - liftIO $ closeSQLiteStore store + liftIO $ closeStore store resumeAgentClient :: AgentClient -> IO () resumeAgentClient c = atomically $ writeTVar (active c) True @@ -2168,7 +2170,7 @@ execAgentStoreSQL :: AgentClient -> Text -> AE [Text] execAgentStoreSQL c sql = withAgentEnv c $ withStore' c (`execSQL` sql) getAgentMigrations :: AgentClient -> AE [UpMigration] -getAgentMigrations c = withAgentEnv c $ map upMigration <$> withStore' c (Migrations.getCurrent . DB.conn) +getAgentMigrations c = withAgentEnv c $ map upMigration <$> withStore' c Migrations.getCurrent debugAgentLocks :: AgentClient -> IO AgentLocks debugAgentLocks AgentClient {connLocks = cs, invLocks = is, deleteLock = d} = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index ef55870fe..5d5e6fcaf 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -1,5 +1,6 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} @@ -205,7 +206,6 @@ import Data.Text.Encoding import Data.Time (UTCTime, addUTCTime, defaultTimeLocale, formatTime, getCurrentTime) import Data.Time.Clock.System (getSystemTime) import Data.Word (Word16) -import qualified Database.SQLite.Simple as SQL import Network.Socket (HostName) import Simplex.FileTransfer.Client (XFTPChunkSpec (..), XFTPClient, XFTPClientConfig (..), XFTPClientError) import qualified Simplex.FileTransfer.Client as X @@ -221,8 +221,8 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Stats import Simplex.Messaging.Agent.Store -import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore (..), withTransaction) -import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB +import Simplex.Messaging.Agent.Store.Common (DBStore, withTransaction) +import qualified Simplex.Messaging.Agent.Store.DB as DB import Simplex.Messaging.Agent.TRcvQueues (TRcvQueues (getRcvQueues)) import qualified Simplex.Messaging.Agent.TRcvQueues as RQ import Simplex.Messaging.Client @@ -282,6 +282,9 @@ import UnliftIO.Concurrent (forkIO, mkWeakThreadId) import UnliftIO.Directory (doesFileExist, getTemporaryDirectory, removeFile) import qualified UnliftIO.Exception as E import UnliftIO.STM +#if !defined(dbPostgres) +import qualified Database.SQLite.Simple as SQL +#endif type ClientVar msg = SessionVar (Either (AgentErrorType, Maybe UTCTime) (Client msg)) @@ -555,7 +558,7 @@ slowNetworkConfig cfg@NetworkConfig {tcpConnectTimeout, tcpTimeout, tcpTimeoutPe slow :: Integral a => a -> a slow t = (t * 3) `div` 2 -agentClientStore :: AgentClient -> SQLiteStore +agentClientStore :: AgentClient -> DBStore agentClientStore AgentClient {agentEnv = Env {store}} = store {-# INLINE agentClientStore #-} @@ -1649,7 +1652,7 @@ disableQueuesNtfs = sendTSessionBatches "NDEL" snd disableQueues_ sendAck :: AgentClient -> RcvQueue -> MsgId -> AM () sendAck c rq@RcvQueue {rcvId, rcvPrivateKey} msgId = withSMPClient c rq ("ACK:" <> logSecret' msgId) $ \smp -> - ackSMPMessage smp rcvPrivateKey rcvId msgId + ackSMPMessage smp rcvPrivateKey rcvId msgId hasGetLock :: AgentClient -> RcvQueue -> IO Bool hasGetLock c RcvQueue {server, rcvId} = @@ -1989,6 +1992,13 @@ withStore c action = do withExceptT storeError . ExceptT . liftIO . agentOperationBracket c AODatabase (\_ -> pure ()) $ withTransaction st action `E.catches` handleDBErrors where +#if defined(dbPostgres) + -- TODO [postgres] postgres specific error handling + handleDBErrors :: [E.Handler IO (Either StoreError a)] + handleDBErrors = + [ E.Handler $ \(E.SomeException e) -> pure . Left $ SEInternal $ bshow e + ] +#else handleDBErrors :: [E.Handler IO (Either StoreError a)] handleDBErrors = [ E.Handler $ \(e :: SQL.SQLError) -> @@ -1997,6 +2007,7 @@ withStore c action = do in pure . Left . (if busy then SEDatabaseBusy else SEInternal) $ bshow se, E.Handler $ \(E.SomeException e) -> pure . Left $ SEInternal $ bshow e ] +#endif withStoreBatch :: Traversable t => AgentClient -> (DB.Connection -> t (IO (Either AgentErrorType a))) -> AM' (t (Either AgentErrorType a)) withStoreBatch c actions = do @@ -2044,7 +2055,7 @@ pickServer = \case getNextServer :: (ProtocolTypeI p, UserProtocol p) => AgentClient -> - UserId -> + UserId -> (UserServers p -> NonEmpty (Maybe OperatorId, ProtoServerWithAuth p)) -> [ProtocolServer p] -> AM (ProtoServerWithAuth p) @@ -2097,7 +2108,7 @@ withNextSrv :: UserId -> (UserServers p -> NonEmpty (Maybe OperatorId, ProtoServerWithAuth p)) -> TVar (Set TransportHost) -> - [ProtocolServer p] -> + [ProtocolServer p] -> (ProtoServerWithAuth p -> AM a) -> AM a withNextSrv c userId srvsSel triedHosts usedSrvs action = do diff --git a/src/Simplex/Messaging/Agent/Env/SQLite.hs b/src/Simplex/Messaging/Agent/Env/SQLite.hs index a78fb428e..80a307efa 100644 --- a/src/Simplex/Messaging/Agent/Env/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Env/SQLite.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -52,7 +53,6 @@ import Control.Monad.Reader import Crypto.Random import Data.Aeson (FromJSON (..), ToJSON (..)) import qualified Data.Aeson.TH as JQ -import Data.ByteArray (ScrubbedBytes) import Data.Int (Int64) import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L @@ -68,8 +68,9 @@ import Numeric.Natural import Simplex.FileTransfer.Client (XFTPClientConfig (..), defaultXFTPClientConfig) import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval -import Simplex.Messaging.Agent.Store.SQLite -import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations +import Simplex.Messaging.Agent.Store (createStore) +import Simplex.Messaging.Agent.Store.Common (DBStore) +import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..)) import Simplex.Messaging.Client import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (VersionRangeE2E, supportedE2EEncryptVRange) @@ -86,6 +87,11 @@ import Simplex.Messaging.Util (allFinally, catchAllErrors, catchAllErrors', tryA import System.Mem.Weak (Weak) import System.Random (StdGen, newStdGen) import UnliftIO.STM +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (ConnectInfo (..)) +#else +import Data.ByteArray (ScrubbedBytes) +#endif type AM' a = ReaderT Env IO a @@ -254,7 +260,7 @@ defaultAgentConfig = data Env = Env { config :: AgentConfig, - store :: SQLiteStore, + store :: DBStore, random :: TVar ChaChaDRG, randomServer :: TVar StdGen, ntfSupervisor :: NtfSupervisor, @@ -262,7 +268,7 @@ data Env = Env multicastSubscribers :: TMVar Int } -newSMPAgentEnv :: AgentConfig -> SQLiteStore -> IO Env +newSMPAgentEnv :: AgentConfig -> DBStore -> IO Env newSMPAgentEnv config store = do random <- C.newRandom randomServer <- newTVarIO =<< liftIO newStdGen @@ -271,8 +277,13 @@ newSMPAgentEnv config store = do multicastSubscribers <- newTMVarIO 0 pure Env {config, store, random, randomServer, ntfSupervisor, xftpAgent, multicastSubscribers} -createAgentStore :: FilePath -> ScrubbedBytes -> Bool -> MigrationConfirmation -> IO (Either MigrationError SQLiteStore) -createAgentStore dbFilePath dbKey keepKey = createSQLiteStore dbFilePath dbKey keepKey Migrations.app +#if defined(dbPostgres) +createAgentStore :: ConnectInfo -> String -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createAgentStore = createStore +#else +createAgentStore :: FilePath -> ScrubbedBytes -> Bool -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createAgentStore = createStore +#endif data NtfSupervisor = NtfSupervisor { ntfTkn :: TVar (Maybe NtfToken), diff --git a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs index f000de8be..3da1b74b6 100644 --- a/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs +++ b/src/Simplex/Messaging/Agent/NtfSubSupervisor.hs @@ -43,8 +43,8 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Stats import Simplex.Messaging.Agent.Store -import Simplex.Messaging.Agent.Store.SQLite -import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB +import Simplex.Messaging.Agent.Store.AgentStore +import qualified Simplex.Messaging.Agent.Store.DB as DB import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Types diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 08e8add24..b87f87f18 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -167,13 +168,12 @@ import Data.Time.Clock.System (SystemTime) import Data.Type.Equality import Data.Typeable () import Data.Word (Word16, Word32) -import Database.SQLite.Simple.FromField -import Database.SQLite.Simple.ToField import Simplex.FileTransfer.Description import Simplex.FileTransfer.Protocol (FileParty (..)) import Simplex.FileTransfer.Transport (XFTPErrorType) import Simplex.FileTransfer.Types (FileErrorType) import Simplex.Messaging.Agent.QueryString +import Simplex.Messaging.Agent.Store.DB (Binary (..)) import Simplex.Messaging.Client (ProxyClientError) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet @@ -224,6 +224,13 @@ import Simplex.Messaging.Version import Simplex.Messaging.Version.Internal import Simplex.RemoteControl.Types import UnliftIO.Exception (Exception) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif -- SMP agent protocol version history: -- 1 - binary protocol encoding (1/1/2022) @@ -644,7 +651,7 @@ instance ToJSON NotificationsMode where instance FromJSON NotificationsMode where parseJSON = strParseJSON "NotificationsMode" -instance ToField NotificationsMode where toField = toField . strEncode +instance ToField NotificationsMode where toField = toField . Binary . strEncode instance FromField NotificationsMode where fromField = blobFieldDecoder $ parseAll strP diff --git a/src/Simplex/Messaging/Agent/Stats.hs b/src/Simplex/Messaging/Agent/Stats.hs index d4663bfb1..1d174622e 100644 --- a/src/Simplex/Messaging/Agent/Stats.hs +++ b/src/Simplex/Messaging/Agent/Stats.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE NamedFieldPuns #-} @@ -10,13 +11,18 @@ import qualified Data.Aeson.TH as J import Data.Int (Int64) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.Protocol (UserId) import Simplex.Messaging.Parsers (defaultJSON, fromTextField_) -import Simplex.Messaging.Protocol (SMPServer, XFTPServer, NtfServer) +import Simplex.Messaging.Protocol (NtfServer, SMPServer, XFTPServer) import Simplex.Messaging.Util (decodeJSON, encodeJSON) import UnliftIO.STM +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif data AgentSMPServerStats = AgentSMPServerStats { sentDirect :: TVar Int, -- successfully sent messages diff --git a/src/Simplex/Messaging/Agent/Store.hs b/src/Simplex/Messaging/Agent/Store.hs index 0e2a7bbe9..c199e480b 100644 --- a/src/Simplex/Messaging/Agent/Store.hs +++ b/src/Simplex/Messaging/Agent/Store.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} @@ -25,10 +26,15 @@ import Data.List (find) import Data.List.NonEmpty (NonEmpty) import qualified Data.List.NonEmpty as L import Data.Maybe (isJust) +import Data.Text (Text) import Data.Time (UTCTime) import Data.Type.Equality import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval (RI2State) +import Simplex.Messaging.Agent.Store.Common +import qualified Simplex.Messaging.Agent.Store.DB as DB +import qualified Simplex.Messaging.Agent.Store.Migrations as Migrations +import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (PQEncryption, PQSupport, RatchetX448) import Simplex.Messaging.Encoding.String @@ -42,12 +48,33 @@ import Simplex.Messaging.Protocol RcvDhSecret, RcvNtfDhSecret, RcvPrivateAuthKey, + SenderCanSecure, SndPrivateAuthKey, SndPublicAuthKey, - SenderCanSecure, VersionSMPC, ) import qualified Simplex.Messaging.Protocol as SMP +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (ConnectInfo (..)) +import qualified Simplex.Messaging.Agent.Store.Postgres as StoreFunctions +#else +import Data.ByteArray (ScrubbedBytes) +import qualified Simplex.Messaging.Agent.Store.SQLite as StoreFunctions +#endif + +#if defined(dbPostgres) +createStore :: ConnectInfo -> String -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createStore connectInfo schema = StoreFunctions.createDBStore connectInfo schema Migrations.app +#else +createStore :: FilePath -> ScrubbedBytes -> Bool -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createStore dbFilePath dbKey keepKey = StoreFunctions.createDBStore dbFilePath dbKey keepKey Migrations.app +#endif + +closeStore :: DBStore -> IO () +closeStore = StoreFunctions.closeDBStore + +execSQL :: DB.Connection -> Text -> IO [Text] +execSQL = StoreFunctions.execSQL -- * Queue types diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs new file mode 100644 index 000000000..5dcda9c79 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -0,0 +1,3032 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module Simplex.Messaging.Agent.Store.AgentStore + ( -- * Users + createUserRecord, + deleteUserRecord, + setUserDeleted, + deleteUserWithoutConns, + deleteUsersWithoutConns, + checkUser, + + -- * Queues and connections + createNewConn, + updateNewConnRcv, + updateNewConnSnd, + createSndConn, + getConn, + getDeletedConn, + getConns, + getDeletedConns, + getConnData, + setConnDeleted, + setConnUserId, + setConnAgentVersion, + setConnPQSupport, + getDeletedConnIds, + getDeletedWaitingDeliveryConnIds, + setConnRatchetSync, + addProcessedRatchetKeyHash, + checkRatchetKeyHashExists, + deleteRatchetKeyHashesExpired, + getRcvConn, + getRcvQueueById, + getSndQueueById, + deleteConn, + upgradeRcvConnToDuplex, + upgradeSndConnToDuplex, + addConnRcvQueue, + addConnSndQueue, + setRcvQueueStatus, + setRcvSwitchStatus, + setRcvQueueDeleted, + setRcvQueueConfirmedE2E, + setSndQueueStatus, + setSndSwitchStatus, + setRcvQueuePrimary, + setSndQueuePrimary, + deleteConnRcvQueue, + incRcvDeleteErrors, + deleteConnSndQueue, + getPrimaryRcvQueue, + getRcvQueue, + getDeletedRcvQueue, + setRcvQueueNtfCreds, + -- Confirmations + createConfirmation, + acceptConfirmation, + getAcceptedConfirmation, + removeConfirmations, + -- Invitations - sent via Contact connections + createInvitation, + getInvitation, + acceptInvitation, + unacceptInvitation, + deleteInvitation, + -- Messages + updateRcvIds, + createRcvMsg, + updateRcvMsgHash, + updateSndIds, + createSndMsg, + updateSndMsgHash, + createSndMsgDelivery, + getSndMsgViaRcpt, + updateSndMsgRcpt, + getPendingQueueMsg, + getConnectionsForDelivery, + updatePendingMsgRIState, + deletePendingMsgs, + getExpiredSndMessages, + setMsgUserAck, + getRcvMsg, + getLastMsg, + checkRcvMsgHashExists, + getRcvMsgBrokerTs, + deleteMsg, + deleteDeliveredSndMsg, + deleteSndMsgDelivery, + deleteRcvMsgHashesExpired, + deleteSndMsgsExpired, + -- Double ratchet persistence + createRatchetX3dhKeys, + getRatchetX3dhKeys, + setRatchetX3dhKeys, + createSndRatchet, + getSndRatchet, + createRatchet, + deleteRatchet, + getRatchet, + getSkippedMsgKeys, + updateRatchet, + -- Async commands + createCommand, + getPendingCommandServers, + getPendingServerCommand, + updateCommandServer, + deleteCommand, + -- Notification device token persistence + createNtfToken, + getSavedNtfToken, + updateNtfTokenRegistration, + updateDeviceToken, + updateNtfMode, + updateNtfToken, + removeNtfToken, + addNtfTokenToDelete, + deleteExpiredNtfTokensToDelete, + NtfTokenToDelete, + getNextNtfTokenToDelete, + markNtfTokenToDeleteFailed_, -- exported for tests + getPendingDelTknServers, + deleteNtfTokenToDelete, + -- Notification subscription persistence + NtfSupervisorSub, + getNtfSubscription, + createNtfSubscription, + supervisorUpdateNtfSub, + supervisorUpdateNtfAction, + updateNtfSubscription, + setNullNtfSubscriptionAction, + deleteNtfSubscription, + deleteNtfSubscription', + getNextNtfSubNTFActions, + markNtfSubActionNtfFailed_, -- exported for tests + getNextNtfSubSMPActions, + markNtfSubActionSMPFailed_, -- exported for tests + getActiveNtfToken, + getNtfRcvQueue, + setConnectionNtfs, + + -- * File transfer + + -- Rcv files + createRcvFile, + createRcvFileRedirect, + getRcvFile, + getRcvFileByEntityId, + getRcvFileRedirects, + updateRcvChunkReplicaDelay, + updateRcvFileChunkReceived, + updateRcvFileStatus, + updateRcvFileError, + updateRcvFileComplete, + updateRcvFileRedirect, + updateRcvFileNoTmpPath, + updateRcvFileDeleted, + deleteRcvFile', + getNextRcvChunkToDownload, + getNextRcvFileToDecrypt, + getPendingRcvFilesServers, + getCleanupRcvFilesTmpPaths, + getCleanupRcvFilesDeleted, + getRcvFilesExpired, + -- Snd files + createSndFile, + getSndFile, + getSndFileByEntityId, + getNextSndFileToPrepare, + updateSndFileError, + updateSndFileStatus, + updateSndFileEncrypted, + updateSndFileComplete, + updateSndFileNoPrefixPath, + updateSndFileDeleted, + deleteSndFile', + getSndFileDeleted, + createSndFileReplica, + createSndFileReplica_, -- exported for tests + getNextSndChunkToUpload, + updateSndChunkReplicaDelay, + addSndChunkReplicaRecipients, + updateSndChunkReplicaStatus, + getPendingSndFilesServers, + getCleanupSndFilesPrefixPaths, + getCleanupSndFilesDeleted, + getSndFilesExpired, + createDeletedSndChunkReplica, + getNextDeletedSndChunkReplica, + updateDeletedSndChunkReplicaDelay, + deleteDeletedSndChunkReplica, + getPendingDelFilesServers, + deleteDeletedSndChunkReplicasExpired, + -- Stats + updateServersStats, + getServersStats, + resetServersStats, + + -- * utilities + withConnection, + withTransaction, + withTransactionPriority, + firstRow, + firstRow', + maybeFirstRow, + fromOnlyBI, + ) +where + +import Control.Logger.Simple +import Control.Monad +import Control.Monad.IO.Class +import Control.Monad.Trans.Except +import Crypto.Random (ChaChaDRG) +import Data.Bifunctor (first, second) +import Data.ByteString (ByteString) +import qualified Data.ByteString.Base64.URL as U +import qualified Data.ByteString.Char8 as B +import Data.Functor (($>)) +import Data.Int (Int64) +import Data.List (foldl', sortBy) +import Data.List.NonEmpty (NonEmpty (..)) +import qualified Data.List.NonEmpty as L +import qualified Data.Map.Strict as M +import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, listToMaybe) +import Data.Ord (Down (..)) +import Data.Text.Encoding (decodeLatin1, encodeUtf8) +import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) +import Data.Word (Word32) +import Network.Socket (ServiceName) +import Simplex.FileTransfer.Client (XFTPChunkSpec (..)) +import Simplex.FileTransfer.Description +import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..)) +import Simplex.FileTransfer.Types +import Simplex.Messaging.Agent.Protocol +import Simplex.Messaging.Agent.RetryInterval (RI2State (..)) +import Simplex.Messaging.Agent.Stats +import Simplex.Messaging.Agent.Store +import Simplex.Messaging.Agent.Store.Common +import qualified Simplex.Messaging.Agent.Store.DB as DB +import Simplex.Messaging.Agent.Store.DB (Binary (..), BoolInt (..)) +import qualified Simplex.Messaging.Crypto as C +import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..)) +import Simplex.Messaging.Crypto.Ratchet (PQEncryption (..), PQSupport (..), RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys) +import qualified Simplex.Messaging.Crypto.Ratchet as CR +import Simplex.Messaging.Encoding +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..)) +import Simplex.Messaging.Notifications.Types +import Simplex.Messaging.Parsers (blobFieldParser, fromTextField_) +import Simplex.Messaging.Protocol +import qualified Simplex.Messaging.Protocol as SMP +import Simplex.Messaging.Transport.Client (TransportHost) +import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, tshow, ($>>=), (<$$>)) +import Simplex.Messaging.Version.Internal +import qualified UnliftIO.Exception as E +import UnliftIO.STM +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (Only (..), Query, SqlError, (:.) (..)) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.Errors (constraintViolation) +import Database.PostgreSQL.Simple.SqlQQ (sql) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple (FromRow (..), Only (..), Query (..), SQLError, ToRow (..), field, (:.) (..)) +import qualified Database.SQLite.Simple as SQL +import Database.SQLite.Simple.FromField +import Database.SQLite.Simple.QQ (sql) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif + +checkConstraint :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a) +checkConstraint err action = action `E.catch` (pure . Left . handleSQLError err) + +#if defined(dbPostgres) +handleSQLError :: StoreError -> SqlError -> StoreError +handleSQLError err e = case constraintViolation e of + Just _ -> err + Nothing -> SEInternal $ bshow e +#else +handleSQLError :: StoreError -> SQLError -> StoreError +handleSQLError err e + | SQL.sqlError e == SQL.ErrorConstraint = err + | otherwise = SEInternal $ bshow e +#endif + +createUserRecord :: DB.Connection -> IO UserId +createUserRecord db = do + DB.execute_ db "INSERT INTO users DEFAULT VALUES" + insertedRowId db + +checkUser :: DB.Connection -> UserId -> IO (Either StoreError ()) +checkUser db userId = + firstRow (\(_ :: Only Int64) -> ()) SEUserNotFound $ + DB.query db "SELECT user_id FROM users WHERE user_id = ? AND deleted = ?" (userId, BI False) + +deleteUserRecord :: DB.Connection -> UserId -> IO (Either StoreError ()) +deleteUserRecord db userId = runExceptT $ do + ExceptT $ checkUser db userId + liftIO $ DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId) + +setUserDeleted :: DB.Connection -> UserId -> IO (Either StoreError [ConnId]) +setUserDeleted db userId = runExceptT $ do + ExceptT $ checkUser db userId + liftIO $ do + DB.execute db "UPDATE users SET deleted = ? WHERE user_id = ?" (BI True, userId) + map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE user_id = ?" (Only userId) + +deleteUserWithoutConns :: DB.Connection -> UserId -> IO Bool +deleteUserWithoutConns db userId = do + userId_ :: Maybe Int64 <- + maybeFirstRow fromOnly $ + DB.query + db + [sql| + SELECT user_id FROM users u + WHERE u.user_id = ? + AND u.deleted = ? + AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id) + |] + (userId, BI True) + case userId_ of + Just _ -> DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId) $> True + _ -> pure False + +deleteUsersWithoutConns :: DB.Connection -> IO [Int64] +deleteUsersWithoutConns db = do + userIds <- + map fromOnly + <$> DB.query + db + [sql| + SELECT user_id FROM users u + WHERE u.deleted = ? + AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id) + |] + (Only (BI True)) + forM_ userIds $ DB.execute db "DELETE FROM users WHERE user_id = ?" . Only + pure userIds + +createConn_ :: + TVar ChaChaDRG -> + ConnData -> + (ConnId -> IO a) -> + IO (Either StoreError (ConnId, a)) +createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of + ConnData {connId = ""} -> createWithRandomId' gVar create + ConnData {connId} -> Right . (connId,) <$> create connId + +createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId) +createNewConn db gVar cData cMode = do + fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode) + +updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) +updateNewConnRcv db connId rq = + getConn db connId $>>= \case + (SomeConn _ NewConnection {}) -> updateConn + (SomeConn _ RcvConnection {}) -> updateConn -- to allow retries + (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + where + updateConn :: IO (Either StoreError RcvQueue) + updateConn = Right <$> addConnRcvQueue_ db connId rq + +updateNewConnSnd :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) +updateNewConnSnd db connId sq = + getConn db connId $>>= \case + (SomeConn _ NewConnection {}) -> updateConn + (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + where + updateConn :: IO (Either StoreError SndQueue) + updateConn = Right <$> addConnSndQueue_ db connId sq + +createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewSndQueue -> IO (Either StoreError (ConnId, SndQueue)) +createSndConn db gVar cData q@SndQueue {server} = + -- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_ + ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $ + createConn_ gVar cData $ \connId -> do + serverKeyHash_ <- createServer_ db server + createConnRecord db connId cData SCMInvitation + insertSndQueue_ db connId q serverKeyHash_ + +createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO () +createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqSupport} cMode = + DB.execute + db + [sql| + INSERT INTO connections + (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_support, duplex_handshake) VALUES (?,?,?,?,?,?,?) + |] + (userId, connId, cMode, connAgentVersion, BI enableNtfs, pqSupport, BI True) + +checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool +checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do + fromMaybe False + <$> maybeFirstRow + fromOnly + ( DB.query + db + "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1" + (host server, port server, sndId, New) + ) + +getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn)) +getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do + rq@RcvQueue {connId} <- + ExceptT . firstRow toRcvQueue SEConnNotFound $ + DB.query db (rcvQueueQuery <> " WHERE q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (host, port, rcvId) + (rq,) <$> ExceptT (getConn db connId) + +-- | Deletes connection, optionally checking for pending snd message deliveries; returns connection id if it was deleted +deleteConn :: DB.Connection -> Maybe NominalDiffTime -> ConnId -> IO (Maybe ConnId) +deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of + Nothing -> delete + Just timeout -> + ifM + checkNoPendingDeliveries_ + delete + ( ifM + (checkWaitDeliveryTimeout_ timeout) + delete + (pure Nothing) + ) + where + delete = DB.execute db "DELETE FROM connections WHERE conn_id = ?" (Only connId) $> Just connId + checkNoPendingDeliveries_ = do + r :: (Maybe Int64) <- + maybeFirstRow fromOnly $ + DB.query db "SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND failed = 0 LIMIT 1" (Only connId) + pure $ isNothing r + checkWaitDeliveryTimeout_ timeout = do + cutoffTs <- addUTCTime (-timeout) <$> getCurrentTime + r :: (Maybe Int64) <- + maybeFirstRow fromOnly $ + DB.query db "SELECT 1 FROM connections WHERE conn_id = ? AND deleted_at_wait_delivery < ? LIMIT 1" (connId, cutoffTs) + pure $ isJust r + +upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) +upgradeRcvConnToDuplex db connId sq = + getConn db connId $>>= \case + (SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq + (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + +upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) +upgradeSndConnToDuplex db connId rq = + getConn db connId >>= \case + Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound + +addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) +addConnRcvQueue db connId rq = + getConn db connId >>= \case + Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound + +addConnRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> IO RcvQueue +addConnRcvQueue_ db connId rq@RcvQueue {server} = do + serverKeyHash_ <- createServer_ db server + insertRcvQueue_ db connId rq serverKeyHash_ + +addConnSndQueue :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) +addConnSndQueue db connId sq = + getConn db connId >>= \case + Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnSndQueue_ db connId sq + Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c + _ -> pure $ Left SEConnNotFound + +addConnSndQueue_ :: DB.Connection -> ConnId -> NewSndQueue -> IO SndQueue +addConnSndQueue_ db connId sq@SndQueue {server} = do + serverKeyHash_ <- createServer_ db server + insertSndQueue_ db connId sq serverKeyHash_ + +setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO () +setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status = + -- ? return error if queue does not exist? + DB.execute + db + [sql| + UPDATE rcv_queues + SET status = ? + WHERE host = ? AND port = ? AND rcv_id = ? + |] + (status, host, port, rcvId) + +setRcvSwitchStatus :: DB.Connection -> RcvQueue -> Maybe RcvSwitchStatus -> IO RcvQueue +setRcvSwitchStatus db rq@RcvQueue {rcvId, server = ProtocolServer {host, port}} rcvSwchStatus = do + DB.execute + db + [sql| + UPDATE rcv_queues + SET switch_status = ? + WHERE host = ? AND port = ? AND rcv_id = ? + |] + (rcvSwchStatus, host, port, rcvId) + pure rq {rcvSwchStatus} + +setRcvQueueDeleted :: DB.Connection -> RcvQueue -> IO () +setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = do + DB.execute + db + [sql| + UPDATE rcv_queues + SET deleted = 1 + WHERE host = ? AND port = ? AND rcv_id = ? + |] + (host, port, rcvId) + +setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO () +setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion = + DB.execute + db + [sql| + UPDATE rcv_queues + SET e2e_dh_secret = ?, + status = ?, + smp_client_version = ? + WHERE host = ? AND port = ? AND rcv_id = ? + |] + (e2eDhSecret, Confirmed, smpClientVersion, host, port, rcvId) + +setSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO () +setSndQueueStatus db SndQueue {sndId, server = ProtocolServer {host, port}} status = + -- ? return error if queue does not exist? + DB.execute + db + [sql| + UPDATE snd_queues + SET status = ? + WHERE host = ? AND port = ? AND snd_id = ? + |] + (status, host, port, sndId) + +setSndSwitchStatus :: DB.Connection -> SndQueue -> Maybe SndSwitchStatus -> IO SndQueue +setSndSwitchStatus db sq@SndQueue {sndId, server = ProtocolServer {host, port}} sndSwchStatus = do + DB.execute + db + [sql| + UPDATE snd_queues + SET switch_status = ? + WHERE host = ? AND port = ? AND snd_id = ? + |] + (sndSwchStatus, host, port, sndId) + pure sq {sndSwchStatus} + +setRcvQueuePrimary :: DB.Connection -> ConnId -> RcvQueue -> IO () +setRcvQueuePrimary db connId RcvQueue {dbQueueId} = do + DB.execute db "UPDATE rcv_queues SET rcv_primary = ? WHERE conn_id = ?" (BI False, connId) + DB.execute + db + "UPDATE rcv_queues SET rcv_primary = ?, replace_rcv_queue_id = ? WHERE conn_id = ? AND rcv_queue_id = ?" + (BI True, Nothing :: Maybe Int64, connId, dbQueueId) + +setSndQueuePrimary :: DB.Connection -> ConnId -> SndQueue -> IO () +setSndQueuePrimary db connId SndQueue {dbQueueId} = do + DB.execute db "UPDATE snd_queues SET snd_primary = ? WHERE conn_id = ?" (BI False, connId) + DB.execute + db + "UPDATE snd_queues SET snd_primary = ?, replace_snd_queue_id = ? WHERE conn_id = ? AND snd_queue_id = ?" + (BI True, Nothing :: Maybe Int64, connId, dbQueueId) + +incRcvDeleteErrors :: DB.Connection -> RcvQueue -> IO () +incRcvDeleteErrors db RcvQueue {connId, dbQueueId} = + DB.execute db "UPDATE rcv_queues SET delete_errors = delete_errors + 1 WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId) + +deleteConnRcvQueue :: DB.Connection -> RcvQueue -> IO () +deleteConnRcvQueue db RcvQueue {connId, dbQueueId} = + DB.execute db "DELETE FROM rcv_queues WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId) + +deleteConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO () +deleteConnSndQueue db connId SndQueue {dbQueueId} = do + DB.execute db "DELETE FROM snd_queues WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) + DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) + +getPrimaryRcvQueue :: DB.Connection -> ConnId -> IO (Either StoreError RcvQueue) +getPrimaryRcvQueue db connId = + maybe (Left SEConnNotFound) (Right . L.head) <$> getRcvQueuesByConnId_ db connId + +getRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue) +getRcvQueue db connId (SMPServer host port _) rcvId = + firstRow toRcvQueue SEConnNotFound $ + DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (connId, host, port, rcvId) + +getDeletedRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue) +getDeletedRcvQueue db connId (SMPServer host port _) rcvId = + firstRow toRcvQueue SEConnNotFound $ + DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 1") (connId, host, port, rcvId) + +setRcvQueueNtfCreds :: DB.Connection -> ConnId -> Maybe ClientNtfCreds -> IO () +setRcvQueueNtfCreds db connId clientNtfCreds = + DB.execute + db + [sql| + UPDATE rcv_queues + SET ntf_public_key = ?, ntf_private_key = ?, ntf_id = ?, rcv_ntf_dh_secret = ? + WHERE conn_id = ? + |] + (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_, connId) + where + (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) = case clientNtfCreds of + Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) + Nothing -> (Nothing, Nothing, Nothing, Nothing) + +type SMPConfirmationRow = (Maybe SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe VersionSMPC) + +smpConfirmation :: SMPConfirmationRow -> SMPConfirmation +smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersion_) = + SMPConfirmation + { senderKey, + e2ePubKey, + connInfo, + smpReplyQueues = fromMaybe [] smpReplyQueues_, + smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ + } + +createConfirmation :: DB.Connection -> TVar ChaChaDRG -> NewConfirmation -> IO (Either StoreError ConfirmationId) +createConfirmation db gVar NewConfirmation {connId, senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo, smpReplyQueues, smpClientVersion}, ratchetState} = + createWithRandomId gVar $ \confirmationId -> + DB.execute + db + [sql| + INSERT INTO conn_confirmations + (confirmation_id, conn_id, sender_key, e2e_snd_pub_key, ratchet_state, sender_conn_info, smp_reply_queues, smp_client_version, accepted) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0); + |] + (Binary confirmationId, connId, senderKey, e2ePubKey, ratchetState, Binary connInfo, smpReplyQueues, smpClientVersion) + +acceptConfirmation :: DB.Connection -> ConfirmationId -> ConnInfo -> IO (Either StoreError AcceptedConfirmation) +acceptConfirmation db confirmationId ownConnInfo = do + DB.execute + db + [sql| + UPDATE conn_confirmations + SET accepted = 1, + own_conn_info = ? + WHERE confirmation_id = ? + |] + (Binary ownConnInfo, Binary confirmationId) + firstRow confirmation SEConfirmationNotFound $ + DB.query + db + [sql| + SELECT conn_id, ratchet_state, sender_key, e2e_snd_pub_key, sender_conn_info, smp_reply_queues, smp_client_version + FROM conn_confirmations + WHERE confirmation_id = ?; + |] + (Only (Binary confirmationId)) + where + confirmation ((connId, ratchetState) :. confRow) = + AcceptedConfirmation + { confirmationId, + connId, + senderConf = smpConfirmation confRow, + ratchetState, + ownConnInfo + } + +getAcceptedConfirmation :: DB.Connection -> ConnId -> IO (Either StoreError AcceptedConfirmation) +getAcceptedConfirmation db connId = + firstRow confirmation SEConfirmationNotFound $ + DB.query + db + [sql| + SELECT confirmation_id, ratchet_state, own_conn_info, sender_key, e2e_snd_pub_key, sender_conn_info, smp_reply_queues, smp_client_version + FROM conn_confirmations + WHERE conn_id = ? AND accepted = 1; + |] + (Only connId) + where + confirmation ((confirmationId, ratchetState, ownConnInfo) :. confRow) = + AcceptedConfirmation + { confirmationId, + connId, + senderConf = smpConfirmation confRow, + ratchetState, + ownConnInfo + } + +removeConfirmations :: DB.Connection -> ConnId -> IO () +removeConfirmations db connId = + DB.execute + db + [sql| + DELETE FROM conn_confirmations + WHERE conn_id = ? + |] + (Only connId) + +createInvitation :: DB.Connection -> TVar ChaChaDRG -> NewInvitation -> IO (Either StoreError InvitationId) +createInvitation db gVar NewInvitation {contactConnId, connReq, recipientConnInfo} = + createWithRandomId gVar $ \invitationId -> + DB.execute + db + [sql| + INSERT INTO conn_invitations + (invitation_id, contact_conn_id, cr_invitation, recipient_conn_info, accepted) VALUES (?, ?, ?, ?, 0); + |] + (Binary invitationId, contactConnId, connReq, Binary recipientConnInfo) + +getInvitation :: DB.Connection -> String -> InvitationId -> IO (Either StoreError Invitation) +getInvitation db cxt invitationId = + firstRow invitation (SEInvitationNotFound cxt invitationId) $ + DB.query + db + [sql| + SELECT contact_conn_id, cr_invitation, recipient_conn_info, own_conn_info, accepted + FROM conn_invitations + WHERE invitation_id = ? + AND accepted = 0 + |] + (Only (Binary invitationId)) + where + invitation (contactConnId, connReq, recipientConnInfo, ownConnInfo, BI accepted) = + Invitation {invitationId, contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted} + +acceptInvitation :: DB.Connection -> InvitationId -> ConnInfo -> IO () +acceptInvitation db invitationId ownConnInfo = + DB.execute + db + [sql| + UPDATE conn_invitations + SET accepted = 1, + own_conn_info = ? + WHERE invitation_id = ? + |] + (Binary ownConnInfo, Binary invitationId) + +unacceptInvitation :: DB.Connection -> InvitationId -> IO () +unacceptInvitation db invitationId = + DB.execute db "UPDATE conn_invitations SET accepted = 0, own_conn_info = NULL WHERE invitation_id = ?" (Only (Binary invitationId)) + +deleteInvitation :: DB.Connection -> ConnId -> InvitationId -> IO (Either StoreError ()) +deleteInvitation db contactConnId invId = + getConn db contactConnId $>>= \case + SomeConn SCContact _ -> + Right <$> DB.execute db "DELETE FROM conn_invitations WHERE contact_conn_id = ? AND invitation_id = ?" (contactConnId, Binary invId) + _ -> pure $ Left SEConnNotFound + +updateRcvIds :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) +updateRcvIds db connId = do + (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ db connId + let internalId = InternalId $ unId lastInternalId + 1 + internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 + updateLastIdsRcv_ db connId internalId internalRcvId + pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash) + +createRcvMsg :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO () +createRcvMsg db connId rq@RcvQueue {dbQueueId} rcvMsgData@RcvMsgData {msgMeta = MsgMeta {sndMsgId, broker = (_, brokerTs)}, internalRcvId, internalHash} = do + insertRcvMsgBase_ db connId rcvMsgData + insertRcvMsgDetails_ db connId rq rcvMsgData + updateRcvMsgHash db connId sndMsgId internalRcvId internalHash + DB.execute db "UPDATE rcv_queues SET last_broker_ts = ? WHERE conn_id = ? AND rcv_queue_id = ?" (brokerTs, connId, dbQueueId) + +updateSndIds :: DB.Connection -> ConnId -> IO (Either StoreError (InternalId, InternalSndId, PrevSndMsgHash)) +updateSndIds db connId = runExceptT $ do + (lastInternalId, lastInternalSndId, prevSndHash) <- ExceptT $ retrieveLastIdsAndHashSnd_ db connId + let internalId = InternalId $ unId lastInternalId + 1 + internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 + liftIO $ updateLastIdsSnd_ db connId internalId internalSndId + pure (internalId, internalSndId, prevSndHash) + +createSndMsg :: DB.Connection -> ConnId -> SndMsgData -> IO () +createSndMsg db connId sndMsgData@SndMsgData {internalSndId, internalHash} = do + insertSndMsgBase_ db connId sndMsgData + insertSndMsgDetails_ db connId sndMsgData + updateSndMsgHash db connId internalSndId internalHash + +createSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO () +createSndMsgDelivery db connId SndQueue {dbQueueId} msgId = + DB.execute db "INSERT INTO snd_message_deliveries (conn_id, snd_queue_id, internal_id) VALUES (?, ?, ?)" (connId, dbQueueId, msgId) + +getSndMsgViaRcpt :: DB.Connection -> ConnId -> InternalSndId -> IO (Either StoreError SndMsg) +getSndMsgViaRcpt db connId sndMsgId = + firstRow toSndMsg SEMsgNotFound $ + DB.query + db + [sql| + SELECT s.internal_id, m.msg_type, s.internal_hash, s.rcpt_internal_id, s.rcpt_status + FROM snd_messages s + JOIN messages m ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id + WHERE s.conn_id = ? AND s.internal_snd_id = ? + |] + (connId, sndMsgId) + where + toSndMsg :: (InternalId, AgentMessageType, MsgHash, Maybe AgentMsgId, Maybe MsgReceiptStatus) -> SndMsg + toSndMsg (internalId, msgType, internalHash, rcptInternalId_, rcptStatus_) = + let msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_ + in SndMsg {internalId, internalSndId = sndMsgId, msgType, internalHash, msgReceipt} + +updateSndMsgRcpt :: DB.Connection -> ConnId -> InternalSndId -> MsgReceipt -> IO () +updateSndMsgRcpt db connId sndMsgId MsgReceipt {agentMsgId, msgRcptStatus} = + DB.execute + db + "UPDATE snd_messages SET rcpt_internal_id = ?, rcpt_status = ? WHERE conn_id = ? AND internal_snd_id = ?" + (agentMsgId, msgRcptStatus, connId, sndMsgId) + +getConnectionsForDelivery :: DB.Connection -> IO [ConnId] +getConnectionsForDelivery db = + map fromOnly <$> DB.query_ db "SELECT DISTINCT conn_id FROM snd_message_deliveries WHERE failed = 0" + +getPendingQueueMsg :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError (Maybe (Maybe RcvQueue, PendingMsgData))) +getPendingQueueMsg db connId SndQueue {dbQueueId} = + getWorkItem "message" getMsgId getMsgData markMsgFailed + where + getMsgId :: IO (Maybe InternalId) + getMsgId = + maybeFirstRow fromOnly $ + DB.query + db + [sql| + SELECT internal_id + FROM snd_message_deliveries d + WHERE conn_id = ? AND snd_queue_id = ? AND failed = 0 + ORDER BY internal_id ASC + LIMIT 1 + |] + (connId, dbQueueId) + getMsgData :: InternalId -> IO (Either StoreError (Maybe RcvQueue, PendingMsgData)) + getMsgData msgId = runExceptT $ do + msg <- ExceptT $ firstRow pendingMsgData err getMsgData_ + rq_ <- liftIO $ L.head <$$> getRcvQueuesByConnId_ db connId + pure (rq_, msg) + where + getMsgData_ = + DB.query + db + [sql| + SELECT m.msg_type, m.msg_flags, m.msg_body, m.pq_encryption, m.internal_ts, s.retry_int_slow, s.retry_int_fast + FROM messages m + JOIN snd_messages s ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id + WHERE m.conn_id = ? AND m.internal_id = ? + |] + (connId, msgId) + err = SEInternal $ "msg delivery " <> bshow msgId <> " returned []" + pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, PQEncryption, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData + pendingMsgData (msgType, msgFlags_, msgBody, pqEncryption, internalTs, riSlow_, riFast_) = + let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_ + msgRetryState = RI2State <$> riSlow_ <*> riFast_ + in PendingMsgData {msgId, msgType, msgFlags, msgBody, pqEncryption, msgRetryState, internalTs} + markMsgFailed msgId = DB.execute db "UPDATE snd_message_deliveries SET failed = 1 WHERE conn_id = ? AND internal_id = ?" (connId, msgId) + +getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError (Maybe a)) +getWorkItem itemName getId getItem markFailed = + runExceptT $ handleWrkErr itemName "getId" getId >>= mapM (tryGetItem itemName getItem markFailed) + +getWorkItems :: Show i => ByteString -> IO [i] -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError [Either StoreError a]) +getWorkItems itemName getIds getItem markFailed = + runExceptT $ handleWrkErr itemName "getIds" getIds >>= mapM (tryE . tryGetItem itemName getItem markFailed) + +tryGetItem :: Show i => ByteString -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> i -> ExceptT StoreError IO a +tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchStoreError` \e -> mark >> throwE e + where + mark = handleWrkErr itemName ("markFailed ID " <> bshow itemId) $ markFailed itemId + +catchStoreError :: ExceptT StoreError IO a -> (StoreError -> ExceptT StoreError IO a) -> ExceptT StoreError IO a +catchStoreError = catchAllErrors (SEInternal . bshow) + +-- Errors caught by this function will suspend worker as if there is no more work, +handleWrkErr :: ByteString -> ByteString -> IO a -> ExceptT StoreError IO a +handleWrkErr itemName opName action = ExceptT $ first mkError <$> E.try action + where + mkError :: E.SomeException -> StoreError + mkError e = SEWorkItemError $ itemName <> " " <> opName <> " error: " <> bshow e + +updatePendingMsgRIState :: DB.Connection -> ConnId -> InternalId -> RI2State -> IO () +updatePendingMsgRIState db connId msgId RI2State {slowInterval, fastInterval} = + DB.execute db "UPDATE snd_messages SET retry_int_slow = ?, retry_int_fast = ? WHERE conn_id = ? AND internal_id = ?" (slowInterval, fastInterval, connId, msgId) + +deletePendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO () +deletePendingMsgs db connId SndQueue {dbQueueId} = + DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) + +getExpiredSndMessages :: DB.Connection -> ConnId -> SndQueue -> UTCTime -> IO [InternalId] +getExpiredSndMessages db connId SndQueue {dbQueueId} expireTs = do + -- type is Maybe InternalId because MAX always returns one row, possibly with NULL value + maxId :: [Maybe InternalId] <- + map fromOnly + <$> DB.query + db + [sql| + SELECT MAX(internal_id) + FROM messages + WHERE conn_id = ? AND internal_snd_id IS NOT NULL AND internal_ts < ? + |] + (connId, expireTs) + case maxId of + Just msgId : _ -> + map fromOnly + <$> DB.query + db + [sql| + SELECT internal_id + FROM snd_message_deliveries + WHERE conn_id = ? AND snd_queue_id = ? AND failed = 0 AND internal_id <= ? + ORDER BY internal_id ASC + |] + (connId, dbQueueId, msgId) + _ -> pure [] + +setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError (RcvQueue, SMP.MsgId)) +setMsgUserAck db connId agentMsgId = runExceptT $ do + (dbRcvId, srvMsgId) <- + ExceptT . firstRow id SEMsgNotFound $ + DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId) + rq <- ExceptT $ getRcvQueueById db connId dbRcvId + liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (BI True, connId, agentMsgId) + pure (rq, srvMsgId) + +getRcvMsg :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError RcvMsg) +getRcvMsg db connId agentMsgId = + firstRow toRcvMsg SEMsgNotFound $ + DB.query + db + [sql| + SELECT + r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash, + m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack + FROM rcv_messages r + JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id + LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id + WHERE r.conn_id = ? AND r.internal_id = ? + |] + (connId, agentMsgId) + +getLastMsg :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Maybe RcvMsg) +getLastMsg db connId msgId = + maybeFirstRow toRcvMsg $ + DB.query + db + [sql| + SELECT + r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash, + m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack + FROM rcv_messages r + JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id + JOIN connections c ON r.conn_id = c.conn_id AND c.last_internal_msg_id = r.internal_id + LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id + WHERE r.conn_id = ? AND r.broker_id = ? + |] + (connId, Binary msgId) + +toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, BoolInt) -> RcvMsg +toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, BI userAck)) = + let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity, pqEncryption} + msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_ + in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck} + +checkRcvMsgHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool +checkRcvMsgHashExists db connId hash = do + fromMaybe False + <$> maybeFirstRow + fromOnly + ( DB.query + db + "SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" + (connId, Binary hash) + ) + +getRcvMsgBrokerTs :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Either StoreError BrokerTs) +getRcvMsgBrokerTs db connId msgId = + firstRow fromOnly SEMsgNotFound $ + DB.query db "SELECT broker_ts FROM rcv_messages WHERE conn_id = ? AND broker_id = ?" (connId, Binary msgId) + +deleteMsg :: DB.Connection -> ConnId -> InternalId -> IO () +deleteMsg db connId msgId = + DB.execute db "DELETE FROM messages WHERE conn_id = ? AND internal_id = ?;" (connId, msgId) + +deleteMsgContent :: DB.Connection -> ConnId -> InternalId -> IO () +deleteMsgContent db connId msgId = +#if defined(dbPostgres) + DB.execute db "UPDATE messages SET msg_body = ''::BYTEA WHERE conn_id = ? AND internal_id = ?" (connId, msgId) +#else + DB.execute db "UPDATE messages SET msg_body = x'' WHERE conn_id = ? AND internal_id = ?" (connId, msgId) +#endif + +deleteDeliveredSndMsg :: DB.Connection -> ConnId -> InternalId -> IO () +deleteDeliveredSndMsg db connId msgId = do + cnt <- countPendingSndDeliveries_ db connId msgId + when (cnt == 0) $ deleteMsg db connId msgId + +deleteSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> Bool -> IO () +deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId keepForReceipt = do + DB.execute + db + "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ? AND internal_id = ?" + (connId, dbQueueId, msgId) + cnt <- countPendingSndDeliveries_ db connId msgId + when (cnt == 0) $ do + del <- + maybeFirstRow id (DB.query db "SELECT rcpt_internal_id, rcpt_status FROM snd_messages WHERE conn_id = ? AND internal_id = ?" (connId, msgId)) >>= \case + Just (Just (_ :: Int64), Just MROk) -> pure deleteMsg + _ -> pure $ if keepForReceipt then deleteMsgContent else deleteMsg + del db connId msgId + +countPendingSndDeliveries_ :: DB.Connection -> ConnId -> InternalId -> IO Int +countPendingSndDeliveries_ db connId msgId = do + (Only cnt : _) <- DB.query db "SELECT count(*) FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ? AND failed = 0" (connId, msgId) + pure cnt + +deleteRcvMsgHashesExpired :: DB.Connection -> NominalDiffTime -> IO () +deleteRcvMsgHashesExpired db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + DB.execute db "DELETE FROM encrypted_rcv_message_hashes WHERE created_at < ?" (Only cutoffTs) + +deleteSndMsgsExpired :: DB.Connection -> NominalDiffTime -> IO () +deleteSndMsgsExpired db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + DB.execute + db + "DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL" + (Only cutoffTs) + +createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO () +createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem = + DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem) VALUES (?, ?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2, pqPrivKem) + +getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams)) +getRatchetX3dhKeys db connId = + firstRow' keys SEX3dhKeysNotFound $ + DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem FROM ratchets WHERE conn_id = ?" (Only connId) + where + keys = \case + (Just k1, Just k2, pKem) -> Right (k1, k2, pKem) + _ -> Left SEX3dhKeysNotFound + +-- used to remember new keys when starting ratchet re-synchronization +setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO () +setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem = + DB.execute + db + [sql| + UPDATE ratchets + SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, pq_priv_kem = ? + WHERE conn_id = ? + |] + (x3dhPrivKey1, x3dhPrivKey2, pqPrivKem, connId) + +createSndRatchet :: DB.Connection -> ConnId -> RatchetX448 -> CR.AE2ERatchetParams 'C.X448 -> IO () +createSndRatchet db connId ratchetState (CR.AE2ERatchetParams s (CR.E2ERatchetParams _ x3dhPubKey1 x3dhPubKey2 pqPubKem)) = + DB.execute + db + [sql| + INSERT INTO ratchets + (conn_id, ratchet_state, x3dh_pub_key_1, x3dh_pub_key_2, pq_pub_kem) VALUES (?, ?, ?, ?, ?) + ON CONFLICT (conn_id) DO UPDATE SET + ratchet_state = EXCLUDED.ratchet_state, + x3dh_priv_key_1 = NULL, + x3dh_priv_key_2 = NULL, + x3dh_pub_key_1 = EXCLUDED.x3dh_pub_key_1, + x3dh_pub_key_2 = EXCLUDED.x3dh_pub_key_2, + pq_priv_kem = NULL, + pq_pub_kem = EXCLUDED.pq_pub_kem + |] + (connId, ratchetState, x3dhPubKey1, x3dhPubKey2, CR.ARKP s <$> pqPubKem) + +getSndRatchet :: DB.Connection -> ConnId -> CR.VersionE2E -> IO (Either StoreError (RatchetX448, CR.AE2ERatchetParams 'C.X448)) +getSndRatchet db connId v = + firstRow' result SEX3dhKeysNotFound $ + DB.query db "SELECT ratchet_state, x3dh_pub_key_1, x3dh_pub_key_2, pq_pub_kem FROM ratchets WHERE conn_id = ?" (Only connId) + where + result = \case + (Just ratchetState, Just k1, Just k2, pKem_) -> + let params = case pKem_ of + Nothing -> CR.AE2ERatchetParams CR.SRKSProposed (CR.E2ERatchetParams v k1 k2 Nothing) + Just (CR.ARKP s pKem) -> CR.AE2ERatchetParams s (CR.E2ERatchetParams v k1 k2 (Just pKem)) + in Right (ratchetState, params) + _ -> Left SEX3dhKeysNotFound + +-- TODO remove the columns for public keys in v5.7. +createRatchet :: DB.Connection -> ConnId -> RatchetX448 -> IO () +createRatchet db connId rc = + DB.execute + db + [sql| + INSERT INTO ratchets (conn_id, ratchet_state) + VALUES (?, ?) + ON CONFLICT (conn_id) DO UPDATE SET + ratchet_state = ?, + x3dh_priv_key_1 = NULL, + x3dh_priv_key_2 = NULL, + x3dh_pub_key_1 = NULL, + x3dh_pub_key_2 = NULL, + pq_priv_kem = NULL, + pq_pub_kem = NULL + |] + (connId, rc, rc) + +deleteRatchet :: DB.Connection -> ConnId -> IO () +deleteRatchet db connId = + DB.execute db "DELETE FROM ratchets WHERE conn_id = ?" (Only connId) + +getRatchet :: DB.Connection -> ConnId -> IO (Either StoreError RatchetX448) +getRatchet db connId = + firstRow' ratchet SERatchetNotFound $ DB.query db "SELECT ratchet_state FROM ratchets WHERE conn_id = ?" (Only connId) + where + ratchet = maybe (Left SERatchetNotFound) Right . fromOnly + +getSkippedMsgKeys :: DB.Connection -> ConnId -> IO SkippedMsgKeys +getSkippedMsgKeys db connId = + skipped <$> DB.query db "SELECT header_key, msg_n, msg_key FROM skipped_messages WHERE conn_id = ?" (Only connId) + where + skipped = foldl' addSkippedKey M.empty + addSkippedKey smks (hk, msgN, mk) = M.alter (Just . addMsgKey) hk smks + where + addMsgKey = maybe (M.singleton msgN mk) (M.insert msgN mk) + +updateRatchet :: DB.Connection -> ConnId -> RatchetX448 -> SkippedMsgDiff -> IO () +updateRatchet db connId rc skipped = do + DB.execute db "UPDATE ratchets SET ratchet_state = ? WHERE conn_id = ?" (rc, connId) + case skipped of + SMDNoChange -> pure () + SMDRemove hk msgN -> + DB.execute db "DELETE FROM skipped_messages WHERE conn_id = ? AND header_key = ? AND msg_n = ?" (connId, hk, msgN) + SMDAdd smks -> + forM_ (M.assocs smks) $ \(hk, mks) -> + forM_ (M.assocs mks) $ \(msgN, mk) -> + DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk) + +createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> IO (Either StoreError ()) +createCommand db corrId connId srv_ cmd = runExceptT $ do + (host_, port_, serverKeyHash_) <- serverFields + createdAt <- liftIO getCurrentTime + liftIO . E.handle handleErr $ + DB.execute + db + "INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command, server_key_hash, created_at) VALUES (?,?,?,?,?,?,?,?)" + (host_, port_, Binary corrId, connId, cmdTag, cmd, serverKeyHash_, createdAt) + where + cmdTag = agentCommandTag cmd +#if defined(dbPostgres) + handleErr e = case constraintViolation e of + Just _ -> logError $ "tried to create command " <> tshow cmdTag <> " for deleted connection" + Nothing -> E.throwIO e +#else + handleErr e + | SQL.sqlError e == SQL.ErrorConstraint = logError $ "tried to create command " <> tshow cmdTag <> " for deleted connection" + | otherwise = E.throwIO e +#endif + serverFields :: ExceptT StoreError IO (Maybe (NonEmpty TransportHost), Maybe ServiceName, Maybe C.KeyHash) + serverFields = case srv_ of + Just srv@(SMPServer host port _) -> + (Just host,Just port,) <$> ExceptT (getServerKeyHash_ db srv) + Nothing -> pure (Nothing, Nothing, Nothing) + +insertedRowId :: DB.Connection -> IO Int64 +insertedRowId db = fromOnly . head <$> DB.query_ db q + where +#if defined(dbPostgres) + q = "SELECT lastval()" +#else + q = "SELECT last_insert_rowid()" +#endif + +getPendingCommandServers :: DB.Connection -> ConnId -> IO [Maybe SMPServer] +getPendingCommandServers db connId = do + -- TODO review whether this can break if, e.g., the server has another key hash. + map smpServer + <$> DB.query + db + [sql| + SELECT DISTINCT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash) + FROM commands c + LEFT JOIN servers s ON s.host = c.host AND s.port = c.port + WHERE conn_id = ? + |] + (Only connId) + where + smpServer (host, port, keyHash) = SMPServer <$> host <*> port <*> keyHash + +getPendingServerCommand :: DB.Connection -> ConnId -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand)) +getPendingServerCommand db connId srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed + where + getCmdId :: IO (Maybe Int64) + getCmdId = + maybeFirstRow fromOnly $ case srv_ of + Nothing -> + DB.query + db + [sql| + SELECT command_id FROM commands + WHERE conn_id = ? AND host IS NULL AND port IS NULL AND failed = 0 + ORDER BY created_at ASC, command_id ASC + LIMIT 1 + |] + (Only connId) + Just (SMPServer host port _) -> + DB.query + db + [sql| + SELECT command_id FROM commands + WHERE conn_id = ? AND host = ? AND port = ? AND failed = 0 + ORDER BY created_at ASC, command_id ASC + LIMIT 1 + |] + (connId, host, port) + getCommand :: Int64 -> IO (Either StoreError PendingCommand) + getCommand cmdId = + firstRow pendingCommand err $ + DB.query + db + [sql| + SELECT c.corr_id, cs.user_id, c.command + FROM commands c + JOIN connections cs USING (conn_id) + WHERE c.command_id = ? + |] + (Only cmdId) + where + err = SEInternal $ "command " <> bshow cmdId <> " returned []" + pendingCommand (corrId, userId, command) = PendingCommand {cmdId, corrId, userId, connId, command} + markCommandFailed cmdId = DB.execute db "UPDATE commands SET failed = 1 WHERE command_id = ?" (Only cmdId) + +updateCommandServer :: DB.Connection -> AsyncCmdId -> SMPServer -> IO (Either StoreError ()) +updateCommandServer db cmdId srv@(SMPServer host port _) = runExceptT $ do + serverKeyHash_ <- ExceptT $ getServerKeyHash_ db srv + liftIO $ + DB.execute + db + [sql| + UPDATE commands + SET host = ?, port = ?, server_key_hash = ? + WHERE command_id = ? + |] + (host, port, serverKeyHash_, cmdId) + +deleteCommand :: DB.Connection -> AsyncCmdId -> IO () +deleteCommand db cmdId = + DB.execute db "DELETE FROM commands WHERE command_id = ?" (Only cmdId) + +createNtfToken :: DB.Connection -> NtfToken -> IO () +createNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey), ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} = do + upsertNtfServer_ db srv + DB.execute + db + [sql| + INSERT INTO ntf_tokens + (provider, device_token, ntf_host, ntf_port, tkn_id, tkn_pub_key, tkn_priv_key, tkn_pub_dh_key, tkn_priv_dh_key, tkn_dh_secret, tkn_status, tkn_action, ntf_mode) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + |] + ((provider, token, host, port, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode)) + +getSavedNtfToken :: DB.Connection -> IO (Maybe NtfToken) +getSavedNtfToken db = do + maybeFirstRow ntfToken $ + DB.query_ + db + [sql| + SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash, + t.provider, t.device_token, t.tkn_id, t.tkn_pub_key, t.tkn_priv_key, t.tkn_pub_dh_key, t.tkn_priv_dh_key, t.tkn_dh_secret, + t.tkn_status, t.tkn_action, t.ntf_mode + FROM ntf_tokens t + JOIN ntf_servers s USING (ntf_host, ntf_port) + |] + where + ntfToken ((host, port, keyHash) :. (provider, dt, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode_)) = + let ntfServer = NtfServer host port keyHash + ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey) + ntfMode = fromMaybe NMPeriodic ntfMode_ + in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} + +updateNtfTokenRegistration :: DB.Connection -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> IO () +updateNtfTokenRegistration db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret = do + updatedAt <- getCurrentTime + DB.execute + db + [sql| + UPDATE ntf_tokens + SET tkn_id = ?, tkn_dh_secret = ?, tkn_status = ?, tkn_action = ?, updated_at = ? + WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? + |] + (tknId, ntfDhSecret, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port) + +updateDeviceToken :: DB.Connection -> NtfToken -> DeviceToken -> IO () +updateDeviceToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} (DeviceToken toProvider toToken) = do + updatedAt <- getCurrentTime + DB.execute + db + [sql| + UPDATE ntf_tokens + SET provider = ?, device_token = ?, tkn_status = ?, tkn_action = ?, updated_at = ? + WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? + |] + (toProvider, toToken, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port) + +updateNtfMode :: DB.Connection -> NtfToken -> NotificationsMode -> IO () +updateNtfMode db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} ntfMode = do + updatedAt <- getCurrentTime + DB.execute + db + [sql| + UPDATE ntf_tokens + SET ntf_mode = ?, updated_at = ? + WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? + |] + (ntfMode, updatedAt, provider, token, host, port) + +updateNtfToken :: DB.Connection -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> IO () +updateNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknStatus tknAction = do + updatedAt <- getCurrentTime + DB.execute + db + [sql| + UPDATE ntf_tokens + SET tkn_status = ?, tkn_action = ?, updated_at = ? + WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? + |] + (tknStatus, tknAction, updatedAt, provider, token, host, port) + +removeNtfToken :: DB.Connection -> NtfToken -> IO () +removeNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} = + DB.execute + db + [sql| + DELETE FROM ntf_tokens + WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? + |] + (provider, token, host, port) + +addNtfTokenToDelete :: DB.Connection -> NtfServer -> C.APrivateAuthKey -> NtfTokenId -> IO () +addNtfTokenToDelete db ProtocolServer {host, port, keyHash} ntfPrivKey tknId = + DB.execute db "INSERT INTO ntf_tokens_to_delete (ntf_host, ntf_port, ntf_key_hash, tkn_id, tkn_priv_key) VALUES (?,?,?,?,?)" (host, port, keyHash, tknId, ntfPrivKey) + +deleteExpiredNtfTokensToDelete :: DB.Connection -> NominalDiffTime -> IO () +deleteExpiredNtfTokensToDelete db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + DB.execute db "DELETE FROM ntf_tokens_to_delete WHERE created_at < ?" (Only cutoffTs) + +type NtfTokenToDelete = (Int64, C.APrivateAuthKey, NtfTokenId) + +getNextNtfTokenToDelete :: DB.Connection -> NtfServer -> IO (Either StoreError (Maybe NtfTokenToDelete)) +getNextNtfTokenToDelete db (NtfServer ntfHost ntfPort _) = + getWorkItem "ntf tkn del" getNtfTknDbId getNtfTknToDelete (markNtfTokenToDeleteFailed_ db) + where + getNtfTknDbId :: IO (Maybe Int64) + getNtfTknDbId = + maybeFirstRow fromOnly $ + DB.query + db + [sql| + SELECT ntf_token_to_delete_id + FROM ntf_tokens_to_delete + WHERE ntf_host = ? AND ntf_port = ? + AND del_failed = 0 + ORDER BY created_at ASC + LIMIT 1 + |] + (ntfHost, ntfPort) + getNtfTknToDelete :: Int64 -> IO (Either StoreError NtfTokenToDelete) + getNtfTknToDelete tknDbId = + firstRow ntfTokenToDelete err $ + DB.query + db + [sql| + SELECT tkn_priv_key, tkn_id + FROM ntf_tokens_to_delete + WHERE ntf_token_to_delete_id = ? + |] + (Only tknDbId) + where + err = SEInternal $ "ntf token to delete " <> bshow tknDbId <> " returned []" + ntfTokenToDelete (tknPrivKey, tknId) = (tknDbId, tknPrivKey, tknId) + +markNtfTokenToDeleteFailed_ :: DB.Connection -> Int64 -> IO () +markNtfTokenToDeleteFailed_ db tknDbId = + DB.execute db "UPDATE ntf_tokens_to_delete SET del_failed = 1 where ntf_token_to_delete_id = ?" (Only tknDbId) + +getPendingDelTknServers :: DB.Connection -> IO [NtfServer] +getPendingDelTknServers db = + map toNtfServer + <$> DB.query_ + db + [sql| + SELECT DISTINCT ntf_host, ntf_port, ntf_key_hash + FROM ntf_tokens_to_delete + |] + where + toNtfServer (host, port, keyHash) = NtfServer host port keyHash + +deleteNtfTokenToDelete :: DB.Connection -> Int64 -> IO () +deleteNtfTokenToDelete db tknDbId = + DB.execute db "DELETE FROM ntf_tokens_to_delete WHERE ntf_token_to_delete_id = ?" (Only tknDbId) + +type NtfSupervisorSub = (NtfSubscription, Maybe (NtfSubAction, NtfActionTs)) + +getNtfSubscription :: DB.Connection -> ConnId -> IO (Maybe NtfSupervisorSub) +getNtfSubscription db connId = + maybeFirstRow ntfSubscription $ + DB.query + db + [sql| + SELECT c.user_id, s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash, + nsb.smp_ntf_id, nsb.ntf_sub_id, nsb.ntf_sub_status, nsb.ntf_sub_action, nsb.ntf_sub_smp_action, nsb.ntf_sub_action_ts + FROM ntf_subscriptions nsb + JOIN connections c USING (conn_id) + JOIN servers s ON s.host = nsb.smp_host AND s.port = nsb.smp_port + JOIN ntf_servers ns USING (ntf_host, ntf_port) + WHERE nsb.conn_id = ? + |] + (Only connId) + where + ntfSubscription ((userId, smpHost, smpPort, smpKeyHash, ntfHost, ntfPort, ntfKeyHash) :. (ntfQueueId, ntfSubId, ntfSubStatus, ntfAction_, smpAction_, actionTs_)) = + let smpServer = SMPServer smpHost smpPort smpKeyHash + ntfServer = NtfServer ntfHost ntfPort ntfKeyHash + action = case (ntfAction_, smpAction_, actionTs_) of + (Just ntfAction, Nothing, Just actionTs) -> Just (NSANtf ntfAction, actionTs) + (Nothing, Just smpAction, Just actionTs) -> Just (NSASMP smpAction, actionTs) + _ -> Nothing + in (NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action) + +createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO (Either StoreError ()) +createNtfSubscription db ntfSubscription action = runExceptT $ do + let NtfSubscription {connId, smpServer = smpServer@(SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription + smpServerKeyHash_ <- ExceptT $ getServerKeyHash_ db smpServer + actionTs <- liftIO getCurrentTime + liftIO $ + DB.execute + db + [sql| + INSERT INTO ntf_subscriptions + (conn_id, smp_host, smp_port, smp_ntf_id, ntf_host, ntf_port, ntf_sub_id, + ntf_sub_status, ntf_sub_action, ntf_sub_smp_action, ntf_sub_action_ts, smp_server_key_hash) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?) + |] + ( (connId, host, port, ntfQueueId, ntfHost, ntfPort, ntfSubId) + :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, smpServerKeyHash_) + ) + where + (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action + +supervisorUpdateNtfSub :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO () +supervisorUpdateNtfSub db NtfSubscription {connId, smpServer = (SMPServer smpHost smpPort _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action = do + ts <- getCurrentTime + DB.execute + db + [sql| + UPDATE ntf_subscriptions + SET smp_host = ?, smp_port = ?, smp_ntf_id = ?, ntf_host = ?, ntf_port = ?, ntf_sub_id = ?, + ntf_sub_status = ?, ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? + WHERE conn_id = ? + |] + ( (smpHost, smpPort, ntfQueueId, ntfHost, ntfPort, ntfSubId) + :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, ts, BI True, ts, connId) + ) + where + (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action + +supervisorUpdateNtfAction :: DB.Connection -> ConnId -> NtfSubAction -> IO () +supervisorUpdateNtfAction db connId action = do + ts <- getCurrentTime + DB.execute + db + [sql| + UPDATE ntf_subscriptions + SET ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? + WHERE conn_id = ? + |] + (ntfSubAction, ntfSubSMPAction, ts, BI True, ts, connId) + where + (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action + +updateNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> NtfActionTs -> IO () +updateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action actionTs = do + r <- maybeFirstRow fromOnlyBI $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) + forM_ r $ \updatedBySupervisor -> do + updatedAt <- getCurrentTime + if updatedBySupervisor + then + DB.execute + db + [sql| + UPDATE ntf_subscriptions + SET smp_ntf_id = ?, ntf_sub_id = ?, ntf_sub_status = ?, updated_by_supervisor = ?, updated_at = ? + WHERE conn_id = ? + |] + (ntfQueueId, ntfSubId, ntfSubStatus, BI False, updatedAt, connId) + else + DB.execute + db + [sql| + UPDATE ntf_subscriptions + SET smp_ntf_id = ?, ntf_host = ?, ntf_port = ?, ntf_sub_id = ?, ntf_sub_status = ?, ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? + WHERE conn_id = ? + |] + ((ntfQueueId, ntfHost, ntfPort, ntfSubId) :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, BI False, updatedAt, connId)) + where + (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action + +setNullNtfSubscriptionAction :: DB.Connection -> ConnId -> IO () +setNullNtfSubscriptionAction db connId = do + r <- maybeFirstRow fromOnlyBI $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) + forM_ r $ \updatedBySupervisor -> + unless updatedBySupervisor $ do + updatedAt <- getCurrentTime + DB.execute + db + [sql| + UPDATE ntf_subscriptions + SET ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? + WHERE conn_id = ? + |] + (Nothing :: Maybe NtfSubNTFAction, Nothing :: Maybe NtfSubSMPAction, Nothing :: Maybe UTCTime, BI False, updatedAt, connId) + +deleteNtfSubscription :: DB.Connection -> ConnId -> IO () +deleteNtfSubscription db connId = do + r <- maybeFirstRow fromOnlyBI $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) + forM_ r $ \updatedBySupervisor -> do + updatedAt <- getCurrentTime + if updatedBySupervisor + then + DB.execute + db + [sql| + UPDATE ntf_subscriptions + SET smp_ntf_id = ?, ntf_sub_id = ?, ntf_sub_status = ?, updated_by_supervisor = ?, updated_at = ? + WHERE conn_id = ? + |] + (Nothing :: Maybe SMP.NotifierId, Nothing :: Maybe NtfSubscriptionId, NASDeleted, BI False, updatedAt, connId) + else deleteNtfSubscription' db connId + +deleteNtfSubscription' :: DB.Connection -> ConnId -> IO () +deleteNtfSubscription' db connId = do + DB.execute db "DELETE FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) + +getNextNtfSubNTFActions :: DB.Connection -> NtfServer -> Int -> IO (Either StoreError [Either StoreError (NtfSubNTFAction, NtfSubscription, NtfActionTs)]) +getNextNtfSubNTFActions db ntfServer@(NtfServer ntfHost ntfPort _) ntfBatchSize = + getWorkItems "ntf NTF" getNtfConnIds getNtfSubAction (markNtfSubActionNtfFailed_ db) + where + getNtfConnIds :: IO [ConnId] + getNtfConnIds = + map fromOnly + <$> DB.query + db + [sql| + SELECT conn_id + FROM ntf_subscriptions + WHERE ntf_host = ? AND ntf_port = ? AND ntf_sub_action IS NOT NULL + AND (ntf_failed = 0 OR updated_by_supervisor = 1) + ORDER BY ntf_sub_action_ts ASC + LIMIT ? + |] + (ntfHost, ntfPort, ntfBatchSize) + getNtfSubAction :: ConnId -> IO (Either StoreError (NtfSubNTFAction, NtfSubscription, NtfActionTs)) + getNtfSubAction connId = do + markUpdatedByWorker db connId + firstRow ntfSubAction err $ + DB.query + db + [sql| + SELECT c.user_id, s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash), + ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_action + FROM ntf_subscriptions ns + JOIN connections c USING (conn_id) + JOIN servers s ON s.host = ns.smp_host AND s.port = ns.smp_port + WHERE ns.conn_id = ? + |] + (Only connId) + where + err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []" + ntfSubAction (userId, smpHost, smpPort, smpKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = + let smpServer = SMPServer smpHost smpPort smpKeyHash + ntfSubscription = NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} + in (action, ntfSubscription, actionTs) + +markNtfSubActionNtfFailed_ :: DB.Connection -> ConnId -> IO () +markNtfSubActionNtfFailed_ db connId = + DB.execute db "UPDATE ntf_subscriptions SET ntf_failed = 1 where conn_id = ?" (Only connId) + +getNextNtfSubSMPActions :: DB.Connection -> SMPServer -> Int -> IO (Either StoreError [Either StoreError (NtfSubSMPAction, NtfSubscription)]) +getNextNtfSubSMPActions db smpServer@(SMPServer smpHost smpPort _) ntfBatchSize = + getWorkItems "ntf SMP" getNtfConnIds getNtfSubAction (markNtfSubActionSMPFailed_ db) + where + getNtfConnIds :: IO [ConnId] + getNtfConnIds = + map fromOnly + <$> DB.query + db + [sql| + SELECT conn_id + FROM ntf_subscriptions ns + WHERE smp_host = ? AND smp_port = ? AND ntf_sub_smp_action IS NOT NULL AND ntf_sub_action_ts IS NOT NULL + AND (smp_failed = 0 OR updated_by_supervisor = 1) + ORDER BY ntf_sub_action_ts ASC + LIMIT ? + |] + (smpHost, smpPort, ntfBatchSize) + getNtfSubAction :: ConnId -> IO (Either StoreError (NtfSubSMPAction, NtfSubscription)) + getNtfSubAction connId = do + markUpdatedByWorker db connId + firstRow ntfSubAction err $ + DB.query + db + [sql| + SELECT c.user_id, s.ntf_host, s.ntf_port, s.ntf_key_hash, + ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_smp_action + FROM ntf_subscriptions ns + JOIN connections c USING (conn_id) + JOIN ntf_servers s USING (ntf_host, ntf_port) + WHERE ns.conn_id = ? + |] + (Only connId) + where + err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []" + ntfSubAction (userId, ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, action) = + let ntfServer = NtfServer ntfHost ntfPort ntfKeyHash + ntfSubscription = NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} + in (action, ntfSubscription) + +markNtfSubActionSMPFailed_ :: DB.Connection -> ConnId -> IO () +markNtfSubActionSMPFailed_ db connId = + DB.execute db "UPDATE ntf_subscriptions SET smp_failed = 1 where conn_id = ?" (Only connId) + +markUpdatedByWorker :: DB.Connection -> ConnId -> IO () +markUpdatedByWorker db connId = + DB.execute db "UPDATE ntf_subscriptions SET updated_by_supervisor = 0 WHERE conn_id = ?" (Only connId) + +getActiveNtfToken :: DB.Connection -> IO (Maybe NtfToken) +getActiveNtfToken db = + maybeFirstRow ntfToken $ + DB.query + db + [sql| + SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash, + t.provider, t.device_token, t.tkn_id, t.tkn_pub_key, t.tkn_priv_key, t.tkn_pub_dh_key, t.tkn_priv_dh_key, t.tkn_dh_secret, + t.tkn_status, t.tkn_action, t.ntf_mode + FROM ntf_tokens t + JOIN ntf_servers s USING (ntf_host, ntf_port) + WHERE t.tkn_status = ? + |] + (Only NTActive) + where + ntfToken ((host, port, keyHash) :. (provider, dt, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode_)) = + let ntfServer = NtfServer host port keyHash + ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey) + ntfMode = fromMaybe NMPeriodic ntfMode_ + in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} + +getNtfRcvQueue :: DB.Connection -> SMPQueueNtf -> IO (Either StoreError (ConnId, RcvNtfDhSecret, Maybe UTCTime)) +getNtfRcvQueue db SMPQueueNtf {smpServer = (SMPServer host port _), notifierId} = + firstRow' res SEConnNotFound $ + DB.query + db + [sql| + SELECT conn_id, rcv_ntf_dh_secret, last_broker_ts + FROM rcv_queues + WHERE host = ? AND port = ? AND ntf_id = ? AND deleted = 0 + |] + (host, port, notifierId) + where + res (connId, Just rcvNtfDhSecret, lastBrokerTs_) = Right (connId, rcvNtfDhSecret, lastBrokerTs_) + res _ = Left SEConnNotFound + +setConnectionNtfs :: DB.Connection -> ConnId -> Bool -> IO () +setConnectionNtfs db connId enableNtfs = + DB.execute db "UPDATE connections SET enable_ntfs = ? WHERE conn_id = ?" (BI enableNtfs, connId) + +-- * Auxiliary helpers + +instance ToField QueueStatus where toField = toField . serializeQueueStatus + +instance FromField QueueStatus where fromField = fromTextField_ queueStatusT + +instance ToField (DBQueueId 'QSStored) where toField (DBQueueId qId) = toField qId + +instance FromField (DBQueueId 'QSStored) where +#if defined(dbPostgres) + fromField x dat = DBQueueId <$> fromField x dat +#else + fromField x = DBQueueId <$> fromField x +#endif + +instance ToField InternalRcvId where toField (InternalRcvId x) = toField x + +deriving newtype instance FromField InternalRcvId + +instance ToField InternalSndId where toField (InternalSndId x) = toField x + +deriving newtype instance FromField InternalSndId + +instance ToField InternalId where toField (InternalId x) = toField x + +deriving newtype instance FromField InternalId + +instance ToField AgentMessageType where toField = toField . Binary . smpEncode + +instance FromField AgentMessageType where fromField = blobFieldParser smpP + +instance ToField MsgIntegrity where toField = toField . Binary . strEncode + +instance FromField MsgIntegrity where fromField = blobFieldParser strP + +instance ToField SMPQueueUri where toField = toField . Binary . strEncode + +instance FromField SMPQueueUri where fromField = blobFieldParser strP + +instance ToField AConnectionRequestUri where toField = toField . Binary . strEncode + +instance FromField AConnectionRequestUri where fromField = blobFieldParser strP + +instance ConnectionModeI c => ToField (ConnectionRequestUri c) where toField = toField . Binary . strEncode + +instance (E.Typeable c, ConnectionModeI c) => FromField (ConnectionRequestUri c) where fromField = blobFieldParser strP + +instance ToField ConnectionMode where toField = toField . decodeLatin1 . strEncode + +instance FromField ConnectionMode where fromField = fromTextField_ connModeT + +instance ToField (SConnectionMode c) where toField = toField . connMode + +instance FromField AConnectionMode where fromField = fromTextField_ $ fmap connMode' . connModeT + +instance ToField MsgFlags where toField = toField . decodeLatin1 . smpEncode + +instance FromField MsgFlags where fromField = fromTextField_ $ eitherToMaybe . smpDecode . encodeUtf8 + +instance ToField [SMPQueueInfo] where toField = toField . Binary . smpEncodeList + +instance FromField [SMPQueueInfo] where fromField = blobFieldParser smpListP + +instance ToField (NonEmpty TransportHost) where toField = toField . decodeLatin1 . strEncode + +instance FromField (NonEmpty TransportHost) where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 + +instance ToField AgentCommand where toField = toField . Binary . strEncode + +instance FromField AgentCommand where fromField = blobFieldParser strP + +instance ToField AgentCommandTag where toField = toField . Binary . strEncode + +instance FromField AgentCommandTag where fromField = blobFieldParser strP + +instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEncode + +instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 + +instance ToField (Version v) where toField (Version v) = toField v + +deriving newtype instance FromField (Version v) + +instance ToField EntityId where toField (EntityId s) = toField $ Binary s + +deriving newtype instance FromField EntityId + +deriving newtype instance ToField ChunkReplicaId + +deriving newtype instance FromField ChunkReplicaId + +listToEither :: e -> [a] -> Either e a +listToEither _ (x : _) = Right x +listToEither e _ = Left e + +firstRow :: (a -> b) -> e -> IO [a] -> IO (Either e b) +firstRow f e a = second f . listToEither e <$> a + +maybeFirstRow :: Functor f => (a -> b) -> f [a] -> f (Maybe b) +maybeFirstRow f q = fmap f . listToMaybe <$> q + +fromOnlyBI :: Only BoolInt -> Bool +fromOnlyBI (Only (BI b)) = b +{-# INLINE fromOnlyBI #-} + +firstRow' :: (a -> Either e b) -> e -> IO [a] -> IO (Either e b) +firstRow' f e a = (f <=< listToEither e) <$> a + +#if !defined(dbPostgres) +{- ORMOLU_DISABLE -} +-- SQLite.Simple only has these up to 10 fields, which is insufficient for some of our queries +instance (FromField a, FromField b, FromField c, FromField d, FromField e, + FromField f, FromField g, FromField h, FromField i, FromField j, + FromField k) => + FromRow (a,b,c,d,e,f,g,h,i,j,k) where + fromRow = (,,,,,,,,,,) <$> field <*> field <*> field <*> field <*> field + <*> field <*> field <*> field <*> field <*> field + <*> field + +instance (FromField a, FromField b, FromField c, FromField d, FromField e, + FromField f, FromField g, FromField h, FromField i, FromField j, + FromField k, FromField l) => + FromRow (a,b,c,d,e,f,g,h,i,j,k,l) where + fromRow = (,,,,,,,,,,,) <$> field <*> field <*> field <*> field <*> field + <*> field <*> field <*> field <*> field <*> field + <*> field <*> field + +instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, + ToField g, ToField h, ToField i, ToField j, ToField k, ToField l) => + ToRow (a,b,c,d,e,f,g,h,i,j,k,l) where + toRow (a,b,c,d,e,f,g,h,i,j,k,l) = + [ toField a, toField b, toField c, toField d, toField e, toField f, + toField g, toField h, toField i, toField j, toField k, toField l + ] + +{- ORMOLU_ENABLE -} +#endif + +-- * Server helper + +-- | Creates a new server, if it doesn't exist, and returns the passed key hash if it is different from stored. +createServer_ :: DB.Connection -> SMPServer -> IO (Maybe C.KeyHash) +createServer_ db newSrv@ProtocolServer {host, port, keyHash} = + getServerKeyHash_ db newSrv >>= \case + Right keyHash_ -> pure keyHash_ + Left _ -> insertNewServer_ $> Nothing + where + insertNewServer_ = + DB.execute db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?)" (host, port, keyHash) + +-- | Returns the passed server key hash if it is different from the stored one, or the error if the server does not exist. +getServerKeyHash_ :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe C.KeyHash)) +getServerKeyHash_ db ProtocolServer {host, port, keyHash} = do + firstRow useKeyHash SEServerNotFound $ + DB.query db "SELECT key_hash FROM servers WHERE host = ? AND port = ?" (host, port) + where + useKeyHash (Only keyHash') = if keyHash /= keyHash' then Just keyHash else Nothing + +upsertNtfServer_ :: DB.Connection -> NtfServer -> IO () +upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do + DB.execute + db + [sql| + INSERT INTO ntf_servers (ntf_host, ntf_port, ntf_key_hash) VALUES (?,?,?) + ON CONFLICT (ntf_host, ntf_port) DO UPDATE SET + ntf_host=excluded.ntf_host, + ntf_port=excluded.ntf_port, + ntf_key_hash=excluded.ntf_key_hash; + |] + (host, port, keyHash) + +-- * createRcvConn helpers + +insertRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> Maybe C.KeyHash -> IO RcvQueue +insertRcvQueue_ db connId' rq@RcvQueue {..} serverKeyHash_ = do + -- to preserve ID if the queue already exists. + -- possibly, it can be done in one query. + currQId_ <- maybeFirstRow fromOnly $ DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? AND host = ? AND port = ? AND snd_id = ?" (connId', host server, port server, sndId) + qId <- maybe (newQueueId_ <$> DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? ORDER BY rcv_queue_id DESC LIMIT 1" (Only connId')) pure currQId_ + DB.execute + db + [sql| + INSERT INTO rcv_queues + (host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, snd_secure, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); + |] + ((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, BI sndSecure, status, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)) + pure (rq :: NewRcvQueue) {connId = connId', dbQueueId = qId} + +-- * createSndConn helpers + +insertSndQueue_ :: DB.Connection -> ConnId -> NewSndQueue -> Maybe C.KeyHash -> IO SndQueue +insertSndQueue_ db connId' sq@SndQueue {..} serverKeyHash_ = do + -- to preserve ID if the queue already exists. + -- possibly, it can be done in one query. + currQId_ <- maybeFirstRow fromOnly $ DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? AND host = ? AND port = ? AND snd_id = ?" (connId', host server, port server, sndId) + qId <- maybe (newQueueId_ <$> DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? ORDER BY snd_queue_id DESC LIMIT 1" (Only connId')) pure currQId_ + DB.execute + db + [sql| + INSERT INTO snd_queues + (host, port, snd_id, snd_secure, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, + status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version, server_key_hash) + VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) + ON CONFLICT (host, port, snd_id) DO UPDATE SET + host=EXCLUDED.host, + port=EXCLUDED.port, + snd_id=EXCLUDED.snd_id, + snd_secure=EXCLUDED.snd_secure, + conn_id=EXCLUDED.conn_id, + snd_public_key=EXCLUDED.snd_public_key, + snd_private_key=EXCLUDED.snd_private_key, + e2e_pub_key=EXCLUDED.e2e_pub_key, + e2e_dh_secret=EXCLUDED.e2e_dh_secret, + status=EXCLUDED.status, + snd_queue_id=EXCLUDED.snd_queue_id, + snd_primary=EXCLUDED.snd_primary, + replace_snd_queue_id=EXCLUDED.replace_snd_queue_id, + smp_client_version=EXCLUDED.smp_client_version, + server_key_hash=EXCLUDED.server_key_hash + |] + ((host server, port server, sndId, BI sndSecure, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) + :. (status, qId, BI primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)) + pure (sq :: NewSndQueue) {connId = connId', dbQueueId = qId} + +newQueueId_ :: [Only Int64] -> DBQueueId 'QSStored +newQueueId_ [] = DBQueueId 1 +newQueueId_ (Only maxId : _) = DBQueueId (maxId + 1) + +-- * getConn helpers + +getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) +getConn = getAnyConn False +{-# INLINE getConn #-} + +getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) +getDeletedConn = getAnyConn True +{-# INLINE getDeletedConn #-} + +getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn) +getAnyConn deleted' dbConn connId = + getConnData dbConn connId >>= \case + Nothing -> pure $ Left SEConnNotFound + Just (cData@ConnData {deleted}, cMode) + | deleted /= deleted' -> pure $ Left SEConnNotFound + | otherwise -> do + rQ <- getRcvQueuesByConnId_ dbConn connId + sQ <- getSndQueuesByConnId_ dbConn connId + pure $ case (rQ, sQ, cMode) of + (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) + (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) + (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) + (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) + (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) + _ -> Left SEConnNotFound + +getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] +getConns = getAnyConns_ False +{-# INLINE getConns #-} + +getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] +getDeletedConns = getAnyConns_ True +{-# INLINE getDeletedConns #-} + +getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] +getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db + where + handleDBError :: E.SomeException -> IO (Either StoreError SomeConn) + handleDBError = pure . Left . SEInternal . bshow + +getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) +getConnData db connId' = + maybeFirstRow cData $ + DB.query + db + [sql| + SELECT + user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, + last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support + FROM connections + WHERE conn_id = ? + |] + (Only connId') + where + cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) = + (ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) + +setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO () +setConnDeleted db waitDelivery connId + | waitDelivery = do + currentTs <- getCurrentTime + DB.execute db "UPDATE connections SET deleted_at_wait_delivery = ? WHERE conn_id = ?" (currentTs, connId) + | otherwise = + DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (BI True, connId) + +setConnUserId :: DB.Connection -> UserId -> ConnId -> UserId -> IO () +setConnUserId db oldUserId connId newUserId = + DB.execute db "UPDATE connections SET user_id = ? WHERE conn_id = ? and user_id = ?" (newUserId, connId, oldUserId) + +setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () +setConnAgentVersion db connId aVersion = + DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) + +setConnPQSupport :: DB.Connection -> ConnId -> PQSupport -> IO () +setConnPQSupport db connId pqSupport = + DB.execute db "UPDATE connections SET pq_support = ? WHERE conn_id = ?" (pqSupport, connId) + +getDeletedConnIds :: DB.Connection -> IO [ConnId] +getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only (BI True)) + +getDeletedWaitingDeliveryConnIds :: DB.Connection -> IO [ConnId] +getDeletedWaitingDeliveryConnIds db = + map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted_at_wait_delivery IS NOT NULL" + +setConnRatchetSync :: DB.Connection -> ConnId -> RatchetSyncState -> IO () +setConnRatchetSync db connId ratchetSyncState = + DB.execute db "UPDATE connections SET ratchet_sync_state = ? WHERE conn_id = ?" (ratchetSyncState, connId) + +addProcessedRatchetKeyHash :: DB.Connection -> ConnId -> ByteString -> IO () +addProcessedRatchetKeyHash db connId hash = + DB.execute db "INSERT INTO processed_ratchet_key_hashes (conn_id, hash) VALUES (?,?)" (connId, Binary hash) + +checkRatchetKeyHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool +checkRatchetKeyHashExists db connId hash = do + fromMaybe False + <$> maybeFirstRow + fromOnly + ( DB.query + db + "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" + (connId, Binary hash) + ) + +deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO () +deleteRatchetKeyHashesExpired db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + DB.execute db "DELETE FROM processed_ratchet_key_hashes WHERE created_at < ?" (Only cutoffTs) + +-- | returns all connection queues, the first queue is the primary one +getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue)) +getRcvQueuesByConnId_ db connId = + L.nonEmpty . sortBy primaryFirst . map toRcvQueue + <$> DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.deleted = 0") (Only connId) + where + primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} = + -- the current primary queue is ordered first, the next primary - second + compare (Down p) (Down p') <> compare i i' + +rcvQueueQuery :: Query +rcvQueueQuery = + [sql| + SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, + q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.snd_secure, q.status, + q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.switch_status, q.smp_client_version, q.delete_errors, + q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret + FROM rcv_queues q + JOIN servers s ON q.host = s.host AND q.port = s.port + JOIN connections c ON q.conn_id = c.conn_id + |] + +toRcvQueue :: + (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, BoolInt) + :. (QueueStatus, DBQueueId 'QSStored, BoolInt, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) + :. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) -> + RcvQueue +toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, BI sndSecure) :. (status, dbQueueId, BI primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = + let server = SMPServer host port keyHash + smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ + clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of + (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} + _ -> Nothing + in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndSecure, status, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors} + +getRcvQueueById :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue) +getRcvQueueById db connId dbRcvId = + firstRow toRcvQueue SEConnNotFound $ + DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.rcv_queue_id = ? AND q.deleted = 0") (connId, dbRcvId) + +-- | returns all connection queues, the first queue is the primary one +getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue)) +getSndQueuesByConnId_ dbConn connId = + L.nonEmpty . sortBy primaryFirst . map toSndQueue + <$> DB.query dbConn (sndQueueQuery <> " WHERE q.conn_id = ?") (Only connId) + where + primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} = + -- the current primary queue is ordered first, the next primary - second + compare (Down p) (Down p') <> compare i i' + +sndQueueQuery :: Query +sndQueueQuery = + [sql| + SELECT + c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.snd_id, q.snd_secure, + q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, + q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.switch_status, q.smp_client_version + FROM snd_queues q + JOIN servers s ON q.host = s.host AND q.port = s.port + JOIN connections c ON q.conn_id = c.conn_id + |] + +toSndQueue :: + (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId, BoolInt) + :. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus) + :. (DBQueueId 'QSStored, BoolInt, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) -> + SndQueue +toSndQueue + ( (userId, keyHash, connId, host, port, sndId, BI sndSecure) + :. (sndPubKey, sndPrivateKey@(C.APrivateAuthKey a pk), e2ePubKey, e2eDhSecret, status) + :. (dbQueueId, BI primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion) + ) = + let server = SMPServer host port keyHash + sndPublicKey = fromMaybe (C.APublicAuthKey a (C.publicKey pk)) sndPubKey + in SndQueue {userId, connId, server, sndId, sndSecure, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion} + +getSndQueueById :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError SndQueue) +getSndQueueById db connId dbSndId = + firstRow toSndQueue SEConnNotFound $ + DB.query db (sndQueueQuery <> " WHERE q.conn_id = ? AND q.snd_queue_id = ?") (connId, dbSndId) + +-- * updateRcvIds helpers + +retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) +retrieveLastIdsAndHashRcv_ dbConn connId = do + [(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <- + DB.query + dbConn + [sql| + SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash + FROM connections + WHERE conn_id = ? + |] + (Only connId) + return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) + +updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO () +updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId = + DB.execute + dbConn + [sql| + UPDATE connections + SET last_internal_msg_id = ?, + last_internal_rcv_msg_id = ? + WHERE conn_id = ? + |] + (newInternalId, newInternalRcvId, connId) + +-- * createRcvMsg helpers + +insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () +insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody, internalRcvId} = do + let MsgMeta {recipient = (internalId, internalTs), pqEncryption} = msgMeta + DB.execute + dbConn + [sql| + INSERT INTO messages + (conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption) + VALUES (?,?,?,?,?,?,?,?,?); + |] + (connId, internalId, internalTs, internalRcvId, Nothing :: Maybe Int64, msgType, msgFlags, Binary msgBody, pqEncryption) + +insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO () +insertRcvMsgDetails_ db connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash, encryptedMsgHash} = do + let MsgMeta {integrity, recipient, broker, sndMsgId} = msgMeta + DB.execute + db + [sql| + INSERT INTO rcv_messages + ( conn_id, rcv_queue_id, internal_rcv_id, internal_id, external_snd_id, + broker_id, broker_ts, + internal_hash, external_prev_snd_hash, integrity) + VALUES + (?,?,?,?,?,?,?,?,?,?) + |] + (connId, dbQueueId, internalRcvId, fst recipient, sndMsgId, Binary (fst broker), snd broker, Binary internalHash, Binary externalPrevSndHash, integrity) + DB.execute db "INSERT INTO encrypted_rcv_message_hashes (conn_id, hash) VALUES (?,?)" (connId, Binary encryptedMsgHash) + +updateRcvMsgHash :: DB.Connection -> ConnId -> AgentMsgId -> InternalRcvId -> MsgHash -> IO () +updateRcvMsgHash db connId sndMsgId internalRcvId internalHash = + DB.execute + db + -- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved + [sql| + UPDATE connections + SET last_external_snd_msg_id = ?, + last_rcv_msg_hash = ? + WHERE conn_id = ? + AND last_internal_rcv_msg_id = ? + |] + (sndMsgId, Binary internalHash, connId, internalRcvId) + +-- * updateSndIds helpers + +retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (Either StoreError (InternalId, InternalSndId, PrevSndMsgHash)) +retrieveLastIdsAndHashSnd_ dbConn connId = do + firstRow id SEConnNotFound $ + DB.query + dbConn + [sql| + SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash + FROM connections + WHERE conn_id = ? + |] + (Only connId) + +updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO () +updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId = + DB.execute + dbConn + [sql| + UPDATE connections + SET last_internal_msg_id = ?, + last_internal_snd_msg_id = ? + WHERE conn_id = ? + |] + (newInternalId, newInternalSndId, connId) + +-- * createSndMsg helpers + +insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO () +insertSndMsgBase_ db connId SndMsgData {internalId, internalTs, internalSndId, msgType, msgFlags, msgBody, pqEncryption} = do + DB.execute + db + [sql| + INSERT INTO messages + (conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption) + VALUES + (?,?,?,?,?,?,?,?,?); + |] + (connId, internalId, internalTs, Nothing :: Maybe Int64, internalSndId, msgType, msgFlags, Binary msgBody, pqEncryption) + +insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO () +insertSndMsgDetails_ dbConn connId SndMsgData {..} = + DB.execute + dbConn + [sql| + INSERT INTO snd_messages + ( conn_id, internal_snd_id, internal_id, internal_hash, previous_msg_hash) + VALUES + (?,?,?,?,?) + |] + (connId, internalSndId, internalId, Binary internalHash, Binary prevMsgHash) + +updateSndMsgHash :: DB.Connection -> ConnId -> InternalSndId -> MsgHash -> IO () +updateSndMsgHash db connId internalSndId internalHash = + DB.execute + db + -- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved + [sql| + UPDATE connections + SET last_snd_msg_hash = ? + WHERE conn_id = ? + AND last_internal_snd_msg_id = ?; + |] + (Binary internalHash, connId, internalSndId) + +-- create record with a random ID +createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString) +createWithRandomId gVar create = fst <$$> createWithRandomId' gVar create + +createWithRandomId' :: forall a. TVar ChaChaDRG -> (ByteString -> IO a) -> IO (Either StoreError (ByteString, a)) +createWithRandomId' gVar create = tryCreate 3 + where + tryCreate :: Int -> IO (Either StoreError (ByteString, a)) + tryCreate 0 = pure $ Left SEUniqueID + tryCreate n = do + id' <- randomId gVar 12 + E.try (create id') >>= \case + Right r -> pure $ Right (id', r) + Left e -> handleErr n e +#if defined(dbPostgres) + handleErr n e = case constraintViolation e of + Just _ -> tryCreate (n - 1) + Nothing -> pure . Left . SEInternal $ bshow e +#else + handleErr n e + | SQL.sqlError e == SQL.ErrorConstraint = tryCreate (n - 1) + | otherwise = pure . Left . SEInternal $ bshow e +#endif + +randomId :: TVar ChaChaDRG -> Int -> IO ByteString +randomId gVar n = atomically $ U.encode <$> C.randomBytes n gVar + +ntfSubAndSMPAction :: NtfSubAction -> (Maybe NtfSubNTFAction, Maybe NtfSubSMPAction) +ntfSubAndSMPAction (NSANtf action) = (Just action, Nothing) +ntfSubAndSMPAction (NSASMP action) = (Nothing, Just action) + +createXFTPServer_ :: DB.Connection -> XFTPServer -> IO Int64 +createXFTPServer_ db newSrv@ProtocolServer {host, port, keyHash} = + getXFTPServerId_ db newSrv >>= \case + Right srvId -> pure srvId + Left _ -> insertNewServer_ + where + insertNewServer_ = do + DB.execute db "INSERT INTO xftp_servers (xftp_host, xftp_port, xftp_key_hash) VALUES (?,?,?)" (host, port, keyHash) + insertedRowId db + +getXFTPServerId_ :: DB.Connection -> XFTPServer -> IO (Either StoreError Int64) +getXFTPServerId_ db ProtocolServer {host, port, keyHash} = do + firstRow fromOnly SEXFTPServerNotFound $ + DB.query db "SELECT xftp_server_id FROM xftp_servers WHERE xftp_host = ? AND xftp_port = ? AND xftp_key_hash = ?" (host, port, keyHash) + +createRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Bool -> IO (Either StoreError RcvFileId) +createRcvFile db gVar userId fd@FileDescription {chunks} prefixPath tmpPath file approvedRelays = runExceptT $ do + (rcvFileEntityId, rcvFileId) <- ExceptT $ insertRcvFile db gVar userId fd prefixPath tmpPath file Nothing Nothing approvedRelays + liftIO $ + forM_ chunks $ \fc@FileChunk {replicas} -> do + chunkId <- insertRcvFileChunk db fc rcvFileId + forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId + pure rcvFileEntityId + +createRcvFileRedirect :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> FilePath -> CryptoFile -> Bool -> IO (Either StoreError RcvFileId) +createRcvFileRedirect _ _ _ FileDescription {redirect = Nothing} _ _ _ _ _ _ = pure $ Left $ SEInternal "createRcvFileRedirect called without redirect" +createRcvFileRedirect db gVar userId redirectFd@FileDescription {chunks = redirectChunks, redirect = Just RedirectFileInfo {size, digest}} prefixPath redirectPath redirectFile dstPath dstFile approvedRelays = runExceptT $ do + (dstEntityId, dstId) <- ExceptT $ insertRcvFile db gVar userId dummyDst prefixPath dstPath dstFile Nothing Nothing approvedRelays + (_, redirectId) <- ExceptT $ insertRcvFile db gVar userId redirectFd prefixPath redirectPath redirectFile (Just dstId) (Just dstEntityId) approvedRelays + liftIO $ + forM_ redirectChunks $ \fc@FileChunk {replicas} -> do + chunkId <- insertRcvFileChunk db fc redirectId + forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId + pure dstEntityId + where + dummyDst = + FileDescription + { party = SFRecipient, + size, + digest, + redirect = Nothing, + -- updated later with updateRcvFileRedirect + key = C.unsafeSbKey $ B.replicate 32 '#', + nonce = C.cbNonce "", + chunkSize = FileSize 0, + chunks = [] + } + +insertRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Maybe DBRcvFileId -> Maybe RcvFileId -> Bool -> IO (Either StoreError (RcvFileId, DBRcvFileId)) +insertRcvFile db gVar userId FileDescription {size, digest, key, nonce, chunkSize, redirect} prefixPath tmpPath (CryptoFile savePath cfArgs) redirectId_ redirectEntityId_ approvedRelays = runExceptT $ do + let (redirectDigest_, redirectSize_) = case redirect of + Just RedirectFileInfo {digest = d, size = s} -> (Just d, Just s) + Nothing -> (Nothing, Nothing) + rcvFileEntityId <- ExceptT $ + createWithRandomId gVar $ \rcvFileEntityId -> + DB.execute + db + "INSERT INTO rcv_files (rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, redirect_id, redirect_entity_id, redirect_digest, redirect_size, approved_relays) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" + ((Binary rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, fileKey <$> cfArgs, fileNonce <$> cfArgs, RFSReceiving, redirectId_, Binary <$> redirectEntityId_, redirectDigest_, redirectSize_, BI approvedRelays)) + rcvFileId <- liftIO $ insertedRowId db + pure (rcvFileEntityId, rcvFileId) + +insertRcvFileChunk :: DB.Connection -> FileChunk -> DBRcvFileId -> IO Int64 +insertRcvFileChunk db FileChunk {chunkNo, chunkSize, digest} rcvFileId = do + DB.execute + db + "INSERT INTO rcv_file_chunks (rcv_file_id, chunk_no, chunk_size, digest) VALUES (?,?,?,?)" + (rcvFileId, chunkNo, chunkSize, digest) + insertedRowId db + +insertRcvFileChunkReplica :: DB.Connection -> Int -> FileChunkReplica -> Int64 -> IO () +insertRcvFileChunkReplica db replicaNo FileChunkReplica {server, replicaId, replicaKey} chunkId = do + srvId <- createXFTPServer_ db server + DB.execute + db + "INSERT INTO rcv_file_chunk_replicas (replica_number, rcv_file_chunk_id, xftp_server_id, replica_id, replica_key) VALUES (?,?,?,?,?)" + (replicaNo, chunkId, srvId, replicaId, replicaKey) + +getRcvFileByEntityId :: DB.Connection -> RcvFileId -> IO (Either StoreError RcvFile) +getRcvFileByEntityId db rcvFileEntityId = runExceptT $ do + rcvFileId <- ExceptT $ getRcvFileIdByEntityId_ db rcvFileEntityId + ExceptT $ getRcvFile db rcvFileId + +getRcvFileIdByEntityId_ :: DB.Connection -> RcvFileId -> IO (Either StoreError DBRcvFileId) +getRcvFileIdByEntityId_ db rcvFileEntityId = + firstRow fromOnly SEFileNotFound $ + DB.query db "SELECT rcv_file_id FROM rcv_files WHERE rcv_file_entity_id = ?" (Only (Binary rcvFileEntityId)) + +getRcvFileRedirects :: DB.Connection -> DBRcvFileId -> IO [RcvFile] +getRcvFileRedirects db rcvFileId = do + redirects <- fromOnly <$$> DB.query db "SELECT rcv_file_id FROM rcv_files WHERE redirect_id = ?" (Only rcvFileId) + fmap catMaybes . forM redirects $ getRcvFile db >=> either (const $ pure Nothing) (pure . Just) + +getRcvFile :: DB.Connection -> DBRcvFileId -> IO (Either StoreError RcvFile) +getRcvFile db rcvFileId = runExceptT $ do + f@RcvFile {rcvFileEntityId, userId, tmpPath} <- ExceptT getFile + chunks <- maybe (pure []) (liftIO . getChunks rcvFileEntityId userId) tmpPath + pure (f {chunks} :: RcvFile) + where + getFile :: IO (Either StoreError RcvFile) + getFile = do + firstRow toFile SEFileNotFound $ + DB.query + db + [sql| + SELECT rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, deleted, redirect_id, redirect_entity_id, redirect_size, redirect_digest + FROM rcv_files + WHERE rcv_file_id = ? + |] + (Only rcvFileId) + where + toFile :: (RcvFileId, UserId, FileSize Int64, FileDigest, C.SbKey, C.CbNonce, FileSize Word32, FilePath, Maybe FilePath) :. (FilePath, Maybe C.SbKey, Maybe C.CbNonce, RcvFileStatus, BoolInt, Maybe DBRcvFileId, Maybe RcvFileId, Maybe (FileSize Int64), Maybe FileDigest) -> RcvFile + toFile ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, saveKey_, saveNonce_, status, BI deleted, redirectDbId, redirectEntityId, redirectSize_, redirectDigest_)) = + let cfArgs = CFArgs <$> saveKey_ <*> saveNonce_ + saveFile = CryptoFile savePath cfArgs + redirect = + RcvFileRedirect + <$> redirectDbId + <*> redirectEntityId + <*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_) + in RcvFile {rcvFileId, rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, redirect, prefixPath, tmpPath, saveFile, status, deleted, chunks = []} + getChunks :: RcvFileId -> UserId -> FilePath -> IO [RcvFileChunk] + getChunks rcvFileEntityId userId fileTmpPath = do + chunks <- + map toChunk + <$> DB.query + db + [sql| + SELECT rcv_file_chunk_id, chunk_no, chunk_size, digest, tmp_path + FROM rcv_file_chunks + WHERE rcv_file_id = ? + |] + (Only rcvFileId) + forM chunks $ \chunk@RcvFileChunk {rcvChunkId} -> do + replicas' <- getChunkReplicas rcvChunkId + pure (chunk {replicas = replicas'} :: RcvFileChunk) + where + toChunk :: (Int64, Int, FileSize Word32, FileDigest, Maybe FilePath) -> RcvFileChunk + toChunk (rcvChunkId, chunkNo, chunkSize, digest, chunkTmpPath) = + RcvFileChunk {rcvFileId, rcvFileEntityId, userId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath, chunkTmpPath, replicas = []} + getChunkReplicas :: Int64 -> IO [RcvFileChunkReplica] + getChunkReplicas chunkId = do + map toReplica + <$> DB.query + db + [sql| + SELECT + r.rcv_file_chunk_replica_id, r.replica_id, r.replica_key, r.received, r.delay, r.retries, + s.xftp_host, s.xftp_port, s.xftp_key_hash + FROM rcv_file_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + WHERE r.rcv_file_chunk_id = ? + |] + (Only chunkId) + where + toReplica :: (Int64, ChunkReplicaId, C.APrivateAuthKey, BoolInt, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> RcvFileChunkReplica + toReplica (rcvChunkReplicaId, replicaId, replicaKey, BI received, delay, retries, host, port, keyHash) = + let server = XFTPServer host port keyHash + in RcvFileChunkReplica {rcvChunkReplicaId, server, replicaId, replicaKey, received, delay, retries} + +updateRcvChunkReplicaDelay :: DB.Connection -> Int64 -> Int64 -> IO () +updateRcvChunkReplicaDelay db replicaId delay = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE rcv_file_chunk_replicas SET delay = ?, retries = retries + 1, updated_at = ? WHERE rcv_file_chunk_replica_id = ?" (delay, updatedAt, replicaId) + +updateRcvFileChunkReceived :: DB.Connection -> Int64 -> Int64 -> FilePath -> IO () +updateRcvFileChunkReceived db replicaId chunkId chunkTmpPath = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE rcv_file_chunk_replicas SET received = 1, updated_at = ? WHERE rcv_file_chunk_replica_id = ?" (updatedAt, replicaId) + DB.execute db "UPDATE rcv_file_chunks SET tmp_path = ?, updated_at = ? WHERE rcv_file_chunk_id = ?" (chunkTmpPath, updatedAt, chunkId) + +updateRcvFileStatus :: DB.Connection -> DBRcvFileId -> RcvFileStatus -> IO () +updateRcvFileStatus db rcvFileId status = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE rcv_files SET status = ?, updated_at = ? WHERE rcv_file_id = ?" (status, updatedAt, rcvFileId) + +updateRcvFileError :: DB.Connection -> DBRcvFileId -> String -> IO () +updateRcvFileError db rcvFileId errStr = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE rcv_files SET tmp_path = NULL, error = ?, status = ?, updated_at = ? WHERE rcv_file_id = ?" (errStr, RFSError, updatedAt, rcvFileId) + +updateRcvFileComplete :: DB.Connection -> DBRcvFileId -> IO () +updateRcvFileComplete db rcvFileId = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE rcv_files SET tmp_path = NULL, status = ?, updated_at = ? WHERE rcv_file_id = ?" (RFSComplete, updatedAt, rcvFileId) + +updateRcvFileRedirect :: DB.Connection -> DBRcvFileId -> FileDescription 'FRecipient -> IO (Either StoreError ()) +updateRcvFileRedirect db rcvFileId FileDescription {key, nonce, chunkSize, chunks} = runExceptT $ do + updatedAt <- liftIO getCurrentTime + liftIO $ DB.execute db "UPDATE rcv_files SET key = ?, nonce = ?, chunk_size = ?, updated_at = ? WHERE rcv_file_id = ?" (key, nonce, chunkSize, updatedAt, rcvFileId) + liftIO $ forM_ chunks $ \fc@FileChunk {replicas} -> do + chunkId <- insertRcvFileChunk db fc rcvFileId + forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId + +updateRcvFileNoTmpPath :: DB.Connection -> DBRcvFileId -> IO () +updateRcvFileNoTmpPath db rcvFileId = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE rcv_files SET tmp_path = NULL, updated_at = ? WHERE rcv_file_id = ?" (updatedAt, rcvFileId) + +updateRcvFileDeleted :: DB.Connection -> DBRcvFileId -> IO () +updateRcvFileDeleted db rcvFileId = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE rcv_files SET deleted = 1, updated_at = ? WHERE rcv_file_id = ?" (updatedAt, rcvFileId) + +deleteRcvFile' :: DB.Connection -> DBRcvFileId -> IO () +deleteRcvFile' db rcvFileId = + DB.execute db "DELETE FROM rcv_files WHERE rcv_file_id = ?" (Only rcvFileId) + +getNextRcvChunkToDownload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Either StoreError (Maybe (RcvFileChunk, Bool, Maybe RcvFileId))) +getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = do + getWorkItem "rcv_file_download" getReplicaId getChunkData (markRcvFileFailed db . snd) + where + getReplicaId :: IO (Maybe (Int64, DBRcvFileId)) + getReplicaId = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + maybeFirstRow id $ + DB.query + db + [sql| + SELECT r.rcv_file_chunk_replica_id, f.rcv_file_id + FROM rcv_file_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + JOIN rcv_file_chunks c ON c.rcv_file_chunk_id = r.rcv_file_chunk_id + JOIN rcv_files f ON f.rcv_file_id = c.rcv_file_id + WHERE s.xftp_host = ? AND s.xftp_port = ? AND s.xftp_key_hash = ? + AND r.received = 0 AND r.replica_number = 1 + AND f.status = ? AND f.deleted = 0 AND f.created_at >= ? + AND f.failed = 0 + ORDER BY r.retries ASC, r.created_at ASC + LIMIT 1 + |] + (host, port, keyHash, RFSReceiving, cutoffTs) + getChunkData :: (Int64, DBRcvFileId) -> IO (Either StoreError (RcvFileChunk, Bool, Maybe RcvFileId)) + getChunkData (rcvFileChunkReplicaId, _fileId) = + firstRow toChunk SEFileNotFound $ + DB.query + db + [sql| + SELECT + f.rcv_file_id, f.rcv_file_entity_id, f.user_id, c.rcv_file_chunk_id, c.chunk_no, c.chunk_size, c.digest, f.tmp_path, c.tmp_path, + r.rcv_file_chunk_replica_id, r.replica_id, r.replica_key, r.received, r.delay, r.retries, + f.approved_relays, f.redirect_entity_id + FROM rcv_file_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + JOIN rcv_file_chunks c ON c.rcv_file_chunk_id = r.rcv_file_chunk_id + JOIN rcv_files f ON f.rcv_file_id = c.rcv_file_id + WHERE r.rcv_file_chunk_replica_id = ? + |] + (Only rcvFileChunkReplicaId) + where + toChunk :: ((DBRcvFileId, RcvFileId, UserId, Int64, Int, FileSize Word32, FileDigest, FilePath, Maybe FilePath) :. (Int64, ChunkReplicaId, C.APrivateAuthKey, BoolInt, Maybe Int64, Int) :. (BoolInt, Maybe RcvFileId)) -> (RcvFileChunk, Bool, Maybe RcvFileId) + toChunk ((rcvFileId, rcvFileEntityId, userId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath, chunkTmpPath) :. (rcvChunkReplicaId, replicaId, replicaKey, BI received, delay, retries) :. (BI approvedRelays, redirectEntityId_)) = + ( RcvFileChunk + { rcvFileId, + rcvFileEntityId, + userId, + rcvChunkId, + chunkNo, + chunkSize, + digest, + fileTmpPath, + chunkTmpPath, + replicas = [RcvFileChunkReplica {rcvChunkReplicaId, server, replicaId, replicaKey, received, delay, retries}] + }, + approvedRelays, + redirectEntityId_ + ) + +getNextRcvFileToDecrypt :: DB.Connection -> NominalDiffTime -> IO (Either StoreError (Maybe RcvFile)) +getNextRcvFileToDecrypt db ttl = + getWorkItem "rcv_file_decrypt" getFileId (getRcvFile db) (markRcvFileFailed db) + where + getFileId :: IO (Maybe DBRcvFileId) + getFileId = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + maybeFirstRow fromOnly $ + DB.query + db + [sql| + SELECT rcv_file_id + FROM rcv_files + WHERE status IN (?,?) AND deleted = 0 AND created_at >= ? + AND failed = 0 + ORDER BY created_at ASC LIMIT 1 + |] + (RFSReceived, RFSDecrypting, cutoffTs) + +markRcvFileFailed :: DB.Connection -> DBRcvFileId -> IO () +markRcvFileFailed db fileId = do + DB.execute db "UPDATE rcv_files SET failed = 1 WHERE rcv_file_id = ?" (Only fileId) + +getPendingRcvFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer] +getPendingRcvFilesServers db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + map toXFTPServer + <$> DB.query + db + [sql| + SELECT DISTINCT + s.xftp_host, s.xftp_port, s.xftp_key_hash + FROM rcv_file_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + JOIN rcv_file_chunks c ON c.rcv_file_chunk_id = r.rcv_file_chunk_id + JOIN rcv_files f ON f.rcv_file_id = c.rcv_file_id + WHERE r.received = 0 AND r.replica_number = 1 + AND f.status = ? AND f.deleted = 0 AND f.created_at >= ? + |] + (RFSReceiving, cutoffTs) + +toXFTPServer :: (NonEmpty TransportHost, ServiceName, C.KeyHash) -> XFTPServer +toXFTPServer (host, port, keyHash) = XFTPServer host port keyHash + +getCleanupRcvFilesTmpPaths :: DB.Connection -> IO [(DBRcvFileId, RcvFileId, FilePath)] +getCleanupRcvFilesTmpPaths db = + DB.query + db + [sql| + SELECT rcv_file_id, rcv_file_entity_id, tmp_path + FROM rcv_files + WHERE status IN (?,?) AND tmp_path IS NOT NULL + |] + (RFSComplete, RFSError) + +getCleanupRcvFilesDeleted :: DB.Connection -> IO [(DBRcvFileId, RcvFileId, FilePath)] +getCleanupRcvFilesDeleted db = + DB.query_ + db + [sql| + SELECT rcv_file_id, rcv_file_entity_id, prefix_path + FROM rcv_files + WHERE deleted = 1 + |] + +getRcvFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBRcvFileId, RcvFileId, FilePath)] +getRcvFilesExpired db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + DB.query + db + [sql| + SELECT rcv_file_id, rcv_file_entity_id, prefix_path + FROM rcv_files + WHERE created_at < ? + |] + (Only cutoffTs) + +createSndFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> CryptoFile -> Int -> FilePath -> C.SbKey -> C.CbNonce -> Maybe RedirectFileInfo -> IO (Either StoreError SndFileId) +createSndFile db gVar userId (CryptoFile path cfArgs) numRecipients prefixPath key nonce redirect_ = + createWithRandomId gVar $ \sndFileEntityId -> + DB.execute + db + "INSERT INTO snd_files (snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, prefix_path, key, nonce, status, redirect_size, redirect_digest) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)" + ((Binary sndFileEntityId, userId, path, fileKey <$> cfArgs, fileNonce <$> cfArgs, numRecipients) :. (prefixPath, key, nonce, SFSNew, redirectSize_, redirectDigest_)) + where + (redirectSize_, redirectDigest_) = + case redirect_ of + Nothing -> (Nothing, Nothing) + Just RedirectFileInfo {size, digest} -> (Just size, Just digest) + +getSndFileByEntityId :: DB.Connection -> SndFileId -> IO (Either StoreError SndFile) +getSndFileByEntityId db sndFileEntityId = runExceptT $ do + sndFileId <- ExceptT $ getSndFileIdByEntityId_ db sndFileEntityId + ExceptT $ getSndFile db sndFileId + +getSndFileIdByEntityId_ :: DB.Connection -> SndFileId -> IO (Either StoreError DBSndFileId) +getSndFileIdByEntityId_ db sndFileEntityId = + firstRow fromOnly SEFileNotFound $ + DB.query db "SELECT snd_file_id FROM snd_files WHERE snd_file_entity_id = ?" (Only (Binary sndFileEntityId)) + +getSndFile :: DB.Connection -> DBSndFileId -> IO (Either StoreError SndFile) +getSndFile db sndFileId = runExceptT $ do + f@SndFile {sndFileEntityId, userId, numRecipients, prefixPath} <- ExceptT getFile + chunks <- maybe (pure []) (liftIO . getChunks sndFileEntityId userId numRecipients) prefixPath + pure (f {chunks} :: SndFile) + where + getFile :: IO (Either StoreError SndFile) + getFile = do + firstRow toFile SEFileNotFound $ + DB.query + db + [sql| + SELECT snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, digest, prefix_path, key, nonce, status, deleted, redirect_size, redirect_digest + FROM snd_files + WHERE snd_file_id = ? + |] + (Only sndFileId) + where + toFile :: (SndFileId, UserId, FilePath, Maybe C.SbKey, Maybe C.CbNonce, Int, Maybe FileDigest, Maybe FilePath, C.SbKey, C.CbNonce) :. (SndFileStatus, BoolInt, Maybe (FileSize Int64), Maybe FileDigest) -> SndFile + toFile ((sndFileEntityId, userId, srcPath, srcKey_, srcNonce_, numRecipients, digest, prefixPath, key, nonce) :. (status, BI deleted, redirectSize_, redirectDigest_)) = + let cfArgs = CFArgs <$> srcKey_ <*> srcNonce_ + srcFile = CryptoFile srcPath cfArgs + redirect = RedirectFileInfo <$> redirectSize_ <*> redirectDigest_ + in SndFile {sndFileId, sndFileEntityId, userId, srcFile, numRecipients, digest, prefixPath, key, nonce, status, deleted, redirect, chunks = []} + getChunks :: SndFileId -> UserId -> Int -> FilePath -> IO [SndFileChunk] + getChunks sndFileEntityId userId numRecipients filePrefixPath = do + chunks <- + map toChunk + <$> DB.query + db + [sql| + SELECT snd_file_chunk_id, chunk_no, chunk_offset, chunk_size, digest + FROM snd_file_chunks + WHERE snd_file_id = ? + |] + (Only sndFileId) + forM chunks $ \chunk@SndFileChunk {sndChunkId} -> do + replicas' <- getChunkReplicas sndChunkId + pure (chunk {replicas = replicas'} :: SndFileChunk) + where + toChunk :: (Int64, Int, Int64, Word32, FileDigest) -> SndFileChunk + toChunk (sndChunkId, chunkNo, chunkOffset, chunkSize, digest) = + let chunkSpec = XFTPChunkSpec {filePath = sndFileEncPath filePrefixPath, chunkOffset, chunkSize} + in SndFileChunk {sndFileId, sndFileEntityId, userId, numRecipients, sndChunkId, chunkNo, chunkSpec, filePrefixPath, digest, replicas = []} + getChunkReplicas :: Int64 -> IO [SndFileChunkReplica] + getChunkReplicas chunkId = do + replicas <- + map toReplica + <$> DB.query + db + [sql| + SELECT + r.snd_file_chunk_replica_id, r.replica_id, r.replica_key, r.replica_status, r.delay, r.retries, + s.xftp_host, s.xftp_port, s.xftp_key_hash + FROM snd_file_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + WHERE r.snd_file_chunk_id = ? + |] + (Only chunkId) + forM replicas $ \replica@SndFileChunkReplica {sndChunkReplicaId} -> do + rcvIdsKeys <- getChunkReplicaRecipients_ db sndChunkReplicaId + pure (replica :: SndFileChunkReplica) {rcvIdsKeys} + where + toReplica :: (Int64, ChunkReplicaId, C.APrivateAuthKey, SndFileReplicaStatus, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> SndFileChunkReplica + toReplica (sndChunkReplicaId, replicaId, replicaKey, replicaStatus, delay, retries, host, port, keyHash) = + let server = XFTPServer host port keyHash + in SndFileChunkReplica {sndChunkReplicaId, server, replicaId, replicaKey, replicaStatus, delay, retries, rcvIdsKeys = []} + +getChunkReplicaRecipients_ :: DB.Connection -> Int64 -> IO [(ChunkReplicaId, C.APrivateAuthKey)] +getChunkReplicaRecipients_ db replicaId = + DB.query + db + [sql| + SELECT rcv_replica_id, rcv_replica_key + FROM snd_file_chunk_replica_recipients + WHERE snd_file_chunk_replica_id = ? + |] + (Only replicaId) + +getNextSndFileToPrepare :: DB.Connection -> NominalDiffTime -> IO (Either StoreError (Maybe SndFile)) +getNextSndFileToPrepare db ttl = + getWorkItem "snd_file_prepare" getFileId (getSndFile db) (markSndFileFailed db) + where + getFileId :: IO (Maybe DBSndFileId) + getFileId = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + maybeFirstRow fromOnly $ + DB.query + db + [sql| + SELECT snd_file_id + FROM snd_files + WHERE status IN (?,?,?) AND deleted = 0 AND created_at >= ? + AND failed = 0 + ORDER BY created_at ASC LIMIT 1 + |] + (SFSNew, SFSEncrypting, SFSEncrypted, cutoffTs) + +markSndFileFailed :: DB.Connection -> DBSndFileId -> IO () +markSndFileFailed db fileId = + DB.execute db "UPDATE snd_files SET failed = 1 WHERE snd_file_id = ?" (Only fileId) + +updateSndFileError :: DB.Connection -> DBSndFileId -> String -> IO () +updateSndFileError db sndFileId errStr = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE snd_files SET prefix_path = NULL, error = ?, status = ?, updated_at = ? WHERE snd_file_id = ?" (errStr, SFSError, updatedAt, sndFileId) + +updateSndFileStatus :: DB.Connection -> DBSndFileId -> SndFileStatus -> IO () +updateSndFileStatus db sndFileId status = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE snd_files SET status = ?, updated_at = ? WHERE snd_file_id = ?" (status, updatedAt, sndFileId) + +updateSndFileEncrypted :: DB.Connection -> DBSndFileId -> FileDigest -> [(XFTPChunkSpec, FileDigest)] -> IO () +updateSndFileEncrypted db sndFileId digest chunkSpecsDigests = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE snd_files SET status = ?, digest = ?, updated_at = ? WHERE snd_file_id = ?" (SFSEncrypted, digest, updatedAt, sndFileId) + forM_ (zip [1 ..] chunkSpecsDigests) $ \(chunkNo :: Int, (XFTPChunkSpec {chunkOffset, chunkSize}, chunkDigest)) -> + DB.execute db "INSERT INTO snd_file_chunks (snd_file_id, chunk_no, chunk_offset, chunk_size, digest) VALUES (?,?,?,?,?)" (sndFileId, chunkNo, chunkOffset, chunkSize, chunkDigest) + +updateSndFileComplete :: DB.Connection -> DBSndFileId -> IO () +updateSndFileComplete db sndFileId = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE snd_files SET prefix_path = NULL, status = ?, updated_at = ? WHERE snd_file_id = ?" (SFSComplete, updatedAt, sndFileId) + +updateSndFileNoPrefixPath :: DB.Connection -> DBSndFileId -> IO () +updateSndFileNoPrefixPath db sndFileId = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE snd_files SET prefix_path = NULL, updated_at = ? WHERE snd_file_id = ?" (updatedAt, sndFileId) + +updateSndFileDeleted :: DB.Connection -> DBSndFileId -> IO () +updateSndFileDeleted db sndFileId = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE snd_files SET deleted = 1, updated_at = ? WHERE snd_file_id = ?" (updatedAt, sndFileId) + +deleteSndFile' :: DB.Connection -> DBSndFileId -> IO () +deleteSndFile' db sndFileId = + DB.execute db "DELETE FROM snd_files WHERE snd_file_id = ?" (Only sndFileId) + +getSndFileDeleted :: DB.Connection -> DBSndFileId -> IO Bool +getSndFileDeleted db sndFileId = + fromMaybe True + <$> maybeFirstRow fromOnlyBI (DB.query db "SELECT deleted FROM snd_files WHERE snd_file_id = ?" (Only sndFileId)) + +createSndFileReplica :: DB.Connection -> SndFileChunk -> NewSndChunkReplica -> IO () +createSndFileReplica db SndFileChunk {sndChunkId} = createSndFileReplica_ db sndChunkId + +createSndFileReplica_ :: DB.Connection -> Int64 -> NewSndChunkReplica -> IO () +createSndFileReplica_ db sndChunkId NewSndChunkReplica {server, replicaId, replicaKey, rcvIdsKeys} = do + srvId <- createXFTPServer_ db server + DB.execute + db + [sql| + INSERT INTO snd_file_chunk_replicas + (snd_file_chunk_id, replica_number, xftp_server_id, replica_id, replica_key, replica_status) + VALUES (?,?,?,?,?,?) + |] + (sndChunkId, 1 :: Int, srvId, replicaId, replicaKey, SFRSCreated) + rId <- insertedRowId db + forM_ rcvIdsKeys $ \(rcvId, rcvKey) -> do + DB.execute + db + [sql| + INSERT INTO snd_file_chunk_replica_recipients + (snd_file_chunk_replica_id, rcv_replica_id, rcv_replica_key) + VALUES (?,?,?) + |] + (rId, rcvId, rcvKey) + +getNextSndChunkToUpload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Either StoreError (Maybe SndFileChunk)) +getNextSndChunkToUpload db server@ProtocolServer {host, port, keyHash} ttl = do + getWorkItem "snd_file_upload" getReplicaId getChunkData (markSndFileFailed db . snd) + where + getReplicaId :: IO (Maybe (Int64, DBSndFileId)) + getReplicaId = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + maybeFirstRow id $ + DB.query + db + [sql| + SELECT r.snd_file_chunk_replica_id, f.snd_file_id + FROM snd_file_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + JOIN snd_file_chunks c ON c.snd_file_chunk_id = r.snd_file_chunk_id + JOIN snd_files f ON f.snd_file_id = c.snd_file_id + WHERE s.xftp_host = ? AND s.xftp_port = ? AND s.xftp_key_hash = ? + AND r.replica_status = ? AND r.replica_number = 1 + AND (f.status = ? OR f.status = ?) AND f.deleted = 0 AND f.created_at >= ? + AND f.failed = 0 + ORDER BY r.retries ASC, r.created_at ASC + LIMIT 1 + |] + (host, port, keyHash, SFRSCreated, SFSEncrypted, SFSUploading, cutoffTs) + getChunkData :: (Int64, DBSndFileId) -> IO (Either StoreError SndFileChunk) + getChunkData (sndFileChunkReplicaId, _fileId) = do + chunk_ <- + firstRow toChunk SEFileNotFound $ + DB.query + db + [sql| + SELECT + f.snd_file_id, f.snd_file_entity_id, f.user_id, f.num_recipients, f.prefix_path, + c.snd_file_chunk_id, c.chunk_no, c.chunk_offset, c.chunk_size, c.digest, + r.snd_file_chunk_replica_id, r.replica_id, r.replica_key, r.replica_status, r.delay, r.retries + FROM snd_file_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + JOIN snd_file_chunks c ON c.snd_file_chunk_id = r.snd_file_chunk_id + JOIN snd_files f ON f.snd_file_id = c.snd_file_id + WHERE r.snd_file_chunk_replica_id = ? + |] + (Only sndFileChunkReplicaId) + forM chunk_ $ \chunk@SndFileChunk {replicas} -> do + replicas' <- forM replicas $ \replica@SndFileChunkReplica {sndChunkReplicaId} -> do + rcvIdsKeys <- getChunkReplicaRecipients_ db sndChunkReplicaId + pure (replica :: SndFileChunkReplica) {rcvIdsKeys} + pure (chunk {replicas = replicas'} :: SndFileChunk) + where + toChunk :: ((DBSndFileId, SndFileId, UserId, Int, FilePath) :. (Int64, Int, Int64, Word32, FileDigest) :. (Int64, ChunkReplicaId, C.APrivateAuthKey, SndFileReplicaStatus, Maybe Int64, Int)) -> SndFileChunk + toChunk ((sndFileId, sndFileEntityId, userId, numRecipients, filePrefixPath) :. (sndChunkId, chunkNo, chunkOffset, chunkSize, digest) :. (sndChunkReplicaId, replicaId, replicaKey, replicaStatus, delay, retries)) = + let chunkSpec = XFTPChunkSpec {filePath = sndFileEncPath filePrefixPath, chunkOffset, chunkSize} + in SndFileChunk + { sndFileId, + sndFileEntityId, + userId, + numRecipients, + sndChunkId, + chunkNo, + chunkSpec, + digest, + filePrefixPath, + replicas = [SndFileChunkReplica {sndChunkReplicaId, server, replicaId, replicaKey, replicaStatus, delay, retries, rcvIdsKeys = []}] + } + +updateSndChunkReplicaDelay :: DB.Connection -> Int64 -> Int64 -> IO () +updateSndChunkReplicaDelay db replicaId delay = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE snd_file_chunk_replicas SET delay = ?, retries = retries + 1, updated_at = ? WHERE snd_file_chunk_replica_id = ?" (delay, updatedAt, replicaId) + +addSndChunkReplicaRecipients :: DB.Connection -> SndFileChunkReplica -> [(ChunkReplicaId, C.APrivateAuthKey)] -> IO SndFileChunkReplica +addSndChunkReplicaRecipients db r@SndFileChunkReplica {sndChunkReplicaId} rcvIdsKeys = do + forM_ rcvIdsKeys $ \(rcvId, rcvKey) -> do + DB.execute + db + [sql| + INSERT INTO snd_file_chunk_replica_recipients + (snd_file_chunk_replica_id, rcv_replica_id, rcv_replica_key) + VALUES (?,?,?) + |] + (sndChunkReplicaId, rcvId, rcvKey) + rcvIdsKeys' <- getChunkReplicaRecipients_ db sndChunkReplicaId + pure (r :: SndFileChunkReplica) {rcvIdsKeys = rcvIdsKeys'} + +updateSndChunkReplicaStatus :: DB.Connection -> Int64 -> SndFileReplicaStatus -> IO () +updateSndChunkReplicaStatus db replicaId status = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE snd_file_chunk_replicas SET replica_status = ?, updated_at = ? WHERE snd_file_chunk_replica_id = ?" (status, updatedAt, replicaId) + +getPendingSndFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer] +getPendingSndFilesServers db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + map toXFTPServer + <$> DB.query + db + [sql| + SELECT DISTINCT + s.xftp_host, s.xftp_port, s.xftp_key_hash + FROM snd_file_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + JOIN snd_file_chunks c ON c.snd_file_chunk_id = r.snd_file_chunk_id + JOIN snd_files f ON f.snd_file_id = c.snd_file_id + WHERE r.replica_status = ? AND r.replica_number = 1 + AND (f.status = ? OR f.status = ?) AND f.deleted = 0 AND f.created_at >= ? + |] + (SFRSCreated, SFSEncrypted, SFSUploading, cutoffTs) + +getCleanupSndFilesPrefixPaths :: DB.Connection -> IO [(DBSndFileId, SndFileId, FilePath)] +getCleanupSndFilesPrefixPaths db = + DB.query + db + [sql| + SELECT snd_file_id, snd_file_entity_id, prefix_path + FROM snd_files + WHERE status IN (?,?) AND prefix_path IS NOT NULL + |] + (SFSComplete, SFSError) + +getCleanupSndFilesDeleted :: DB.Connection -> IO [(DBSndFileId, SndFileId, Maybe FilePath)] +getCleanupSndFilesDeleted db = + DB.query_ + db + [sql| + SELECT snd_file_id, snd_file_entity_id, prefix_path + FROM snd_files + WHERE deleted = 1 + |] + +getSndFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBSndFileId, SndFileId, Maybe FilePath)] +getSndFilesExpired db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + DB.query + db + [sql| + SELECT snd_file_id, snd_file_entity_id, prefix_path + FROM snd_files + WHERE created_at < ? + |] + (Only cutoffTs) + +createDeletedSndChunkReplica :: DB.Connection -> UserId -> FileChunkReplica -> FileDigest -> IO () +createDeletedSndChunkReplica db userId FileChunkReplica {server, replicaId, replicaKey} chunkDigest = do + srvId <- createXFTPServer_ db server + DB.execute + db + "INSERT INTO deleted_snd_chunk_replicas (user_id, xftp_server_id, replica_id, replica_key, chunk_digest) VALUES (?,?,?,?,?)" + (userId, srvId, replicaId, replicaKey, chunkDigest) + +getDeletedSndChunkReplica :: DB.Connection -> DBSndFileId -> IO (Either StoreError DeletedSndChunkReplica) +getDeletedSndChunkReplica db deletedSndChunkReplicaId = + firstRow toReplica SEDeletedSndChunkReplicaNotFound $ + DB.query + db + [sql| + SELECT + r.user_id, r.replica_id, r.replica_key, r.chunk_digest, r.delay, r.retries, + s.xftp_host, s.xftp_port, s.xftp_key_hash + FROM deleted_snd_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + WHERE r.deleted_snd_chunk_replica_id = ? + |] + (Only deletedSndChunkReplicaId) + where + toReplica :: (UserId, ChunkReplicaId, C.APrivateAuthKey, FileDigest, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> DeletedSndChunkReplica + toReplica (userId, replicaId, replicaKey, chunkDigest, delay, retries, host, port, keyHash) = + let server = XFTPServer host port keyHash + in DeletedSndChunkReplica {deletedSndChunkReplicaId, userId, server, replicaId, replicaKey, chunkDigest, delay, retries} + +getNextDeletedSndChunkReplica :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Either StoreError (Maybe DeletedSndChunkReplica)) +getNextDeletedSndChunkReplica db ProtocolServer {host, port, keyHash} ttl = + getWorkItem "deleted replica" getReplicaId (getDeletedSndChunkReplica db) markReplicaFailed + where + getReplicaId :: IO (Maybe Int64) + getReplicaId = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + maybeFirstRow fromOnly $ + DB.query + db + [sql| + SELECT r.deleted_snd_chunk_replica_id + FROM deleted_snd_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + WHERE s.xftp_host = ? AND s.xftp_port = ? AND s.xftp_key_hash = ? + AND r.created_at >= ? + AND failed = 0 + ORDER BY r.retries ASC, r.created_at ASC + LIMIT 1 + |] + (host, port, keyHash, cutoffTs) + markReplicaFailed :: Int64 -> IO () + markReplicaFailed replicaId = do + DB.execute db "UPDATE deleted_snd_chunk_replicas SET failed = 1 WHERE deleted_snd_chunk_replica_id = ?" (Only replicaId) + +updateDeletedSndChunkReplicaDelay :: DB.Connection -> Int64 -> Int64 -> IO () +updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId delay = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE deleted_snd_chunk_replicas SET delay = ?, retries = retries + 1, updated_at = ? WHERE deleted_snd_chunk_replica_id = ?" (delay, updatedAt, deletedSndChunkReplicaId) + +deleteDeletedSndChunkReplica :: DB.Connection -> Int64 -> IO () +deleteDeletedSndChunkReplica db deletedSndChunkReplicaId = + DB.execute db "DELETE FROM deleted_snd_chunk_replicas WHERE deleted_snd_chunk_replica_id = ?" (Only deletedSndChunkReplicaId) + +getPendingDelFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer] +getPendingDelFilesServers db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + map toXFTPServer + <$> DB.query + db + [sql| + SELECT DISTINCT + s.xftp_host, s.xftp_port, s.xftp_key_hash + FROM deleted_snd_chunk_replicas r + JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id + WHERE r.created_at >= ? + |] + (Only cutoffTs) + +deleteDeletedSndChunkReplicasExpired :: DB.Connection -> NominalDiffTime -> IO () +deleteDeletedSndChunkReplicasExpired db ttl = do + cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime + DB.execute db "DELETE FROM deleted_snd_chunk_replicas WHERE created_at < ?" (Only cutoffTs) + +updateServersStats :: DB.Connection -> AgentPersistedServerStats -> IO () +updateServersStats db stats = do + updatedAt <- getCurrentTime + DB.execute db "UPDATE servers_stats SET servers_stats = ?, updated_at = ? WHERE servers_stats_id = 1" (stats, updatedAt) + +getServersStats :: DB.Connection -> IO (Either StoreError (UTCTime, Maybe AgentPersistedServerStats)) +getServersStats db = + firstRow id SEServersStatsNotFound $ + DB.query_ db "SELECT started_at, servers_stats FROM servers_stats WHERE servers_stats_id = 1" + +resetServersStats :: DB.Connection -> UTCTime -> IO () +resetServersStats db startedAt = + DB.execute db "UPDATE servers_stats SET servers_stats = NULL, started_at = ?, updated_at = ? WHERE servers_stats_id = 1" (startedAt, startedAt) diff --git a/src/Simplex/Messaging/Agent/Store/Common.hs b/src/Simplex/Messaging/Agent/Store/Common.hs new file mode 100644 index 000000000..e83cfe03b --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Common.hs @@ -0,0 +1,14 @@ +{-# LANGUAGE CPP #-} + +module Simplex.Messaging.Agent.Store.Common +#if defined(dbPostgres) + ( module Simplex.Messaging.Agent.Store.Postgres.Common, + ) + where +import Simplex.Messaging.Agent.Store.Postgres.Common +#else + ( module Simplex.Messaging.Agent.Store.SQLite.Common, + ) + where +import Simplex.Messaging.Agent.Store.SQLite.Common +#endif diff --git a/src/Simplex/Messaging/Agent/Store/DB.hs b/src/Simplex/Messaging/Agent/Store/DB.hs new file mode 100644 index 000000000..f8c54e463 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/DB.hs @@ -0,0 +1,15 @@ +{-# LANGUAGE CPP #-} + +module Simplex.Messaging.Agent.Store.DB +#if defined(dbPostgres) + ( module Simplex.Messaging.Agent.Store.Postgres.DB, + ) + where +import Simplex.Messaging.Agent.Store.Postgres.DB +#else + ( module Simplex.Messaging.Agent.Store.SQLite.DB, + ) + where +import Simplex.Messaging.Agent.Store.SQLite.DB +#endif + diff --git a/src/Simplex/Messaging/Agent/Store/Migrations.hs b/src/Simplex/Messaging/Agent/Store/Migrations.hs new file mode 100644 index 000000000..35015f634 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Migrations.hs @@ -0,0 +1,95 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE LambdaCase #-} + +module Simplex.Messaging.Agent.Store.Migrations + ( Migration (..), + MigrationsToRun (..), + DownMigration (..), + Migrations.app, + Migrations.getCurrent, + get, + Migrations.initialize, + Migrations.run, + migrateSchema, + -- for tests + migrationsToRun, + toDownMigration, + ) +where + +import Control.Monad +import Data.Char (toLower) +import Data.Functor (($>)) +import Data.Maybe (isNothing, mapMaybe) +import Simplex.Messaging.Agent.Store.Common +import Simplex.Messaging.Agent.Store.Shared +import System.Exit (exitFailure) +import System.IO (hFlush, stdout) +#if defined(dbPostgres) +import qualified Simplex.Messaging.Agent.Store.Postgres.Migrations as Migrations +#else +import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations +import System.Directory (copyFile) +#endif + +get :: DBStore -> [Migration] -> IO (Either MTRError MigrationsToRun) +get st migrations = migrationsToRun migrations <$> withTransaction st Migrations.getCurrent + +migrationsToRun :: [Migration] -> [Migration] -> Either MTRError MigrationsToRun +migrationsToRun [] [] = Right MTRNone +migrationsToRun appMs [] = Right $ MTRUp appMs +migrationsToRun [] dbMs + | length dms == length dbMs = Right $ MTRDown dms + | otherwise = Left $ MTRENoDown $ mapMaybe nameNoDown dbMs + where + dms = mapMaybe toDownMigration dbMs + nameNoDown m = if isNothing (down m) then Just $ name m else Nothing +migrationsToRun (a : as) (d : ds) + | name a == name d = migrationsToRun as ds + | otherwise = Left $ MTREDifferent (name a) (name d) + +migrateSchema :: DBStore -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError ()) +migrateSchema st migrations confirmMigrations = do + Migrations.initialize st + get st migrations >>= \case + Left e -> do + when (confirmMigrations == MCConsole) $ confirmOrExit ("Database state error: " <> mtrErrorDescription e) + pure . Left $ MigrationError e + Right MTRNone -> pure $ Right () + Right ms@(MTRUp ums) + | dbNew st -> Migrations.run st ms $> Right () + | otherwise -> case confirmMigrations of + MCYesUp -> runWithBackup st ms + MCYesUpDown -> runWithBackup st ms + MCConsole -> confirm err >> runWithBackup st ms + MCError -> pure $ Left err + where + err = MEUpgrade $ map upMigration ums -- "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map name ums) + Right ms@(MTRDown dms) -> case confirmMigrations of + MCYesUpDown -> runWithBackup st ms + MCConsole -> confirm err >> runWithBackup st ms + MCYesUp -> pure $ Left err + MCError -> pure $ Left err + where + err = MEDowngrade $ map downName dms + where + confirm err = confirmOrExit $ migrationErrorDescription err + +runWithBackup :: DBStore -> MigrationsToRun -> IO (Either a ()) +#if defined(dbPostgres) +runWithBackup st ms = Migrations.run st ms $> Right () +#else +runWithBackup st ms = do + let f = dbFilePath st + copyFile f (f <> ".bak") + Migrations.run st ms + pure $ Right () +#endif + +confirmOrExit :: String -> IO () +confirmOrExit s = do + putStrLn s + putStr "Continue (y/N): " + hFlush stdout + ok <- getLine + when (map toLower ok /= "y") exitFailure diff --git a/src/Simplex/Messaging/Agent/Store/Postgres.hs b/src/Simplex/Messaging/Agent/Store/Postgres.hs new file mode 100644 index 000000000..a4c8a52bb --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres.hs @@ -0,0 +1,94 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Simplex.Messaging.Agent.Store.Postgres + ( createDBStore, + defaultSimplexConnectInfo, + closeDBStore, + execSQL + ) +where + +import Control.Exception (throwIO) +import Control.Monad (unless, void) +import Data.Functor (($>)) +import Data.String (fromString) +import Data.Text (Text) +import Database.PostgreSQL.Simple (ConnectInfo (..), Only (..), defaultConnectInfo) +import qualified Database.PostgreSQL.Simple as PSQL +import Database.PostgreSQL.Simple.SqlQQ (sql) +import Simplex.Messaging.Agent.Store.Migrations (migrateSchema) +import Simplex.Messaging.Agent.Store.Postgres.Common +import qualified Simplex.Messaging.Agent.Store.Postgres.DB as DB +import Simplex.Messaging.Agent.Store.Postgres.Util (createDBAndUserIfNotExists) +import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationError (..)) +import Simplex.Messaging.Util (ifM) +import UnliftIO.Exception (onException) +import UnliftIO.MVar +import UnliftIO.STM + +defaultSimplexConnectInfo :: ConnectInfo +defaultSimplexConnectInfo = + defaultConnectInfo + { connectUser = "simplex", + connectDatabase = "simplex_v6_3_client_db" + } + +-- | Create a new Postgres DBStore with the given connection info, schema name and migrations. +-- This function creates the user and/or database passed in connectInfo if they do not exist +-- (expects the default 'postgres' user and 'postgres' db to exist). +-- If passed schema does not exist in connectInfo database, it will be created. +-- Applies necessary migrations to schema. +-- TODO [postgres] authentication / user password, db encryption (?) +createDBStore :: ConnectInfo -> String -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createDBStore connectInfo schema migrations confirmMigrations = do + createDBAndUserIfNotExists connectInfo + st <- connectPostgresStore connectInfo schema + r <- migrateSchema st migrations confirmMigrations `onException` closeDBStore st + case r of + Right () -> pure $ Right st + Left e -> closeDBStore st $> Left e + +connectPostgresStore :: ConnectInfo -> String -> IO DBStore +connectPostgresStore dbConnectInfo schema = do + (dbConn, dbNew) <- connectDB dbConnectInfo schema -- TODO [postgres] analogue for dbBusyLoop? + dbConnection <- newMVar dbConn + dbClosed <- newTVarIO False + pure DBStore {dbConnectInfo, dbConnection, dbNew, dbClosed} + +connectDB :: ConnectInfo -> String -> IO (DB.Connection, Bool) +connectDB dbConnectInfo schema = do + db <- PSQL.connect dbConnectInfo + schemaExists <- prepare db `onException` PSQL.close db + let dbNew = not schemaExists + pure (db, dbNew) + where + prepare db = do + void $ PSQL.execute_ db "SET client_min_messages TO WARNING" + [Only schemaExists] <- + PSQL.query + db + [sql| + SELECT EXISTS ( + SELECT 1 FROM pg_catalog.pg_namespace + WHERE nspname = ? + ) + |] + (Only schema) + unless schemaExists $ void $ PSQL.execute_ db (fromString $ "CREATE SCHEMA " <> schema) + void $ PSQL.execute_ db (fromString $ "SET search_path TO " <> schema) + pure schemaExists + +-- can share with SQLite +closeDBStore :: DBStore -> IO () +closeDBStore st@DBStore {dbClosed} = + ifM (readTVarIO dbClosed) (putStrLn "closeDBStore: already closed") $ + withConnection st $ \conn -> do + DB.close conn + atomically $ writeTVar dbClosed True + +-- TODO [postgres] not necessary for postgres (used for ExecAgentStoreSQL, ExecChatStoreSQL) +execSQL :: PSQL.Connection -> Text -> IO [Text] +execSQL _db _query = throwIO (userError "not implemented") diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs new file mode 100644 index 000000000..b23dcf9c8 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Common.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE NamedFieldPuns #-} + +module Simplex.Messaging.Agent.Store.Postgres.Common + ( DBStore (..), + withConnection, + withConnection', + withTransaction, + withTransaction', + withTransactionPriority, + ) +where + +import qualified Database.PostgreSQL.Simple as PSQL +import UnliftIO.MVar +import UnliftIO.STM + +-- TODO [postgres] use log_min_duration_statement instead of custom slow queries (SQLite's Connection type) +data DBStore = DBStore + { dbConnectInfo :: PSQL.ConnectInfo, + dbConnection :: MVar PSQL.Connection, + dbClosed :: TVar Bool, + dbNew :: Bool + } + +-- TODO [postgres] connection pool +withConnectionPriority :: DBStore -> Bool -> (PSQL.Connection -> IO a) -> IO a +withConnectionPriority DBStore {dbConnection} _priority action = + withMVar dbConnection action + +withConnection :: DBStore -> (PSQL.Connection -> IO a) -> IO a +withConnection st = withConnectionPriority st False + +withConnection' :: DBStore -> (PSQL.Connection -> IO a) -> IO a +withConnection' = withConnection + +withTransaction' :: DBStore -> (PSQL.Connection -> IO a) -> IO a +withTransaction' = withTransaction + +withTransaction :: DBStore -> (PSQL.Connection -> IO a) -> IO a +withTransaction st = withTransactionPriority st False +{-# INLINE withTransaction #-} + +-- TODO [postgres] analogue for dbBusyLoop? +withTransactionPriority :: DBStore -> Bool -> (PSQL.Connection -> IO a) -> IO a +withTransactionPriority st priority action = withConnectionPriority st priority transaction + where + transaction conn = PSQL.withTransaction conn $ action conn diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs b/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs new file mode 100644 index 000000000..9e597aef7 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs @@ -0,0 +1,63 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +module Simplex.Messaging.Agent.Store.Postgres.DB + ( BoolInt (..), + PSQL.Binary (..), + PSQL.Connection, + PSQL.connect, + PSQL.close, + execute, + execute_, + executeMany, + PSQL.query, + PSQL.query_, + ) +where + +import Control.Monad (void) +import Data.Int (Int32, Int64) +import Data.Word (Word16, Word32) +import Database.PostgreSQL.Simple (ResultError (..)) +import qualified Database.PostgreSQL.Simple as PSQL +import Database.PostgreSQL.Simple.FromField (FromField (..), returnError) +import Database.PostgreSQL.Simple.ToField (ToField (..)) + +newtype BoolInt = BI {unBI :: Bool} + +instance FromField BoolInt where + fromField field dat = BI . (/= (0 :: Int)) <$> fromField field dat + {-# INLINE fromField #-} + +instance ToField BoolInt where + toField (BI b) = toField ((if b then 1 else 0) :: Int) + {-# INLINE toField #-} + +execute :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> q -> IO () +execute db q qs = void $ PSQL.execute db q qs +{-# INLINE execute #-} + +execute_ :: PSQL.Connection -> PSQL.Query -> IO () +execute_ db q = void $ PSQL.execute_ db q +{-# INLINE execute_ #-} + +executeMany :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> [q] -> IO () +executeMany db q qs = void $ PSQL.executeMany db q qs +{-# INLINE executeMany #-} + +-- orphan instances + +-- used in FileSize +instance FromField Word32 where + fromField field dat = do + i <- fromField field dat + if i >= (0 :: Int64) + then pure (fromIntegral i :: Word32) + else returnError ConversionFailed field "Negative value can't be converted to Word32" + +-- used in Version +instance FromField Word16 where + fromField field dat = do + i <- fromField field dat + if i >= (0 :: Int32) + then pure (fromIntegral i :: Word16) + else returnError ConversionFailed field "Negative value can't be converted to Word16" diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs new file mode 100644 index 000000000..bf8d56caa --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TupleSections #-} + +module Simplex.Messaging.Agent.Store.Postgres.Migrations + ( app, + initialize, + run, + getCurrent, + ) +where + +import Control.Monad (void) +import Data.List (sortOn) +import Data.Text (Text) +import qualified Data.Text as T +import qualified Data.Text.Encoding as TE +import Data.Time.Clock (getCurrentTime) +import qualified Database.PostgreSQL.LibPQ as LibPQ +import Database.PostgreSQL.Simple (Only (..)) +import qualified Database.PostgreSQL.Simple as PSQL +import Database.PostgreSQL.Simple.Internal (Connection (..)) +import Database.PostgreSQL.Simple.SqlQQ (sql) +import Simplex.Messaging.Agent.Store.Postgres.Common +import Simplex.Messaging.Agent.Store.Postgres.Migrations.M20241210_initial +import Simplex.Messaging.Agent.Store.Shared +import UnliftIO.MVar + +schemaMigrations :: [(String, Text, Maybe Text)] +schemaMigrations = + [ ("20241210_initial", m20241210_initial, Nothing) + ] + +-- | The list of migrations in ascending order by date +app :: [Migration] +app = sortOn name $ map migration schemaMigrations + where + migration (name, up, down) = Migration {name, up, down = down} + +initialize :: DBStore -> IO () +initialize st = withTransaction' st $ \db -> + void $ + PSQL.execute_ + db + [sql| + CREATE TABLE IF NOT EXISTS migrations ( + name TEXT NOT NULL, + ts TIMESTAMP NOT NULL, + down TEXT, + PRIMARY KEY (name) + ) + |] + +run :: DBStore -> MigrationsToRun -> IO () +run st = \case + MTRUp [] -> pure () + MTRUp ms -> mapM_ runUp ms + MTRDown ms -> mapM_ runDown $ reverse ms + MTRNone -> pure () + where + runUp Migration {name, up, down} = withTransaction' st $ \db -> do + insert db + execSQL db up + where + insert db = void $ PSQL.execute db "INSERT INTO migrations (name, down, ts) VALUES (?,?,?)" . (name,down,) =<< getCurrentTime + runDown DownMigration {downName, downQuery} = withTransaction' st $ \db -> do + execSQL db downQuery + void $ PSQL.execute db "DELETE FROM migrations WHERE name = ?" (Only downName) + execSQL db query = + withMVar (connectionHandle db) $ \pqConn -> + void $ LibPQ.exec pqConn (TE.encodeUtf8 query) + +getCurrent :: PSQL.Connection -> IO [Migration] +getCurrent db = map toMigration <$> PSQL.query_ db "SELECT name, down FROM migrations ORDER BY name ASC;" + where + toMigration (name, down) = Migration {name, up = T.pack "", down} diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs new file mode 100644 index 000000000..15574d313 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Migrations/M20241210_initial.hs @@ -0,0 +1,545 @@ +{-# LANGUAGE QuasiQuotes #-} + +module Simplex.Messaging.Agent.Store.Postgres.Migrations.M20241210_initial where + +import Data.Text (Text) +import qualified Data.Text as T +import Text.RawString.QQ (r) + +m20241210_initial :: Text +m20241210_initial = + T.pack + [r| +CREATE TABLE users( + user_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + deleted SMALLINT NOT NULL DEFAULT 0 +); +CREATE TABLE servers( + host TEXT NOT NULL, + port TEXT NOT NULL, + key_hash BYTEA NOT NULL, + PRIMARY KEY(host, port) +); +CREATE TABLE connections( + conn_id BYTEA NOT NULL PRIMARY KEY, + conn_mode TEXT NOT NULL, + last_internal_msg_id BIGINT NOT NULL DEFAULT 0, + last_internal_rcv_msg_id BIGINT NOT NULL DEFAULT 0, + last_internal_snd_msg_id BIGINT NOT NULL DEFAULT 0, + last_external_snd_msg_id BIGINT NOT NULL DEFAULT 0, + last_rcv_msg_hash BYTEA NOT NULL DEFAULT ''::BYTEA, + last_snd_msg_hash BYTEA NOT NULL DEFAULT ''::BYTEA, + smp_agent_version INTEGER NOT NULL DEFAULT 1, + duplex_handshake SMALLINT NULL DEFAULT 0, + enable_ntfs SMALLINT, + deleted SMALLINT NOT NULL DEFAULT 0, + user_id BIGINT NOT NULL REFERENCES users ON DELETE CASCADE, + ratchet_sync_state TEXT NOT NULL DEFAULT 'ok', + deleted_at_wait_delivery TIMESTAMPTZ, + pq_support SMALLINT NOT NULL DEFAULT 0 +); +CREATE TABLE rcv_queues( + host TEXT NOT NULL, + port TEXT NOT NULL, + rcv_id BYTEA NOT NULL, + conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, + rcv_private_key BYTEA NOT NULL, + rcv_dh_secret BYTEA NOT NULL, + e2e_priv_key BYTEA NOT NULL, + e2e_dh_secret BYTEA, + snd_id BYTEA NOT NULL, + snd_key BYTEA, + status TEXT NOT NULL, + smp_server_version INTEGER NOT NULL DEFAULT 1, + smp_client_version INTEGER, + ntf_public_key BYTEA, + ntf_private_key BYTEA, + ntf_id BYTEA, + rcv_ntf_dh_secret BYTEA, + rcv_queue_id BIGINT NOT NULL, + rcv_primary SMALLINT NOT NULL, + replace_rcv_queue_id BIGINT NULL, + delete_errors BIGINT NOT NULL DEFAULT 0, + server_key_hash BYTEA, + switch_status TEXT, + deleted SMALLINT NOT NULL DEFAULT 0, + snd_secure SMALLINT NOT NULL DEFAULT 0, + last_broker_ts TIMESTAMPTZ, + PRIMARY KEY(host, port, rcv_id), + FOREIGN KEY(host, port) REFERENCES servers + ON DELETE RESTRICT ON UPDATE CASCADE, + UNIQUE(host, port, snd_id) +); +CREATE TABLE snd_queues( + host TEXT NOT NULL, + port TEXT NOT NULL, + snd_id BYTEA NOT NULL, + conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, + snd_private_key BYTEA NOT NULL, + e2e_dh_secret BYTEA NOT NULL, + status TEXT NOT NULL, + smp_server_version INTEGER NOT NULL DEFAULT 1, + smp_client_version INTEGER NOT NULL DEFAULT 1, + snd_public_key BYTEA, + e2e_pub_key BYTEA, + snd_queue_id BIGINT NOT NULL, + snd_primary SMALLINT NOT NULL, + replace_snd_queue_id BIGINT NULL, + server_key_hash BYTEA, + switch_status TEXT, + snd_secure SMALLINT NOT NULL DEFAULT 0, + PRIMARY KEY(host, port, snd_id), + FOREIGN KEY(host, port) REFERENCES servers + ON DELETE RESTRICT ON UPDATE CASCADE +); +CREATE TABLE messages( + conn_id BYTEA NOT NULL REFERENCES connections(conn_id) + ON DELETE CASCADE, + internal_id BIGINT NOT NULL, + internal_ts TIMESTAMPTZ NOT NULL, + internal_rcv_id BIGINT, + internal_snd_id BIGINT, + msg_type BYTEA NOT NULL, + msg_body BYTEA NOT NULL DEFAULT ''::BYTEA, + msg_flags TEXT NULL, + pq_encryption SMALLINT NOT NULL DEFAULT 0, + PRIMARY KEY(conn_id, internal_id) +); +CREATE TABLE rcv_messages( + conn_id BYTEA NOT NULL, + internal_rcv_id BIGINT NOT NULL, + internal_id BIGINT NOT NULL, + external_snd_id BIGINT NOT NULL, + broker_id BYTEA NOT NULL, + broker_ts TIMESTAMPTZ NOT NULL, + internal_hash BYTEA NOT NULL, + external_prev_snd_hash BYTEA NOT NULL, + integrity BYTEA NOT NULL, + user_ack SMALLINT NULL DEFAULT 0, + rcv_queue_id BIGINT NOT NULL, + PRIMARY KEY(conn_id, internal_rcv_id), + FOREIGN KEY(conn_id, internal_id) REFERENCES messages + ON DELETE CASCADE +); +ALTER TABLE messages +ADD CONSTRAINT fk_messages_rcv_messages + FOREIGN KEY (conn_id, internal_rcv_id) REFERENCES rcv_messages + ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED; +CREATE TABLE snd_messages( + conn_id BYTEA NOT NULL, + internal_snd_id BIGINT NOT NULL, + internal_id BIGINT NOT NULL, + internal_hash BYTEA NOT NULL, + previous_msg_hash BYTEA NOT NULL DEFAULT ''::BYTEA, + retry_int_slow BIGINT, + retry_int_fast BIGINT, + rcpt_internal_id BIGINT, + rcpt_status TEXT, + PRIMARY KEY(conn_id, internal_snd_id), + FOREIGN KEY(conn_id, internal_id) REFERENCES messages + ON DELETE CASCADE +); +ALTER TABLE messages +ADD CONSTRAINT fk_messages_snd_messages + FOREIGN KEY (conn_id, internal_snd_id) REFERENCES snd_messages + ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED; +CREATE TABLE conn_confirmations( + confirmation_id BYTEA NOT NULL PRIMARY KEY, + conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, + e2e_snd_pub_key BYTEA NOT NULL, + sender_key BYTEA, + ratchet_state BYTEA NOT NULL, + sender_conn_info BYTEA NOT NULL, + accepted SMALLINT NOT NULL, + own_conn_info BYTEA, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + smp_reply_queues BYTEA NULL, + smp_client_version INTEGER +); +CREATE TABLE conn_invitations( + invitation_id BYTEA NOT NULL PRIMARY KEY, + contact_conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, + cr_invitation BYTEA NOT NULL, + recipient_conn_info BYTEA NOT NULL, + accepted SMALLINT NOT NULL DEFAULT 0, + own_conn_info BYTEA, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +CREATE TABLE ratchets( + conn_id BYTEA NOT NULL PRIMARY KEY REFERENCES connections + ON DELETE CASCADE, + x3dh_priv_key_1 BYTEA, + x3dh_priv_key_2 BYTEA, + ratchet_state BYTEA, + e2e_version INTEGER NOT NULL DEFAULT 1, + x3dh_pub_key_1 BYTEA, + x3dh_pub_key_2 BYTEA, + pq_priv_kem BYTEA, + pq_pub_kem BYTEA +); +CREATE TABLE skipped_messages( + skipped_message_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + conn_id BYTEA NOT NULL REFERENCES ratchets + ON DELETE CASCADE, + header_key BYTEA NOT NULL, + msg_n BIGINT NOT NULL, + msg_key BYTEA NOT NULL +); +CREATE TABLE ntf_servers( + ntf_host TEXT NOT NULL, + ntf_port TEXT NOT NULL, + ntf_key_hash BYTEA NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + PRIMARY KEY(ntf_host, ntf_port) +); +CREATE TABLE ntf_tokens( + provider TEXT NOT NULL, + device_token TEXT NOT NULL, + ntf_host TEXT NOT NULL, + ntf_port TEXT NOT NULL, + tkn_id BYTEA, + tkn_pub_key BYTEA NOT NULL, + tkn_priv_key BYTEA NOT NULL, + tkn_pub_dh_key BYTEA NOT NULL, + tkn_priv_dh_key BYTEA NOT NULL, + tkn_dh_secret BYTEA, + tkn_status TEXT NOT NULL, + tkn_action BYTEA, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + ntf_mode TEXT NULL, + PRIMARY KEY(provider, device_token, ntf_host, ntf_port), + FOREIGN KEY(ntf_host, ntf_port) REFERENCES ntf_servers + ON DELETE RESTRICT ON UPDATE CASCADE +); +CREATE TABLE ntf_subscriptions( + conn_id BYTEA NOT NULL, + smp_host TEXT NULL, + smp_port TEXT NULL, + smp_ntf_id BYTEA, + ntf_host TEXT NOT NULL, + ntf_port TEXT NOT NULL, + ntf_sub_id BYTEA, + ntf_sub_status TEXT NOT NULL, + ntf_sub_action TEXT, + ntf_sub_smp_action TEXT, + ntf_sub_action_ts TIMESTAMPTZ, + updated_by_supervisor SMALLINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + smp_server_key_hash BYTEA, + ntf_failed SMALLINT DEFAULT 0, + smp_failed SMALLINT DEFAULT 0, + PRIMARY KEY(conn_id), + FOREIGN KEY(smp_host, smp_port) REFERENCES servers(host, port) + ON DELETE SET NULL ON UPDATE CASCADE, + FOREIGN KEY(ntf_host, ntf_port) REFERENCES ntf_servers + ON DELETE RESTRICT ON UPDATE CASCADE +); +CREATE TABLE commands( + command_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, + host TEXT, + port TEXT, + corr_id BYTEA NOT NULL, + command_tag BYTEA NOT NULL, + command BYTEA NOT NULL, + agent_version INTEGER NOT NULL DEFAULT 1, + server_key_hash BYTEA, + created_at TIMESTAMPTZ NOT NULL DEFAULT '1970-01-01 00:00:00', + failed SMALLINT DEFAULT 0, + FOREIGN KEY(host, port) REFERENCES servers + ON DELETE RESTRICT ON UPDATE CASCADE +); +CREATE TABLE snd_message_deliveries( + snd_message_delivery_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, + snd_queue_id BIGINT NOT NULL, + internal_id BIGINT NOT NULL, + failed SMALLINT DEFAULT 0, + FOREIGN KEY(conn_id, internal_id) REFERENCES messages ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED +); +CREATE TABLE xftp_servers( + xftp_server_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + xftp_host TEXT NOT NULL, + xftp_port TEXT NOT NULL, + xftp_key_hash BYTEA NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + UNIQUE(xftp_host, xftp_port, xftp_key_hash) +); +CREATE TABLE rcv_files( + rcv_file_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + rcv_file_entity_id BYTEA NOT NULL, + user_id BIGINT NOT NULL REFERENCES users ON DELETE CASCADE, + size BIGINT NOT NULL, + digest BYTEA NOT NULL, + key BYTEA NOT NULL, + nonce BYTEA NOT NULL, + chunk_size BIGINT NOT NULL, + prefix_path TEXT NOT NULL, + tmp_path TEXT, + save_path TEXT NOT NULL, + status TEXT NOT NULL, + deleted SMALLINT NOT NULL DEFAULT 0, + error TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + save_file_key BYTEA, + save_file_nonce BYTEA, + failed SMALLINT DEFAULT 0, + redirect_id BIGINT REFERENCES rcv_files ON DELETE SET NULL, + redirect_entity_id BYTEA, + redirect_size BIGINT, + redirect_digest BYTEA, + approved_relays SMALLINT NOT NULL DEFAULT 0, + UNIQUE(rcv_file_entity_id) +); +CREATE TABLE rcv_file_chunks( + rcv_file_chunk_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + rcv_file_id BIGINT NOT NULL REFERENCES rcv_files ON DELETE CASCADE, + chunk_no BIGINT NOT NULL, + chunk_size BIGINT NOT NULL, + digest BYTEA NOT NULL, + tmp_path TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +CREATE TABLE rcv_file_chunk_replicas( + rcv_file_chunk_replica_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + rcv_file_chunk_id BIGINT NOT NULL REFERENCES rcv_file_chunks ON DELETE CASCADE, + replica_number BIGINT NOT NULL, + xftp_server_id BIGINT NOT NULL REFERENCES xftp_servers ON DELETE CASCADE, + replica_id BYTEA NOT NULL, + replica_key BYTEA NOT NULL, + received SMALLINT NOT NULL DEFAULT 0, + delay BIGINT, + retries BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +CREATE TABLE snd_files( + snd_file_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + snd_file_entity_id BYTEA NOT NULL, + user_id BIGINT NOT NULL REFERENCES users ON DELETE CASCADE, + num_recipients BIGINT NOT NULL, + digest BYTEA, + key BYTEA NOT NUll, + nonce BYTEA NOT NUll, + path TEXT NOT NULL, + prefix_path TEXT, + status TEXT NOT NULL, + deleted SMALLINT NOT NULL DEFAULT 0, + error TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + src_file_key BYTEA, + src_file_nonce BYTEA, + failed SMALLINT DEFAULT 0, + redirect_size BIGINT, + redirect_digest BYTEA +); +CREATE TABLE snd_file_chunks( + snd_file_chunk_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + snd_file_id BIGINT NOT NULL REFERENCES snd_files ON DELETE CASCADE, + chunk_no BIGINT NOT NULL, + chunk_offset BIGINT NOT NULL, + chunk_size BIGINT NOT NULL, + digest BYTEA NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +CREATE TABLE snd_file_chunk_replicas( + snd_file_chunk_replica_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + snd_file_chunk_id BIGINT NOT NULL REFERENCES snd_file_chunks ON DELETE CASCADE, + replica_number BIGINT NOT NULL, + xftp_server_id BIGINT NOT NULL REFERENCES xftp_servers ON DELETE CASCADE, + replica_id BYTEA NOT NULL, + replica_key BYTEA NOT NULL, + replica_status TEXT NOT NULL, + delay BIGINT, + retries BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +CREATE TABLE snd_file_chunk_replica_recipients( + snd_file_chunk_replica_recipient_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + snd_file_chunk_replica_id BIGINT NOT NULL REFERENCES snd_file_chunk_replicas ON DELETE CASCADE, + rcv_replica_id BYTEA NOT NULL, + rcv_replica_key BYTEA NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +CREATE TABLE deleted_snd_chunk_replicas( + deleted_snd_chunk_replica_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + user_id BIGINT NOT NULL REFERENCES users ON DELETE CASCADE, + xftp_server_id BIGINT NOT NULL REFERENCES xftp_servers ON DELETE CASCADE, + replica_id BYTEA NOT NULL, + replica_key BYTEA NOT NULL, + chunk_digest BYTEA NOT NULL, + delay BIGINT, + retries BIGINT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + failed SMALLINT DEFAULT 0 +); +CREATE TABLE encrypted_rcv_message_hashes( + encrypted_rcv_message_hash_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, + hash BYTEA NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +CREATE TABLE processed_ratchet_key_hashes( + processed_ratchet_key_hash_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + conn_id BYTEA NOT NULL REFERENCES connections ON DELETE CASCADE, + hash BYTEA NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +CREATE TABLE servers_stats( + servers_stats_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + servers_stats TEXT, + started_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()), + updated_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +INSERT INTO servers_stats DEFAULT VALUES; +CREATE TABLE ntf_tokens_to_delete( + ntf_token_to_delete_id BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + ntf_host TEXT NOT NULL, + ntf_port TEXT NOT NULL, + ntf_key_hash BYTEA NOT NULL, + tkn_id BYTEA NOT NULL, + tkn_priv_key BYTEA NOT NULL, + del_failed SMALLINT DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT (now()) +); +CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues(host, port, ntf_id); +CREATE UNIQUE INDEX idx_rcv_queue_id ON rcv_queues(conn_id, rcv_queue_id); +CREATE UNIQUE INDEX idx_snd_queue_id ON snd_queues(conn_id, snd_queue_id); +CREATE INDEX idx_snd_message_deliveries ON snd_message_deliveries( + conn_id, + snd_queue_id +); +CREATE INDEX idx_connections_user ON connections(user_id); +CREATE INDEX idx_commands_conn_id ON commands(conn_id); +CREATE INDEX idx_commands_host_port ON commands(host, port); +CREATE INDEX idx_conn_confirmations_conn_id ON conn_confirmations(conn_id); +CREATE INDEX idx_conn_invitations_contact_conn_id ON conn_invitations( + contact_conn_id +); +CREATE INDEX idx_messages_conn_id_internal_snd_id ON messages( + conn_id, + internal_snd_id +); +CREATE INDEX idx_messages_conn_id_internal_rcv_id ON messages( + conn_id, + internal_rcv_id +); +CREATE INDEX idx_messages_conn_id ON messages(conn_id); +CREATE INDEX idx_ntf_subscriptions_ntf_host_ntf_port ON ntf_subscriptions( + ntf_host, + ntf_port +); +CREATE INDEX idx_ntf_subscriptions_smp_host_smp_port ON ntf_subscriptions( + smp_host, + smp_port +); +CREATE INDEX idx_ntf_tokens_ntf_host_ntf_port ON ntf_tokens( + ntf_host, + ntf_port +); +CREATE INDEX idx_ratchets_conn_id ON ratchets(conn_id); +CREATE INDEX idx_rcv_messages_conn_id_internal_id ON rcv_messages( + conn_id, + internal_id +); +CREATE INDEX idx_skipped_messages_conn_id ON skipped_messages(conn_id); +CREATE INDEX idx_snd_message_deliveries_conn_id_internal_id ON snd_message_deliveries( + conn_id, + internal_id +); +CREATE INDEX idx_snd_messages_conn_id_internal_id ON snd_messages( + conn_id, + internal_id +); +CREATE INDEX idx_snd_queues_host_port ON snd_queues(host, port); +CREATE INDEX idx_rcv_files_user_id ON rcv_files(user_id); +CREATE INDEX idx_rcv_file_chunks_rcv_file_id ON rcv_file_chunks(rcv_file_id); +CREATE INDEX idx_rcv_file_chunk_replicas_rcv_file_chunk_id ON rcv_file_chunk_replicas( + rcv_file_chunk_id +); +CREATE INDEX idx_rcv_file_chunk_replicas_xftp_server_id ON rcv_file_chunk_replicas( + xftp_server_id +); +CREATE INDEX idx_snd_files_user_id ON snd_files(user_id); +CREATE INDEX idx_snd_file_chunks_snd_file_id ON snd_file_chunks(snd_file_id); +CREATE INDEX idx_snd_file_chunk_replicas_snd_file_chunk_id ON snd_file_chunk_replicas( + snd_file_chunk_id +); +CREATE INDEX idx_snd_file_chunk_replicas_xftp_server_id ON snd_file_chunk_replicas( + xftp_server_id +); +CREATE INDEX idx_snd_file_chunk_replica_recipients_snd_file_chunk_replica_id ON snd_file_chunk_replica_recipients( + snd_file_chunk_replica_id +); +CREATE INDEX idx_deleted_snd_chunk_replicas_user_id ON deleted_snd_chunk_replicas( + user_id +); +CREATE INDEX idx_deleted_snd_chunk_replicas_xftp_server_id ON deleted_snd_chunk_replicas( + xftp_server_id +); +CREATE INDEX idx_rcv_file_chunk_replicas_pending ON rcv_file_chunk_replicas( + received, + replica_number +); +CREATE INDEX idx_snd_file_chunk_replicas_pending ON snd_file_chunk_replicas( + replica_status, + replica_number +); +CREATE INDEX idx_deleted_snd_chunk_replicas_pending ON deleted_snd_chunk_replicas( + created_at +); +CREATE INDEX idx_encrypted_rcv_message_hashes_hash ON encrypted_rcv_message_hashes( + conn_id, + hash +); +CREATE INDEX idx_processed_ratchet_key_hashes_hash ON processed_ratchet_key_hashes( + conn_id, + hash +); +CREATE INDEX idx_snd_messages_rcpt_internal_id ON snd_messages( + conn_id, + rcpt_internal_id +); +CREATE INDEX idx_processed_ratchet_key_hashes_created_at ON processed_ratchet_key_hashes( + created_at +); +CREATE INDEX idx_encrypted_rcv_message_hashes_created_at ON encrypted_rcv_message_hashes( + created_at +); +CREATE INDEX idx_messages_internal_ts ON messages(internal_ts); +CREATE INDEX idx_commands_server_commands ON commands( + host, + port, + created_at, + command_id +); +CREATE INDEX idx_rcv_files_status_created_at ON rcv_files(status, created_at); +CREATE INDEX idx_snd_files_status_created_at ON snd_files(status, created_at); +CREATE INDEX idx_snd_files_snd_file_entity_id ON snd_files(snd_file_entity_id); +CREATE INDEX idx_messages_snd_expired ON messages( + conn_id, + internal_snd_id, + internal_ts +); +CREATE INDEX idx_snd_message_deliveries_expired ON snd_message_deliveries( + conn_id, + snd_queue_id, + failed, + internal_id +); +CREATE INDEX idx_rcv_files_redirect_id on rcv_files(redirect_id); +|] diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs b/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs new file mode 100644 index 000000000..98c0024f3 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Postgres/Util.hs @@ -0,0 +1,101 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Simplex.Messaging.Agent.Store.Postgres.Util + ( createDBAndUserIfNotExists, + -- for tests + dropSchema, + dropAllSchemasExceptSystem, + dropDatabaseAndUser, + ) +where + +import Control.Exception (bracket) +import Control.Monad (forM_, unless, void, when) +import Data.String (fromString) +import Database.PostgreSQL.Simple (ConnectInfo (..), Only (..), defaultConnectInfo) +import qualified Database.PostgreSQL.Simple as PSQL +import Database.PostgreSQL.Simple.SqlQQ (sql) + +createDBAndUserIfNotExists :: ConnectInfo -> IO () +createDBAndUserIfNotExists ConnectInfo {connectUser = user, connectDatabase = dbName} = do + -- connect to the default "postgres" maintenance database + bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ + \postgresDB -> do + void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING" + -- check if the user exists, create if not + [Only userExists] <- + PSQL.query + postgresDB + [sql| + SELECT EXISTS ( + SELECT 1 FROM pg_catalog.pg_roles + WHERE rolname = ? + ) + |] + (Only user) + unless userExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE USER " <> user) + -- check if the database exists, create if not + dbExists <- checkDBExists postgresDB dbName + unless dbExists $ void $ PSQL.execute_ postgresDB (fromString $ "CREATE DATABASE " <> dbName <> " OWNER " <> user) + +checkDBExists :: PSQL.Connection -> String -> IO Bool +checkDBExists postgresDB dbName = do + [Only dbExists] <- + PSQL.query + postgresDB + [sql| + SELECT EXISTS ( + SELECT 1 FROM pg_catalog.pg_database + WHERE datname = ? + ) + |] + (Only dbName) + pure dbExists + +dropSchema :: ConnectInfo -> String -> IO () +dropSchema connectInfo schema = + bracket (PSQL.connect connectInfo) PSQL.close $ + \db -> do + void $ PSQL.execute_ db "SET client_min_messages TO WARNING" + void $ PSQL.execute_ db (fromString $ "DROP SCHEMA IF EXISTS " <> schema <> " CASCADE") + +dropAllSchemasExceptSystem :: ConnectInfo -> IO () +dropAllSchemasExceptSystem connectInfo = + bracket (PSQL.connect connectInfo) PSQL.close $ + \db -> do + void $ PSQL.execute_ db "SET client_min_messages TO WARNING" + schemaNames :: [Only String] <- + PSQL.query_ + db + [sql| + SELECT schema_name + FROM information_schema.schemata + WHERE schema_name NOT IN ('public', 'pg_catalog', 'information_schema') + |] + forM_ schemaNames $ \(Only schema) -> + PSQL.execute_ db (fromString $ "DROP SCHEMA " <> schema <> " CASCADE") + +dropDatabaseAndUser :: ConnectInfo -> IO () +dropDatabaseAndUser ConnectInfo {connectUser = user, connectDatabase = dbName} = + bracket (PSQL.connect defaultConnectInfo {connectUser = "postgres", connectDatabase = "postgres"}) PSQL.close $ + \postgresDB -> do + void $ PSQL.execute_ postgresDB "SET client_min_messages TO WARNING" + dbExists <- checkDBExists postgresDB dbName + when dbExists $ do + void $ PSQL.execute_ postgresDB (fromString $ "ALTER DATABASE " <> dbName <> " WITH ALLOW_CONNECTIONS false") + -- terminate all connections to the database + _r :: [Only Bool] <- + PSQL.query + postgresDB + [sql| + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE datname = ? + AND pid <> pg_backend_pid() + |] + (Only dbName) + void $ PSQL.execute_ postgresDB (fromString $ "DROP DATABASE " <> dbName) + void $ PSQL.execute_ postgresDB (fromString $ "DROP USER IF EXISTS " <> user) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite.hs b/src/Simplex/Messaging/Agent/Store/SQLite.hs index 59b1d8687..816968208 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite.hs @@ -18,7 +18,6 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -26,385 +25,57 @@ {-# OPTIONS_GHC -fno-warn-orphans #-} module Simplex.Messaging.Agent.Store.SQLite - ( SQLiteStore (..), - MigrationConfirmation (..), - MigrationError (..), - UpMigration (..), - createSQLiteStore, - connectSQLiteStore, - closeSQLiteStore, - openSQLiteStore, - reopenSQLiteStore, + ( createDBStore, + closeDBStore, + execSQL, + -- used in Simplex.Chat.Archive sqlString, keyString, storeKey, - execSQL, - upMigration, -- used in tests - - -- * Users - createUserRecord, - deleteUserRecord, - setUserDeleted, - deleteUserWithoutConns, - deleteUsersWithoutConns, - checkUser, - - -- * Queues and connections - createNewConn, - updateNewConnRcv, - updateNewConnSnd, - createSndConn, - getConn, - getDeletedConn, - getConns, - getDeletedConns, - getConnData, - setConnDeleted, - setConnUserId, - setConnAgentVersion, - setConnPQSupport, - getDeletedConnIds, - getDeletedWaitingDeliveryConnIds, - setConnRatchetSync, - addProcessedRatchetKeyHash, - checkRatchetKeyHashExists, - deleteRatchetKeyHashesExpired, - getRcvConn, - getRcvQueueById, - getSndQueueById, - deleteConn, - upgradeRcvConnToDuplex, - upgradeSndConnToDuplex, - addConnRcvQueue, - addConnSndQueue, - setRcvQueueStatus, - setRcvSwitchStatus, - setRcvQueueDeleted, - setRcvQueueConfirmedE2E, - setSndQueueStatus, - setSndSwitchStatus, - setRcvQueuePrimary, - setSndQueuePrimary, - deleteConnRcvQueue, - incRcvDeleteErrors, - deleteConnSndQueue, - getPrimaryRcvQueue, - getRcvQueue, - getDeletedRcvQueue, - setRcvQueueNtfCreds, - -- Confirmations - createConfirmation, - acceptConfirmation, - getAcceptedConfirmation, - removeConfirmations, - -- Invitations - sent via Contact connections - createInvitation, - getInvitation, - acceptInvitation, - unacceptInvitation, - deleteInvitation, - -- Messages - updateRcvIds, - createRcvMsg, - updateRcvMsgHash, - updateSndIds, - createSndMsg, - updateSndMsgHash, - createSndMsgDelivery, - getSndMsgViaRcpt, - updateSndMsgRcpt, - getPendingQueueMsg, - getConnectionsForDelivery, - updatePendingMsgRIState, - deletePendingMsgs, - getExpiredSndMessages, - setMsgUserAck, - getRcvMsg, - getLastMsg, - checkRcvMsgHashExists, - getRcvMsgBrokerTs, - deleteMsg, - deleteDeliveredSndMsg, - deleteSndMsgDelivery, - deleteRcvMsgHashesExpired, - deleteSndMsgsExpired, - -- Double ratchet persistence - createRatchetX3dhKeys, - getRatchetX3dhKeys, - createSndRatchet, - getSndRatchet, - setRatchetX3dhKeys, - createRatchet, - deleteRatchet, - getRatchet, - getSkippedMsgKeys, - updateRatchet, - -- Async commands - createCommand, - getPendingCommandServers, - getPendingServerCommand, - updateCommandServer, - deleteCommand, - -- Notification device token persistence - createNtfToken, - getSavedNtfToken, - updateNtfTokenRegistration, - updateDeviceToken, - updateNtfMode, - updateNtfToken, - removeNtfToken, - addNtfTokenToDelete, - deleteExpiredNtfTokensToDelete, - NtfTokenToDelete, - getNextNtfTokenToDelete, - markNtfTokenToDeleteFailed_, -- exported for tests - getPendingDelTknServers, - deleteNtfTokenToDelete, - -- Notification subscription persistence - NtfSupervisorSub, - getNtfSubscription, - createNtfSubscription, - supervisorUpdateNtfSub, - supervisorUpdateNtfAction, - updateNtfSubscription, - setNullNtfSubscriptionAction, - deleteNtfSubscription, - deleteNtfSubscription', - getNextNtfSubNTFActions, - markNtfSubActionNtfFailed_, -- exported for tests - getNextNtfSubSMPActions, - markNtfSubActionSMPFailed_, -- exported for tests - getActiveNtfToken, - getNtfRcvQueue, - setConnectionNtfs, - - -- * File transfer - - -- Rcv files - createRcvFile, - createRcvFileRedirect, - getRcvFile, - getRcvFileByEntityId, - getRcvFileRedirects, - updateRcvChunkReplicaDelay, - updateRcvFileChunkReceived, - updateRcvFileStatus, - updateRcvFileError, - updateRcvFileComplete, - updateRcvFileRedirect, - updateRcvFileNoTmpPath, - updateRcvFileDeleted, - deleteRcvFile', - getNextRcvChunkToDownload, - getNextRcvFileToDecrypt, - getPendingRcvFilesServers, - getCleanupRcvFilesTmpPaths, - getCleanupRcvFilesDeleted, - getRcvFilesExpired, - -- Snd files - createSndFile, - getSndFile, - getSndFileByEntityId, - getNextSndFileToPrepare, - updateSndFileError, - updateSndFileStatus, - updateSndFileEncrypted, - updateSndFileComplete, - updateSndFileNoPrefixPath, - updateSndFileDeleted, - deleteSndFile', - getSndFileDeleted, - createSndFileReplica, - createSndFileReplica_, -- exported for tests - getNextSndChunkToUpload, - updateSndChunkReplicaDelay, - addSndChunkReplicaRecipients, - updateSndChunkReplicaStatus, - getPendingSndFilesServers, - getCleanupSndFilesPrefixPaths, - getCleanupSndFilesDeleted, - getSndFilesExpired, - createDeletedSndChunkReplica, - getNextDeletedSndChunkReplica, - updateDeletedSndChunkReplicaDelay, - deleteDeletedSndChunkReplica, - getPendingDelFilesServers, - deleteDeletedSndChunkReplicasExpired, - -- Stats - updateServersStats, - getServersStats, - resetServersStats, - - -- * utilities - withConnection, - withTransaction, - withTransactionPriority, - firstRow, - firstRow', - maybeFirstRow, + -- used in Simplex.Chat.Mobile and tests + reopenSQLiteStore, + -- used in tests + connectSQLiteStore, + openSQLiteStore, ) where -import Control.Logger.Simple import Control.Monad -import Control.Monad.IO.Class -import Control.Monad.Trans.Except -import Crypto.Random (ChaChaDRG) -import qualified Data.Aeson.TH as J -import qualified Data.Attoparsec.ByteString.Char8 as A -import Data.Bifunctor (first, second) import Data.ByteArray (ScrubbedBytes) import qualified Data.ByteArray as BA -import Data.ByteString (ByteString) -import qualified Data.ByteString.Base64.URL as U -import qualified Data.ByteString.Char8 as B -import Data.Char (toLower) import Data.Functor (($>)) import Data.IORef -import Data.Int (Int64) -import Data.List (foldl', intercalate, sortBy) -import Data.List.NonEmpty (NonEmpty (..)) -import qualified Data.List.NonEmpty as L -import qualified Data.Map.Strict as M -import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, listToMaybe) -import Data.Ord (Down (..)) +import Data.Maybe (fromMaybe) import Data.Text (Text) import qualified Data.Text as T -import Data.Text.Encoding (decodeLatin1, encodeUtf8) -import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) -import Data.Word (Word32) -import Database.SQLite.Simple (FromRow (..), NamedParam (..), Only (..), Query (..), SQLError, ToRow (..), field, (:.) (..)) +import Database.SQLite.Simple (Query (..)) import qualified Database.SQLite.Simple as SQL -import Database.SQLite.Simple.FromField import Database.SQLite.Simple.QQ (sql) -import Database.SQLite.Simple.ToField (ToField (..)) import qualified Database.SQLite3 as SQLite3 -import Network.Socket (ServiceName) -import Simplex.FileTransfer.Client (XFTPChunkSpec (..)) -import Simplex.FileTransfer.Description -import Simplex.FileTransfer.Protocol (FileParty (..), SFileParty (..)) -import Simplex.FileTransfer.Types -import Simplex.Messaging.Agent.Protocol -import Simplex.Messaging.Agent.RetryInterval (RI2State (..)) -import Simplex.Messaging.Agent.Stats -import Simplex.Messaging.Agent.Store +import Simplex.Messaging.Agent.Store.Migrations (migrateSchema) import Simplex.Messaging.Agent.Store.SQLite.Common import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB -import Simplex.Messaging.Agent.Store.SQLite.Migrations (DownMigration (..), MTRError, Migration (..), MigrationsToRun (..), mtrErrorDescription) -import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations -import qualified Simplex.Messaging.Crypto as C -import Simplex.Messaging.Crypto.File (CryptoFile (..), CryptoFileArgs (..)) -import Simplex.Messaging.Crypto.Ratchet (PQEncryption (..), PQSupport (..), RatchetX448, SkippedMsgDiff (..), SkippedMsgKeys) -import qualified Simplex.Messaging.Crypto.Ratchet as CR -import Simplex.Messaging.Encoding -import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), NtfSubscriptionId, NtfTknStatus (..), NtfTokenId, SMPQueueNtf (..)) -import Simplex.Messaging.Notifications.Types -import Simplex.Messaging.Parsers (blobFieldParser, defaultJSON, dropPrefix, fromTextField_, sumTypeJSON) -import Simplex.Messaging.Protocol -import qualified Simplex.Messaging.Protocol as SMP -import Simplex.Messaging.Transport.Client (TransportHost) -import Simplex.Messaging.Util (bshow, catchAllErrors, eitherToMaybe, ifM, safeDecodeUtf8, tshow, ($>>=), (<$$>)) -import Simplex.Messaging.Version.Internal -import System.Directory (copyFile, createDirectoryIfMissing, doesFileExist) -import System.Exit (exitFailure) +import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationError (..)) +import Simplex.Messaging.Util (ifM, safeDecodeUtf8) +import System.Directory (createDirectoryIfMissing, doesFileExist) import System.FilePath (takeDirectory) -import System.IO (hFlush, stdout) import UnliftIO.Exception (bracketOnError, onException) -import qualified UnliftIO.Exception as E import UnliftIO.MVar import UnliftIO.STM -- * SQLite Store implementation -data MigrationError - = MEUpgrade {upMigrations :: [UpMigration]} - | MEDowngrade {downMigrations :: [String]} - | MigrationError {mtrError :: MTRError} - deriving (Eq, Show) - -migrationErrorDescription :: MigrationError -> String -migrationErrorDescription = \case - MEUpgrade ums -> - "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map upName ums) - MEDowngrade dms -> - "Database version is newer than the app.\nConfirm to back up and downgrade using these migrations: " <> intercalate ", " dms - MigrationError err -> mtrErrorDescription err - -data UpMigration = UpMigration {upName :: String, withDown :: Bool} - deriving (Eq, Show) - -upMigration :: Migration -> UpMigration -upMigration Migration {name, down} = UpMigration name $ isJust down - -data MigrationConfirmation = MCYesUp | MCYesUpDown | MCConsole | MCError - deriving (Eq, Show) - -instance StrEncoding MigrationConfirmation where - strEncode = \case - MCYesUp -> "yesUp" - MCYesUpDown -> "yesUpDown" - MCConsole -> "console" - MCError -> "error" - strP = - A.takeByteString >>= \case - "yesUp" -> pure MCYesUp - "yesUpDown" -> pure MCYesUpDown - "console" -> pure MCConsole - "error" -> pure MCError - _ -> fail "invalid MigrationConfirmation" - -createSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError SQLiteStore) -createSQLiteStore dbFilePath dbKey keepKey migrations confirmMigrations = do +createDBStore :: FilePath -> ScrubbedBytes -> Bool -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createDBStore dbFilePath dbKey keepKey migrations confirmMigrations = do let dbDir = takeDirectory dbFilePath createDirectoryIfMissing True dbDir st <- connectSQLiteStore dbFilePath dbKey keepKey - r <- migrateSchema st migrations confirmMigrations `onException` closeSQLiteStore st + r <- migrateSchema st migrations confirmMigrations `onException` closeDBStore st case r of Right () -> pure $ Right st - Left e -> closeSQLiteStore st $> Left e - -migrateSchema :: SQLiteStore -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError ()) -migrateSchema st migrations confirmMigrations = do - Migrations.initialize st - Migrations.get st migrations >>= \case - Left e -> do - when (confirmMigrations == MCConsole) $ confirmOrExit ("Database state error: " <> mtrErrorDescription e) - pure . Left $ MigrationError e - Right MTRNone -> pure $ Right () - Right ms@(MTRUp ums) - | dbNew st -> Migrations.run st ms $> Right () - | otherwise -> case confirmMigrations of - MCYesUp -> run ms - MCYesUpDown -> run ms - MCConsole -> confirm err >> run ms - MCError -> pure $ Left err - where - err = MEUpgrade $ map upMigration ums -- "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map name ums) - Right ms@(MTRDown dms) -> case confirmMigrations of - MCYesUpDown -> run ms - MCConsole -> confirm err >> run ms - MCYesUp -> pure $ Left err - MCError -> pure $ Left err - where - err = MEDowngrade $ map downName dms - where - confirm err = confirmOrExit $ migrationErrorDescription err - run ms = do - let f = dbFilePath st - copyFile f (f <> ".bak") - Migrations.run st ms - pure $ Right () + Left e -> closeDBStore st $> Left e -confirmOrExit :: String -> IO () -confirmOrExit s = do - putStrLn s - putStr "Continue (y/N): " - hFlush stdout - ok <- getLine - when (map toLower ok /= "y") exitFailure - -connectSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> IO SQLiteStore +connectSQLiteStore :: FilePath -> ScrubbedBytes -> Bool -> IO DBStore connectSQLiteStore dbFilePath key keepKey = do dbNew <- not <$> doesFileExist dbFilePath dbConn <- dbBusyLoop (connectDB dbFilePath key) @@ -412,7 +83,7 @@ connectSQLiteStore dbFilePath key keepKey = do dbKey <- newTVarIO $! storeKey key keepKey dbClosed <- newTVarIO False dbSem <- newTVarIO 0 - pure SQLiteStore {dbFilePath, dbKey, dbSem, dbConnection, dbNew, dbClosed} + pure DBStore {dbFilePath, dbKey, dbSem, dbConnection, dbNew, dbClosed} connectDB :: FilePath -> ScrubbedBytes -> IO DB.Connection connectDB path key = do @@ -433,19 +104,19 @@ connectDB path key = do PRAGMA auto_vacuum = FULL; |] -closeSQLiteStore :: SQLiteStore -> IO () -closeSQLiteStore st@SQLiteStore {dbClosed} = - ifM (readTVarIO dbClosed) (putStrLn "closeSQLiteStore: already closed") $ +closeDBStore :: DBStore -> IO () +closeDBStore st@DBStore {dbClosed} = + ifM (readTVarIO dbClosed) (putStrLn "closeDBStore: already closed") $ withConnection st $ \conn -> do DB.close conn atomically $ writeTVar dbClosed True -openSQLiteStore :: SQLiteStore -> ScrubbedBytes -> Bool -> IO () -openSQLiteStore st@SQLiteStore {dbClosed} key keepKey = +openSQLiteStore :: DBStore -> ScrubbedBytes -> Bool -> IO () +openSQLiteStore st@DBStore {dbClosed} key keepKey = ifM (readTVarIO dbClosed) (openSQLiteStore_ st key keepKey) (putStrLn "openSQLiteStore: already opened") -openSQLiteStore_ :: SQLiteStore -> ScrubbedBytes -> Bool -> IO () -openSQLiteStore_ SQLiteStore {dbConnection, dbFilePath, dbKey, dbClosed} key keepKey = +openSQLiteStore_ :: DBStore -> ScrubbedBytes -> Bool -> IO () +openSQLiteStore_ DBStore {dbConnection, dbFilePath, dbKey, dbClosed} key keepKey = bracketOnError (takeMVar dbConnection) (tryPutMVar dbConnection) @@ -456,8 +127,8 @@ openSQLiteStore_ SQLiteStore {dbConnection, dbFilePath, dbKey, dbClosed} key kee writeTVar dbKey $! storeKey key keepKey putMVar dbConnection DB.Connection {conn, slow} -reopenSQLiteStore :: SQLiteStore -> IO () -reopenSQLiteStore st@SQLiteStore {dbKey, dbClosed} = +reopenSQLiteStore :: DBStore -> IO () +reopenSQLiteStore st@DBStore {dbKey, dbClosed} = ifM (readTVarIO dbClosed) open (putStrLn "reopenSQLiteStore: already opened") where open = @@ -497,2725 +168,3 @@ addSQLResultRow rs _count names values = modifyIORef' rs $ \case rs' -> showValues values : rs' where showValues = T.intercalate "|" . map (fromMaybe "") - -checkConstraint :: StoreError -> IO (Either StoreError a) -> IO (Either StoreError a) -checkConstraint err action = action `E.catch` (pure . Left . handleSQLError err) - -handleSQLError :: StoreError -> SQLError -> StoreError -handleSQLError err e - | SQL.sqlError e == SQL.ErrorConstraint = err - | otherwise = SEInternal $ bshow e - -createUserRecord :: DB.Connection -> IO UserId -createUserRecord db = do - DB.execute_ db "INSERT INTO users DEFAULT VALUES" - insertedRowId db - -checkUser :: DB.Connection -> UserId -> IO (Either StoreError ()) -checkUser db userId = - firstRow (\(_ :: Only Int64) -> ()) SEUserNotFound $ - DB.query db "SELECT user_id FROM users WHERE user_id = ? AND deleted = ?" (userId, False) - -deleteUserRecord :: DB.Connection -> UserId -> IO (Either StoreError ()) -deleteUserRecord db userId = runExceptT $ do - ExceptT $ checkUser db userId - liftIO $ DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId) - -setUserDeleted :: DB.Connection -> UserId -> IO (Either StoreError [ConnId]) -setUserDeleted db userId = runExceptT $ do - ExceptT $ checkUser db userId - liftIO $ do - DB.execute db "UPDATE users SET deleted = ? WHERE user_id = ?" (True, userId) - map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE user_id = ?" (Only userId) - -deleteUserWithoutConns :: DB.Connection -> UserId -> IO Bool -deleteUserWithoutConns db userId = do - userId_ :: Maybe Int64 <- - maybeFirstRow fromOnly $ - DB.query - db - [sql| - SELECT user_id FROM users u - WHERE u.user_id = ? - AND u.deleted = ? - AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id) - |] - (userId, True) - case userId_ of - Just _ -> DB.execute db "DELETE FROM users WHERE user_id = ?" (Only userId) $> True - _ -> pure False - -deleteUsersWithoutConns :: DB.Connection -> IO [Int64] -deleteUsersWithoutConns db = do - userIds <- - map fromOnly - <$> DB.query - db - [sql| - SELECT user_id FROM users u - WHERE u.deleted = ? - AND NOT EXISTS (SELECT c.conn_id FROM connections c WHERE c.user_id = u.user_id) - |] - (Only True) - forM_ userIds $ DB.execute db "DELETE FROM users WHERE user_id = ?" . Only - pure userIds - -createConn_ :: - TVar ChaChaDRG -> - ConnData -> - (ConnId -> IO a) -> - IO (Either StoreError (ConnId, a)) -createConn_ gVar cData create = checkConstraint SEConnDuplicate $ case cData of - ConnData {connId = ""} -> createWithRandomId' gVar create - ConnData {connId} -> Right . (connId,) <$> create connId - -createNewConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> SConnectionMode c -> IO (Either StoreError ConnId) -createNewConn db gVar cData cMode = do - fst <$$> createConn_ gVar cData (\connId -> createConnRecord db connId cData cMode) - -updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) -updateNewConnRcv db connId rq = - getConn db connId $>>= \case - (SomeConn _ NewConnection {}) -> updateConn - (SomeConn _ RcvConnection {}) -> updateConn -- to allow retries - (SomeConn c _) -> pure . Left . SEBadConnType $ connType c - where - updateConn :: IO (Either StoreError RcvQueue) - updateConn = Right <$> addConnRcvQueue_ db connId rq - -updateNewConnSnd :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) -updateNewConnSnd db connId sq = - getConn db connId $>>= \case - (SomeConn _ NewConnection {}) -> updateConn - (SomeConn c _) -> pure . Left . SEBadConnType $ connType c - where - updateConn :: IO (Either StoreError SndQueue) - updateConn = Right <$> addConnSndQueue_ db connId sq - -createSndConn :: DB.Connection -> TVar ChaChaDRG -> ConnData -> NewSndQueue -> IO (Either StoreError (ConnId, SndQueue)) -createSndConn db gVar cData q@SndQueue {server} = - -- check confirmed snd queue doesn't already exist, to prevent it being deleted by REPLACE in insertSndQueue_ - ifM (liftIO $ checkConfirmedSndQueueExists_ db q) (pure $ Left SESndQueueExists) $ - createConn_ gVar cData $ \connId -> do - serverKeyHash_ <- createServer_ db server - createConnRecord db connId cData SCMInvitation - insertSndQueue_ db connId q serverKeyHash_ - -createConnRecord :: DB.Connection -> ConnId -> ConnData -> SConnectionMode c -> IO () -createConnRecord db connId ConnData {userId, connAgentVersion, enableNtfs, pqSupport} cMode = - DB.execute - db - [sql| - INSERT INTO connections - (user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, pq_support, duplex_handshake) VALUES (?,?,?,?,?,?,?) - |] - (userId, connId, cMode, connAgentVersion, enableNtfs, pqSupport, True) - -checkConfirmedSndQueueExists_ :: DB.Connection -> NewSndQueue -> IO Bool -checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = do - fromMaybe False - <$> maybeFirstRow - fromOnly - ( DB.query - db - "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1" - (host server, port server, sndId, New) - ) - -getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn)) -getRcvConn db ProtocolServer {host, port} rcvId = runExceptT $ do - rq@RcvQueue {connId} <- - ExceptT . firstRow toRcvQueue SEConnNotFound $ - DB.query db (rcvQueueQuery <> " WHERE q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (host, port, rcvId) - (rq,) <$> ExceptT (getConn db connId) - --- | Deletes connection, optionally checking for pending snd message deliveries; returns connection id if it was deleted -deleteConn :: DB.Connection -> Maybe NominalDiffTime -> ConnId -> IO (Maybe ConnId) -deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of - Nothing -> delete - Just timeout -> - ifM - checkNoPendingDeliveries_ - delete - ( ifM - (checkWaitDeliveryTimeout_ timeout) - delete - (pure Nothing) - ) - where - delete = DB.execute db "DELETE FROM connections WHERE conn_id = ?" (Only connId) $> Just connId - checkNoPendingDeliveries_ = do - r :: (Maybe Int64) <- - maybeFirstRow fromOnly $ - DB.query db "SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND failed = 0 LIMIT 1" (Only connId) - pure $ isNothing r - checkWaitDeliveryTimeout_ timeout = do - cutoffTs <- addUTCTime (-timeout) <$> getCurrentTime - r :: (Maybe Int64) <- - maybeFirstRow fromOnly $ - DB.query db "SELECT 1 FROM connections WHERE conn_id = ? AND deleted_at_wait_delivery < ? LIMIT 1" (connId, cutoffTs) - pure $ isJust r - -upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) -upgradeRcvConnToDuplex db connId sq = - getConn db connId $>>= \case - (SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq - (SomeConn c _) -> pure . Left . SEBadConnType $ connType c - -upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) -upgradeSndConnToDuplex db connId rq = - getConn db connId >>= \case - Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq - Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c - _ -> pure $ Left SEConnNotFound - -addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> IO (Either StoreError RcvQueue) -addConnRcvQueue db connId rq = - getConn db connId >>= \case - Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq - Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c - _ -> pure $ Left SEConnNotFound - -addConnRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> IO RcvQueue -addConnRcvQueue_ db connId rq@RcvQueue {server} = do - serverKeyHash_ <- createServer_ db server - insertRcvQueue_ db connId rq serverKeyHash_ - -addConnSndQueue :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) -addConnSndQueue db connId sq = - getConn db connId >>= \case - Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnSndQueue_ db connId sq - Right (SomeConn c _) -> pure . Left . SEBadConnType $ connType c - _ -> pure $ Left SEConnNotFound - -addConnSndQueue_ :: DB.Connection -> ConnId -> NewSndQueue -> IO SndQueue -addConnSndQueue_ db connId sq@SndQueue {server} = do - serverKeyHash_ <- createServer_ db server - insertSndQueue_ db connId sq serverKeyHash_ - -setRcvQueueStatus :: DB.Connection -> RcvQueue -> QueueStatus -> IO () -setRcvQueueStatus db RcvQueue {rcvId, server = ProtocolServer {host, port}} status = - -- ? return error if queue does not exist? - DB.executeNamed - db - [sql| - UPDATE rcv_queues - SET status = :status - WHERE host = :host AND port = :port AND rcv_id = :rcv_id; - |] - [":status" := status, ":host" := host, ":port" := port, ":rcv_id" := rcvId] - -setRcvSwitchStatus :: DB.Connection -> RcvQueue -> Maybe RcvSwitchStatus -> IO RcvQueue -setRcvSwitchStatus db rq@RcvQueue {rcvId, server = ProtocolServer {host, port}} rcvSwchStatus = do - DB.execute - db - [sql| - UPDATE rcv_queues - SET switch_status = ? - WHERE host = ? AND port = ? AND rcv_id = ? - |] - (rcvSwchStatus, host, port, rcvId) - pure rq {rcvSwchStatus} - -setRcvQueueDeleted :: DB.Connection -> RcvQueue -> IO () -setRcvQueueDeleted db RcvQueue {rcvId, server = ProtocolServer {host, port}} = do - DB.execute - db - [sql| - UPDATE rcv_queues - SET deleted = 1 - WHERE host = ? AND port = ? AND rcv_id = ? - |] - (host, port, rcvId) - -setRcvQueueConfirmedE2E :: DB.Connection -> RcvQueue -> C.DhSecretX25519 -> VersionSMPC -> IO () -setRcvQueueConfirmedE2E db RcvQueue {rcvId, server = ProtocolServer {host, port}} e2eDhSecret smpClientVersion = - DB.executeNamed - db - [sql| - UPDATE rcv_queues - SET e2e_dh_secret = :e2e_dh_secret, - status = :status, - smp_client_version = :smp_client_version - WHERE host = :host AND port = :port AND rcv_id = :rcv_id - |] - [ ":status" := Confirmed, - ":e2e_dh_secret" := e2eDhSecret, - ":smp_client_version" := smpClientVersion, - ":host" := host, - ":port" := port, - ":rcv_id" := rcvId - ] - -setSndQueueStatus :: DB.Connection -> SndQueue -> QueueStatus -> IO () -setSndQueueStatus db SndQueue {sndId, server = ProtocolServer {host, port}} status = - -- ? return error if queue does not exist? - DB.executeNamed - db - [sql| - UPDATE snd_queues - SET status = :status - WHERE host = :host AND port = :port AND snd_id = :snd_id; - |] - [":status" := status, ":host" := host, ":port" := port, ":snd_id" := sndId] - -setSndSwitchStatus :: DB.Connection -> SndQueue -> Maybe SndSwitchStatus -> IO SndQueue -setSndSwitchStatus db sq@SndQueue {sndId, server = ProtocolServer {host, port}} sndSwchStatus = do - DB.execute - db - [sql| - UPDATE snd_queues - SET switch_status = ? - WHERE host = ? AND port = ? AND snd_id = ? - |] - (sndSwchStatus, host, port, sndId) - pure sq {sndSwchStatus} - -setRcvQueuePrimary :: DB.Connection -> ConnId -> RcvQueue -> IO () -setRcvQueuePrimary db connId RcvQueue {dbQueueId} = do - DB.execute db "UPDATE rcv_queues SET rcv_primary = ? WHERE conn_id = ?" (False, connId) - DB.execute - db - "UPDATE rcv_queues SET rcv_primary = ?, replace_rcv_queue_id = ? WHERE conn_id = ? AND rcv_queue_id = ?" - (True, Nothing :: Maybe Int64, connId, dbQueueId) - -setSndQueuePrimary :: DB.Connection -> ConnId -> SndQueue -> IO () -setSndQueuePrimary db connId SndQueue {dbQueueId} = do - DB.execute db "UPDATE snd_queues SET snd_primary = ? WHERE conn_id = ?" (False, connId) - DB.execute - db - "UPDATE snd_queues SET snd_primary = ?, replace_snd_queue_id = ? WHERE conn_id = ? AND snd_queue_id = ?" - (True, Nothing :: Maybe Int64, connId, dbQueueId) - -incRcvDeleteErrors :: DB.Connection -> RcvQueue -> IO () -incRcvDeleteErrors db RcvQueue {connId, dbQueueId} = - DB.execute db "UPDATE rcv_queues SET delete_errors = delete_errors + 1 WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId) - -deleteConnRcvQueue :: DB.Connection -> RcvQueue -> IO () -deleteConnRcvQueue db RcvQueue {connId, dbQueueId} = - DB.execute db "DELETE FROM rcv_queues WHERE conn_id = ? AND rcv_queue_id = ?" (connId, dbQueueId) - -deleteConnSndQueue :: DB.Connection -> ConnId -> SndQueue -> IO () -deleteConnSndQueue db connId SndQueue {dbQueueId} = do - DB.execute db "DELETE FROM snd_queues WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) - DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) - -getPrimaryRcvQueue :: DB.Connection -> ConnId -> IO (Either StoreError RcvQueue) -getPrimaryRcvQueue db connId = - maybe (Left SEConnNotFound) (Right . L.head) <$> getRcvQueuesByConnId_ db connId - -getRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue) -getRcvQueue db connId (SMPServer host port _) rcvId = - firstRow toRcvQueue SEConnNotFound $ - DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 0") (connId, host, port, rcvId) - -getDeletedRcvQueue :: DB.Connection -> ConnId -> SMPServer -> SMP.RecipientId -> IO (Either StoreError RcvQueue) -getDeletedRcvQueue db connId (SMPServer host port _) rcvId = - firstRow toRcvQueue SEConnNotFound $ - DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.host = ? AND q.port = ? AND q.rcv_id = ? AND q.deleted = 1") (connId, host, port, rcvId) - -setRcvQueueNtfCreds :: DB.Connection -> ConnId -> Maybe ClientNtfCreds -> IO () -setRcvQueueNtfCreds db connId clientNtfCreds = - DB.execute - db - [sql| - UPDATE rcv_queues - SET ntf_public_key = ?, ntf_private_key = ?, ntf_id = ?, rcv_ntf_dh_secret = ? - WHERE conn_id = ? - |] - (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_, connId) - where - (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) = case clientNtfCreds of - Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} -> (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) - Nothing -> (Nothing, Nothing, Nothing, Nothing) - -type SMPConfirmationRow = (Maybe SndPublicAuthKey, C.PublicKeyX25519, ConnInfo, Maybe [SMPQueueInfo], Maybe VersionSMPC) - -smpConfirmation :: SMPConfirmationRow -> SMPConfirmation -smpConfirmation (senderKey, e2ePubKey, connInfo, smpReplyQueues_, smpClientVersion_) = - SMPConfirmation - { senderKey, - e2ePubKey, - connInfo, - smpReplyQueues = fromMaybe [] smpReplyQueues_, - smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ - } - -createConfirmation :: DB.Connection -> TVar ChaChaDRG -> NewConfirmation -> IO (Either StoreError ConfirmationId) -createConfirmation db gVar NewConfirmation {connId, senderConf = SMPConfirmation {senderKey, e2ePubKey, connInfo, smpReplyQueues, smpClientVersion}, ratchetState} = - createWithRandomId gVar $ \confirmationId -> - DB.execute - db - [sql| - INSERT INTO conn_confirmations - (confirmation_id, conn_id, sender_key, e2e_snd_pub_key, ratchet_state, sender_conn_info, smp_reply_queues, smp_client_version, accepted) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0); - |] - (confirmationId, connId, senderKey, e2ePubKey, ratchetState, connInfo, smpReplyQueues, smpClientVersion) - -acceptConfirmation :: DB.Connection -> ConfirmationId -> ConnInfo -> IO (Either StoreError AcceptedConfirmation) -acceptConfirmation db confirmationId ownConnInfo = do - DB.executeNamed - db - [sql| - UPDATE conn_confirmations - SET accepted = 1, - own_conn_info = :own_conn_info - WHERE confirmation_id = :confirmation_id; - |] - [ ":own_conn_info" := ownConnInfo, - ":confirmation_id" := confirmationId - ] - firstRow confirmation SEConfirmationNotFound $ - DB.query - db - [sql| - SELECT conn_id, ratchet_state, sender_key, e2e_snd_pub_key, sender_conn_info, smp_reply_queues, smp_client_version - FROM conn_confirmations - WHERE confirmation_id = ?; - |] - (Only confirmationId) - where - confirmation ((connId, ratchetState) :. confRow) = - AcceptedConfirmation - { confirmationId, - connId, - senderConf = smpConfirmation confRow, - ratchetState, - ownConnInfo - } - -getAcceptedConfirmation :: DB.Connection -> ConnId -> IO (Either StoreError AcceptedConfirmation) -getAcceptedConfirmation db connId = - firstRow confirmation SEConfirmationNotFound $ - DB.query - db - [sql| - SELECT confirmation_id, ratchet_state, own_conn_info, sender_key, e2e_snd_pub_key, sender_conn_info, smp_reply_queues, smp_client_version - FROM conn_confirmations - WHERE conn_id = ? AND accepted = 1; - |] - (Only connId) - where - confirmation ((confirmationId, ratchetState, ownConnInfo) :. confRow) = - AcceptedConfirmation - { confirmationId, - connId, - senderConf = smpConfirmation confRow, - ratchetState, - ownConnInfo - } - -removeConfirmations :: DB.Connection -> ConnId -> IO () -removeConfirmations db connId = - DB.executeNamed - db - [sql| - DELETE FROM conn_confirmations - WHERE conn_id = :conn_id; - |] - [":conn_id" := connId] - -createInvitation :: DB.Connection -> TVar ChaChaDRG -> NewInvitation -> IO (Either StoreError InvitationId) -createInvitation db gVar NewInvitation {contactConnId, connReq, recipientConnInfo} = - createWithRandomId gVar $ \invitationId -> - DB.execute - db - [sql| - INSERT INTO conn_invitations - (invitation_id, contact_conn_id, cr_invitation, recipient_conn_info, accepted) VALUES (?, ?, ?, ?, 0); - |] - (invitationId, contactConnId, connReq, recipientConnInfo) - -getInvitation :: DB.Connection -> String -> InvitationId -> IO (Either StoreError Invitation) -getInvitation db cxt invitationId = - firstRow invitation (SEInvitationNotFound cxt invitationId) $ - DB.query - db - [sql| - SELECT contact_conn_id, cr_invitation, recipient_conn_info, own_conn_info, accepted - FROM conn_invitations - WHERE invitation_id = ? - AND accepted = 0 - |] - (Only invitationId) - where - invitation (contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted) = - Invitation {invitationId, contactConnId, connReq, recipientConnInfo, ownConnInfo, accepted} - -acceptInvitation :: DB.Connection -> InvitationId -> ConnInfo -> IO () -acceptInvitation db invitationId ownConnInfo = - DB.executeNamed - db - [sql| - UPDATE conn_invitations - SET accepted = 1, - own_conn_info = :own_conn_info - WHERE invitation_id = :invitation_id - |] - [ ":own_conn_info" := ownConnInfo, - ":invitation_id" := invitationId - ] - -unacceptInvitation :: DB.Connection -> InvitationId -> IO () -unacceptInvitation db invitationId = - DB.execute db "UPDATE conn_invitations SET accepted = 0, own_conn_info = NULL WHERE invitation_id = ?" (Only invitationId) - -deleteInvitation :: DB.Connection -> ConnId -> InvitationId -> IO (Either StoreError ()) -deleteInvitation db contactConnId invId = - getConn db contactConnId $>>= \case - SomeConn SCContact _ -> - Right <$> DB.execute db "DELETE FROM conn_invitations WHERE contact_conn_id = ? AND invitation_id = ?" (contactConnId, invId) - _ -> pure $ Left SEConnNotFound - -updateRcvIds :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) -updateRcvIds db connId = do - (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) <- retrieveLastIdsAndHashRcv_ db connId - let internalId = InternalId $ unId lastInternalId + 1 - internalRcvId = InternalRcvId $ unRcvId lastInternalRcvId + 1 - updateLastIdsRcv_ db connId internalId internalRcvId - pure (internalId, internalRcvId, lastExternalSndId, lastRcvHash) - -createRcvMsg :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO () -createRcvMsg db connId rq@RcvQueue {dbQueueId} rcvMsgData@RcvMsgData {msgMeta = MsgMeta {sndMsgId, broker = (_, brokerTs)}, internalRcvId, internalHash} = do - insertRcvMsgBase_ db connId rcvMsgData - insertRcvMsgDetails_ db connId rq rcvMsgData - updateRcvMsgHash db connId sndMsgId internalRcvId internalHash - DB.execute db "UPDATE rcv_queues SET last_broker_ts = ? WHERE conn_id = ? AND rcv_queue_id = ?" (brokerTs, connId, dbQueueId) - -updateSndIds :: DB.Connection -> ConnId -> IO (Either StoreError (InternalId, InternalSndId, PrevSndMsgHash)) -updateSndIds db connId = runExceptT $ do - (lastInternalId, lastInternalSndId, prevSndHash) <- ExceptT $ retrieveLastIdsAndHashSnd_ db connId - let internalId = InternalId $ unId lastInternalId + 1 - internalSndId = InternalSndId $ unSndId lastInternalSndId + 1 - liftIO $ updateLastIdsSnd_ db connId internalId internalSndId - pure (internalId, internalSndId, prevSndHash) - -createSndMsg :: DB.Connection -> ConnId -> SndMsgData -> IO () -createSndMsg db connId sndMsgData@SndMsgData {internalSndId, internalHash} = do - insertSndMsgBase_ db connId sndMsgData - insertSndMsgDetails_ db connId sndMsgData - updateSndMsgHash db connId internalSndId internalHash - -createSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> IO () -createSndMsgDelivery db connId SndQueue {dbQueueId} msgId = - DB.execute db "INSERT INTO snd_message_deliveries (conn_id, snd_queue_id, internal_id) VALUES (?, ?, ?)" (connId, dbQueueId, msgId) - -getSndMsgViaRcpt :: DB.Connection -> ConnId -> InternalSndId -> IO (Either StoreError SndMsg) -getSndMsgViaRcpt db connId sndMsgId = - firstRow toSndMsg SEMsgNotFound $ - DB.query - db - [sql| - SELECT s.internal_id, m.msg_type, s.internal_hash, s.rcpt_internal_id, s.rcpt_status - FROM snd_messages s - JOIN messages m ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id - WHERE s.conn_id = ? AND s.internal_snd_id = ? - |] - (connId, sndMsgId) - where - toSndMsg :: (InternalId, AgentMessageType, MsgHash, Maybe AgentMsgId, Maybe MsgReceiptStatus) -> SndMsg - toSndMsg (internalId, msgType, internalHash, rcptInternalId_, rcptStatus_) = - let msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_ - in SndMsg {internalId, internalSndId = sndMsgId, msgType, internalHash, msgReceipt} - -updateSndMsgRcpt :: DB.Connection -> ConnId -> InternalSndId -> MsgReceipt -> IO () -updateSndMsgRcpt db connId sndMsgId MsgReceipt {agentMsgId, msgRcptStatus} = - DB.execute - db - "UPDATE snd_messages SET rcpt_internal_id = ?, rcpt_status = ? WHERE conn_id = ? AND internal_snd_id = ?" - (agentMsgId, msgRcptStatus, connId, sndMsgId) - -getConnectionsForDelivery :: DB.Connection -> IO [ConnId] -getConnectionsForDelivery db = - map fromOnly <$> DB.query_ db "SELECT DISTINCT conn_id FROM snd_message_deliveries WHERE failed = 0" - -getPendingQueueMsg :: DB.Connection -> ConnId -> SndQueue -> IO (Either StoreError (Maybe (Maybe RcvQueue, PendingMsgData))) -getPendingQueueMsg db connId SndQueue {dbQueueId} = - getWorkItem "message" getMsgId getMsgData markMsgFailed - where - getMsgId :: IO (Maybe InternalId) - getMsgId = - maybeFirstRow fromOnly $ - DB.query - db - [sql| - SELECT internal_id - FROM snd_message_deliveries d - WHERE conn_id = ? AND snd_queue_id = ? AND failed = 0 - ORDER BY internal_id ASC - LIMIT 1 - |] - (connId, dbQueueId) - getMsgData :: InternalId -> IO (Either StoreError (Maybe RcvQueue, PendingMsgData)) - getMsgData msgId = runExceptT $ do - msg <- ExceptT $ firstRow pendingMsgData err getMsgData_ - rq_ <- liftIO $ L.head <$$> getRcvQueuesByConnId_ db connId - pure (rq_, msg) - where - getMsgData_ = - DB.query - db - [sql| - SELECT m.msg_type, m.msg_flags, m.msg_body, m.pq_encryption, m.internal_ts, s.retry_int_slow, s.retry_int_fast - FROM messages m - JOIN snd_messages s ON s.conn_id = m.conn_id AND s.internal_id = m.internal_id - WHERE m.conn_id = ? AND m.internal_id = ? - |] - (connId, msgId) - err = SEInternal $ "msg delivery " <> bshow msgId <> " returned []" - pendingMsgData :: (AgentMessageType, Maybe MsgFlags, MsgBody, PQEncryption, InternalTs, Maybe Int64, Maybe Int64) -> PendingMsgData - pendingMsgData (msgType, msgFlags_, msgBody, pqEncryption, internalTs, riSlow_, riFast_) = - let msgFlags = fromMaybe SMP.noMsgFlags msgFlags_ - msgRetryState = RI2State <$> riSlow_ <*> riFast_ - in PendingMsgData {msgId, msgType, msgFlags, msgBody, pqEncryption, msgRetryState, internalTs} - markMsgFailed msgId = DB.execute db "UPDATE snd_message_deliveries SET failed = 1 WHERE conn_id = ? AND internal_id = ?" (connId, msgId) - -getWorkItem :: Show i => ByteString -> IO (Maybe i) -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError (Maybe a)) -getWorkItem itemName getId getItem markFailed = - runExceptT $ handleWrkErr itemName "getId" getId >>= mapM (tryGetItem itemName getItem markFailed) - -getWorkItems :: Show i => ByteString -> IO [i] -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> IO (Either StoreError [Either StoreError a]) -getWorkItems itemName getIds getItem markFailed = - runExceptT $ handleWrkErr itemName "getIds" getIds >>= mapM (tryE . tryGetItem itemName getItem markFailed) - -tryGetItem :: Show i => ByteString -> (i -> IO (Either StoreError a)) -> (i -> IO ()) -> i -> ExceptT StoreError IO a -tryGetItem itemName getItem markFailed itemId = ExceptT (getItem itemId) `catchStoreError` \e -> mark >> throwE e - where - mark = handleWrkErr itemName ("markFailed ID " <> bshow itemId) $ markFailed itemId - -catchStoreError :: ExceptT StoreError IO a -> (StoreError -> ExceptT StoreError IO a) -> ExceptT StoreError IO a -catchStoreError = catchAllErrors (SEInternal . bshow) - --- Errors caught by this function will suspend worker as if there is no more work, -handleWrkErr :: ByteString -> ByteString -> IO a -> ExceptT StoreError IO a -handleWrkErr itemName opName action = ExceptT $ first mkError <$> E.try action - where - mkError :: E.SomeException -> StoreError - mkError e = SEWorkItemError $ itemName <> " " <> opName <> " error: " <> bshow e - -updatePendingMsgRIState :: DB.Connection -> ConnId -> InternalId -> RI2State -> IO () -updatePendingMsgRIState db connId msgId RI2State {slowInterval, fastInterval} = - DB.execute db "UPDATE snd_messages SET retry_int_slow = ?, retry_int_fast = ? WHERE conn_id = ? AND internal_id = ?" (slowInterval, fastInterval, connId, msgId) - -deletePendingMsgs :: DB.Connection -> ConnId -> SndQueue -> IO () -deletePendingMsgs db connId SndQueue {dbQueueId} = - DB.execute db "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ?" (connId, dbQueueId) - -getExpiredSndMessages :: DB.Connection -> ConnId -> SndQueue -> UTCTime -> IO [InternalId] -getExpiredSndMessages db connId SndQueue {dbQueueId} expireTs = do - -- type is Maybe InternalId because MAX always returns one row, possibly with NULL value - maxId :: [Maybe InternalId] <- - map fromOnly - <$> DB.query - db - [sql| - SELECT MAX(internal_id) - FROM messages - WHERE conn_id = ? AND internal_snd_id IS NOT NULL AND internal_ts < ? - |] - (connId, expireTs) - case maxId of - Just msgId : _ -> - map fromOnly - <$> DB.query - db - [sql| - SELECT internal_id - FROM snd_message_deliveries - WHERE conn_id = ? AND snd_queue_id = ? AND failed = 0 AND internal_id <= ? - ORDER BY internal_id ASC - |] - (connId, dbQueueId, msgId) - _ -> pure [] - -setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError (RcvQueue, SMP.MsgId)) -setMsgUserAck db connId agentMsgId = runExceptT $ do - (dbRcvId, srvMsgId) <- - ExceptT . firstRow id SEMsgNotFound $ - DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId) - rq <- ExceptT $ getRcvQueueById db connId dbRcvId - liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (True, connId, agentMsgId) - pure (rq, srvMsgId) - -getRcvMsg :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError RcvMsg) -getRcvMsg db connId agentMsgId = - firstRow toRcvMsg SEMsgNotFound $ - DB.query - db - [sql| - SELECT - r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash, - m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack - FROM rcv_messages r - JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id - LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id - WHERE r.conn_id = ? AND r.internal_id = ? - |] - (connId, agentMsgId) - -getLastMsg :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Maybe RcvMsg) -getLastMsg db connId msgId = - maybeFirstRow toRcvMsg $ - DB.query - db - [sql| - SELECT - r.internal_id, m.internal_ts, r.broker_id, r.broker_ts, r.external_snd_id, r.integrity, r.internal_hash, - m.msg_type, m.msg_body, m.pq_encryption, s.internal_id, s.rcpt_status, r.user_ack - FROM rcv_messages r - JOIN messages m ON r.conn_id = m.conn_id AND r.internal_id = m.internal_id - JOIN connections c ON r.conn_id = c.conn_id AND c.last_internal_msg_id = r.internal_id - LEFT JOIN snd_messages s ON s.conn_id = r.conn_id AND s.rcpt_internal_id = r.internal_id - WHERE r.conn_id = ? AND r.broker_id = ? - |] - (connId, msgId) - -toRcvMsg :: (Int64, InternalTs, BrokerId, BrokerTs) :. (AgentMsgId, MsgIntegrity, MsgHash, AgentMessageType, MsgBody, PQEncryption, Maybe AgentMsgId, Maybe MsgReceiptStatus, Bool) -> RcvMsg -toRcvMsg ((agentMsgId, internalTs, brokerId, brokerTs) :. (sndMsgId, integrity, internalHash, msgType, msgBody, pqEncryption, rcptInternalId_, rcptStatus_, userAck)) = - let msgMeta = MsgMeta {recipient = (agentMsgId, internalTs), broker = (brokerId, brokerTs), sndMsgId, integrity, pqEncryption} - msgReceipt = MsgReceipt <$> rcptInternalId_ <*> rcptStatus_ - in RcvMsg {internalId = InternalId agentMsgId, msgMeta, msgType, msgBody, internalHash, msgReceipt, userAck} - -checkRcvMsgHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool -checkRcvMsgHashExists db connId hash = do - fromMaybe False - <$> maybeFirstRow - fromOnly - ( DB.query - db - "SELECT 1 FROM encrypted_rcv_message_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" - (connId, hash) - ) - -getRcvMsgBrokerTs :: DB.Connection -> ConnId -> SMP.MsgId -> IO (Either StoreError BrokerTs) -getRcvMsgBrokerTs db connId msgId = - firstRow fromOnly SEMsgNotFound $ - DB.query db "SELECT broker_ts FROM rcv_messages WHERE conn_id = ? AND broker_id = ?" (connId, msgId) - -deleteMsg :: DB.Connection -> ConnId -> InternalId -> IO () -deleteMsg db connId msgId = - DB.execute db "DELETE FROM messages WHERE conn_id = ? AND internal_id = ?;" (connId, msgId) - -deleteMsgContent :: DB.Connection -> ConnId -> InternalId -> IO () -deleteMsgContent db connId msgId = - DB.execute db "UPDATE messages SET msg_body = x'' WHERE conn_id = ? AND internal_id = ?;" (connId, msgId) - -deleteDeliveredSndMsg :: DB.Connection -> ConnId -> InternalId -> IO () -deleteDeliveredSndMsg db connId msgId = do - cnt <- countPendingSndDeliveries_ db connId msgId - when (cnt == 0) $ deleteMsg db connId msgId - -deleteSndMsgDelivery :: DB.Connection -> ConnId -> SndQueue -> InternalId -> Bool -> IO () -deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId keepForReceipt = do - DB.execute - db - "DELETE FROM snd_message_deliveries WHERE conn_id = ? AND snd_queue_id = ? AND internal_id = ?" - (connId, dbQueueId, msgId) - cnt <- countPendingSndDeliveries_ db connId msgId - when (cnt == 0) $ do - del <- - maybeFirstRow id (DB.query db "SELECT rcpt_internal_id, rcpt_status FROM snd_messages WHERE conn_id = ? AND internal_id = ?" (connId, msgId)) >>= \case - Just (Just (_ :: Int64), Just MROk) -> pure deleteMsg - _ -> pure $ if keepForReceipt then deleteMsgContent else deleteMsg - del db connId msgId - -countPendingSndDeliveries_ :: DB.Connection -> ConnId -> InternalId -> IO Int -countPendingSndDeliveries_ db connId msgId = do - (Only cnt : _) <- DB.query db "SELECT count(*) FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ? AND failed = 0" (connId, msgId) - pure cnt - -deleteRcvMsgHashesExpired :: DB.Connection -> NominalDiffTime -> IO () -deleteRcvMsgHashesExpired db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - DB.execute db "DELETE FROM encrypted_rcv_message_hashes WHERE created_at < ?" (Only cutoffTs) - -deleteSndMsgsExpired :: DB.Connection -> NominalDiffTime -> IO () -deleteSndMsgsExpired db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - DB.execute - db - "DELETE FROM messages WHERE internal_ts < ? AND internal_snd_id IS NOT NULL" - (Only cutoffTs) - -createRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO () -createRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem = - DB.execute db "INSERT INTO ratchets (conn_id, x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem) VALUES (?, ?, ?, ?)" (connId, x3dhPrivKey1, x3dhPrivKey2, pqPrivKem) - -getRatchetX3dhKeys :: DB.Connection -> ConnId -> IO (Either StoreError (C.PrivateKeyX448, C.PrivateKeyX448, Maybe CR.RcvPrivRKEMParams)) -getRatchetX3dhKeys db connId = - firstRow' keys SEX3dhKeysNotFound $ - DB.query db "SELECT x3dh_priv_key_1, x3dh_priv_key_2, pq_priv_kem FROM ratchets WHERE conn_id = ?" (Only connId) - where - keys = \case - (Just k1, Just k2, pKem) -> Right (k1, k2, pKem) - _ -> Left SEX3dhKeysNotFound - -createSndRatchet :: DB.Connection -> ConnId -> RatchetX448 -> CR.AE2ERatchetParams 'C.X448 -> IO () -createSndRatchet db connId ratchetState (CR.AE2ERatchetParams s (CR.E2ERatchetParams _ x3dhPubKey1 x3dhPubKey2 pqPubKem)) = - DB.execute - db - [sql| - INSERT INTO ratchets - (conn_id, ratchet_state, x3dh_pub_key_1, x3dh_pub_key_2, pq_pub_kem) VALUES (?, ?, ?, ?, ?) - ON CONFLICT (conn_id) DO UPDATE SET - ratchet_state = EXCLUDED.ratchet_state, - x3dh_priv_key_1 = NULL, - x3dh_priv_key_2 = NULL, - x3dh_pub_key_1 = EXCLUDED.x3dh_pub_key_1, - x3dh_pub_key_2 = EXCLUDED.x3dh_pub_key_2, - pq_priv_kem = NULL, - pq_pub_kem = EXCLUDED.pq_pub_kem - |] - (connId, ratchetState, x3dhPubKey1, x3dhPubKey2, CR.ARKP s <$> pqPubKem) - -getSndRatchet :: DB.Connection -> ConnId -> CR.VersionE2E -> IO (Either StoreError (RatchetX448, CR.AE2ERatchetParams 'C.X448)) -getSndRatchet db connId v = - firstRow' result SEX3dhKeysNotFound $ - DB.query db "SELECT ratchet_state, x3dh_pub_key_1, x3dh_pub_key_2, pq_pub_kem FROM ratchets WHERE conn_id = ?" (Only connId) - where - result = \case - (Just ratchetState, Just k1, Just k2, pKem_) -> - let params = case pKem_ of - Nothing -> CR.AE2ERatchetParams CR.SRKSProposed (CR.E2ERatchetParams v k1 k2 Nothing) - Just (CR.ARKP s pKem) -> CR.AE2ERatchetParams s (CR.E2ERatchetParams v k1 k2 (Just pKem)) - in Right (ratchetState, params) - _ -> Left SEX3dhKeysNotFound - --- used to remember new keys when starting ratchet re-synchronization -setRatchetX3dhKeys :: DB.Connection -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> Maybe CR.RcvPrivRKEMParams -> IO () -setRatchetX3dhKeys db connId x3dhPrivKey1 x3dhPrivKey2 pqPrivKem = - DB.execute - db - [sql| - UPDATE ratchets - SET x3dh_priv_key_1 = ?, x3dh_priv_key_2 = ?, pq_priv_kem = ? - WHERE conn_id = ? - |] - (x3dhPrivKey1, x3dhPrivKey2, pqPrivKem, connId) - --- TODO remove the columns for public keys in v5.7. -createRatchet :: DB.Connection -> ConnId -> RatchetX448 -> IO () -createRatchet db connId rc = - DB.executeNamed - db - [sql| - INSERT INTO ratchets (conn_id, ratchet_state) - VALUES (:conn_id, :ratchet_state) - ON CONFLICT (conn_id) DO UPDATE SET - ratchet_state = :ratchet_state, - x3dh_priv_key_1 = NULL, - x3dh_priv_key_2 = NULL, - x3dh_pub_key_1 = NULL, - x3dh_pub_key_2 = NULL, - pq_priv_kem = NULL, - pq_pub_kem = NULL - |] - [":conn_id" := connId, ":ratchet_state" := rc] - -deleteRatchet :: DB.Connection -> ConnId -> IO () -deleteRatchet db connId = - DB.execute db "DELETE FROM ratchets WHERE conn_id = ?" (Only connId) - -getRatchet :: DB.Connection -> ConnId -> IO (Either StoreError RatchetX448) -getRatchet db connId = - firstRow' ratchet SERatchetNotFound $ DB.query db "SELECT ratchet_state FROM ratchets WHERE conn_id = ?" (Only connId) - where - ratchet = maybe (Left SERatchetNotFound) Right . fromOnly - -getSkippedMsgKeys :: DB.Connection -> ConnId -> IO SkippedMsgKeys -getSkippedMsgKeys db connId = - skipped <$> DB.query db "SELECT header_key, msg_n, msg_key FROM skipped_messages WHERE conn_id = ?" (Only connId) - where - skipped = foldl' addSkippedKey M.empty - addSkippedKey smks (hk, msgN, mk) = M.alter (Just . addMsgKey) hk smks - where - addMsgKey = maybe (M.singleton msgN mk) (M.insert msgN mk) - -updateRatchet :: DB.Connection -> ConnId -> RatchetX448 -> SkippedMsgDiff -> IO () -updateRatchet db connId rc skipped = do - DB.execute db "UPDATE ratchets SET ratchet_state = ? WHERE conn_id = ?" (rc, connId) - case skipped of - SMDNoChange -> pure () - SMDRemove hk msgN -> - DB.execute db "DELETE FROM skipped_messages WHERE conn_id = ? AND header_key = ? AND msg_n = ?" (connId, hk, msgN) - SMDAdd smks -> - forM_ (M.assocs smks) $ \(hk, mks) -> - forM_ (M.assocs mks) $ \(msgN, mk) -> - DB.execute db "INSERT INTO skipped_messages (conn_id, header_key, msg_n, msg_key) VALUES (?, ?, ?, ?)" (connId, hk, msgN, mk) - -createCommand :: DB.Connection -> ACorrId -> ConnId -> Maybe SMPServer -> AgentCommand -> IO (Either StoreError ()) -createCommand db corrId connId srv_ cmd = runExceptT $ do - (host_, port_, serverKeyHash_) <- serverFields - createdAt <- liftIO getCurrentTime - liftIO . E.handle handleErr $ - DB.execute - db - "INSERT INTO commands (host, port, corr_id, conn_id, command_tag, command, server_key_hash, created_at) VALUES (?,?,?,?,?,?,?,?)" - (host_, port_, corrId, connId, cmdTag, cmd, serverKeyHash_, createdAt) - where - cmdTag = agentCommandTag cmd - handleErr e - | SQL.sqlError e == SQL.ErrorConstraint = logError $ "tried to create command " <> tshow cmdTag <> " for deleted connection" - | otherwise = E.throwIO e - serverFields :: ExceptT StoreError IO (Maybe (NonEmpty TransportHost), Maybe ServiceName, Maybe C.KeyHash) - serverFields = case srv_ of - Just srv@(SMPServer host port _) -> - (Just host,Just port,) <$> ExceptT (getServerKeyHash_ db srv) - Nothing -> pure (Nothing, Nothing, Nothing) - -insertedRowId :: DB.Connection -> IO Int64 -insertedRowId db = fromOnly . head <$> DB.query_ db "SELECT last_insert_rowid()" - -getPendingCommandServers :: DB.Connection -> ConnId -> IO [Maybe SMPServer] -getPendingCommandServers db connId = do - -- TODO review whether this can break if, e.g., the server has another key hash. - map smpServer - <$> DB.query - db - [sql| - SELECT DISTINCT c.host, c.port, COALESCE(c.server_key_hash, s.key_hash) - FROM commands c - LEFT JOIN servers s ON s.host = c.host AND s.port = c.port - WHERE conn_id = ? - |] - (Only connId) - where - smpServer (host, port, keyHash) = SMPServer <$> host <*> port <*> keyHash - -getPendingServerCommand :: DB.Connection -> ConnId -> Maybe SMPServer -> IO (Either StoreError (Maybe PendingCommand)) -getPendingServerCommand db connId srv_ = getWorkItem "command" getCmdId getCommand markCommandFailed - where - getCmdId :: IO (Maybe Int64) - getCmdId = - maybeFirstRow fromOnly $ case srv_ of - Nothing -> - DB.query - db - [sql| - SELECT command_id FROM commands - WHERE conn_id = ? AND host IS NULL AND port IS NULL AND failed = 0 - ORDER BY created_at ASC, command_id ASC - LIMIT 1 - |] - (Only connId) - Just (SMPServer host port _) -> - DB.query - db - [sql| - SELECT command_id FROM commands - WHERE conn_id = ? AND host = ? AND port = ? AND failed = 0 - ORDER BY created_at ASC, command_id ASC - LIMIT 1 - |] - (connId, host, port) - getCommand :: Int64 -> IO (Either StoreError PendingCommand) - getCommand cmdId = - firstRow pendingCommand err $ - DB.query - db - [sql| - SELECT c.corr_id, cs.user_id, c.command - FROM commands c - JOIN connections cs USING (conn_id) - WHERE c.command_id = ? - |] - (Only cmdId) - where - err = SEInternal $ "command " <> bshow cmdId <> " returned []" - pendingCommand (corrId, userId, command) = PendingCommand {cmdId, corrId, userId, connId, command} - markCommandFailed cmdId = DB.execute db "UPDATE commands SET failed = 1 WHERE command_id = ?" (Only cmdId) - -updateCommandServer :: DB.Connection -> AsyncCmdId -> SMPServer -> IO (Either StoreError ()) -updateCommandServer db cmdId srv@(SMPServer host port _) = runExceptT $ do - serverKeyHash_ <- ExceptT $ getServerKeyHash_ db srv - liftIO $ - DB.execute - db - [sql| - UPDATE commands - SET host = ?, port = ?, server_key_hash = ? - WHERE command_id = ? - |] - (host, port, serverKeyHash_, cmdId) - -deleteCommand :: DB.Connection -> AsyncCmdId -> IO () -deleteCommand db cmdId = - DB.execute db "DELETE FROM commands WHERE command_id = ?" (Only cmdId) - -createNtfToken :: DB.Connection -> NtfToken -> IO () -createNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = srv@ProtocolServer {host, port}, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey), ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} = do - upsertNtfServer_ db srv - DB.execute - db - [sql| - INSERT INTO ntf_tokens - (provider, device_token, ntf_host, ntf_port, tkn_id, tkn_pub_key, tkn_priv_key, tkn_pub_dh_key, tkn_priv_dh_key, tkn_dh_secret, tkn_status, tkn_action, ntf_mode) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - |] - ((provider, token, host, port, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode)) - -getSavedNtfToken :: DB.Connection -> IO (Maybe NtfToken) -getSavedNtfToken db = do - maybeFirstRow ntfToken $ - DB.query_ - db - [sql| - SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash, - t.provider, t.device_token, t.tkn_id, t.tkn_pub_key, t.tkn_priv_key, t.tkn_pub_dh_key, t.tkn_priv_dh_key, t.tkn_dh_secret, - t.tkn_status, t.tkn_action, t.ntf_mode - FROM ntf_tokens t - JOIN ntf_servers s USING (ntf_host, ntf_port) - |] - where - ntfToken ((host, port, keyHash) :. (provider, dt, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode_)) = - let ntfServer = NtfServer host port keyHash - ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey) - ntfMode = fromMaybe NMPeriodic ntfMode_ - in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} - -updateNtfTokenRegistration :: DB.Connection -> NtfToken -> NtfTokenId -> C.DhSecretX25519 -> IO () -updateNtfTokenRegistration db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknId ntfDhSecret = do - updatedAt <- getCurrentTime - DB.execute - db - [sql| - UPDATE ntf_tokens - SET tkn_id = ?, tkn_dh_secret = ?, tkn_status = ?, tkn_action = ?, updated_at = ? - WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? - |] - (tknId, ntfDhSecret, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port) - -updateDeviceToken :: DB.Connection -> NtfToken -> DeviceToken -> IO () -updateDeviceToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} (DeviceToken toProvider toToken) = do - updatedAt <- getCurrentTime - DB.execute - db - [sql| - UPDATE ntf_tokens - SET provider = ?, device_token = ?, tkn_status = ?, tkn_action = ?, updated_at = ? - WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? - |] - (toProvider, toToken, NTRegistered, Nothing :: Maybe NtfTknAction, updatedAt, provider, token, host, port) - -updateNtfMode :: DB.Connection -> NtfToken -> NotificationsMode -> IO () -updateNtfMode db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} ntfMode = do - updatedAt <- getCurrentTime - DB.execute - db - [sql| - UPDATE ntf_tokens - SET ntf_mode = ?, updated_at = ? - WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? - |] - (ntfMode, updatedAt, provider, token, host, port) - -updateNtfToken :: DB.Connection -> NtfToken -> NtfTknStatus -> Maybe NtfTknAction -> IO () -updateNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} tknStatus tknAction = do - updatedAt <- getCurrentTime - DB.execute - db - [sql| - UPDATE ntf_tokens - SET tkn_status = ?, tkn_action = ?, updated_at = ? - WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? - |] - (tknStatus, tknAction, updatedAt, provider, token, host, port) - -removeNtfToken :: DB.Connection -> NtfToken -> IO () -removeNtfToken db NtfToken {deviceToken = DeviceToken provider token, ntfServer = ProtocolServer {host, port}} = - DB.execute - db - [sql| - DELETE FROM ntf_tokens - WHERE provider = ? AND device_token = ? AND ntf_host = ? AND ntf_port = ? - |] - (provider, token, host, port) - -addNtfTokenToDelete :: DB.Connection -> NtfServer -> C.APrivateAuthKey -> NtfTokenId -> IO () -addNtfTokenToDelete db ProtocolServer {host, port, keyHash} ntfPrivKey tknId = - DB.execute db "INSERT INTO ntf_tokens_to_delete (ntf_host, ntf_port, ntf_key_hash, tkn_id, tkn_priv_key) VALUES (?,?,?,?,?)" (host, port, keyHash, tknId, ntfPrivKey) - -deleteExpiredNtfTokensToDelete :: DB.Connection -> NominalDiffTime -> IO () -deleteExpiredNtfTokensToDelete db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - DB.execute db "DELETE FROM ntf_tokens_to_delete WHERE created_at < ?" (Only cutoffTs) - -type NtfTokenToDelete = (Int64, C.APrivateAuthKey, NtfTokenId) - -getNextNtfTokenToDelete :: DB.Connection -> NtfServer -> IO (Either StoreError (Maybe NtfTokenToDelete)) -getNextNtfTokenToDelete db (NtfServer ntfHost ntfPort _) = - getWorkItem "ntf tkn del" getNtfTknDbId getNtfTknToDelete (markNtfTokenToDeleteFailed_ db) - where - getNtfTknDbId :: IO (Maybe Int64) - getNtfTknDbId = - maybeFirstRow fromOnly $ - DB.query - db - [sql| - SELECT ntf_token_to_delete_id - FROM ntf_tokens_to_delete - WHERE ntf_host = ? AND ntf_port = ? - AND del_failed = 0 - ORDER BY created_at ASC - LIMIT 1 - |] - (ntfHost, ntfPort) - getNtfTknToDelete :: Int64 -> IO (Either StoreError NtfTokenToDelete) - getNtfTknToDelete tknDbId = - firstRow ntfTokenToDelete err $ - DB.query - db - [sql| - SELECT tkn_priv_key, tkn_id - FROM ntf_tokens_to_delete - WHERE ntf_token_to_delete_id = ? - |] - (Only tknDbId) - where - err = SEInternal $ "ntf token to delete " <> bshow tknDbId <> " returned []" - ntfTokenToDelete (tknPrivKey, tknId) = (tknDbId, tknPrivKey, tknId) - -markNtfTokenToDeleteFailed_ :: DB.Connection -> Int64 -> IO () -markNtfTokenToDeleteFailed_ db tknDbId = - DB.execute db "UPDATE ntf_tokens_to_delete SET del_failed = 1 where ntf_token_to_delete_id = ?" (Only tknDbId) - -getPendingDelTknServers :: DB.Connection -> IO [NtfServer] -getPendingDelTknServers db = - map toNtfServer - <$> DB.query_ - db - [sql| - SELECT DISTINCT ntf_host, ntf_port, ntf_key_hash - FROM ntf_tokens_to_delete - |] - where - toNtfServer (host, port, keyHash) = NtfServer host port keyHash - -deleteNtfTokenToDelete :: DB.Connection -> Int64 -> IO () -deleteNtfTokenToDelete db tknDbId = - DB.execute db "DELETE FROM ntf_tokens_to_delete WHERE ntf_token_to_delete_id = ?" (Only tknDbId) - -type NtfSupervisorSub = (NtfSubscription, Maybe (NtfSubAction, NtfActionTs)) - -getNtfSubscription :: DB.Connection -> ConnId -> IO (Maybe NtfSupervisorSub) -getNtfSubscription db connId = - maybeFirstRow ntfSubscription $ - DB.query - db - [sql| - SELECT c.user_id, s.host, s.port, COALESCE(nsb.smp_server_key_hash, s.key_hash), ns.ntf_host, ns.ntf_port, ns.ntf_key_hash, - nsb.smp_ntf_id, nsb.ntf_sub_id, nsb.ntf_sub_status, nsb.ntf_sub_action, nsb.ntf_sub_smp_action, nsb.ntf_sub_action_ts - FROM ntf_subscriptions nsb - JOIN connections c USING (conn_id) - JOIN servers s ON s.host = nsb.smp_host AND s.port = nsb.smp_port - JOIN ntf_servers ns USING (ntf_host, ntf_port) - WHERE nsb.conn_id = ? - |] - (Only connId) - where - ntfSubscription ((userId, smpHost, smpPort, smpKeyHash, ntfHost, ntfPort, ntfKeyHash) :. (ntfQueueId, ntfSubId, ntfSubStatus, ntfAction_, smpAction_, actionTs_)) = - let smpServer = SMPServer smpHost smpPort smpKeyHash - ntfServer = NtfServer ntfHost ntfPort ntfKeyHash - action = case (ntfAction_, smpAction_, actionTs_) of - (Just ntfAction, Nothing, Just actionTs) -> Just (NSANtf ntfAction, actionTs) - (Nothing, Just smpAction, Just actionTs) -> Just (NSASMP smpAction, actionTs) - _ -> Nothing - in (NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus}, action) - -createNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO (Either StoreError ()) -createNtfSubscription db ntfSubscription action = runExceptT $ do - let NtfSubscription {connId, smpServer = smpServer@(SMPServer host port _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} = ntfSubscription - smpServerKeyHash_ <- ExceptT $ getServerKeyHash_ db smpServer - actionTs <- liftIO getCurrentTime - liftIO $ - DB.execute - db - [sql| - INSERT INTO ntf_subscriptions - (conn_id, smp_host, smp_port, smp_ntf_id, ntf_host, ntf_port, ntf_sub_id, - ntf_sub_status, ntf_sub_action, ntf_sub_smp_action, ntf_sub_action_ts, smp_server_key_hash) - VALUES (?,?,?,?,?,?,?,?,?,?,?,?) - |] - ( (connId, host, port, ntfQueueId, ntfHost, ntfPort, ntfSubId) - :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, smpServerKeyHash_) - ) - where - (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action - -supervisorUpdateNtfSub :: DB.Connection -> NtfSubscription -> NtfSubAction -> IO () -supervisorUpdateNtfSub db NtfSubscription {connId, smpServer = (SMPServer smpHost smpPort _), ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action = do - ts <- getCurrentTime - DB.execute - db - [sql| - UPDATE ntf_subscriptions - SET smp_host = ?, smp_port = ?, smp_ntf_id = ?, ntf_host = ?, ntf_port = ?, ntf_sub_id = ?, - ntf_sub_status = ?, ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? - WHERE conn_id = ? - |] - ( (smpHost, smpPort, ntfQueueId, ntfHost, ntfPort, ntfSubId) - :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, ts, True, ts, connId) - ) - where - (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action - -supervisorUpdateNtfAction :: DB.Connection -> ConnId -> NtfSubAction -> IO () -supervisorUpdateNtfAction db connId action = do - ts <- getCurrentTime - DB.execute - db - [sql| - UPDATE ntf_subscriptions - SET ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? - WHERE conn_id = ? - |] - (ntfSubAction, ntfSubSMPAction, ts, True, ts, connId) - where - (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action - -updateNtfSubscription :: DB.Connection -> NtfSubscription -> NtfSubAction -> NtfActionTs -> IO () -updateNtfSubscription db NtfSubscription {connId, ntfQueueId, ntfServer = (NtfServer ntfHost ntfPort _), ntfSubId, ntfSubStatus} action actionTs = do - r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) - forM_ r $ \updatedBySupervisor -> do - updatedAt <- getCurrentTime - if updatedBySupervisor - then - DB.execute - db - [sql| - UPDATE ntf_subscriptions - SET smp_ntf_id = ?, ntf_sub_id = ?, ntf_sub_status = ?, updated_by_supervisor = ?, updated_at = ? - WHERE conn_id = ? - |] - (ntfQueueId, ntfSubId, ntfSubStatus, False, updatedAt, connId) - else - DB.execute - db - [sql| - UPDATE ntf_subscriptions - SET smp_ntf_id = ?, ntf_host = ?, ntf_port = ?, ntf_sub_id = ?, ntf_sub_status = ?, ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? - WHERE conn_id = ? - |] - ((ntfQueueId, ntfHost, ntfPort, ntfSubId) :. (ntfSubStatus, ntfSubAction, ntfSubSMPAction, actionTs, False, updatedAt, connId)) - where - (ntfSubAction, ntfSubSMPAction) = ntfSubAndSMPAction action - -setNullNtfSubscriptionAction :: DB.Connection -> ConnId -> IO () -setNullNtfSubscriptionAction db connId = do - r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) - forM_ r $ \updatedBySupervisor -> - unless updatedBySupervisor $ do - updatedAt <- getCurrentTime - DB.execute - db - [sql| - UPDATE ntf_subscriptions - SET ntf_sub_action = ?, ntf_sub_smp_action = ?, ntf_sub_action_ts = ?, updated_by_supervisor = ?, updated_at = ? - WHERE conn_id = ? - |] - (Nothing :: Maybe NtfSubNTFAction, Nothing :: Maybe NtfSubSMPAction, Nothing :: Maybe UTCTime, False, updatedAt, connId) - -deleteNtfSubscription :: DB.Connection -> ConnId -> IO () -deleteNtfSubscription db connId = do - r <- maybeFirstRow fromOnly $ DB.query db "SELECT updated_by_supervisor FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) - forM_ r $ \updatedBySupervisor -> do - updatedAt <- getCurrentTime - if updatedBySupervisor - then - DB.execute - db - [sql| - UPDATE ntf_subscriptions - SET smp_ntf_id = ?, ntf_sub_id = ?, ntf_sub_status = ?, updated_by_supervisor = ?, updated_at = ? - WHERE conn_id = ? - |] - (Nothing :: Maybe SMP.NotifierId, Nothing :: Maybe NtfSubscriptionId, NASDeleted, False, updatedAt, connId) - else deleteNtfSubscription' db connId - -deleteNtfSubscription' :: DB.Connection -> ConnId -> IO () -deleteNtfSubscription' db connId = do - DB.execute db "DELETE FROM ntf_subscriptions WHERE conn_id = ?" (Only connId) - -getNextNtfSubNTFActions :: DB.Connection -> NtfServer -> Int -> IO (Either StoreError [Either StoreError (NtfSubNTFAction, NtfSubscription, NtfActionTs)]) -getNextNtfSubNTFActions db ntfServer@(NtfServer ntfHost ntfPort _) ntfBatchSize = - getWorkItems "ntf NTF" getNtfConnIds getNtfSubAction (markNtfSubActionNtfFailed_ db) - where - getNtfConnIds :: IO [ConnId] - getNtfConnIds = - map fromOnly - <$> DB.query - db - [sql| - SELECT conn_id - FROM ntf_subscriptions - WHERE ntf_host = ? AND ntf_port = ? AND ntf_sub_action IS NOT NULL - AND (ntf_failed = 0 OR updated_by_supervisor = 1) - ORDER BY ntf_sub_action_ts ASC - LIMIT ? - |] - (ntfHost, ntfPort, ntfBatchSize) - getNtfSubAction :: ConnId -> IO (Either StoreError (NtfSubNTFAction, NtfSubscription, NtfActionTs)) - getNtfSubAction connId = do - markUpdatedByWorker db connId - firstRow ntfSubAction err $ - DB.query - db - [sql| - SELECT c.user_id, s.host, s.port, COALESCE(ns.smp_server_key_hash, s.key_hash), - ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_action_ts, ns.ntf_sub_action - FROM ntf_subscriptions ns - JOIN connections c USING (conn_id) - JOIN servers s ON s.host = ns.smp_host AND s.port = ns.smp_port - WHERE ns.conn_id = ? - |] - (Only connId) - where - err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []" - ntfSubAction (userId, smpHost, smpPort, smpKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, actionTs, action) = - let smpServer = SMPServer smpHost smpPort smpKeyHash - ntfSubscription = NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} - in (action, ntfSubscription, actionTs) - -markNtfSubActionNtfFailed_ :: DB.Connection -> ConnId -> IO () -markNtfSubActionNtfFailed_ db connId = - DB.execute db "UPDATE ntf_subscriptions SET ntf_failed = 1 where conn_id = ?" (Only connId) - -getNextNtfSubSMPActions :: DB.Connection -> SMPServer -> Int -> IO (Either StoreError [Either StoreError (NtfSubSMPAction, NtfSubscription)]) -getNextNtfSubSMPActions db smpServer@(SMPServer smpHost smpPort _) ntfBatchSize = - getWorkItems "ntf SMP" getNtfConnIds getNtfSubAction (markNtfSubActionSMPFailed_ db) - where - getNtfConnIds :: IO [ConnId] - getNtfConnIds = - map fromOnly - <$> DB.query - db - [sql| - SELECT conn_id - FROM ntf_subscriptions ns - WHERE smp_host = ? AND smp_port = ? AND ntf_sub_smp_action IS NOT NULL AND ntf_sub_action_ts IS NOT NULL - AND (smp_failed = 0 OR updated_by_supervisor = 1) - ORDER BY ntf_sub_action_ts ASC - LIMIT ? - |] - (smpHost, smpPort, ntfBatchSize) - getNtfSubAction :: ConnId -> IO (Either StoreError (NtfSubSMPAction, NtfSubscription)) - getNtfSubAction connId = do - markUpdatedByWorker db connId - firstRow ntfSubAction err $ - DB.query - db - [sql| - SELECT c.user_id, s.ntf_host, s.ntf_port, s.ntf_key_hash, - ns.smp_ntf_id, ns.ntf_sub_id, ns.ntf_sub_status, ns.ntf_sub_smp_action - FROM ntf_subscriptions ns - JOIN connections c USING (conn_id) - JOIN ntf_servers s USING (ntf_host, ntf_port) - WHERE ns.conn_id = ? - |] - (Only connId) - where - err = SEInternal $ "ntf subscription " <> bshow connId <> " returned []" - ntfSubAction (userId, ntfHost, ntfPort, ntfKeyHash, ntfQueueId, ntfSubId, ntfSubStatus, action) = - let ntfServer = NtfServer ntfHost ntfPort ntfKeyHash - ntfSubscription = NtfSubscription {userId, connId, smpServer, ntfQueueId, ntfServer, ntfSubId, ntfSubStatus} - in (action, ntfSubscription) - -markNtfSubActionSMPFailed_ :: DB.Connection -> ConnId -> IO () -markNtfSubActionSMPFailed_ db connId = - DB.execute db "UPDATE ntf_subscriptions SET smp_failed = 1 where conn_id = ?" (Only connId) - -markUpdatedByWorker :: DB.Connection -> ConnId -> IO () -markUpdatedByWorker db connId = - DB.execute db "UPDATE ntf_subscriptions SET updated_by_supervisor = 0 WHERE conn_id = ?" (Only connId) - -getActiveNtfToken :: DB.Connection -> IO (Maybe NtfToken) -getActiveNtfToken db = - maybeFirstRow ntfToken $ - DB.query - db - [sql| - SELECT s.ntf_host, s.ntf_port, s.ntf_key_hash, - t.provider, t.device_token, t.tkn_id, t.tkn_pub_key, t.tkn_priv_key, t.tkn_pub_dh_key, t.tkn_priv_dh_key, t.tkn_dh_secret, - t.tkn_status, t.tkn_action, t.ntf_mode - FROM ntf_tokens t - JOIN ntf_servers s USING (ntf_host, ntf_port) - WHERE t.tkn_status = ? - |] - (Only NTActive) - where - ntfToken ((host, port, keyHash) :. (provider, dt, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhPubKey, ntfDhPrivKey, ntfDhSecret) :. (ntfTknStatus, ntfTknAction, ntfMode_)) = - let ntfServer = NtfServer host port keyHash - ntfDhKeys = (ntfDhPubKey, ntfDhPrivKey) - ntfMode = fromMaybe NMPeriodic ntfMode_ - in NtfToken {deviceToken = DeviceToken provider dt, ntfServer, ntfTokenId, ntfPubKey, ntfPrivKey, ntfDhKeys, ntfDhSecret, ntfTknStatus, ntfTknAction, ntfMode} - -getNtfRcvQueue :: DB.Connection -> SMPQueueNtf -> IO (Either StoreError (ConnId, RcvNtfDhSecret, Maybe UTCTime)) -getNtfRcvQueue db SMPQueueNtf {smpServer = (SMPServer host port _), notifierId} = - firstRow' res SEConnNotFound $ - DB.query - db - [sql| - SELECT conn_id, rcv_ntf_dh_secret, last_broker_ts - FROM rcv_queues - WHERE host = ? AND port = ? AND ntf_id = ? AND deleted = 0 - |] - (host, port, notifierId) - where - res (connId, Just rcvNtfDhSecret, lastBrokerTs_) = Right (connId, rcvNtfDhSecret, lastBrokerTs_) - res _ = Left SEConnNotFound - -setConnectionNtfs :: DB.Connection -> ConnId -> Bool -> IO () -setConnectionNtfs db connId enableNtfs = - DB.execute db "UPDATE connections SET enable_ntfs = ? WHERE conn_id = ?" (enableNtfs, connId) - --- * Auxiliary helpers - -instance ToField QueueStatus where toField = toField . serializeQueueStatus - -instance FromField QueueStatus where fromField = fromTextField_ queueStatusT - -instance ToField (DBQueueId 'QSStored) where toField (DBQueueId qId) = toField qId - -instance FromField (DBQueueId 'QSStored) where fromField x = DBQueueId <$> fromField x - -instance ToField InternalRcvId where toField (InternalRcvId x) = toField x - -instance FromField InternalRcvId where fromField x = InternalRcvId <$> fromField x - -instance ToField InternalSndId where toField (InternalSndId x) = toField x - -instance FromField InternalSndId where fromField x = InternalSndId <$> fromField x - -instance ToField InternalId where toField (InternalId x) = toField x - -instance FromField InternalId where fromField x = InternalId <$> fromField x - -instance ToField AgentMessageType where toField = toField . smpEncode - -instance FromField AgentMessageType where fromField = blobFieldParser smpP - -instance ToField MsgIntegrity where toField = toField . strEncode - -instance FromField MsgIntegrity where fromField = blobFieldParser strP - -instance ToField SMPQueueUri where toField = toField . strEncode - -instance FromField SMPQueueUri where fromField = blobFieldParser strP - -instance ToField AConnectionRequestUri where toField = toField . strEncode - -instance FromField AConnectionRequestUri where fromField = blobFieldParser strP - -instance ConnectionModeI c => ToField (ConnectionRequestUri c) where toField = toField . strEncode - -instance (E.Typeable c, ConnectionModeI c) => FromField (ConnectionRequestUri c) where fromField = blobFieldParser strP - -instance ToField ConnectionMode where toField = toField . decodeLatin1 . strEncode - -instance FromField ConnectionMode where fromField = fromTextField_ connModeT - -instance ToField (SConnectionMode c) where toField = toField . connMode - -instance FromField AConnectionMode where fromField = fromTextField_ $ fmap connMode' . connModeT - -instance ToField MsgFlags where toField = toField . decodeLatin1 . smpEncode - -instance FromField MsgFlags where fromField = fromTextField_ $ eitherToMaybe . smpDecode . encodeUtf8 - -instance ToField [SMPQueueInfo] where toField = toField . smpEncodeList - -instance FromField [SMPQueueInfo] where fromField = blobFieldParser smpListP - -instance ToField (NonEmpty TransportHost) where toField = toField . decodeLatin1 . strEncode - -instance FromField (NonEmpty TransportHost) where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 - -instance ToField AgentCommand where toField = toField . strEncode - -instance FromField AgentCommand where fromField = blobFieldParser strP - -instance ToField AgentCommandTag where toField = toField . strEncode - -instance FromField AgentCommandTag where fromField = blobFieldParser strP - -instance ToField MsgReceiptStatus where toField = toField . decodeLatin1 . strEncode - -instance FromField MsgReceiptStatus where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8 - -instance ToField (Version v) where toField (Version v) = toField v - -instance FromField (Version v) where fromField f = Version <$> fromField f - -deriving newtype instance ToField EntityId - -deriving newtype instance FromField EntityId - -deriving newtype instance ToField ChunkReplicaId - -deriving newtype instance FromField ChunkReplicaId - -listToEither :: e -> [a] -> Either e a -listToEither _ (x : _) = Right x -listToEither e _ = Left e - -firstRow :: (a -> b) -> e -> IO [a] -> IO (Either e b) -firstRow f e a = second f . listToEither e <$> a - -maybeFirstRow :: Functor f => (a -> b) -> f [a] -> f (Maybe b) -maybeFirstRow f q = fmap f . listToMaybe <$> q - -firstRow' :: (a -> Either e b) -> e -> IO [a] -> IO (Either e b) -firstRow' f e a = (f <=< listToEither e) <$> a - -{- ORMOLU_DISABLE -} --- SQLite.Simple only has these up to 10 fields, which is insufficient for some of our queries -instance (FromField a, FromField b, FromField c, FromField d, FromField e, - FromField f, FromField g, FromField h, FromField i, FromField j, - FromField k) => - FromRow (a,b,c,d,e,f,g,h,i,j,k) where - fromRow = (,,,,,,,,,,) <$> field <*> field <*> field <*> field <*> field - <*> field <*> field <*> field <*> field <*> field - <*> field - -instance (FromField a, FromField b, FromField c, FromField d, FromField e, - FromField f, FromField g, FromField h, FromField i, FromField j, - FromField k, FromField l) => - FromRow (a,b,c,d,e,f,g,h,i,j,k,l) where - fromRow = (,,,,,,,,,,,) <$> field <*> field <*> field <*> field <*> field - <*> field <*> field <*> field <*> field <*> field - <*> field <*> field - -instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, - ToField g, ToField h, ToField i, ToField j, ToField k, ToField l) => - ToRow (a,b,c,d,e,f,g,h,i,j,k,l) where - toRow (a,b,c,d,e,f,g,h,i,j,k,l) = - [ toField a, toField b, toField c, toField d, toField e, toField f, - toField g, toField h, toField i, toField j, toField k, toField l - ] - -{- ORMOLU_ENABLE -} - --- * Server helper - --- | Creates a new server, if it doesn't exist, and returns the passed key hash if it is different from stored. -createServer_ :: DB.Connection -> SMPServer -> IO (Maybe C.KeyHash) -createServer_ db newSrv@ProtocolServer {host, port, keyHash} = - getServerKeyHash_ db newSrv >>= \case - Right keyHash_ -> pure keyHash_ - Left _ -> insertNewServer_ $> Nothing - where - insertNewServer_ = - DB.execute db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?)" (host, port, keyHash) - --- | Returns the passed server key hash if it is different from the stored one, or the error if the server does not exist. -getServerKeyHash_ :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe C.KeyHash)) -getServerKeyHash_ db ProtocolServer {host, port, keyHash} = do - firstRow useKeyHash SEServerNotFound $ - DB.query db "SELECT key_hash FROM servers WHERE host = ? AND port = ?" (host, port) - where - useKeyHash (Only keyHash') = if keyHash /= keyHash' then Just keyHash else Nothing - -upsertNtfServer_ :: DB.Connection -> NtfServer -> IO () -upsertNtfServer_ db ProtocolServer {host, port, keyHash} = do - DB.executeNamed - db - [sql| - INSERT INTO ntf_servers (ntf_host, ntf_port, ntf_key_hash) VALUES (:host,:port,:key_hash) - ON CONFLICT (ntf_host, ntf_port) DO UPDATE SET - ntf_host=excluded.ntf_host, - ntf_port=excluded.ntf_port, - ntf_key_hash=excluded.ntf_key_hash; - |] - [":host" := host, ":port" := port, ":key_hash" := keyHash] - --- * createRcvConn helpers - -insertRcvQueue_ :: DB.Connection -> ConnId -> NewRcvQueue -> Maybe C.KeyHash -> IO RcvQueue -insertRcvQueue_ db connId' rq@RcvQueue {..} serverKeyHash_ = do - -- to preserve ID if the queue already exists. - -- possibly, it can be done in one query. - currQId_ <- maybeFirstRow fromOnly $ DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? AND host = ? AND port = ? AND snd_id = ?" (connId', host server, port server, sndId) - qId <- maybe (newQueueId_ <$> DB.query db "SELECT rcv_queue_id FROM rcv_queues WHERE conn_id = ? ORDER BY rcv_queue_id DESC LIMIT 1" (Only connId')) pure currQId_ - DB.execute - db - [sql| - INSERT INTO rcv_queues - (host, port, rcv_id, conn_id, rcv_private_key, rcv_dh_secret, e2e_priv_key, e2e_dh_secret, snd_id, snd_secure, status, rcv_queue_id, rcv_primary, replace_rcv_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); - |] - ((host server, port server, rcvId, connId', rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret) :. (sndId, sndSecure, status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)) - pure (rq :: NewRcvQueue) {connId = connId', dbQueueId = qId} - --- * createSndConn helpers - -insertSndQueue_ :: DB.Connection -> ConnId -> NewSndQueue -> Maybe C.KeyHash -> IO SndQueue -insertSndQueue_ db connId' sq@SndQueue {..} serverKeyHash_ = do - -- to preserve ID if the queue already exists. - -- possibly, it can be done in one query. - currQId_ <- maybeFirstRow fromOnly $ DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? AND host = ? AND port = ? AND snd_id = ?" (connId', host server, port server, sndId) - qId <- maybe (newQueueId_ <$> DB.query db "SELECT snd_queue_id FROM snd_queues WHERE conn_id = ? ORDER BY snd_queue_id DESC LIMIT 1" (Only connId')) pure currQId_ - DB.execute - db - [sql| - INSERT OR REPLACE INTO snd_queues - (host, port, snd_id, snd_secure, conn_id, snd_public_key, snd_private_key, e2e_pub_key, e2e_dh_secret, status, snd_queue_id, snd_primary, replace_snd_queue_id, smp_client_version, server_key_hash) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?); - |] - ((host server, port server, sndId, sndSecure, connId', sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret) :. (status, qId, primary, dbReplaceQueueId, smpClientVersion, serverKeyHash_)) - pure (sq :: NewSndQueue) {connId = connId', dbQueueId = qId} - -newQueueId_ :: [Only Int64] -> DBQueueId 'QSStored -newQueueId_ [] = DBQueueId 1 -newQueueId_ (Only maxId : _) = DBQueueId (maxId + 1) - --- * getConn helpers - -getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) -getConn = getAnyConn False -{-# INLINE getConn #-} - -getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) -getDeletedConn = getAnyConn True -{-# INLINE getDeletedConn #-} - -getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn) -getAnyConn deleted' dbConn connId = - getConnData dbConn connId >>= \case - Nothing -> pure $ Left SEConnNotFound - Just (cData@ConnData {deleted}, cMode) - | deleted /= deleted' -> pure $ Left SEConnNotFound - | otherwise -> do - rQ <- getRcvQueuesByConnId_ dbConn connId - sQ <- getSndQueuesByConnId_ dbConn connId - pure $ case (rQ, sQ, cMode) of - (Just rqs, Just sqs, CMInvitation) -> Right $ SomeConn SCDuplex (DuplexConnection cData rqs sqs) - (Just (rq :| _), Nothing, CMInvitation) -> Right $ SomeConn SCRcv (RcvConnection cData rq) - (Nothing, Just (sq :| _), CMInvitation) -> Right $ SomeConn SCSnd (SndConnection cData sq) - (Just (rq :| _), Nothing, CMContact) -> Right $ SomeConn SCContact (ContactConnection cData rq) - (Nothing, Nothing, _) -> Right $ SomeConn SCNew (NewConnection cData) - _ -> Left SEConnNotFound - -getConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] -getConns = getAnyConns_ False -{-# INLINE getConns #-} - -getDeletedConns :: DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] -getDeletedConns = getAnyConns_ True -{-# INLINE getDeletedConns #-} - -getAnyConns_ :: Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError SomeConn] -getAnyConns_ deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn deleted' db - where - handleDBError :: E.SomeException -> IO (Either StoreError SomeConn) - handleDBError = pure . Left . SEInternal . bshow - -getConnData :: DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) -getConnData db connId' = - maybeFirstRow cData $ - DB.query - db - [sql| - SELECT - user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, - last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support - FROM connections - WHERE conn_id = ? - |] - (Only connId') - where - cData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport) = - (ConnData {userId, connId, connAgentVersion, enableNtfs = fromMaybe True enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) - -setConnDeleted :: DB.Connection -> Bool -> ConnId -> IO () -setConnDeleted db waitDelivery connId - | waitDelivery = do - currentTs <- getCurrentTime - DB.execute db "UPDATE connections SET deleted_at_wait_delivery = ? WHERE conn_id = ?" (currentTs, connId) - | otherwise = - DB.execute db "UPDATE connections SET deleted = ? WHERE conn_id = ?" (True, connId) - -setConnUserId :: DB.Connection -> UserId -> ConnId -> UserId -> IO () -setConnUserId db oldUserId connId newUserId = - DB.execute db "UPDATE connections SET user_id = ? WHERE conn_id = ? and user_id = ?" (newUserId, connId, oldUserId) - -setConnAgentVersion :: DB.Connection -> ConnId -> VersionSMPA -> IO () -setConnAgentVersion db connId aVersion = - DB.execute db "UPDATE connections SET smp_agent_version = ? WHERE conn_id = ?" (aVersion, connId) - -setConnPQSupport :: DB.Connection -> ConnId -> PQSupport -> IO () -setConnPQSupport db connId pqSupport = - DB.execute db "UPDATE connections SET pq_support = ? WHERE conn_id = ?" (pqSupport, connId) - -getDeletedConnIds :: DB.Connection -> IO [ConnId] -getDeletedConnIds db = map fromOnly <$> DB.query db "SELECT conn_id FROM connections WHERE deleted = ?" (Only True) - -getDeletedWaitingDeliveryConnIds :: DB.Connection -> IO [ConnId] -getDeletedWaitingDeliveryConnIds db = - map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted_at_wait_delivery IS NOT NULL" - -setConnRatchetSync :: DB.Connection -> ConnId -> RatchetSyncState -> IO () -setConnRatchetSync db connId ratchetSyncState = - DB.execute db "UPDATE connections SET ratchet_sync_state = ? WHERE conn_id = ?" (ratchetSyncState, connId) - -addProcessedRatchetKeyHash :: DB.Connection -> ConnId -> ByteString -> IO () -addProcessedRatchetKeyHash db connId hash = - DB.execute db "INSERT INTO processed_ratchet_key_hashes (conn_id, hash) VALUES (?,?)" (connId, hash) - -checkRatchetKeyHashExists :: DB.Connection -> ConnId -> ByteString -> IO Bool -checkRatchetKeyHashExists db connId hash = do - fromMaybe False - <$> maybeFirstRow - fromOnly - ( DB.query - db - "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" - (connId, hash) - ) - -deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO () -deleteRatchetKeyHashesExpired db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - DB.execute db "DELETE FROM processed_ratchet_key_hashes WHERE created_at < ?" (Only cutoffTs) - --- | returns all connection queues, the first queue is the primary one -getRcvQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty RcvQueue)) -getRcvQueuesByConnId_ db connId = - L.nonEmpty . sortBy primaryFirst . map toRcvQueue - <$> DB.query db (rcvQueueQuery <> "WHERE q.conn_id = ? AND q.deleted = 0") (Only connId) - where - primaryFirst RcvQueue {primary = p, dbReplaceQueueId = i} RcvQueue {primary = p', dbReplaceQueueId = i'} = - -- the current primary queue is ordered first, the next primary - second - compare (Down p) (Down p') <> compare i i' - -rcvQueueQuery :: Query -rcvQueueQuery = - [sql| - SELECT c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.rcv_id, q.rcv_private_key, q.rcv_dh_secret, - q.e2e_priv_key, q.e2e_dh_secret, q.snd_id, q.snd_secure, q.status, - q.rcv_queue_id, q.rcv_primary, q.replace_rcv_queue_id, q.switch_status, q.smp_client_version, q.delete_errors, - q.ntf_public_key, q.ntf_private_key, q.ntf_id, q.rcv_ntf_dh_secret - FROM rcv_queues q - JOIN servers s ON q.host = s.host AND q.port = s.port - JOIN connections c ON q.conn_id = c.conn_id - |] - -toRcvQueue :: - (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SMP.RecipientId, SMP.RcvPrivateAuthKey, SMP.RcvDhSecret, C.PrivateKeyX25519, Maybe C.DhSecretX25519, SMP.SenderId, SenderCanSecure) - :. (QueueStatus, DBQueueId 'QSStored, Bool, Maybe Int64, Maybe RcvSwitchStatus, Maybe VersionSMPC, Int) - :. (Maybe SMP.NtfPublicAuthKey, Maybe SMP.NtfPrivateAuthKey, Maybe SMP.NotifierId, Maybe RcvNtfDhSecret) -> - RcvQueue -toRcvQueue ((userId, keyHash, connId, host, port, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndSecure) :. (status, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion_, deleteErrors) :. (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_)) = - let server = SMPServer host port keyHash - smpClientVersion = fromMaybe initialSMPClientVersion smpClientVersion_ - clientNtfCreds = case (ntfPublicKey_, ntfPrivateKey_, notifierId_, rcvNtfDhSecret_) of - (Just ntfPublicKey, Just ntfPrivateKey, Just notifierId, Just rcvNtfDhSecret) -> Just $ ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret} - _ -> Nothing - in RcvQueue {userId, connId, server, rcvId, rcvPrivateKey, rcvDhSecret, e2ePrivKey, e2eDhSecret, sndId, sndSecure, status, dbQueueId, primary, dbReplaceQueueId, rcvSwchStatus, smpClientVersion, clientNtfCreds, deleteErrors} - -getRcvQueueById :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError RcvQueue) -getRcvQueueById db connId dbRcvId = - firstRow toRcvQueue SEConnNotFound $ - DB.query db (rcvQueueQuery <> " WHERE q.conn_id = ? AND q.rcv_queue_id = ? AND q.deleted = 0") (connId, dbRcvId) - --- | returns all connection queues, the first queue is the primary one -getSndQueuesByConnId_ :: DB.Connection -> ConnId -> IO (Maybe (NonEmpty SndQueue)) -getSndQueuesByConnId_ dbConn connId = - L.nonEmpty . sortBy primaryFirst . map toSndQueue - <$> DB.query dbConn (sndQueueQuery <> "WHERE q.conn_id = ?") (Only connId) - where - primaryFirst SndQueue {primary = p, dbReplaceQueueId = i} SndQueue {primary = p', dbReplaceQueueId = i'} = - -- the current primary queue is ordered first, the next primary - second - compare (Down p) (Down p') <> compare i i' - -sndQueueQuery :: Query -sndQueueQuery = - [sql| - SELECT - c.user_id, COALESCE(q.server_key_hash, s.key_hash), q.conn_id, q.host, q.port, q.snd_id, q.snd_secure, - q.snd_public_key, q.snd_private_key, q.e2e_pub_key, q.e2e_dh_secret, q.status, - q.snd_queue_id, q.snd_primary, q.replace_snd_queue_id, q.switch_status, q.smp_client_version - FROM snd_queues q - JOIN servers s ON q.host = s.host AND q.port = s.port - JOIN connections c ON q.conn_id = c.conn_id - |] - -toSndQueue :: - (UserId, C.KeyHash, ConnId, NonEmpty TransportHost, ServiceName, SenderId, SenderCanSecure) - :. (Maybe SndPublicAuthKey, SndPrivateAuthKey, Maybe C.PublicKeyX25519, C.DhSecretX25519, QueueStatus) - :. (DBQueueId 'QSStored, Bool, Maybe Int64, Maybe SndSwitchStatus, VersionSMPC) -> - SndQueue -toSndQueue - ( (userId, keyHash, connId, host, port, sndId, sndSecure) - :. (sndPubKey, sndPrivateKey@(C.APrivateAuthKey a pk), e2ePubKey, e2eDhSecret, status) - :. (dbQueueId, primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion) - ) = - let server = SMPServer host port keyHash - sndPublicKey = fromMaybe (C.APublicAuthKey a (C.publicKey pk)) sndPubKey - in SndQueue {userId, connId, server, sndId, sndSecure, sndPublicKey, sndPrivateKey, e2ePubKey, e2eDhSecret, status, dbQueueId, primary, dbReplaceQueueId, sndSwchStatus, smpClientVersion} - -getSndQueueById :: DB.Connection -> ConnId -> Int64 -> IO (Either StoreError SndQueue) -getSndQueueById db connId dbSndId = - firstRow toSndQueue SEConnNotFound $ - DB.query db (sndQueueQuery <> " WHERE q.conn_id = ? AND q.snd_queue_id = ?") (connId, dbSndId) - --- * updateRcvIds helpers - -retrieveLastIdsAndHashRcv_ :: DB.Connection -> ConnId -> IO (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash) -retrieveLastIdsAndHashRcv_ dbConn connId = do - [(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <- - DB.queryNamed - dbConn - [sql| - SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash - FROM connections - WHERE conn_id = :conn_id; - |] - [":conn_id" := connId] - return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) - -updateLastIdsRcv_ :: DB.Connection -> ConnId -> InternalId -> InternalRcvId -> IO () -updateLastIdsRcv_ dbConn connId newInternalId newInternalRcvId = - DB.executeNamed - dbConn - [sql| - UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, - last_internal_rcv_msg_id = :last_internal_rcv_msg_id - WHERE conn_id = :conn_id; - |] - [ ":last_internal_msg_id" := newInternalId, - ":last_internal_rcv_msg_id" := newInternalRcvId, - ":conn_id" := connId - ] - --- * createRcvMsg helpers - -insertRcvMsgBase_ :: DB.Connection -> ConnId -> RcvMsgData -> IO () -insertRcvMsgBase_ dbConn connId RcvMsgData {msgMeta, msgType, msgFlags, msgBody, internalRcvId} = do - let MsgMeta {recipient = (internalId, internalTs), pqEncryption} = msgMeta - DB.execute - dbConn - [sql| - INSERT INTO messages - (conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption) - VALUES (?,?,?,?,?,?,?,?,?); - |] - (connId, internalId, internalTs, internalRcvId, Nothing :: Maybe Int64, msgType, msgFlags, msgBody, pqEncryption) - -insertRcvMsgDetails_ :: DB.Connection -> ConnId -> RcvQueue -> RcvMsgData -> IO () -insertRcvMsgDetails_ db connId RcvQueue {dbQueueId} RcvMsgData {msgMeta, internalRcvId, internalHash, externalPrevSndHash, encryptedMsgHash} = do - let MsgMeta {integrity, recipient, broker, sndMsgId} = msgMeta - DB.executeNamed - db - [sql| - INSERT INTO rcv_messages - ( conn_id, rcv_queue_id, internal_rcv_id, internal_id, external_snd_id, - broker_id, broker_ts, - internal_hash, external_prev_snd_hash, integrity) - VALUES - (:conn_id,:rcv_queue_id,:internal_rcv_id,:internal_id,:external_snd_id, - :broker_id,:broker_ts, - :internal_hash,:external_prev_snd_hash,:integrity); - |] - [ ":conn_id" := connId, - ":rcv_queue_id" := dbQueueId, - ":internal_rcv_id" := internalRcvId, - ":internal_id" := fst recipient, - ":external_snd_id" := sndMsgId, - ":broker_id" := fst broker, - ":broker_ts" := snd broker, - ":internal_hash" := internalHash, - ":external_prev_snd_hash" := externalPrevSndHash, - ":integrity" := integrity - ] - DB.execute db "INSERT INTO encrypted_rcv_message_hashes (conn_id, hash) VALUES (?,?)" (connId, encryptedMsgHash) - -updateRcvMsgHash :: DB.Connection -> ConnId -> AgentMsgId -> InternalRcvId -> MsgHash -> IO () -updateRcvMsgHash db connId sndMsgId internalRcvId internalHash = - DB.executeNamed - db - -- last_internal_rcv_msg_id equality check prevents race condition in case next id was reserved - [sql| - UPDATE connections - SET last_external_snd_msg_id = :last_external_snd_msg_id, - last_rcv_msg_hash = :last_rcv_msg_hash - WHERE conn_id = :conn_id - AND last_internal_rcv_msg_id = :last_internal_rcv_msg_id; - |] - [ ":last_external_snd_msg_id" := sndMsgId, - ":last_rcv_msg_hash" := internalHash, - ":conn_id" := connId, - ":last_internal_rcv_msg_id" := internalRcvId - ] - --- * updateSndIds helpers - -retrieveLastIdsAndHashSnd_ :: DB.Connection -> ConnId -> IO (Either StoreError (InternalId, InternalSndId, PrevSndMsgHash)) -retrieveLastIdsAndHashSnd_ dbConn connId = do - firstRow id SEConnNotFound $ - DB.queryNamed - dbConn - [sql| - SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash - FROM connections - WHERE conn_id = :conn_id; - |] - [":conn_id" := connId] - -updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO () -updateLastIdsSnd_ dbConn connId newInternalId newInternalSndId = - DB.executeNamed - dbConn - [sql| - UPDATE connections - SET last_internal_msg_id = :last_internal_msg_id, - last_internal_snd_msg_id = :last_internal_snd_msg_id - WHERE conn_id = :conn_id; - |] - [ ":last_internal_msg_id" := newInternalId, - ":last_internal_snd_msg_id" := newInternalSndId, - ":conn_id" := connId - ] - --- * createSndMsg helpers - -insertSndMsgBase_ :: DB.Connection -> ConnId -> SndMsgData -> IO () -insertSndMsgBase_ db connId SndMsgData {internalId, internalTs, internalSndId, msgType, msgFlags, msgBody, pqEncryption} = do - DB.execute - db - [sql| - INSERT INTO messages - (conn_id, internal_id, internal_ts, internal_rcv_id, internal_snd_id, msg_type, msg_flags, msg_body, pq_encryption) - VALUES - (?,?,?,?,?,?,?,?,?); - |] - (connId, internalId, internalTs, Nothing :: Maybe Int64, internalSndId, msgType, msgFlags, msgBody, pqEncryption) - -insertSndMsgDetails_ :: DB.Connection -> ConnId -> SndMsgData -> IO () -insertSndMsgDetails_ dbConn connId SndMsgData {..} = - DB.executeNamed - dbConn - [sql| - INSERT INTO snd_messages - ( conn_id, internal_snd_id, internal_id, internal_hash, previous_msg_hash) - VALUES - (:conn_id,:internal_snd_id,:internal_id,:internal_hash,:previous_msg_hash); - |] - [ ":conn_id" := connId, - ":internal_snd_id" := internalSndId, - ":internal_id" := internalId, - ":internal_hash" := internalHash, - ":previous_msg_hash" := prevMsgHash - ] - -updateSndMsgHash :: DB.Connection -> ConnId -> InternalSndId -> MsgHash -> IO () -updateSndMsgHash db connId internalSndId internalHash = - DB.executeNamed - db - -- last_internal_snd_msg_id equality check prevents race condition in case next id was reserved - [sql| - UPDATE connections - SET last_snd_msg_hash = :last_snd_msg_hash - WHERE conn_id = :conn_id - AND last_internal_snd_msg_id = :last_internal_snd_msg_id; - |] - [ ":last_snd_msg_hash" := internalHash, - ":conn_id" := connId, - ":last_internal_snd_msg_id" := internalSndId - ] - --- create record with a random ID -createWithRandomId :: TVar ChaChaDRG -> (ByteString -> IO ()) -> IO (Either StoreError ByteString) -createWithRandomId gVar create = fst <$$> createWithRandomId' gVar create - -createWithRandomId' :: forall a. TVar ChaChaDRG -> (ByteString -> IO a) -> IO (Either StoreError (ByteString, a)) -createWithRandomId' gVar create = tryCreate 3 - where - tryCreate :: Int -> IO (Either StoreError (ByteString, a)) - tryCreate 0 = pure $ Left SEUniqueID - tryCreate n = do - id' <- randomId gVar 12 - E.try (create id') >>= \case - Right r -> pure $ Right (id', r) - Left e - | SQL.sqlError e == SQL.ErrorConstraint -> tryCreate (n - 1) - | otherwise -> pure . Left . SEInternal $ bshow e - -randomId :: TVar ChaChaDRG -> Int -> IO ByteString -randomId gVar n = atomically $ U.encode <$> C.randomBytes n gVar - -ntfSubAndSMPAction :: NtfSubAction -> (Maybe NtfSubNTFAction, Maybe NtfSubSMPAction) -ntfSubAndSMPAction (NSANtf action) = (Just action, Nothing) -ntfSubAndSMPAction (NSASMP action) = (Nothing, Just action) - -createXFTPServer_ :: DB.Connection -> XFTPServer -> IO Int64 -createXFTPServer_ db newSrv@ProtocolServer {host, port, keyHash} = - getXFTPServerId_ db newSrv >>= \case - Right srvId -> pure srvId - Left _ -> insertNewServer_ - where - insertNewServer_ = do - DB.execute db "INSERT INTO xftp_servers (xftp_host, xftp_port, xftp_key_hash) VALUES (?,?,?)" (host, port, keyHash) - insertedRowId db - -getXFTPServerId_ :: DB.Connection -> XFTPServer -> IO (Either StoreError Int64) -getXFTPServerId_ db ProtocolServer {host, port, keyHash} = do - firstRow fromOnly SEXFTPServerNotFound $ - DB.query db "SELECT xftp_server_id FROM xftp_servers WHERE xftp_host = ? AND xftp_port = ? AND xftp_key_hash = ?" (host, port, keyHash) - -createRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Bool -> IO (Either StoreError RcvFileId) -createRcvFile db gVar userId fd@FileDescription {chunks} prefixPath tmpPath file approvedRelays = runExceptT $ do - (rcvFileEntityId, rcvFileId) <- ExceptT $ insertRcvFile db gVar userId fd prefixPath tmpPath file Nothing Nothing approvedRelays - liftIO $ - forM_ chunks $ \fc@FileChunk {replicas} -> do - chunkId <- insertRcvFileChunk db fc rcvFileId - forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId - pure rcvFileEntityId - -createRcvFileRedirect :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> FilePath -> CryptoFile -> Bool -> IO (Either StoreError RcvFileId) -createRcvFileRedirect _ _ _ FileDescription {redirect = Nothing} _ _ _ _ _ _ = pure $ Left $ SEInternal "createRcvFileRedirect called without redirect" -createRcvFileRedirect db gVar userId redirectFd@FileDescription {chunks = redirectChunks, redirect = Just RedirectFileInfo {size, digest}} prefixPath redirectPath redirectFile dstPath dstFile approvedRelays = runExceptT $ do - (dstEntityId, dstId) <- ExceptT $ insertRcvFile db gVar userId dummyDst prefixPath dstPath dstFile Nothing Nothing approvedRelays - (_, redirectId) <- ExceptT $ insertRcvFile db gVar userId redirectFd prefixPath redirectPath redirectFile (Just dstId) (Just dstEntityId) approvedRelays - liftIO $ - forM_ redirectChunks $ \fc@FileChunk {replicas} -> do - chunkId <- insertRcvFileChunk db fc redirectId - forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId - pure dstEntityId - where - dummyDst = - FileDescription - { party = SFRecipient, - size, - digest, - redirect = Nothing, - -- updated later with updateRcvFileRedirect - key = C.unsafeSbKey $ B.replicate 32 '#', - nonce = C.cbNonce "", - chunkSize = FileSize 0, - chunks = [] - } - -insertRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Maybe DBRcvFileId -> Maybe RcvFileId -> Bool -> IO (Either StoreError (RcvFileId, DBRcvFileId)) -insertRcvFile db gVar userId FileDescription {size, digest, key, nonce, chunkSize, redirect} prefixPath tmpPath (CryptoFile savePath cfArgs) redirectId_ redirectEntityId_ approvedRelays = runExceptT $ do - let (redirectDigest_, redirectSize_) = case redirect of - Just RedirectFileInfo {digest = d, size = s} -> (Just d, Just s) - Nothing -> (Nothing, Nothing) - rcvFileEntityId <- ExceptT $ - createWithRandomId gVar $ \rcvFileEntityId -> - DB.execute - db - "INSERT INTO rcv_files (rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, redirect_id, redirect_entity_id, redirect_digest, redirect_size, approved_relays) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)" - ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, fileKey <$> cfArgs, fileNonce <$> cfArgs, RFSReceiving, redirectId_, redirectEntityId_, redirectDigest_, redirectSize_, approvedRelays)) - rcvFileId <- liftIO $ insertedRowId db - pure (rcvFileEntityId, rcvFileId) - -insertRcvFileChunk :: DB.Connection -> FileChunk -> DBRcvFileId -> IO Int64 -insertRcvFileChunk db FileChunk {chunkNo, chunkSize, digest} rcvFileId = do - DB.execute - db - "INSERT INTO rcv_file_chunks (rcv_file_id, chunk_no, chunk_size, digest) VALUES (?,?,?,?)" - (rcvFileId, chunkNo, chunkSize, digest) - insertedRowId db - -insertRcvFileChunkReplica :: DB.Connection -> Int -> FileChunkReplica -> Int64 -> IO () -insertRcvFileChunkReplica db replicaNo FileChunkReplica {server, replicaId, replicaKey} chunkId = do - srvId <- createXFTPServer_ db server - DB.execute - db - "INSERT INTO rcv_file_chunk_replicas (replica_number, rcv_file_chunk_id, xftp_server_id, replica_id, replica_key) VALUES (?,?,?,?,?)" - (replicaNo, chunkId, srvId, replicaId, replicaKey) - -getRcvFileByEntityId :: DB.Connection -> RcvFileId -> IO (Either StoreError RcvFile) -getRcvFileByEntityId db rcvFileEntityId = runExceptT $ do - rcvFileId <- ExceptT $ getRcvFileIdByEntityId_ db rcvFileEntityId - ExceptT $ getRcvFile db rcvFileId - -getRcvFileIdByEntityId_ :: DB.Connection -> RcvFileId -> IO (Either StoreError DBRcvFileId) -getRcvFileIdByEntityId_ db rcvFileEntityId = - firstRow fromOnly SEFileNotFound $ - DB.query db "SELECT rcv_file_id FROM rcv_files WHERE rcv_file_entity_id = ?" (Only rcvFileEntityId) - -getRcvFileRedirects :: DB.Connection -> DBRcvFileId -> IO [RcvFile] -getRcvFileRedirects db rcvFileId = do - redirects <- fromOnly <$$> DB.query db "SELECT rcv_file_id FROM rcv_files WHERE redirect_id = ?" (Only rcvFileId) - fmap catMaybes . forM redirects $ getRcvFile db >=> either (const $ pure Nothing) (pure . Just) - -getRcvFile :: DB.Connection -> DBRcvFileId -> IO (Either StoreError RcvFile) -getRcvFile db rcvFileId = runExceptT $ do - f@RcvFile {rcvFileEntityId, userId, tmpPath} <- ExceptT getFile - chunks <- maybe (pure []) (liftIO . getChunks rcvFileEntityId userId) tmpPath - pure (f {chunks} :: RcvFile) - where - getFile :: IO (Either StoreError RcvFile) - getFile = do - firstRow toFile SEFileNotFound $ - DB.query - db - [sql| - SELECT rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, deleted, redirect_id, redirect_entity_id, redirect_size, redirect_digest - FROM rcv_files - WHERE rcv_file_id = ? - |] - (Only rcvFileId) - where - toFile :: (RcvFileId, UserId, FileSize Int64, FileDigest, C.SbKey, C.CbNonce, FileSize Word32, FilePath, Maybe FilePath) :. (FilePath, Maybe C.SbKey, Maybe C.CbNonce, RcvFileStatus, Bool, Maybe DBRcvFileId, Maybe RcvFileId, Maybe (FileSize Int64), Maybe FileDigest) -> RcvFile - toFile ((rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, prefixPath, tmpPath) :. (savePath, saveKey_, saveNonce_, status, deleted, redirectDbId, redirectEntityId, redirectSize_, redirectDigest_)) = - let cfArgs = CFArgs <$> saveKey_ <*> saveNonce_ - saveFile = CryptoFile savePath cfArgs - redirect = - RcvFileRedirect - <$> redirectDbId - <*> redirectEntityId - <*> (RedirectFileInfo <$> redirectSize_ <*> redirectDigest_) - in RcvFile {rcvFileId, rcvFileEntityId, userId, size, digest, key, nonce, chunkSize, redirect, prefixPath, tmpPath, saveFile, status, deleted, chunks = []} - getChunks :: RcvFileId -> UserId -> FilePath -> IO [RcvFileChunk] - getChunks rcvFileEntityId userId fileTmpPath = do - chunks <- - map toChunk - <$> DB.query - db - [sql| - SELECT rcv_file_chunk_id, chunk_no, chunk_size, digest, tmp_path - FROM rcv_file_chunks - WHERE rcv_file_id = ? - |] - (Only rcvFileId) - forM chunks $ \chunk@RcvFileChunk {rcvChunkId} -> do - replicas' <- getChunkReplicas rcvChunkId - pure (chunk {replicas = replicas'} :: RcvFileChunk) - where - toChunk :: (Int64, Int, FileSize Word32, FileDigest, Maybe FilePath) -> RcvFileChunk - toChunk (rcvChunkId, chunkNo, chunkSize, digest, chunkTmpPath) = - RcvFileChunk {rcvFileId, rcvFileEntityId, userId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath, chunkTmpPath, replicas = []} - getChunkReplicas :: Int64 -> IO [RcvFileChunkReplica] - getChunkReplicas chunkId = do - map toReplica - <$> DB.query - db - [sql| - SELECT - r.rcv_file_chunk_replica_id, r.replica_id, r.replica_key, r.received, r.delay, r.retries, - s.xftp_host, s.xftp_port, s.xftp_key_hash - FROM rcv_file_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - WHERE r.rcv_file_chunk_id = ? - |] - (Only chunkId) - where - toReplica :: (Int64, ChunkReplicaId, C.APrivateAuthKey, Bool, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> RcvFileChunkReplica - toReplica (rcvChunkReplicaId, replicaId, replicaKey, received, delay, retries, host, port, keyHash) = - let server = XFTPServer host port keyHash - in RcvFileChunkReplica {rcvChunkReplicaId, server, replicaId, replicaKey, received, delay, retries} - -updateRcvChunkReplicaDelay :: DB.Connection -> Int64 -> Int64 -> IO () -updateRcvChunkReplicaDelay db replicaId delay = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE rcv_file_chunk_replicas SET delay = ?, retries = retries + 1, updated_at = ? WHERE rcv_file_chunk_replica_id = ?" (delay, updatedAt, replicaId) - -updateRcvFileChunkReceived :: DB.Connection -> Int64 -> Int64 -> FilePath -> IO () -updateRcvFileChunkReceived db replicaId chunkId chunkTmpPath = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE rcv_file_chunk_replicas SET received = 1, updated_at = ? WHERE rcv_file_chunk_replica_id = ?" (updatedAt, replicaId) - DB.execute db "UPDATE rcv_file_chunks SET tmp_path = ?, updated_at = ? WHERE rcv_file_chunk_id = ?" (chunkTmpPath, updatedAt, chunkId) - -updateRcvFileStatus :: DB.Connection -> DBRcvFileId -> RcvFileStatus -> IO () -updateRcvFileStatus db rcvFileId status = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE rcv_files SET status = ?, updated_at = ? WHERE rcv_file_id = ?" (status, updatedAt, rcvFileId) - -updateRcvFileError :: DB.Connection -> DBRcvFileId -> String -> IO () -updateRcvFileError db rcvFileId errStr = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE rcv_files SET tmp_path = NULL, error = ?, status = ?, updated_at = ? WHERE rcv_file_id = ?" (errStr, RFSError, updatedAt, rcvFileId) - -updateRcvFileComplete :: DB.Connection -> DBRcvFileId -> IO () -updateRcvFileComplete db rcvFileId = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE rcv_files SET tmp_path = NULL, status = ?, updated_at = ? WHERE rcv_file_id = ?" (RFSComplete, updatedAt, rcvFileId) - -updateRcvFileRedirect :: DB.Connection -> DBRcvFileId -> FileDescription 'FRecipient -> IO (Either StoreError ()) -updateRcvFileRedirect db rcvFileId FileDescription {key, nonce, chunkSize, chunks} = runExceptT $ do - updatedAt <- liftIO getCurrentTime - liftIO $ DB.execute db "UPDATE rcv_files SET key = ?, nonce = ?, chunk_size = ?, updated_at = ? WHERE rcv_file_id = ?" (key, nonce, chunkSize, updatedAt, rcvFileId) - liftIO $ forM_ chunks $ \fc@FileChunk {replicas} -> do - chunkId <- insertRcvFileChunk db fc rcvFileId - forM_ (zip [1 ..] replicas) $ \(rno, replica) -> insertRcvFileChunkReplica db rno replica chunkId - -updateRcvFileNoTmpPath :: DB.Connection -> DBRcvFileId -> IO () -updateRcvFileNoTmpPath db rcvFileId = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE rcv_files SET tmp_path = NULL, updated_at = ? WHERE rcv_file_id = ?" (updatedAt, rcvFileId) - -updateRcvFileDeleted :: DB.Connection -> DBRcvFileId -> IO () -updateRcvFileDeleted db rcvFileId = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE rcv_files SET deleted = 1, updated_at = ? WHERE rcv_file_id = ?" (updatedAt, rcvFileId) - -deleteRcvFile' :: DB.Connection -> DBRcvFileId -> IO () -deleteRcvFile' db rcvFileId = - DB.execute db "DELETE FROM rcv_files WHERE rcv_file_id = ?" (Only rcvFileId) - -getNextRcvChunkToDownload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Either StoreError (Maybe (RcvFileChunk, Bool, Maybe RcvFileId))) -getNextRcvChunkToDownload db server@ProtocolServer {host, port, keyHash} ttl = do - getWorkItem "rcv_file_download" getReplicaId getChunkData (markRcvFileFailed db . snd) - where - getReplicaId :: IO (Maybe (Int64, DBRcvFileId)) - getReplicaId = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - maybeFirstRow id $ - DB.query - db - [sql| - SELECT r.rcv_file_chunk_replica_id, f.rcv_file_id - FROM rcv_file_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - JOIN rcv_file_chunks c ON c.rcv_file_chunk_id = r.rcv_file_chunk_id - JOIN rcv_files f ON f.rcv_file_id = c.rcv_file_id - WHERE s.xftp_host = ? AND s.xftp_port = ? AND s.xftp_key_hash = ? - AND r.received = 0 AND r.replica_number = 1 - AND f.status = ? AND f.deleted = 0 AND f.created_at >= ? - AND f.failed = 0 - ORDER BY r.retries ASC, r.created_at ASC - LIMIT 1 - |] - (host, port, keyHash, RFSReceiving, cutoffTs) - getChunkData :: (Int64, DBRcvFileId) -> IO (Either StoreError (RcvFileChunk, Bool, Maybe RcvFileId)) - getChunkData (rcvFileChunkReplicaId, _fileId) = - firstRow toChunk SEFileNotFound $ - DB.query - db - [sql| - SELECT - f.rcv_file_id, f.rcv_file_entity_id, f.user_id, c.rcv_file_chunk_id, c.chunk_no, c.chunk_size, c.digest, f.tmp_path, c.tmp_path, - r.rcv_file_chunk_replica_id, r.replica_id, r.replica_key, r.received, r.delay, r.retries, - f.approved_relays, f.redirect_entity_id - FROM rcv_file_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - JOIN rcv_file_chunks c ON c.rcv_file_chunk_id = r.rcv_file_chunk_id - JOIN rcv_files f ON f.rcv_file_id = c.rcv_file_id - WHERE r.rcv_file_chunk_replica_id = ? - |] - (Only rcvFileChunkReplicaId) - where - toChunk :: ((DBRcvFileId, RcvFileId, UserId, Int64, Int, FileSize Word32, FileDigest, FilePath, Maybe FilePath) :. (Int64, ChunkReplicaId, C.APrivateAuthKey, Bool, Maybe Int64, Int) :. (Bool, Maybe RcvFileId)) -> (RcvFileChunk, Bool, Maybe RcvFileId) - toChunk ((rcvFileId, rcvFileEntityId, userId, rcvChunkId, chunkNo, chunkSize, digest, fileTmpPath, chunkTmpPath) :. (rcvChunkReplicaId, replicaId, replicaKey, received, delay, retries) :. (approvedRelays, redirectEntityId_)) = - ( RcvFileChunk - { rcvFileId, - rcvFileEntityId, - userId, - rcvChunkId, - chunkNo, - chunkSize, - digest, - fileTmpPath, - chunkTmpPath, - replicas = [RcvFileChunkReplica {rcvChunkReplicaId, server, replicaId, replicaKey, received, delay, retries}] - }, - approvedRelays, - redirectEntityId_ - ) - -getNextRcvFileToDecrypt :: DB.Connection -> NominalDiffTime -> IO (Either StoreError (Maybe RcvFile)) -getNextRcvFileToDecrypt db ttl = - getWorkItem "rcv_file_decrypt" getFileId (getRcvFile db) (markRcvFileFailed db) - where - getFileId :: IO (Maybe DBRcvFileId) - getFileId = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - maybeFirstRow fromOnly $ - DB.query - db - [sql| - SELECT rcv_file_id - FROM rcv_files - WHERE status IN (?,?) AND deleted = 0 AND created_at >= ? - AND failed = 0 - ORDER BY created_at ASC LIMIT 1 - |] - (RFSReceived, RFSDecrypting, cutoffTs) - -markRcvFileFailed :: DB.Connection -> DBRcvFileId -> IO () -markRcvFileFailed db fileId = do - DB.execute db "UPDATE rcv_files SET failed = 1 WHERE rcv_file_id = ?" (Only fileId) - -getPendingRcvFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer] -getPendingRcvFilesServers db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - map toXFTPServer - <$> DB.query - db - [sql| - SELECT DISTINCT - s.xftp_host, s.xftp_port, s.xftp_key_hash - FROM rcv_file_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - JOIN rcv_file_chunks c ON c.rcv_file_chunk_id = r.rcv_file_chunk_id - JOIN rcv_files f ON f.rcv_file_id = c.rcv_file_id - WHERE r.received = 0 AND r.replica_number = 1 - AND f.status = ? AND f.deleted = 0 AND f.created_at >= ? - |] - (RFSReceiving, cutoffTs) - -toXFTPServer :: (NonEmpty TransportHost, ServiceName, C.KeyHash) -> XFTPServer -toXFTPServer (host, port, keyHash) = XFTPServer host port keyHash - -getCleanupRcvFilesTmpPaths :: DB.Connection -> IO [(DBRcvFileId, RcvFileId, FilePath)] -getCleanupRcvFilesTmpPaths db = - DB.query - db - [sql| - SELECT rcv_file_id, rcv_file_entity_id, tmp_path - FROM rcv_files - WHERE status IN (?,?) AND tmp_path IS NOT NULL - |] - (RFSComplete, RFSError) - -getCleanupRcvFilesDeleted :: DB.Connection -> IO [(DBRcvFileId, RcvFileId, FilePath)] -getCleanupRcvFilesDeleted db = - DB.query_ - db - [sql| - SELECT rcv_file_id, rcv_file_entity_id, prefix_path - FROM rcv_files - WHERE deleted = 1 - |] - -getRcvFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBRcvFileId, RcvFileId, FilePath)] -getRcvFilesExpired db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - DB.query - db - [sql| - SELECT rcv_file_id, rcv_file_entity_id, prefix_path - FROM rcv_files - WHERE created_at < ? - |] - (Only cutoffTs) - -createSndFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> CryptoFile -> Int -> FilePath -> C.SbKey -> C.CbNonce -> Maybe RedirectFileInfo -> IO (Either StoreError SndFileId) -createSndFile db gVar userId (CryptoFile path cfArgs) numRecipients prefixPath key nonce redirect_ = - createWithRandomId gVar $ \sndFileEntityId -> - DB.execute - db - "INSERT INTO snd_files (snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, prefix_path, key, nonce, status, redirect_size, redirect_digest) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)" - ((sndFileEntityId, userId, path, fileKey <$> cfArgs, fileNonce <$> cfArgs, numRecipients) :. (prefixPath, key, nonce, SFSNew, redirectSize_, redirectDigest_)) - where - (redirectSize_, redirectDigest_) = - case redirect_ of - Nothing -> (Nothing, Nothing) - Just RedirectFileInfo {size, digest} -> (Just size, Just digest) - -getSndFileByEntityId :: DB.Connection -> SndFileId -> IO (Either StoreError SndFile) -getSndFileByEntityId db sndFileEntityId = runExceptT $ do - sndFileId <- ExceptT $ getSndFileIdByEntityId_ db sndFileEntityId - ExceptT $ getSndFile db sndFileId - -getSndFileIdByEntityId_ :: DB.Connection -> SndFileId -> IO (Either StoreError DBSndFileId) -getSndFileIdByEntityId_ db sndFileEntityId = - firstRow fromOnly SEFileNotFound $ - DB.query db "SELECT snd_file_id FROM snd_files WHERE snd_file_entity_id = ?" (Only sndFileEntityId) - -getSndFile :: DB.Connection -> DBSndFileId -> IO (Either StoreError SndFile) -getSndFile db sndFileId = runExceptT $ do - f@SndFile {sndFileEntityId, userId, numRecipients, prefixPath} <- ExceptT getFile - chunks <- maybe (pure []) (liftIO . getChunks sndFileEntityId userId numRecipients) prefixPath - pure (f {chunks} :: SndFile) - where - getFile :: IO (Either StoreError SndFile) - getFile = do - firstRow toFile SEFileNotFound $ - DB.query - db - [sql| - SELECT snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, digest, prefix_path, key, nonce, status, deleted, redirect_size, redirect_digest - FROM snd_files - WHERE snd_file_id = ? - |] - (Only sndFileId) - where - toFile :: (SndFileId, UserId, FilePath, Maybe C.SbKey, Maybe C.CbNonce, Int, Maybe FileDigest, Maybe FilePath, C.SbKey, C.CbNonce) :. (SndFileStatus, Bool, Maybe (FileSize Int64), Maybe FileDigest) -> SndFile - toFile ((sndFileEntityId, userId, srcPath, srcKey_, srcNonce_, numRecipients, digest, prefixPath, key, nonce) :. (status, deleted, redirectSize_, redirectDigest_)) = - let cfArgs = CFArgs <$> srcKey_ <*> srcNonce_ - srcFile = CryptoFile srcPath cfArgs - redirect = RedirectFileInfo <$> redirectSize_ <*> redirectDigest_ - in SndFile {sndFileId, sndFileEntityId, userId, srcFile, numRecipients, digest, prefixPath, key, nonce, status, deleted, redirect, chunks = []} - getChunks :: SndFileId -> UserId -> Int -> FilePath -> IO [SndFileChunk] - getChunks sndFileEntityId userId numRecipients filePrefixPath = do - chunks <- - map toChunk - <$> DB.query - db - [sql| - SELECT snd_file_chunk_id, chunk_no, chunk_offset, chunk_size, digest - FROM snd_file_chunks - WHERE snd_file_id = ? - |] - (Only sndFileId) - forM chunks $ \chunk@SndFileChunk {sndChunkId} -> do - replicas' <- getChunkReplicas sndChunkId - pure (chunk {replicas = replicas'} :: SndFileChunk) - where - toChunk :: (Int64, Int, Int64, Word32, FileDigest) -> SndFileChunk - toChunk (sndChunkId, chunkNo, chunkOffset, chunkSize, digest) = - let chunkSpec = XFTPChunkSpec {filePath = sndFileEncPath filePrefixPath, chunkOffset, chunkSize} - in SndFileChunk {sndFileId, sndFileEntityId, userId, numRecipients, sndChunkId, chunkNo, chunkSpec, filePrefixPath, digest, replicas = []} - getChunkReplicas :: Int64 -> IO [SndFileChunkReplica] - getChunkReplicas chunkId = do - replicas <- - map toReplica - <$> DB.query - db - [sql| - SELECT - r.snd_file_chunk_replica_id, r.replica_id, r.replica_key, r.replica_status, r.delay, r.retries, - s.xftp_host, s.xftp_port, s.xftp_key_hash - FROM snd_file_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - WHERE r.snd_file_chunk_id = ? - |] - (Only chunkId) - forM replicas $ \replica@SndFileChunkReplica {sndChunkReplicaId} -> do - rcvIdsKeys <- getChunkReplicaRecipients_ db sndChunkReplicaId - pure (replica :: SndFileChunkReplica) {rcvIdsKeys} - where - toReplica :: (Int64, ChunkReplicaId, C.APrivateAuthKey, SndFileReplicaStatus, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> SndFileChunkReplica - toReplica (sndChunkReplicaId, replicaId, replicaKey, replicaStatus, delay, retries, host, port, keyHash) = - let server = XFTPServer host port keyHash - in SndFileChunkReplica {sndChunkReplicaId, server, replicaId, replicaKey, replicaStatus, delay, retries, rcvIdsKeys = []} - -getChunkReplicaRecipients_ :: DB.Connection -> Int64 -> IO [(ChunkReplicaId, C.APrivateAuthKey)] -getChunkReplicaRecipients_ db replicaId = - DB.query - db - [sql| - SELECT rcv_replica_id, rcv_replica_key - FROM snd_file_chunk_replica_recipients - WHERE snd_file_chunk_replica_id = ? - |] - (Only replicaId) - -getNextSndFileToPrepare :: DB.Connection -> NominalDiffTime -> IO (Either StoreError (Maybe SndFile)) -getNextSndFileToPrepare db ttl = - getWorkItem "snd_file_prepare" getFileId (getSndFile db) (markSndFileFailed db) - where - getFileId :: IO (Maybe DBSndFileId) - getFileId = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - maybeFirstRow fromOnly $ - DB.query - db - [sql| - SELECT snd_file_id - FROM snd_files - WHERE status IN (?,?,?) AND deleted = 0 AND created_at >= ? - AND failed = 0 - ORDER BY created_at ASC LIMIT 1 - |] - (SFSNew, SFSEncrypting, SFSEncrypted, cutoffTs) - -markSndFileFailed :: DB.Connection -> DBSndFileId -> IO () -markSndFileFailed db fileId = - DB.execute db "UPDATE snd_files SET failed = 1 WHERE snd_file_id = ?" (Only fileId) - -updateSndFileError :: DB.Connection -> DBSndFileId -> String -> IO () -updateSndFileError db sndFileId errStr = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE snd_files SET prefix_path = NULL, error = ?, status = ?, updated_at = ? WHERE snd_file_id = ?" (errStr, SFSError, updatedAt, sndFileId) - -updateSndFileStatus :: DB.Connection -> DBSndFileId -> SndFileStatus -> IO () -updateSndFileStatus db sndFileId status = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE snd_files SET status = ?, updated_at = ? WHERE snd_file_id = ?" (status, updatedAt, sndFileId) - -updateSndFileEncrypted :: DB.Connection -> DBSndFileId -> FileDigest -> [(XFTPChunkSpec, FileDigest)] -> IO () -updateSndFileEncrypted db sndFileId digest chunkSpecsDigests = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE snd_files SET status = ?, digest = ?, updated_at = ? WHERE snd_file_id = ?" (SFSEncrypted, digest, updatedAt, sndFileId) - forM_ (zip [1 ..] chunkSpecsDigests) $ \(chunkNo :: Int, (XFTPChunkSpec {chunkOffset, chunkSize}, chunkDigest)) -> - DB.execute db "INSERT INTO snd_file_chunks (snd_file_id, chunk_no, chunk_offset, chunk_size, digest) VALUES (?,?,?,?,?)" (sndFileId, chunkNo, chunkOffset, chunkSize, chunkDigest) - -updateSndFileComplete :: DB.Connection -> DBSndFileId -> IO () -updateSndFileComplete db sndFileId = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE snd_files SET prefix_path = NULL, status = ?, updated_at = ? WHERE snd_file_id = ?" (SFSComplete, updatedAt, sndFileId) - -updateSndFileNoPrefixPath :: DB.Connection -> DBSndFileId -> IO () -updateSndFileNoPrefixPath db sndFileId = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE snd_files SET prefix_path = NULL, updated_at = ? WHERE snd_file_id = ?" (updatedAt, sndFileId) - -updateSndFileDeleted :: DB.Connection -> DBSndFileId -> IO () -updateSndFileDeleted db sndFileId = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE snd_files SET deleted = 1, updated_at = ? WHERE snd_file_id = ?" (updatedAt, sndFileId) - -deleteSndFile' :: DB.Connection -> DBSndFileId -> IO () -deleteSndFile' db sndFileId = - DB.execute db "DELETE FROM snd_files WHERE snd_file_id = ?" (Only sndFileId) - -getSndFileDeleted :: DB.Connection -> DBSndFileId -> IO Bool -getSndFileDeleted db sndFileId = - fromMaybe True - <$> maybeFirstRow fromOnly (DB.query db "SELECT deleted FROM snd_files WHERE snd_file_id = ?" (Only sndFileId)) - -createSndFileReplica :: DB.Connection -> SndFileChunk -> NewSndChunkReplica -> IO () -createSndFileReplica db SndFileChunk {sndChunkId} = createSndFileReplica_ db sndChunkId - -createSndFileReplica_ :: DB.Connection -> Int64 -> NewSndChunkReplica -> IO () -createSndFileReplica_ db sndChunkId NewSndChunkReplica {server, replicaId, replicaKey, rcvIdsKeys} = do - srvId <- createXFTPServer_ db server - DB.execute - db - [sql| - INSERT INTO snd_file_chunk_replicas - (snd_file_chunk_id, replica_number, xftp_server_id, replica_id, replica_key, replica_status) - VALUES (?,?,?,?,?,?) - |] - (sndChunkId, 1 :: Int, srvId, replicaId, replicaKey, SFRSCreated) - rId <- insertedRowId db - forM_ rcvIdsKeys $ \(rcvId, rcvKey) -> do - DB.execute - db - [sql| - INSERT INTO snd_file_chunk_replica_recipients - (snd_file_chunk_replica_id, rcv_replica_id, rcv_replica_key) - VALUES (?,?,?) - |] - (rId, rcvId, rcvKey) - -getNextSndChunkToUpload :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Either StoreError (Maybe SndFileChunk)) -getNextSndChunkToUpload db server@ProtocolServer {host, port, keyHash} ttl = do - getWorkItem "snd_file_upload" getReplicaId getChunkData (markSndFileFailed db . snd) - where - getReplicaId :: IO (Maybe (Int64, DBSndFileId)) - getReplicaId = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - maybeFirstRow id $ - DB.query - db - [sql| - SELECT r.snd_file_chunk_replica_id, f.snd_file_id - FROM snd_file_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - JOIN snd_file_chunks c ON c.snd_file_chunk_id = r.snd_file_chunk_id - JOIN snd_files f ON f.snd_file_id = c.snd_file_id - WHERE s.xftp_host = ? AND s.xftp_port = ? AND s.xftp_key_hash = ? - AND r.replica_status = ? AND r.replica_number = 1 - AND (f.status = ? OR f.status = ?) AND f.deleted = 0 AND f.created_at >= ? - AND f.failed = 0 - ORDER BY r.retries ASC, r.created_at ASC - LIMIT 1 - |] - (host, port, keyHash, SFRSCreated, SFSEncrypted, SFSUploading, cutoffTs) - getChunkData :: (Int64, DBSndFileId) -> IO (Either StoreError SndFileChunk) - getChunkData (sndFileChunkReplicaId, _fileId) = do - chunk_ <- - firstRow toChunk SEFileNotFound $ - DB.query - db - [sql| - SELECT - f.snd_file_id, f.snd_file_entity_id, f.user_id, f.num_recipients, f.prefix_path, - c.snd_file_chunk_id, c.chunk_no, c.chunk_offset, c.chunk_size, c.digest, - r.snd_file_chunk_replica_id, r.replica_id, r.replica_key, r.replica_status, r.delay, r.retries - FROM snd_file_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - JOIN snd_file_chunks c ON c.snd_file_chunk_id = r.snd_file_chunk_id - JOIN snd_files f ON f.snd_file_id = c.snd_file_id - WHERE r.snd_file_chunk_replica_id = ? - |] - (Only sndFileChunkReplicaId) - forM chunk_ $ \chunk@SndFileChunk {replicas} -> do - replicas' <- forM replicas $ \replica@SndFileChunkReplica {sndChunkReplicaId} -> do - rcvIdsKeys <- getChunkReplicaRecipients_ db sndChunkReplicaId - pure (replica :: SndFileChunkReplica) {rcvIdsKeys} - pure (chunk {replicas = replicas'} :: SndFileChunk) - where - toChunk :: ((DBSndFileId, SndFileId, UserId, Int, FilePath) :. (Int64, Int, Int64, Word32, FileDigest) :. (Int64, ChunkReplicaId, C.APrivateAuthKey, SndFileReplicaStatus, Maybe Int64, Int)) -> SndFileChunk - toChunk ((sndFileId, sndFileEntityId, userId, numRecipients, filePrefixPath) :. (sndChunkId, chunkNo, chunkOffset, chunkSize, digest) :. (sndChunkReplicaId, replicaId, replicaKey, replicaStatus, delay, retries)) = - let chunkSpec = XFTPChunkSpec {filePath = sndFileEncPath filePrefixPath, chunkOffset, chunkSize} - in SndFileChunk - { sndFileId, - sndFileEntityId, - userId, - numRecipients, - sndChunkId, - chunkNo, - chunkSpec, - digest, - filePrefixPath, - replicas = [SndFileChunkReplica {sndChunkReplicaId, server, replicaId, replicaKey, replicaStatus, delay, retries, rcvIdsKeys = []}] - } - -updateSndChunkReplicaDelay :: DB.Connection -> Int64 -> Int64 -> IO () -updateSndChunkReplicaDelay db replicaId delay = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE snd_file_chunk_replicas SET delay = ?, retries = retries + 1, updated_at = ? WHERE snd_file_chunk_replica_id = ?" (delay, updatedAt, replicaId) - -addSndChunkReplicaRecipients :: DB.Connection -> SndFileChunkReplica -> [(ChunkReplicaId, C.APrivateAuthKey)] -> IO SndFileChunkReplica -addSndChunkReplicaRecipients db r@SndFileChunkReplica {sndChunkReplicaId} rcvIdsKeys = do - forM_ rcvIdsKeys $ \(rcvId, rcvKey) -> do - DB.execute - db - [sql| - INSERT INTO snd_file_chunk_replica_recipients - (snd_file_chunk_replica_id, rcv_replica_id, rcv_replica_key) - VALUES (?,?,?) - |] - (sndChunkReplicaId, rcvId, rcvKey) - rcvIdsKeys' <- getChunkReplicaRecipients_ db sndChunkReplicaId - pure (r :: SndFileChunkReplica) {rcvIdsKeys = rcvIdsKeys'} - -updateSndChunkReplicaStatus :: DB.Connection -> Int64 -> SndFileReplicaStatus -> IO () -updateSndChunkReplicaStatus db replicaId status = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE snd_file_chunk_replicas SET replica_status = ?, updated_at = ? WHERE snd_file_chunk_replica_id = ?" (status, updatedAt, replicaId) - -getPendingSndFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer] -getPendingSndFilesServers db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - map toXFTPServer - <$> DB.query - db - [sql| - SELECT DISTINCT - s.xftp_host, s.xftp_port, s.xftp_key_hash - FROM snd_file_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - JOIN snd_file_chunks c ON c.snd_file_chunk_id = r.snd_file_chunk_id - JOIN snd_files f ON f.snd_file_id = c.snd_file_id - WHERE r.replica_status = ? AND r.replica_number = 1 - AND (f.status = ? OR f.status = ?) AND f.deleted = 0 AND f.created_at >= ? - |] - (SFRSCreated, SFSEncrypted, SFSUploading, cutoffTs) - -getCleanupSndFilesPrefixPaths :: DB.Connection -> IO [(DBSndFileId, SndFileId, FilePath)] -getCleanupSndFilesPrefixPaths db = - DB.query - db - [sql| - SELECT snd_file_id, snd_file_entity_id, prefix_path - FROM snd_files - WHERE status IN (?,?) AND prefix_path IS NOT NULL - |] - (SFSComplete, SFSError) - -getCleanupSndFilesDeleted :: DB.Connection -> IO [(DBSndFileId, SndFileId, Maybe FilePath)] -getCleanupSndFilesDeleted db = - DB.query_ - db - [sql| - SELECT snd_file_id, snd_file_entity_id, prefix_path - FROM snd_files - WHERE deleted = 1 - |] - -getSndFilesExpired :: DB.Connection -> NominalDiffTime -> IO [(DBSndFileId, SndFileId, Maybe FilePath)] -getSndFilesExpired db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - DB.query - db - [sql| - SELECT snd_file_id, snd_file_entity_id, prefix_path - FROM snd_files - WHERE created_at < ? - |] - (Only cutoffTs) - -createDeletedSndChunkReplica :: DB.Connection -> UserId -> FileChunkReplica -> FileDigest -> IO () -createDeletedSndChunkReplica db userId FileChunkReplica {server, replicaId, replicaKey} chunkDigest = do - srvId <- createXFTPServer_ db server - DB.execute - db - "INSERT INTO deleted_snd_chunk_replicas (user_id, xftp_server_id, replica_id, replica_key, chunk_digest) VALUES (?,?,?,?,?)" - (userId, srvId, replicaId, replicaKey, chunkDigest) - -getDeletedSndChunkReplica :: DB.Connection -> DBSndFileId -> IO (Either StoreError DeletedSndChunkReplica) -getDeletedSndChunkReplica db deletedSndChunkReplicaId = - firstRow toReplica SEDeletedSndChunkReplicaNotFound $ - DB.query - db - [sql| - SELECT - r.user_id, r.replica_id, r.replica_key, r.chunk_digest, r.delay, r.retries, - s.xftp_host, s.xftp_port, s.xftp_key_hash - FROM deleted_snd_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - WHERE r.deleted_snd_chunk_replica_id = ? - |] - (Only deletedSndChunkReplicaId) - where - toReplica :: (UserId, ChunkReplicaId, C.APrivateAuthKey, FileDigest, Maybe Int64, Int, NonEmpty TransportHost, ServiceName, C.KeyHash) -> DeletedSndChunkReplica - toReplica (userId, replicaId, replicaKey, chunkDigest, delay, retries, host, port, keyHash) = - let server = XFTPServer host port keyHash - in DeletedSndChunkReplica {deletedSndChunkReplicaId, userId, server, replicaId, replicaKey, chunkDigest, delay, retries} - -getNextDeletedSndChunkReplica :: DB.Connection -> XFTPServer -> NominalDiffTime -> IO (Either StoreError (Maybe DeletedSndChunkReplica)) -getNextDeletedSndChunkReplica db ProtocolServer {host, port, keyHash} ttl = - getWorkItem "deleted replica" getReplicaId (getDeletedSndChunkReplica db) markReplicaFailed - where - getReplicaId :: IO (Maybe Int64) - getReplicaId = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - maybeFirstRow fromOnly $ - DB.query - db - [sql| - SELECT r.deleted_snd_chunk_replica_id - FROM deleted_snd_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - WHERE s.xftp_host = ? AND s.xftp_port = ? AND s.xftp_key_hash = ? - AND r.created_at >= ? - AND failed = 0 - ORDER BY r.retries ASC, r.created_at ASC - LIMIT 1 - |] - (host, port, keyHash, cutoffTs) - markReplicaFailed :: Int64 -> IO () - markReplicaFailed replicaId = do - DB.execute db "UPDATE deleted_snd_chunk_replicas SET failed = 1 WHERE deleted_snd_chunk_replica_id = ?" (Only replicaId) - -updateDeletedSndChunkReplicaDelay :: DB.Connection -> Int64 -> Int64 -> IO () -updateDeletedSndChunkReplicaDelay db deletedSndChunkReplicaId delay = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE deleted_snd_chunk_replicas SET delay = ?, retries = retries + 1, updated_at = ? WHERE deleted_snd_chunk_replica_id = ?" (delay, updatedAt, deletedSndChunkReplicaId) - -deleteDeletedSndChunkReplica :: DB.Connection -> Int64 -> IO () -deleteDeletedSndChunkReplica db deletedSndChunkReplicaId = - DB.execute db "DELETE FROM deleted_snd_chunk_replicas WHERE deleted_snd_chunk_replica_id = ?" (Only deletedSndChunkReplicaId) - -getPendingDelFilesServers :: DB.Connection -> NominalDiffTime -> IO [XFTPServer] -getPendingDelFilesServers db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - map toXFTPServer - <$> DB.query - db - [sql| - SELECT DISTINCT - s.xftp_host, s.xftp_port, s.xftp_key_hash - FROM deleted_snd_chunk_replicas r - JOIN xftp_servers s ON s.xftp_server_id = r.xftp_server_id - WHERE r.created_at >= ? - |] - (Only cutoffTs) - -deleteDeletedSndChunkReplicasExpired :: DB.Connection -> NominalDiffTime -> IO () -deleteDeletedSndChunkReplicasExpired db ttl = do - cutoffTs <- addUTCTime (-ttl) <$> getCurrentTime - DB.execute db "DELETE FROM deleted_snd_chunk_replicas WHERE created_at < ?" (Only cutoffTs) - -updateServersStats :: DB.Connection -> AgentPersistedServerStats -> IO () -updateServersStats db stats = do - updatedAt <- getCurrentTime - DB.execute db "UPDATE servers_stats SET servers_stats = ?, updated_at = ? WHERE servers_stats_id = 1" (stats, updatedAt) - -getServersStats :: DB.Connection -> IO (Either StoreError (UTCTime, Maybe AgentPersistedServerStats)) -getServersStats db = - firstRow id SEServersStatsNotFound $ - DB.query_ db "SELECT started_at, servers_stats FROM servers_stats WHERE servers_stats_id = 1" - -resetServersStats :: DB.Connection -> UTCTime -> IO () -resetServersStats db startedAt = - DB.execute db "UPDATE servers_stats SET servers_stats = NULL, started_at = ?, updated_at = ? WHERE servers_stats_id = 1" (startedAt, startedAt) - -$(J.deriveJSON defaultJSON ''UpMigration) - -$(J.deriveToJSON (sumTypeJSON $ dropPrefix "ME") ''MigrationError) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs index a7ad47f37..3b0c4d6c8 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Common.hs @@ -4,7 +4,7 @@ {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Agent.Store.SQLite.Common - ( SQLiteStore (..), + ( DBStore (..), withConnection, withConnection', withTransaction, @@ -30,7 +30,7 @@ import UnliftIO.STM storeKey :: ScrubbedBytes -> Bool -> Maybe ScrubbedBytes storeKey key keepKey = if keepKey || BA.null key then Just key else Nothing -data SQLiteStore = SQLiteStore +data DBStore = DBStore { dbFilePath :: FilePath, dbKey :: TVar (Maybe ScrubbedBytes), dbSem :: TVar Int, @@ -39,8 +39,8 @@ data SQLiteStore = SQLiteStore dbNew :: Bool } -withConnectionPriority :: SQLiteStore -> Bool -> (DB.Connection -> IO a) -> IO a -withConnectionPriority SQLiteStore {dbSem, dbConnection} priority action +withConnectionPriority :: DBStore -> Bool -> (DB.Connection -> IO a) -> IO a +withConnectionPriority DBStore {dbSem, dbConnection} priority action | priority = E.bracket_ signal release $ withMVar dbConnection action | otherwise = lowPriority where @@ -50,20 +50,20 @@ withConnectionPriority SQLiteStore {dbSem, dbConnection} priority action wait = unlessM free $ atomically $ unlessM ((0 ==) <$> readTVar dbSem) retry free = (0 ==) <$> readTVarIO dbSem -withConnection :: SQLiteStore -> (DB.Connection -> IO a) -> IO a +withConnection :: DBStore -> (DB.Connection -> IO a) -> IO a withConnection st = withConnectionPriority st False -withConnection' :: SQLiteStore -> (SQL.Connection -> IO a) -> IO a +withConnection' :: DBStore -> (SQL.Connection -> IO a) -> IO a withConnection' st action = withConnection st $ action . DB.conn -withTransaction' :: SQLiteStore -> (SQL.Connection -> IO a) -> IO a +withTransaction' :: DBStore -> (SQL.Connection -> IO a) -> IO a withTransaction' st action = withTransaction st $ action . DB.conn -withTransaction :: SQLiteStore -> (DB.Connection -> IO a) -> IO a +withTransaction :: DBStore -> (DB.Connection -> IO a) -> IO a withTransaction st = withTransactionPriority st False {-# INLINE withTransaction #-} -withTransactionPriority :: SQLiteStore -> Bool -> (DB.Connection -> IO a) -> IO a +withTransactionPriority :: DBStore -> Bool -> (DB.Connection -> IO a) -> IO a withTransactionPriority st priority action = withConnectionPriority st priority $ dbBusyLoop . transaction where transaction db@DB.Connection {conn} = SQL.withImmediateTransaction conn $ action db diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs index 4d9dbeb57..7e8406d5c 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/DB.hs @@ -1,40 +1,51 @@ {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TemplateHaskell #-} module Simplex.Messaging.Agent.Store.SQLite.DB - ( Connection (..), + ( BoolInt (..), + Binary (..), + Connection (..), SlowQueryStats (..), open, close, execute, execute_, - executeNamed, executeMany, query, query_, - queryNamed, ) where import Control.Concurrent.STM -import Control.Monad (when) import Control.Exception +import Control.Monad (when) import qualified Data.Aeson.TH as J +import Data.ByteString (ByteString) import Data.Int (Int64) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Text (Text) import Data.Time (diffUTCTime, getCurrentTime) -import Database.SQLite.Simple (FromRow, NamedParam, Query, ToRow) +import Database.SQLite.Simple (FromRow, Query, ToRow) import qualified Database.SQLite.Simple as SQL +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Parsers (defaultJSON) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM import Simplex.Messaging.Util (diffToMilliseconds, tshow) +newtype BoolInt = BI {unBI :: Bool} + deriving newtype (FromField, ToField) + +newtype Binary = Binary {fromBinary :: ByteString} + deriving newtype (FromField, ToField) + data Connection = Connection { conn :: SQL.Connection, slow :: TMap Query SlowQueryStats @@ -51,9 +62,10 @@ data SlowQueryStats = SlowQueryStats timeIt :: TMap Query SlowQueryStats -> Query -> IO a -> IO a timeIt slow sql a = do t <- getCurrentTime - r <- a `catch` \e -> do - atomically $ TM.alter (Just . updateQueryErrors e) sql slow - throwIO e + r <- + a `catch` \e -> do + atomically $ TM.alter (Just . updateQueryErrors e) sql slow + throwIO e t' <- getCurrentTime let diff = diffToMilliseconds $ diffUTCTime t' t when (diff > 1) $ atomically $ TM.alter (updateQueryStats diff) sql slow @@ -91,10 +103,6 @@ execute_ :: Connection -> Query -> IO () execute_ Connection {conn, slow} sql = timeIt slow sql $ SQL.execute_ conn sql {-# INLINE execute_ #-} -executeNamed :: Connection -> Query -> [NamedParam] -> IO () -executeNamed Connection {conn, slow} sql = timeIt slow sql . SQL.executeNamed conn sql -{-# INLINE executeNamed #-} - executeMany :: ToRow q => Connection -> Query -> [q] -> IO () executeMany Connection {conn, slow} sql = timeIt slow sql . SQL.executeMany conn sql {-# INLINE executeMany #-} @@ -107,8 +115,4 @@ query_ :: FromRow r => Connection -> Query -> IO [r] query_ Connection {conn, slow} sql = timeIt slow sql $ SQL.query_ conn sql {-# INLINE query_ #-} -queryNamed :: FromRow r => Connection -> Query -> [NamedParam] -> IO [r] -queryNamed Connection {conn, slow} sql = timeIt slow sql . SQL.queryNamed conn sql -{-# INLINE queryNamed #-} - $(J.deriveJSON defaultJSON ''SlowQueryStats) diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index ca78dbf42..2f3c6010e 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -5,40 +5,29 @@ {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StrictData #-} -{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} module Simplex.Messaging.Agent.Store.SQLite.Migrations - ( Migration (..), - MigrationsToRun (..), - MTRError (..), - DownMigration (..), - app, + ( app, initialize, - get, run, getCurrent, - mtrErrorDescription, - -- for unit tests - migrationsToRun, - toDownMigration, ) where import Control.Monad (forM_, when) -import qualified Data.Aeson.TH as J -import Data.List (intercalate, sortOn) +import Data.List (sortOn) import Data.List.NonEmpty (NonEmpty) import qualified Data.Map.Strict as M -import Data.Maybe (isNothing, mapMaybe) import Data.Text (Text) import Data.Text.Encoding (decodeLatin1) import Data.Time.Clock (getCurrentTime) -import Database.SQLite.Simple (Connection, Only (..), Query (..)) -import qualified Database.SQLite.Simple as DB +import Database.SQLite.Simple (Only (..), Query (..)) +import qualified Database.SQLite.Simple as SQL import Database.SQLite.Simple.QQ (sql) import qualified Database.SQLite3 as SQLite3 import Simplex.Messaging.Agent.Protocol (extraSMPServerHosts) +import qualified Simplex.Messaging.Agent.Store.DB as DB import Simplex.Messaging.Agent.Store.SQLite.Common import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220101_initial import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220301_snd_queue_keys @@ -77,13 +66,10 @@ import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240702_servers_stats import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20240930_ntf_tokens_to_delete import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20241007_rcv_queues_last_broker_ts import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20241224_ratchet_e2e_snd_params +import Simplex.Messaging.Agent.Store.Shared import Simplex.Messaging.Encoding.String -import Simplex.Messaging.Parsers (dropPrefix, sumTypeJSON) import Simplex.Messaging.Transport.Client (TransportHost) -data Migration = Migration {name :: String, up :: Text, down :: Maybe Text} - deriving (Eq, Show) - schemaMigrations :: [(String, Query, Maybe Query)] schemaMigrations = [ ("20220101_initial", m20220101_initial, Nothing), @@ -131,15 +117,12 @@ app = sortOn name $ map migration schemaMigrations where migration (name, up, down) = Migration {name, up = fromQuery up, down = fromQuery <$> down} -get :: SQLiteStore -> [Migration] -> IO (Either MTRError MigrationsToRun) -get st migrations = migrationsToRun migrations <$> withTransaction' st getCurrent - -getCurrent :: Connection -> IO [Migration] -getCurrent db = map toMigration <$> DB.query_ db "SELECT name, down FROM migrations ORDER BY name ASC;" +getCurrent :: DB.Connection -> IO [Migration] +getCurrent DB.Connection {DB.conn} = map toMigration <$> SQL.query_ conn "SELECT name, down FROM migrations ORDER BY name ASC;" where toMigration (name, down) = Migration {name, up = "", down} -run :: SQLiteStore -> MigrationsToRun -> IO () +run :: DBStore -> MigrationsToRun -> IO () run st = \case MTRUp [] -> pure () MTRUp ms -> mapM_ runUp ms >> withConnection' st (`execSQL` "VACUUM;") @@ -150,27 +133,27 @@ run st = \case when (name == "m20220811_onion_hosts") $ updateServers db insert db >> execSQL db up' where - insert db = DB.execute db "INSERT INTO migrations (name, down, ts) VALUES (?,?,?)" . (name,down,) =<< getCurrentTime + insert db = SQL.execute db "INSERT INTO migrations (name, down, ts) VALUES (?,?,?)" . (name,down,) =<< getCurrentTime up' | dbNew st && name == "m20230110_users" = fromQuery new_m20230110_users | otherwise = up updateServers db = forM_ (M.assocs extraSMPServerHosts) $ \(h, h') -> let hs = decodeLatin1 . strEncode $ ([h, h'] :: NonEmpty TransportHost) - in DB.execute db "UPDATE servers SET host = ? WHERE host = ?" (hs, decodeLatin1 $ strEncode h) + in SQL.execute db "UPDATE servers SET host = ? WHERE host = ?" (hs, decodeLatin1 $ strEncode h) runDown DownMigration {downName, downQuery} = withTransaction' st $ \db -> do execSQL db downQuery - DB.execute db "DELETE FROM migrations WHERE name = ?" (Only downName) - execSQL db = SQLite3.exec $ DB.connectionHandle db + SQL.execute db "DELETE FROM migrations WHERE name = ?" (Only downName) + execSQL db = SQLite3.exec $ SQL.connectionHandle db -initialize :: SQLiteStore -> IO () +initialize :: DBStore -> IO () initialize st = withTransaction' st $ \db -> do - cs :: [Text] <- map fromOnly <$> DB.query_ db "SELECT name FROM pragma_table_info('migrations')" + cs :: [Text] <- map fromOnly <$> SQL.query_ db "SELECT name FROM pragma_table_info('migrations')" case cs of [] -> createMigrations db - _ -> when ("down" `notElem` cs) $ DB.execute_ db "ALTER TABLE migrations ADD COLUMN down TEXT" + _ -> when ("down" `notElem` cs) $ SQL.execute_ db "ALTER TABLE migrations ADD COLUMN down TEXT" where createMigrations db = - DB.execute_ + SQL.execute_ db [sql| CREATE TABLE IF NOT EXISTS migrations ( @@ -180,37 +163,3 @@ initialize st = withTransaction' st $ \db -> do PRIMARY KEY (name) ); |] - -data DownMigration = DownMigration {downName :: String, downQuery :: Text} - deriving (Eq, Show) - -toDownMigration :: Migration -> Maybe DownMigration -toDownMigration Migration {name, down} = DownMigration name <$> down - -data MigrationsToRun = MTRUp [Migration] | MTRDown [DownMigration] | MTRNone - deriving (Eq, Show) - -data MTRError - = MTRENoDown {dbMigrations :: [String]} - | MTREDifferent {appMigration :: String, dbMigration :: String} - deriving (Eq, Show) - -mtrErrorDescription :: MTRError -> String -mtrErrorDescription = \case - MTRENoDown ms -> "database version is newer than the app, but no down migration for: " <> intercalate ", " ms - MTREDifferent a d -> "different migration in the app/database: " <> a <> " / " <> d - -migrationsToRun :: [Migration] -> [Migration] -> Either MTRError MigrationsToRun -migrationsToRun [] [] = Right MTRNone -migrationsToRun appMs [] = Right $ MTRUp appMs -migrationsToRun [] dbMs - | length dms == length dbMs = Right $ MTRDown dms - | otherwise = Left $ MTRENoDown $ mapMaybe nameNoDown dbMs - where - dms = mapMaybe toDownMigration dbMs - nameNoDown m = if isNothing (down m) then Just $ name m else Nothing -migrationsToRun (a : as) (d : ds) - | name a == name d = migrationsToRun as ds - | otherwise = Left $ MTREDifferent (name a) (name d) - -$(J.deriveJSON (sumTypeJSON $ dropPrefix "MTRE") ''MTRError) diff --git a/src/Simplex/Messaging/Agent/Store/Shared.hs b/src/Simplex/Messaging/Agent/Store/Shared.hs new file mode 100644 index 000000000..3921bf586 --- /dev/null +++ b/src/Simplex/Messaging/Agent/Store/Shared.hs @@ -0,0 +1,93 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} + +module Simplex.Messaging.Agent.Store.Shared + ( Migration (..), + MigrationsToRun (..), + DownMigration (..), + MTRError (..), + mtrErrorDescription, + MigrationConfirmation (..), + MigrationError (..), + UpMigration (..), + migrationErrorDescription, + -- for tests + toDownMigration, + upMigration, + ) +where + +import qualified Data.Aeson.TH as J +import qualified Data.Attoparsec.ByteString.Char8 as A +import Data.List (intercalate) +import Data.Maybe (isJust) +import Data.Text (Text) +import Simplex.Messaging.Encoding.String +import Simplex.Messaging.Parsers (defaultJSON, dropPrefix, sumTypeJSON) + +data Migration = Migration {name :: String, up :: Text, down :: Maybe Text} + deriving (Eq, Show) + +data DownMigration = DownMigration {downName :: String, downQuery :: Text} + deriving (Eq, Show) + +toDownMigration :: Migration -> Maybe DownMigration +toDownMigration Migration {name, down} = DownMigration name <$> down + +data MigrationsToRun = MTRUp [Migration] | MTRDown [DownMigration] | MTRNone + deriving (Eq, Show) + +data MTRError + = MTRENoDown {dbMigrations :: [String]} + | MTREDifferent {appMigration :: String, dbMigration :: String} + deriving (Eq, Show) + +mtrErrorDescription :: MTRError -> String +mtrErrorDescription = \case + MTRENoDown ms -> "database version is newer than the app, but no down migration for: " <> intercalate ", " ms + MTREDifferent a d -> "different migration in the app/database: " <> a <> " / " <> d + +data MigrationError + = MEUpgrade {upMigrations :: [UpMigration]} + | MEDowngrade {downMigrations :: [String]} + | MigrationError {mtrError :: MTRError} + deriving (Eq, Show) + +migrationErrorDescription :: MigrationError -> String +migrationErrorDescription = \case + MEUpgrade ums -> + "The app has a newer version than the database.\nConfirm to back up and upgrade using these migrations: " <> intercalate ", " (map upName ums) + MEDowngrade dms -> + "Database version is newer than the app.\nConfirm to back up and downgrade using these migrations: " <> intercalate ", " dms + MigrationError err -> mtrErrorDescription err + +data UpMigration = UpMigration {upName :: String, withDown :: Bool} + deriving (Eq, Show) + +upMigration :: Migration -> UpMigration +upMigration Migration {name, down} = UpMigration name $ isJust down + +data MigrationConfirmation = MCYesUp | MCYesUpDown | MCConsole | MCError + deriving (Eq, Show) + +instance StrEncoding MigrationConfirmation where + strEncode = \case + MCYesUp -> "yesUp" + MCYesUpDown -> "yesUpDown" + MCConsole -> "console" + MCError -> "error" + strP = + A.takeByteString >>= \case + "yesUp" -> pure MCYesUp + "yesUpDown" -> pure MCYesUpDown + "console" -> pure MCConsole + "error" -> pure MCError + _ -> fail "invalid MigrationConfirmation" + +$(J.deriveJSON (sumTypeJSON $ dropPrefix "MTRE") ''MTRError) + +$(J.deriveJSON defaultJSON ''UpMigration) + +$(J.deriveToJSON (sumTypeJSON $ dropPrefix "ME") ''MigrationError) diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 05ba861bc..a955d0d8a 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -1,11 +1,14 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} @@ -233,14 +236,20 @@ import Data.Typeable (Proxy (Proxy), Typeable) import Data.Word (Word32) import Data.X509 import Data.X509.Validation (Fingerprint (..), getFingerprint) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+)) import Network.Transport.Internal (decodeWord16, encodeWord16) +import Simplex.Messaging.Agent.Store.DB (Binary (..)) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers (blobFieldDecoder, parseAll, parseString) import Simplex.Messaging.Util ((<$?>)) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif -- | Cryptographic algorithms. data Algorithm = Ed25519 | Ed448 | X25519 | X448 @@ -721,23 +730,23 @@ generateKeyPair_ = case sAlgorithm @a of let k = X448.toPublic pk in pure (PublicKeyX448 k, PrivateKeyX448 pk k) -instance ToField APrivateSignKey where toField = toField . encodePrivKey +instance ToField APrivateSignKey where toField = toField . Binary . encodePrivKey -instance ToField APublicVerifyKey where toField = toField . encodePubKey +instance ToField APublicVerifyKey where toField = toField . Binary . encodePubKey -instance ToField APrivateAuthKey where toField = toField . encodePrivKey +instance ToField APrivateAuthKey where toField = toField . Binary . encodePrivKey -instance ToField APublicAuthKey where toField = toField . encodePubKey +instance ToField APublicAuthKey where toField = toField . Binary . encodePubKey -instance ToField APrivateDhKey where toField = toField . encodePrivKey +instance ToField APrivateDhKey where toField = toField . Binary . encodePrivKey -instance ToField APublicDhKey where toField = toField . encodePubKey +instance ToField APublicDhKey where toField = toField . Binary . encodePubKey -instance AlgorithmI a => ToField (PrivateKey a) where toField = toField . encodePrivKey +instance AlgorithmI a => ToField (PrivateKey a) where toField = toField . Binary . encodePrivKey -instance AlgorithmI a => ToField (PublicKey a) where toField = toField . encodePubKey +instance AlgorithmI a => ToField (PublicKey a) where toField = toField . Binary . encodePubKey -instance ToField (DhSecret a) where toField = toField . dhBytes' +instance ToField (DhSecret a) where toField = toField . Binary . dhBytes' instance FromField APrivateSignKey where fromField = blobFieldDecoder decodePrivKey @@ -888,10 +897,9 @@ validSignatureSize n = -- | AES key newtype. newtype Key = Key {unKey :: ByteString} deriving (Eq, Ord, Show) + deriving newtype (FromField) -instance ToField Key where toField = toField . unKey - -instance FromField Key where fromField f = Key <$> fromField f +instance ToField Key where toField (Key s) = toField $ Binary s instance ToJSON Key where toJSON = strToJSON . unKey @@ -952,7 +960,7 @@ instance FromJSON KeyHash where instance IsString KeyHash where fromString = parseString $ parseAll strP -instance ToField KeyHash where toField = toField . strEncode +instance ToField KeyHash where toField = toField . Binary . strEncode instance FromField KeyHash where fromField = blobFieldDecoder $ parseAll strP @@ -1162,10 +1170,14 @@ instance SignatureAlgorithmX509 pk => SignatureAlgorithmX509 (a, pk) where newtype SignedObject a = SignedObject {getSignedExact :: SignedExact a} instance (Typeable a, Eq a, Show a, ASN1Object a) => FromField (SignedObject a) where +#if defined(dbPostgres) + fromField f dat = SignedObject <$> blobFieldDecoder decodeSignedObject f dat +#else fromField = fmap SignedObject . blobFieldDecoder decodeSignedObject +#endif instance (Eq a, Show a, ASN1Object a) => ToField (SignedObject a) where - toField (SignedObject s) = toField $ encodeSignedObject s + toField (SignedObject s) = toField . Binary $ encodeSignedObject s instance (Eq a, Show a, ASN1Object a) => Encoding (SignedObject a) where smpEncode (SignedObject exact) = smpEncode . Large $ encodeSignedObject exact @@ -1265,6 +1277,9 @@ cbVerify k pk nonce (CbAuthenticator s) authorized = cbDecryptNoPad (dh' k pk) n newtype CbNonce = CryptoBoxNonce {unCbNonce :: ByteString} deriving (Eq, Show) + deriving newtype (FromField) + +instance ToField CbNonce where toField (CryptoBoxNonce s) = toField $ Binary s pattern CbNonce :: ByteString -> CbNonce pattern CbNonce s <- CryptoBoxNonce s @@ -1282,10 +1297,6 @@ instance ToJSON CbNonce where instance FromJSON CbNonce where parseJSON = strParseJSON "CbNonce" -instance FromField CbNonce where fromField f = CryptoBoxNonce <$> fromField f - -instance ToField CbNonce where toField (CryptoBoxNonce s) = toField s - cbNonce :: ByteString -> CbNonce cbNonce s | len == 24 = CryptoBoxNonce s @@ -1309,6 +1320,9 @@ instance Encoding CbNonce where newtype SbKey = SecretBoxKey {unSbKey :: ByteString} deriving (Eq, Show) + deriving newtype (FromField) + +instance ToField SbKey where toField (SecretBoxKey s) = toField $ Binary s pattern SbKey :: ByteString -> SbKey pattern SbKey s <- SecretBoxKey s @@ -1326,10 +1340,6 @@ instance ToJSON SbKey where instance FromJSON SbKey where parseJSON = strParseJSON "SbKey" -instance FromField SbKey where fromField f = SecretBoxKey <$> fromField f - -instance ToField SbKey where toField (SecretBoxKey s) = toField s - sbKey :: ByteString -> Either String SbKey sbKey s | B.length s == 32 = Right $ SecretBoxKey s diff --git a/src/Simplex/Messaging/Crypto/Ratchet.hs b/src/Simplex/Messaging/Crypto/Ratchet.hs index 1d22843ff..310893de5 100644 --- a/src/Simplex/Messaging/Crypto/Ratchet.hs +++ b/src/Simplex/Messaging/Crypto/Ratchet.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -109,9 +110,8 @@ import Data.Maybe (fromMaybe, isJust) import Data.Type.Equality import Data.Typeable (Typeable) import Data.Word (Word16, Word32) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.QueryString +import Simplex.Messaging.Agent.Store.DB (Binary (..), BoolInt (..)) import Simplex.Messaging.Crypto import Simplex.Messaging.Crypto.SNTRUP761.Bindings import Simplex.Messaging.Encoding @@ -121,6 +121,13 @@ import Simplex.Messaging.Util (($>>=), (<$?>)) import Simplex.Messaging.Version import Simplex.Messaging.Version.Internal import UnliftIO.STM +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif -- e2e encryption headers version history: -- 1 - binary protocol encoding (1/1/2022) @@ -206,7 +213,7 @@ instance Encoding ARKEMParams where 'A' -> ARKP SRKSAccepted .: RKParamsAccepted <$> smpP <*> smpP _ -> fail "bad ratchet KEM params" -instance ToField ARKEMParams where toField = toField . smpEncode +instance ToField ARKEMParams where toField = toField . Binary . smpEncode instance FromField ARKEMParams where fromField = blobFieldDecoder smpDecode @@ -363,7 +370,7 @@ instance Encoding APrivRKEMParams where 'A' -> APRKP SRKSAccepted .:. PrivateRKParamsAccepted <$> smpP <*> smpP <*> smpP _ -> fail "bad APrivRKEMParams" -instance RatchetKEMStateI s => ToField (PrivRKEMParams s) where toField = toField . smpEncode +instance RatchetKEMStateI s => ToField (PrivRKEMParams s) where toField = toField . Binary . smpEncode instance (Typeable s, RatchetKEMStateI s) => FromField (PrivRKEMParams s) where fromField = blobFieldDecoder smpDecode @@ -580,7 +587,7 @@ instance ToJSON RatchetKey where instance FromJSON RatchetKey where parseJSON = fmap RatchetKey . strParseJSON "Key" -instance ToField MessageKey where toField = toField . smpEncode +instance ToField MessageKey where toField = toField . Binary . smpEncode instance FromField MessageKey where fromField = blobFieldDecoder smpDecode @@ -1124,14 +1131,24 @@ instance AlgorithmI a => ToJSON (Ratchet a) where instance AlgorithmI a => FromJSON (Ratchet a) where parseJSON = $(JQ.mkParseJSON defaultJSON ''Ratchet) -instance AlgorithmI a => ToField (Ratchet a) where toField = toField . LB.toStrict . J.encode +instance AlgorithmI a => ToField (Ratchet a) where toField = toField . Binary . LB.toStrict . J.encode instance (AlgorithmI a, Typeable a) => FromField (Ratchet a) where fromField = blobFieldDecoder J.eitherDecodeStrict' -instance ToField PQEncryption where toField (PQEncryption pqEnc) = toField pqEnc +instance ToField PQEncryption where toField (PQEncryption pqEnc) = toField (BI pqEnc) -instance FromField PQEncryption where fromField f = PQEncryption <$> fromField f +instance FromField PQEncryption where +#if defined(dbPostgres) + fromField f dat = PQEncryption . unBI <$> fromField f dat +#else + fromField f = PQEncryption . unBI <$> fromField f +#endif -instance ToField PQSupport where toField (PQSupport pqEnc) = toField pqEnc +instance ToField PQSupport where toField (PQSupport pqEnc) = toField (BI pqEnc) -instance FromField PQSupport where fromField f = PQSupport <$> fromField f +instance FromField PQSupport where +#if defined(dbPostgres) + fromField f dat = PQSupport . unBI <$> fromField f dat +#else + fromField f = PQSupport . unBI <$> fromField f +#endif diff --git a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs index 3b2238086..35e46e3de 100644 --- a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs +++ b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE TypeApplications #-} module Simplex.Messaging.Crypto.SNTRUP761.Bindings where @@ -9,14 +10,19 @@ import Data.Bifunctor (bimap) import Data.ByteArray (ScrubbedBytes) import qualified Data.ByteArray as BA import Data.ByteString (ByteString) -import Database.SQLite.Simple.FromField -import Database.SQLite.Simple.ToField import Foreign (nullPtr) import Simplex.Messaging.Crypto.SNTRUP761.Bindings.Defines import Simplex.Messaging.Crypto.SNTRUP761.Bindings.FFI import Simplex.Messaging.Crypto.SNTRUP761.Bindings.RNG (withDRG) import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField +import Database.PostgreSQL.Simple.ToField +#else +import Database.SQLite.Simple.FromField +import Database.SQLite.Simple.ToField +#endif newtype KEMPublicKey = KEMPublicKey ByteString deriving (Eq, Show) @@ -121,7 +127,11 @@ instance ToField KEMSharedKey where toField (KEMSharedKey k) = toField (BA.convert k :: ByteString) instance FromField KEMSharedKey where +#if defined(dbPostgres) + fromField f dat = KEMSharedKey . BA.convert @ByteString <$> fromField f dat +#else fromField f = KEMSharedKey . BA.convert @ByteString <$> fromField f +#endif instance ToJSON KEMSharedKey where toJSON = strToJSON diff --git a/src/Simplex/Messaging/Notifications/Protocol.hs b/src/Simplex/Messaging/Notifications/Protocol.hs index af8987dcc..96f8b337e 100644 --- a/src/Simplex/Messaging/Notifications/Protocol.hs +++ b/src/Simplex/Messaging/Notifications/Protocol.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} @@ -27,8 +28,6 @@ import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time.Clock.System import Data.Type.Equality import Data.Word (Word16) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.Protocol (updateSMPServerHosts) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding @@ -37,6 +36,13 @@ import Simplex.Messaging.Notifications.Transport (NTFVersion, ntfClientHandshake import Simplex.Messaging.Parsers (fromTextField_) import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..)) import Simplex.Messaging.Util (eitherToMaybe, (<$?>)) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif data NtfEntity = Token | Subscription deriving (Show) diff --git a/src/Simplex/Messaging/Notifications/Types.hs b/src/Simplex/Messaging/Notifications/Types.hs index 774f354bb..dd6e99733 100644 --- a/src/Simplex/Messaging/Notifications/Types.hs +++ b/src/Simplex/Messaging/Notifications/Types.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE LambdaCase #-} @@ -9,14 +10,20 @@ module Simplex.Messaging.Notifications.Types where import qualified Data.Attoparsec.ByteString.Char8 as A import Data.Text.Encoding (decodeLatin1, encodeUtf8) import Data.Time (UTCTime) -import Database.SQLite.Simple.FromField (FromField (..)) -import Database.SQLite.Simple.ToField (ToField (..)) import Simplex.Messaging.Agent.Protocol (ConnId, NotificationsMode (..), UserId) +import Simplex.Messaging.Agent.Store.DB (Binary (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Parsers (blobFieldDecoder, fromTextField_) import Simplex.Messaging.Protocol (NotifierId, NtfServer, SMPServer) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple.FromField (FromField (..)) +import Database.PostgreSQL.Simple.ToField (ToField (..)) +#else +import Database.SQLite.Simple.FromField (FromField (..)) +import Database.SQLite.Simple.ToField (ToField (..)) +#endif data NtfTknAction = NTARegister @@ -41,7 +48,7 @@ instance Encoding NtfTknAction where instance FromField NtfTknAction where fromField = blobFieldDecoder smpDecode -instance ToField NtfTknAction where toField = toField . smpEncode +instance ToField NtfTknAction where toField = toField . Binary . smpEncode data NtfToken = NtfToken { deviceToken :: DeviceToken, @@ -119,7 +126,7 @@ instance Encoding NtfSubNTFAction where instance FromField NtfSubNTFAction where fromField = blobFieldDecoder smpDecode -instance ToField NtfSubNTFAction where toField = toField . smpEncode +instance ToField NtfSubNTFAction where toField = toField . Binary . smpEncode data NtfSubSMPAction = NSASmpKey @@ -138,7 +145,7 @@ instance Encoding NtfSubSMPAction where instance FromField NtfSubSMPAction where fromField = blobFieldDecoder smpDecode -instance ToField NtfSubSMPAction where toField = toField . smpEncode +instance ToField NtfSubSMPAction where toField = toField . Binary . smpEncode data NtfAgentSubStatus = -- | subscription started diff --git a/src/Simplex/Messaging/Parsers.hs b/src/Simplex/Messaging/Parsers.hs index 6ad9f867d..a75efe0ee 100644 --- a/src/Simplex/Messaging/Parsers.hs +++ b/src/Simplex/Messaging/Parsers.hs @@ -1,5 +1,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} @@ -20,12 +21,19 @@ import qualified Data.Text as T import Data.Time.Clock (UTCTime) import Data.Time.ISO8601 (parseISO8601) import Data.Typeable (Typeable) +import Simplex.Messaging.Util (safeDecodeUtf8, (<$?>)) +import Text.Read (readMaybe) +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (ResultError (..)) +import Database.PostgreSQL.Simple.FromField (FromField(..), FieldParser, returnError, Field (..)) +import Database.PostgreSQL.Simple.TypeInfo.Static (textOid, varcharOid) +import qualified Data.Text.Encoding as TE +#else import Database.SQLite.Simple (ResultError (..), SQLData (..)) import Database.SQLite.Simple.FromField (FieldParser, returnError) import Database.SQLite.Simple.Internal (Field (..)) import Database.SQLite.Simple.Ok (Ok (Ok)) -import Simplex.Messaging.Util (safeDecodeUtf8, (<$?>)) -import Text.Read (readMaybe) +#endif base64P :: Parser ByteString base64P = decode <$?> paddedBase64 rawBase64P @@ -77,6 +85,14 @@ parseString p = either error id . p . B.pack blobFieldParser :: Typeable k => Parser k -> FieldParser k blobFieldParser = blobFieldDecoder . parseAll +#if defined(dbPostgres) +blobFieldDecoder :: Typeable k => (ByteString -> Either String k) -> FieldParser k +blobFieldDecoder dec f val = do + x <- fromField f val + case dec x of + Right k -> pure k + Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e) +#else blobFieldDecoder :: Typeable k => (ByteString -> Either String k) -> FieldParser k blobFieldDecoder dec = \case f@(Field (SQLBlob b) _) -> @@ -84,7 +100,20 @@ blobFieldDecoder dec = \case Right k -> Ok k Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e) f -> returnError ConversionFailed f "expecting SQLBlob column type" +#endif +-- TODO [postgres] review +#if defined(dbPostgres) +fromTextField_ :: Typeable a => (Text -> Maybe a) -> FieldParser a +fromTextField_ fromText f val = + if typeOid f `elem` [textOid, varcharOid] + then case val of + Just t -> case fromText (TE.decodeUtf8 t) of + Just x -> pure x + _ -> returnError ConversionFailed f "invalid text value" + Nothing -> returnError UnexpectedNull f "NULL value found for non-NULL field" + else returnError Incompatible f "expecting TEXT or VARCHAR column type" +#else fromTextField_ :: Typeable a => (Text -> Maybe a) -> Field -> Ok a fromTextField_ fromText = \case f@(Field (SQLText t) _) -> @@ -92,6 +121,7 @@ fromTextField_ fromText = \case Just x -> Ok x _ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t) f -> returnError ConversionFailed f "expecting SQLText column type" +#endif fstToLower :: String -> String fstToLower "" = "" diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 56a7fef1f..dff6cd4b0 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -1,12 +1,10 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PostfixOperators #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} module AgentTests (agentTests) where @@ -14,18 +12,30 @@ import AgentTests.ConnectionRequestTests import AgentTests.DoubleRatchetTests (doubleRatchetTests) import AgentTests.FunctionalAPITests (functionalAPITests) import AgentTests.MigrationTests (migrationTests) -import AgentTests.NotificationTests (notificationTests) -import AgentTests.SQLiteTests (storeTests) import AgentTests.ServerChoice (serverChoiceTests) import Simplex.Messaging.Transport (ATransport (..)) import Test.Hspec +#if defined(dbPostgres) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres.Util (dropAllSchemasExceptSystem) +#else +import AgentTests.NotificationTests (notificationTests) +import AgentTests.SQLiteTests (storeTests) +#endif agentTests :: ATransport -> Spec agentTests (ATransport t) = do + describe "Migration tests" migrationTests describe "Connection request" connectionRequestTests describe "Double ratchet tests" doubleRatchetTests +#if defined(dbPostgres) + after_ (dropAllSchemasExceptSystem testDBConnectInfo) $ do + describe "Functional API" $ functionalAPITests (ATransport t) + describe "Chosen servers" serverChoiceTests +#else describe "Functional API" $ functionalAPITests (ATransport t) + describe "Chosen servers" serverChoiceTests + -- notifications aren't tested with postgres, as we don't plan to use iOS client with it describe "Notification tests" $ notificationTests (ATransport t) describe "SQLite store" storeTests - describe "Chosen servers" serverChoiceTests - describe "Migration tests" migrationTests +#endif diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index b281d8001..9c3c5a972 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE FlexibleContexts #-} @@ -74,7 +75,6 @@ import Data.Time.Clock (diffUTCTime, getCurrentTime) import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Data.Type.Equality (testEquality, (:~:) (Refl)) import Data.Word (Word16) -import qualified Database.SQLite.Simple as SQL import GHC.Stack (withFrozenCallStack) import SMPAgentClient import SMPClient (cfg, prevRange, prevVersion, testPort, testPort2, testStoreLogFile2, testStoreMsgsDir2, withSmpServer, withSmpServerConfigOn, withSmpServerProxy, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) @@ -84,8 +84,9 @@ import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestSte import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..), createAgentStore) import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, REQ, SENT) import qualified Simplex.Messaging.Agent.Protocol as A -import Simplex.Messaging.Agent.Store.SQLite (MigrationConfirmation (..), SQLiteStore (dbNew)) -import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') +import Simplex.Messaging.Agent.Store.Common (DBStore (..), withTransaction) +import qualified Simplex.Messaging.Agent.Store.DB as DB +import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..), MigrationError (..)) import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), SMPProxyFallback (..), SMPProxyMode (..), TransportSessionMode (..), defaultClientConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), PQEncryption (..), PQSupport (..), pattern IKPQOff, pattern IKPQOn, pattern PQEncOff, pattern PQEncOn, pattern PQSupportOff, pattern PQSupportOn) @@ -108,6 +109,9 @@ import Test.Hspec import UnliftIO import Util import XFTPClient (testXFTPServer) +#if defined(dbPostgres) +import Fixtures +#endif type AEntityTransmission e = (ACorrId, ConnId, AEvent e) @@ -325,6 +329,8 @@ functionalAPITests t = do it "should expire multiple messages" $ testExpireManyMessages t it "should expire one message if quota is exceeded" $ testExpireMessageQuota t it "should expire multiple messages if quota is exceeded" $ testExpireManyMessagesQuota t +#if !defined(dbPostgres) + -- TODO [postgres] restore from outdated db backup (we use copyFile/renameFile for sqlite) describe "Ratchet synchronization" $ do it "should report ratchet de-synchronization, synchronize ratchets" $ testRatchetSync t @@ -336,6 +342,7 @@ functionalAPITests t = do testRatchetSyncSuspendForeground t it "should synchronize ratchets when clients start synchronization simultaneously" $ testRatchetSyncSimultaneous t +#endif describe "Subscription mode OnlyCreate" $ do it "messages delivered only when polled (v8 - slow handshake)" $ withSmpServer t testOnlyCreatePullSlowHandshake @@ -2561,7 +2568,7 @@ testSwitchAsync servers = do withB :: (AgentClient -> IO a) -> IO a withB = withAgent 2 agentCfg servers testDB2 -withAgent :: HasCallStack => Int -> AgentConfig -> InitialAgentServers -> FilePath -> (HasCallStack => AgentClient -> IO a) -> IO a +withAgent :: HasCallStack => Int -> AgentConfig -> InitialAgentServers -> String -> (HasCallStack => AgentClient -> IO a) -> IO a withAgent clientId cfg' servers dbPath = bracket (getSMPAgentClient' clientId cfg' servers dbPath) (\a -> disposeAgentClient a >> threadDelay 100000) sessionSubscribe :: (forall a. (AgentClient -> IO a) -> IO a) -> [ConnId] -> (AgentClient -> ExceptT AgentErrorType IO ()) -> IO () @@ -3091,13 +3098,27 @@ testTwoUsers = withAgentClients2 $ \a b -> do hasClients :: HasCallStack => AgentClient -> Int -> ExceptT AgentErrorType IO () hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n -getSMPAgentClient' :: Int -> AgentConfig -> InitialAgentServers -> FilePath -> IO AgentClient +getSMPAgentClient' :: Int -> AgentConfig -> InitialAgentServers -> String -> IO AgentClient getSMPAgentClient' clientId cfg' initServers dbPath = do - Right st <- liftIO $ createAgentStore dbPath "" False MCError + Right st <- liftIO $ createStore dbPath c <- getSMPAgentClient_ clientId cfg' initServers st False - when (dbNew st) $ withTransaction' st (`SQL.execute_` "INSERT INTO users (user_id) VALUES (1)") + when (dbNew st) $ insertUser st pure c +#if defined(dbPostgres) +createStore :: String -> IO (Either MigrationError DBStore) +createStore schema = createAgentStore testDBConnectInfo schema MCError + +insertUser :: DBStore -> IO () +insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users DEFAULT VALUES") +#else +createStore :: String -> IO (Either MigrationError DBStore) +createStore dbPath = createAgentStore dbPath "" False MCError + +insertUser :: DBStore -> IO () +insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users (user_id) VALUES (1)") +#endif + testServerMultipleIdentities :: HasCallStack => IO () testServerMultipleIdentities = withAgentClients2 $ \alice bob -> runRight_ $ do diff --git a/tests/AgentTests/MigrationTests.hs b/tests/AgentTests/MigrationTests.hs index 406bdef60..fb8550a7d 100644 --- a/tests/AgentTests/MigrationTests.hs +++ b/tests/AgentTests/MigrationTests.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} module AgentTests.MigrationTests (migrationTests) where @@ -5,13 +6,23 @@ module AgentTests.MigrationTests (migrationTests) where import Control.Monad import Data.Maybe (fromJust) import Data.Word (Word32) +import Simplex.Messaging.Agent.Store.Common (DBStore, withTransaction) +import Simplex.Messaging.Agent.Store.Migrations (migrationsToRun) +import Simplex.Messaging.Agent.Store.Shared +import System.Random (randomIO) +import Test.Hspec +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (fromOnly) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres (closeDBStore, createDBStore) +import Simplex.Messaging.Agent.Store.Postgres.Util (dropSchema) +import qualified Simplex.Messaging.Agent.Store.Postgres.DB as DB +#else import Database.SQLite.Simple (fromOnly) -import Simplex.Messaging.Agent.Store.SQLite (MigrationConfirmation (..), MigrationError (MEDowngrade, MEUpgrade, MigrationError), SQLiteStore, closeSQLiteStore, createSQLiteStore, upMigration, withTransaction) +import Simplex.Messaging.Agent.Store.SQLite (closeDBStore, createDBStore) import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB -import Simplex.Messaging.Agent.Store.SQLite.Migrations import System.Directory (removeFile) -import System.Random (randomIO) -import Test.Hspec +#endif migrationTests :: Spec migrationTests = do @@ -95,9 +106,6 @@ migrationTests = do ([m1, m2, m3, m4], [t1, t2, t3, t4]) ([m1, m2, m4], [MCYesUp, MCYesUpDown, MCError], Left . MigrationError $ MTREDifferent (name m4) (name m3)) -testDB :: FilePath -testDB = "tests/tmp/test_migrations.db" - m1 :: Migration m1 = Migration "20230301-migration1" "create table test1 (id1 integer primary key);" Nothing @@ -177,21 +185,46 @@ testMigration :: IO () testMigration (initMs, initTables) (finalMs, confirmModes, tablesOrError) = forM_ confirmModes $ \confirmMode -> do r <- randomIO :: IO Word32 - let dpPath = testDB <> show r - Right st <- createSQLiteStore dpPath "" False initMs MCError + Right st <- createStore r initMs MCError st `shouldHaveTables` initTables - closeSQLiteStore st + closeDBStore st case tablesOrError of Right tables -> do - Right st' <- createSQLiteStore dpPath "" False finalMs confirmMode + Right st' <- createStore r finalMs confirmMode st' `shouldHaveTables` tables - closeSQLiteStore st' + closeDBStore st' Left e -> do - Left e' <- createSQLiteStore dpPath "" False finalMs confirmMode + Left e' <- createStore r finalMs confirmMode e `shouldBe` e' - removeFile dpPath - where - shouldHaveTables :: SQLiteStore -> [String] -> IO () - st `shouldHaveTables` expected = do - tables <- map fromOnly <$> withTransaction st (`DB.query_` "SELECT name FROM sqlite_schema WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY 1;") - tables `shouldBe` "migrations" : expected + cleanup r + +#if defined(dbPostgres) +testSchema :: Word32 -> String +testSchema randSuffix = "test_migrations_schema" <> show randSuffix + +createStore :: Word32 -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createStore randSuffix migrations confirmMigrations = + createDBStore testDBConnectInfo (testSchema randSuffix) migrations confirmMigrations + +cleanup :: Word32 -> IO () +cleanup randSuffix = dropSchema testDBConnectInfo (testSchema randSuffix) + +shouldHaveTables :: DBStore -> [String] -> IO () +st `shouldHaveTables` expected = do + tables <- map fromOnly <$> withTransaction st (`DB.query_` "SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema() AND table_type = 'BASE TABLE' ORDER BY 1") + tables `shouldBe` "migrations" : expected +#else +testDB :: Word32 -> FilePath +testDB randSuffix = "tests/tmp/test_migrations.db" <> show randSuffix + +createStore :: Word32 -> [Migration] -> MigrationConfirmation -> IO (Either MigrationError DBStore) +createStore randSuffix = createDBStore (testDB randSuffix) "" False + +cleanup :: Word32 -> IO () +cleanup randSuffix = removeFile (testDB randSuffix) + +shouldHaveTables :: DBStore -> [String] -> IO () +st `shouldHaveTables` expected = do + tables <- map fromOnly <$> withTransaction st (`DB.query_` "SELECT name FROM sqlite_schema WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY 1") + tables `shouldBe` "migrations" : expected +#endif diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 1e7cebff9..33e15792e 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -61,7 +61,9 @@ import Simplex.Messaging.Agent hiding (createConnection, joinConnection, sendMes import Simplex.Messaging.Agent.Client (ProtocolTestFailure (..), ProtocolTestStep (..), withStore') import Simplex.Messaging.Agent.Env.SQLite (AgentConfig, Env (..), InitialAgentServers) import Simplex.Messaging.Agent.Protocol hiding (CON, CONF, INFO, SENT) -import Simplex.Messaging.Agent.Store.SQLite (closeSQLiteStore, getSavedNtfToken, reopenSQLiteStore, withTransaction) +import Simplex.Messaging.Agent.Store.AgentStore (getSavedNtfToken) +import Simplex.Messaging.Agent.Store.SQLite (closeDBStore, reopenSQLiteStore) +import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction) import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String @@ -74,15 +76,9 @@ import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgFlags (MsgFlags), NtfSer import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.Server.Env.STM (ServerConfig (..)) import Simplex.Messaging.Transport (ATransport) -import System.Directory (doesFileExist, removeFile) import Test.Hspec import UnliftIO -removeFileIfExists :: FilePath -> IO () -removeFileIfExists filePath = do - fileExists <- doesFileExist filePath - when fileExists $ removeFile filePath - notificationTests :: ATransport -> Spec notificationTests t = do describe "Managing notification tokens" $ do @@ -500,7 +496,7 @@ testNotificationSubscriptionExistingConnection apns baseId alice@AgentClient {ag threadDelay 500000 suspendAgent alice 0 - closeSQLiteStore store + closeDBStore store threadDelay 1000000 putStrLn "before opening the database from another agent" diff --git a/tests/AgentTests/SQLiteTests.hs b/tests/AgentTests/SQLiteTests.hs index 09356c5b6..3fb791af6 100644 --- a/tests/AgentTests/SQLiteTests.hs +++ b/tests/AgentTests/SQLiteTests.hs @@ -41,10 +41,12 @@ import Simplex.FileTransfer.Types import Simplex.Messaging.Agent.Client () import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.Store +import Simplex.Messaging.Agent.Store.AgentStore import Simplex.Messaging.Agent.Store.SQLite -import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') +import Simplex.Messaging.Agent.Store.SQLite.Common (DBStore (..), withTransaction') import qualified Simplex.Messaging.Agent.Store.SQLite.DB as DB import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations +import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Crypto.File (CryptoFile (..)) import Simplex.Messaging.Crypto.Ratchet (InitialKeys (..), pattern PQSupportOn) @@ -59,36 +61,36 @@ import UnliftIO.Directory (removeFile) testDB :: String testDB = "tests/tmp/smp-agent.test.db" -withStore :: SpecWith SQLiteStore -> Spec -withStore = before createStore . after removeStore +withStore :: SpecWith DBStore -> Spec +withStore = before createStore' . after removeStore -withStore2 :: SpecWith (SQLiteStore, SQLiteStore) -> Spec +withStore2 :: SpecWith (DBStore, DBStore) -> Spec withStore2 = before connect2 . after (removeStore . fst) where - connect2 :: IO (SQLiteStore, SQLiteStore) + connect2 :: IO (DBStore, DBStore) connect2 = do - s1 <- createStore + s1 <- createStore' s2 <- connectSQLiteStore (dbFilePath s1) "" False pure (s1, s2) -createStore :: IO SQLiteStore -createStore = createEncryptedStore "" False +createStore' :: IO DBStore +createStore' = createEncryptedStore "" False -createEncryptedStore :: ScrubbedBytes -> Bool -> IO SQLiteStore +createEncryptedStore :: ScrubbedBytes -> Bool -> IO DBStore createEncryptedStore key keepKey = do -- Randomize DB file name to avoid SQLite IO errors supposedly caused by asynchronous -- IO operations on multiple similarly named files; error seems to be environment specific r <- randomIO :: IO Word32 - Right st <- createSQLiteStore (testDB <> show r) key keepKey Migrations.app MCError + Right st <- createDBStore (testDB <> show r) key keepKey Migrations.app MCError withTransaction' st (`SQL.execute_` "INSERT INTO users (user_id) VALUES (1);") pure st -removeStore :: SQLiteStore -> IO () +removeStore :: DBStore -> IO () removeStore db = do close db removeFile $ dbFilePath db where - close :: SQLiteStore -> IO () + close :: DBStore -> IO () close st = mapM_ DB.close =<< tryTakeMVar (dbConnection st) storeTests :: Spec @@ -147,7 +149,7 @@ storeTests = do it "should close and re-open encrypted store" testCloseReopenEncryptedStore it "should close and re-open encrypted store (keep key)" testReopenEncryptedStoreKeepKey -testConcurrentWrites :: SpecWith (SQLiteStore, SQLiteStore) +testConcurrentWrites :: SpecWith (DBStore, DBStore) testConcurrentWrites = it "should complete multiple concurrent write transactions w/t sqlite busy errors" $ \(s1, s2) -> do g <- C.newRandom @@ -156,22 +158,22 @@ testConcurrentWrites = let ConnData {connId} = cData1 concurrently_ (runTest s1 connId rq) (runTest s2 connId rq) where - runTest :: SQLiteStore -> ConnId -> RcvQueue -> IO () + runTest :: DBStore -> ConnId -> RcvQueue -> IO () runTest st connId rq = replicateM_ 100 . withTransaction st $ \db -> do (internalId, internalRcvId, _, _) <- updateRcvIds db connId let rcvMsgData = mkRcvMsgData internalId internalRcvId 0 "0" "hash_dummy" createRcvMsg db connId rq rcvMsgData -testCompiledThreadsafe :: SpecWith SQLiteStore +testCompiledThreadsafe :: SpecWith DBStore testCompiledThreadsafe = it "compiled sqlite library should be threadsafe" . withStoreTransaction $ \db -> do compileOptions <- DB.query_ db "pragma COMPILE_OPTIONS;" :: IO [[T.Text]] compileOptions `shouldNotContain` [["THREADSAFE=0"]] -withStoreTransaction :: (DB.Connection -> IO a) -> SQLiteStore -> IO a +withStoreTransaction :: (DB.Connection -> IO a) -> DBStore -> IO a withStoreTransaction = flip withTransaction -testForeignKeysEnabled :: SpecWith SQLiteStore +testForeignKeysEnabled :: SpecWith DBStore testForeignKeysEnabled = it "foreign keys should be enabled" . withStoreTransaction $ \db -> do let inconsistentQuery = @@ -261,7 +263,7 @@ createRcvConn db g cData rq cMode = runExceptT $ do rq' <- ExceptT $ updateNewConnRcv db connId rq pure (connId, rq') -testCreateRcvConn :: SpecWith SQLiteStore +testCreateRcvConn :: SpecWith DBStore testCreateRcvConn = it "should create RcvConnection and add SndQueue" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -275,7 +277,7 @@ testCreateRcvConn = getConn db "conn1" `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq] [sq])) -testCreateRcvConnRandomId :: SpecWith SQLiteStore +testCreateRcvConnRandomId :: SpecWith DBStore testCreateRcvConnRandomId = it "should create RcvConnection and add SndQueue with random ID" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -287,7 +289,7 @@ testCreateRcvConnRandomId = getConn db connId `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq] [sq])) -testCreateRcvConnDuplicate :: SpecWith SQLiteStore +testCreateRcvConnDuplicate :: SpecWith DBStore testCreateRcvConnDuplicate = it "should throw error on attempt to create duplicate RcvConnection" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -295,7 +297,7 @@ testCreateRcvConnDuplicate = createRcvConn db g cData1 rcvQueue1 SCMInvitation `shouldReturn` Left SEConnDuplicate -testCreateSndConn :: SpecWith SQLiteStore +testCreateSndConn :: SpecWith DBStore testCreateSndConn = it "should create SndConnection and add RcvQueue" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -309,7 +311,7 @@ testCreateSndConn = getConn db "conn1" `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 [rq] [sq])) -testCreateSndConnRandomID :: SpecWith SQLiteStore +testCreateSndConnRandomID :: SpecWith DBStore testCreateSndConnRandomID = it "should create SndConnection and add RcvQueue with random ID" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -321,7 +323,7 @@ testCreateSndConnRandomID = getConn db connId `shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} [rq] [sq])) -testCreateSndConnDuplicate :: SpecWith SQLiteStore +testCreateSndConnDuplicate :: SpecWith DBStore testCreateSndConnDuplicate = it "should throw error on attempt to create duplicate SndConnection" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -329,7 +331,7 @@ testCreateSndConnDuplicate = createSndConn db g cData1 sndQueue1 `shouldReturn` Left SEConnDuplicate -testGetRcvConn :: SpecWith SQLiteStore +testGetRcvConn :: SpecWith DBStore testGetRcvConn = it "should get connection using rcv queue id and server" . withStoreTransaction $ \db -> do let smpServer = SMPServer "smp.simplex.im" "5223" testKeyHash @@ -339,7 +341,7 @@ testGetRcvConn = getRcvConn db smpServer recipientId `shouldReturn` Right (rq, SomeConn SCRcv (RcvConnection cData1 rq)) -testSetConnUserIdNewConn :: SpecWith SQLiteStore +testSetConnUserIdNewConn :: SpecWith DBStore testSetConnUserIdNewConn = it "should set user id for new connection" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -352,9 +354,9 @@ testSetConnUserIdNewConn = let ConnData {userId} = connData userId `shouldBe` newUserId _ -> do - expectationFailure "Failed to get connection" + expectationFailure "Failed to get connection" -testDeleteRcvConn :: SpecWith SQLiteStore +testDeleteRcvConn :: SpecWith DBStore testDeleteRcvConn = it "should create RcvConnection and delete it" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -366,7 +368,7 @@ testDeleteRcvConn = getConn db "conn1" `shouldReturn` Left SEConnNotFound -testDeleteSndConn :: SpecWith SQLiteStore +testDeleteSndConn :: SpecWith DBStore testDeleteSndConn = it "should create SndConnection and delete it" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -378,7 +380,7 @@ testDeleteSndConn = getConn db "conn1" `shouldReturn` Left SEConnNotFound -testDeleteDuplexConn :: SpecWith SQLiteStore +testDeleteDuplexConn :: SpecWith DBStore testDeleteDuplexConn = it "should create DuplexConnection and delete it" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -391,7 +393,7 @@ testDeleteDuplexConn = getConn db "conn1" `shouldReturn` Left SEConnNotFound -testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore +testUpgradeRcvConnToDuplex :: SpecWith DBStore testUpgradeRcvConnToDuplex = it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -420,7 +422,7 @@ testUpgradeRcvConnToDuplex = upgradeRcvConnToDuplex db "conn1" anotherSndQueue `shouldReturn` Left (SEBadConnType CDuplex) -testUpgradeSndConnToDuplex :: SpecWith SQLiteStore +testUpgradeSndConnToDuplex :: SpecWith DBStore testUpgradeSndConnToDuplex = it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -452,7 +454,7 @@ testUpgradeSndConnToDuplex = upgradeSndConnToDuplex db "conn1" anotherRcvQueue `shouldReturn` Left (SEBadConnType CDuplex) -testSetRcvQueueStatus :: SpecWith SQLiteStore +testSetRcvQueueStatus :: SpecWith DBStore testSetRcvQueueStatus = it "should update status of RcvQueue" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -464,7 +466,7 @@ testSetRcvQueueStatus = getConn db "conn1" `shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rq {status = Confirmed})) -testSetSndQueueStatus :: SpecWith SQLiteStore +testSetSndQueueStatus :: SpecWith DBStore testSetSndQueueStatus = it "should update status of SndQueue" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -476,7 +478,7 @@ testSetSndQueueStatus = getConn db "conn1" `shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sq {status = Confirmed})) -testSetQueueStatusDuplex :: SpecWith SQLiteStore +testSetQueueStatusDuplex :: SpecWith DBStore testSetQueueStatusDuplex = it "should update statuses of RcvQueue and SndQueue in DuplexConnection" . withStoreTransaction $ \db -> do g <- C.newRandom @@ -529,7 +531,7 @@ testCreateRcvMsg_ db expectedPrevSndId expectedPrevHash connId rq rcvMsgData@Rcv createRcvMsg db connId rq rcvMsgData `shouldReturn` () -testCreateRcvMsg :: SpecWith SQLiteStore +testCreateRcvMsg :: SpecWith DBStore testCreateRcvMsg = it "should reserve internal ids and create a RcvMsg" $ \st -> do g <- C.newRandom @@ -563,7 +565,7 @@ testCreateSndMsg_ db expectedPrevHash connId sq sndMsgData@SndMsgData {..} = do createSndMsgDelivery db connId sq internalId `shouldReturn` () -testCreateSndMsg :: SpecWith SQLiteStore +testCreateSndMsg :: SpecWith DBStore testCreateSndMsg = it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \st -> do g <- C.newRandom @@ -574,7 +576,7 @@ testCreateSndMsg = testCreateSndMsg_ db "" connId sq $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy" testCreateSndMsg_ db "hash_dummy" connId sq $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy" -testCreateRcvAndSndMsgs :: SpecWith SQLiteStore +testCreateRcvAndSndMsgs :: SpecWith DBStore testCreateRcvAndSndMsgs = it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \st -> do let ConnData {connId} = cData1 @@ -592,15 +594,15 @@ testCreateRcvAndSndMsgs = testCloseReopenStore :: IO () testCloseReopenStore = do - st <- createStore + st <- createStore' hasMigrations st - closeSQLiteStore st - closeSQLiteStore st + closeDBStore st + closeDBStore st errorGettingMigrations st openSQLiteStore st "" False openSQLiteStore st "" False hasMigrations st - closeSQLiteStore st + closeDBStore st errorGettingMigrations st reopenSQLiteStore st hasMigrations st @@ -610,14 +612,14 @@ testCloseReopenEncryptedStore = do let key = "test_key" st <- createEncryptedStore key False hasMigrations st - closeSQLiteStore st - closeSQLiteStore st + closeDBStore st + closeDBStore st errorGettingMigrations st reopenSQLiteStore st `shouldThrow` \(e :: SomeException) -> "reopenSQLiteStore: no key" `isInfixOf` show e openSQLiteStore st key True openSQLiteStore st key True hasMigrations st - closeSQLiteStore st + closeDBStore st errorGettingMigrations st reopenSQLiteStore st hasMigrations st @@ -627,21 +629,21 @@ testReopenEncryptedStoreKeepKey = do let key = "test_key" st <- createEncryptedStore key True hasMigrations st - closeSQLiteStore st + closeDBStore st errorGettingMigrations st reopenSQLiteStore st hasMigrations st -getMigrations :: SQLiteStore -> IO Bool -getMigrations st = not . null <$> withTransaction st (Migrations.getCurrent . DB.conn) +getMigrations :: DBStore -> IO Bool +getMigrations st = not . null <$> withTransaction st Migrations.getCurrent -hasMigrations :: SQLiteStore -> Expectation +hasMigrations :: DBStore -> Expectation hasMigrations st = getMigrations st `shouldReturn` True -errorGettingMigrations :: SQLiteStore -> Expectation +errorGettingMigrations :: DBStore -> Expectation errorGettingMigrations st = getMigrations st `shouldThrow` \(e :: SomeException) -> "ErrorMisuse" `isInfixOf` show e -testGetPendingQueueMsg :: SQLiteStore -> Expectation +testGetPendingQueueMsg :: DBStore -> Expectation testGetPendingQueueMsg st = do g <- C.newRandom withTransaction st $ \db -> do @@ -658,7 +660,7 @@ testGetPendingQueueMsg st = do Right (Just (Nothing, PendingMsgData {msgId})) <- getPendingQueueMsg db connId sq msgId `shouldBe` InternalId 2 -testGetPendingServerCommand :: SQLiteStore -> Expectation +testGetPendingServerCommand :: DBStore -> Expectation testGetPendingServerCommand st = do g <- C.newRandom withTransaction st $ \db -> do @@ -728,7 +730,7 @@ testFileCbNonce = either error id $ strDecode "dPSF-wrQpDiK_K6sYv0BDBZ9S4dg-jmu" testFileReplicaKey :: C.APrivateAuthKey testFileReplicaKey = C.APrivateAuthKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe" -testGetNextRcvChunkToDownload :: SQLiteStore -> Expectation +testGetNextRcvChunkToDownload :: DBStore -> Expectation testGetNextRcvChunkToDownload st = do g <- C.newRandom withTransaction st $ \db -> do @@ -745,7 +747,7 @@ testGetNextRcvChunkToDownload st = do Right (Just (RcvFileChunk {rcvFileEntityId}, _, Nothing)) <- getNextRcvChunkToDownload db xftpServer1 86400 rcvFileEntityId `shouldBe` fId2 -testGetNextRcvFileToDecrypt :: SQLiteStore -> Expectation +testGetNextRcvFileToDecrypt :: DBStore -> Expectation testGetNextRcvFileToDecrypt st = do g <- C.newRandom withTransaction st $ \db -> do @@ -764,7 +766,7 @@ testGetNextRcvFileToDecrypt st = do Right (Just RcvFile {rcvFileEntityId}) <- getNextRcvFileToDecrypt db 86400 rcvFileEntityId `shouldBe` fId2 -testGetNextSndFileToPrepare :: SQLiteStore -> Expectation +testGetNextSndFileToPrepare :: DBStore -> Expectation testGetNextSndFileToPrepare st = do g <- C.newRandom withTransaction st $ \db -> do @@ -791,7 +793,7 @@ newSndChunkReplica1 = rcvIdsKeys = [(ChunkReplicaId $ EntityId "abc", testFileReplicaKey)] } -testGetNextSndChunkToUpload :: SQLiteStore -> Expectation +testGetNextSndChunkToUpload :: DBStore -> Expectation testGetNextSndChunkToUpload st = do g <- C.newRandom withTransaction st $ \db -> do @@ -814,7 +816,7 @@ testGetNextSndChunkToUpload st = do Right (Just SndFileChunk {sndFileEntityId}) <- getNextSndChunkToUpload db xftpServer1 86400 sndFileEntityId `shouldBe` fId2 -testGetNextDeletedSndChunkReplica :: SQLiteStore -> Expectation +testGetNextDeletedSndChunkReplica :: DBStore -> Expectation testGetNextDeletedSndChunkReplica st = do withTransaction st $ \db -> do Right Nothing <- getNextDeletedSndChunkReplica db xftpServer1 86400 @@ -830,17 +832,17 @@ testGetNextDeletedSndChunkReplica st = do Right (Just DeletedSndChunkReplica {deletedSndChunkReplicaId}) <- getNextDeletedSndChunkReplica db xftpServer1 86400 deletedSndChunkReplicaId `shouldBe` 2 -testMarkNtfSubActionNtfFailed :: SQLiteStore -> Expectation +testMarkNtfSubActionNtfFailed :: DBStore -> Expectation testMarkNtfSubActionNtfFailed st = do withTransaction st $ \db -> do markNtfSubActionNtfFailed_ db "abc" -testMarkNtfSubActionSMPFailed :: SQLiteStore -> Expectation +testMarkNtfSubActionSMPFailed :: DBStore -> Expectation testMarkNtfSubActionSMPFailed st = do withTransaction st $ \db -> do markNtfSubActionSMPFailed_ db "abc" -testMarkNtfTokenToDeleteFailed :: SQLiteStore -> Expectation +testMarkNtfTokenToDeleteFailed :: DBStore -> Expectation testMarkNtfTokenToDeleteFailed st = do withTransaction st $ \db -> do markNtfTokenToDeleteFailed_ db 1 diff --git a/tests/AgentTests/SchemaDump.hs b/tests/AgentTests/SchemaDump.hs index 3ee2774bc..b7fcce8ee 100644 --- a/tests/AgentTests/SchemaDump.hs +++ b/tests/AgentTests/SchemaDump.hs @@ -12,8 +12,8 @@ import Database.SQLite.Simple (Only (..)) import qualified Database.SQLite.Simple as SQL import Simplex.Messaging.Agent.Store.SQLite import Simplex.Messaging.Agent.Store.SQLite.Common (withTransaction') -import Simplex.Messaging.Agent.Store.SQLite.Migrations (Migration (..), MigrationsToRun (..), toDownMigration) import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations +import Simplex.Messaging.Agent.Store.Shared (Migration (..), MigrationConfirmation (..), MigrationsToRun (..), toDownMigration) import Simplex.Messaging.Util (ifM) import System.Directory (createDirectoryIfMissing, doesFileExist, removeDirectoryRecursive, removeFile) import System.Process (readCreateProcess, shell) @@ -49,7 +49,7 @@ testVerifySchemaDump :: IO () testVerifySchemaDump = do savedSchema <- ifM (doesFileExist appSchema) (readFile appSchema) (pure "") savedSchema `deepseq` pure () - void $ createSQLiteStore testDB "" False Migrations.app MCConsole + void $ createDBStore testDB "" False Migrations.app MCConsole getSchema testDB appSchema `shouldReturn` savedSchema removeFile testDB @@ -57,7 +57,7 @@ testVerifyLintFKeyIndexes :: IO () testVerifyLintFKeyIndexes = do savedLint <- ifM (doesFileExist appLint) (readFile appLint) (pure "") savedLint `deepseq` pure () - void $ createSQLiteStore testDB "" False Migrations.app MCConsole + void $ createDBStore testDB "" False Migrations.app MCConsole getLintFKeyIndexes testDB "tests/tmp/agent_lint.sql" `shouldReturn` savedLint removeFile testDB @@ -70,9 +70,9 @@ withTmpFiles = testSchemaMigrations :: IO () testSchemaMigrations = do let noDownMigrations = dropWhileEnd (\Migration {down} -> isJust down) Migrations.app - Right st <- createSQLiteStore testDB "" False noDownMigrations MCError + Right st <- createDBStore testDB "" False noDownMigrations MCError mapM_ (testDownMigration st) $ drop (length noDownMigrations) Migrations.app - closeSQLiteStore st + closeDBStore st removeFile testDB removeFile testSchema where @@ -93,22 +93,22 @@ testSchemaMigrations = do testUsersMigrationNew :: IO () testUsersMigrationNew = do - Right st <- createSQLiteStore testDB "" False Migrations.app MCError + Right st <- createDBStore testDB "" False Migrations.app MCError withTransaction' st (`SQL.query_` "SELECT user_id FROM users;") `shouldReturn` ([] :: [Only Int]) - closeSQLiteStore st + closeDBStore st testUsersMigrationOld :: IO () testUsersMigrationOld = do let beforeUsers = takeWhile (("m20230110_users" /=) . name) Migrations.app - Right st <- createSQLiteStore testDB "" False beforeUsers MCError + Right st <- createDBStore testDB "" False beforeUsers MCError withTransaction' st (`SQL.query_` "SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users';") `shouldReturn` ([] :: [Only String]) - closeSQLiteStore st - Right st' <- createSQLiteStore testDB "" False Migrations.app MCYesUp + closeDBStore st + Right st' <- createDBStore testDB "" False Migrations.app MCYesUp withTransaction' st' (`SQL.query_` "SELECT user_id FROM users;") `shouldReturn` ([Only (1 :: Int)]) - closeSQLiteStore st' + closeDBStore st' skipComparisonForDownMigrations :: [String] skipComparisonForDownMigrations = diff --git a/tests/CoreTests/StoreLogTests.hs b/tests/CoreTests/StoreLogTests.hs index e24f9f1ea..90bea0192 100644 --- a/tests/CoreTests/StoreLogTests.hs +++ b/tests/CoreTests/StoreLogTests.hs @@ -10,13 +10,12 @@ module CoreTests.StoreLogTests where import Control.Concurrent.STM import Control.Monad +import CoreTests.MsgStoreTests import Crypto.Random (ChaChaDRG) import qualified Data.ByteString.Char8 as B import Data.Either (partitionEithers) import qualified Data.Map.Strict as M import SMPClient -import AgentTests.SQLiteTests -import CoreTests.MsgStoreTests import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol @@ -27,6 +26,9 @@ import Simplex.Messaging.Server.QueueStore import Simplex.Messaging.Server.StoreLog import Test.Hspec +testPublicAuthKey :: C.APublicAuthKey +testPublicAuthKey = C.APublicAuthKey C.SEd25519 (C.publicKey "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe") + testNtfCreds :: TVar ChaChaDRG -> IO NtfCreds testNtfCreds g = do (notifierKey, _) <- atomically $ C.generateAuthKeyPair C.SX25519 g @@ -54,7 +56,8 @@ storeLogTests = ((rId, qr), ntfCreds, date) <- runIO $ do g <- C.newRandom (,,) <$> testNewQueueRec g sndSecure <*> testNtfCreds g <*> getSystemDate - testSMPStoreLog ("SMP server store log, sndSecure = " <> show sndSecure) + testSMPStoreLog + ("SMP server store log, sndSecure = " <> show sndSecure) [ SLTC { name = "create new queue", saved = [CreateQueue qr], @@ -66,7 +69,7 @@ storeLogTests = saved = [CreateQueue qr, SecureQueue rId testPublicAuthKey], compacted = [CreateQueue qr {senderKey = Just testPublicAuthKey}], state = M.fromList [(rId, qr {senderKey = Just testPublicAuthKey})] - }, + }, SLTC { name = "create and delete queue", saved = [CreateQueue qr, DeleteQueue rId], @@ -90,7 +93,7 @@ storeLogTests = saved = [CreateQueue qr, UpdateTime rId date], compacted = [CreateQueue qr {updatedAt = Just date}], state = M.fromList [(rId, qr {updatedAt = Just date})] - } + } ] testSMPStoreLog :: String -> [SMPStoreLogTestCase] -> Spec diff --git a/tests/Fixtures.hs b/tests/Fixtures.hs new file mode 100644 index 000000000..a8e2542ec --- /dev/null +++ b/tests/Fixtures.hs @@ -0,0 +1,16 @@ +{-# LANGUAGE CPP #-} + +module Fixtures where + +#if defined(dbPostgres) +import Database.PostgreSQL.Simple (ConnectInfo (..), defaultConnectInfo) +#endif + +#if defined(dbPostgres) +testDBConnectInfo :: ConnectInfo +testDBConnectInfo = + defaultConnectInfo { + connectUser = "test_user", + connectDatabase = "test_agent_db" + } +#endif diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 5e5f91b09..fc66f2ab1 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} @@ -25,6 +26,17 @@ import Simplex.Messaging.Protocol (NtfServer, ProtoServerWithAuth (..), Protocol import Simplex.Messaging.Transport import XFTPClient (testXFTPServer) +-- name fixtures are reused, but they are used as schema name instead of database file path +#if defined(dbPostgres) +testDB :: String +testDB = "smp_agent_test_protocol_schema" + +testDB2 :: String +testDB2 = "smp_agent2_test_protocol_schema" + +testDB3 :: String +testDB3 = "smp_agent3_test_protocol_schema" +#else testDB :: FilePath testDB = "tests/tmp/smp-agent.test.protocol.db" @@ -33,6 +45,7 @@ testDB2 = "tests/tmp/smp-agent2.test.protocol.db" testDB3 :: FilePath testDB3 = "tests/tmp/smp-agent3.test.protocol.db" +#endif testSMPServer :: SMPServer testSMPServer = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001" diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index b827edda2..cbdc7a3f5 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -1,4 +1,5 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE GADTs #-} @@ -46,6 +47,10 @@ import System.Random (randomRIO) import Test.Hspec import UnliftIO import Util +#if defined(dbPostgres) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres.Util (dropAllSchemasExceptSystem) +#endif smpProxyTests :: Spec smpProxyTests = do @@ -101,7 +106,11 @@ smpProxyTests = do it "100x100 N4 C16" . twoServersMoreConc $ withNumCapabilities 4 $ 100 `inParrallel` deliver 100 it "100x100 N" . twoServersFirstProxy $ withNCPUCapabilities $ 100 `inParrallel` deliver 100 it "500x20" . twoServersFirstProxy $ 500 `inParrallel` deliver 20 +#if defined(dbPostgres) + after_ (dropAllSchemasExceptSystem testDBConnectInfo) . describe "agent API" $ do +#else describe "agent API" $ do +#endif describe "one server" $ do it "always via proxy" . oneServer $ agentDeliverMessageViaProxy ([srv1], SPMAlways, True) ([srv1], SPMAlways, True) C.SEd448 "hello 1" "hello 2" 1 diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index bdc5f4dc3..b0bd17dff 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -15,13 +15,12 @@ module ServerTests where -import AgentTests.NotificationTests (removeFileIfExists) -import CoreTests.MsgStoreTests (testJournalStoreCfg) import Control.Concurrent (ThreadId, killThread, threadDelay) import Control.Concurrent.STM import Control.Exception (SomeException, try) import Control.Monad import Control.Monad.IO.Class +import CoreTests.MsgStoreTests (testJournalStoreCfg) import Data.Bifunctor (first) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) @@ -51,9 +50,10 @@ import System.TimeIt (timeItT) import System.Timeout import Test.HUnit import Test.Hspec +import Util (removeFileIfExists) serverTests :: SpecWith (ATransport, AMSType) -serverTests = do +serverTests = do describe "SMP queues" $ do describe "NEW and KEY commands, SEND messages" testCreateSecure describe "NEW and SKEY commands" $ do diff --git a/tests/Test.hs b/tests/Test.hs index f8505b133..09fb856fd 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -1,8 +1,8 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TypeApplications #-} import AgentTests (agentTests) -import AgentTests.SchemaDump (schemaDumpTest) import CLITests import Control.Concurrent (threadDelay) import qualified Control.Exception as E @@ -34,6 +34,12 @@ import Test.Hspec import XFTPAgent import XFTPCLI import XFTPServerTests (xftpServerTests) +#if defined(dbPostgres) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres.Util (createDBAndUserIfNotExists, dropDatabaseAndUser) +#else +import AgentTests.SchemaDump (schemaDumpTest) +#endif logCfg :: LogConfig logCfg = LogConfig {lc_file = Nothing, lc_stderr = True} @@ -45,10 +51,17 @@ main = do setEnv "APNS_KEY_ID" "H82WD9K9AQ" setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8" hspec +#if defined(dbPostgres) + . beforeAll_ (dropDatabaseAndUser testDBConnectInfo >> createDBAndUserIfNotExists testDBConnectInfo) + . afterAll_ (dropDatabaseAndUser testDBConnectInfo) +#endif . before_ (createDirectoryIfMissing False "tests/tmp") . after_ (eventuallyRemove "tests/tmp" 3) $ do +-- TODO [postgres] schema dump for postgres +#if !defined(dbPostgres) describe "Agent SQLite schema dump" schemaDumpTest +#endif describe "Core tests" $ do describe "Batching tests" batchingTests describe "Encoding tests" encodingTests diff --git a/tests/Util.hs b/tests/Util.hs index 6ad6d054f..0ad371b69 100644 --- a/tests/Util.hs +++ b/tests/Util.hs @@ -1,9 +1,10 @@ module Util where -import Control.Monad (replicateM) +import Control.Monad (replicateM, when) import Data.Either (partitionEithers) import Data.List (tails) import GHC.Conc (getNumCapabilities, getNumProcessors, setNumCapabilities) +import System.Directory (doesFileExist, removeFile) import Test.Hspec import UnliftIO @@ -26,3 +27,8 @@ inParrallel n action = do combinations :: Int -> [a] -> [[a]] combinations 0 _ = [[]] combinations k xs = [y : ys | y : xs' <- tails xs, ys <- combinations (k - 1) xs'] + +removeFileIfExists :: FilePath -> IO () +removeFileIfExists filePath = do + fileExists <- doesFileExist filePath + when fileExists $ removeFile filePath diff --git a/tests/XFTPAgent.hs b/tests/XFTPAgent.hs index 6d6446959..f7e880083 100644 --- a/tests/XFTPAgent.hs +++ b/tests/XFTPAgent.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} @@ -46,40 +47,49 @@ import UnliftIO import UnliftIO.Concurrent import XFTPCLI import XFTPClient +#if defined(dbPostgres) +import Fixtures +import Simplex.Messaging.Agent.Store.Postgres.Util (dropAllSchemasExceptSystem) +#endif xftpAgentTests :: Spec -xftpAgentTests = around_ testBracket . describe "agent XFTP API" $ do - it "should send and receive file" $ withXFTPServer testXFTPAgentSendReceive - -- uncomment CPP option slow_servers and run hpack to run this test - xit "should send and receive file with slow server responses" $ - withXFTPServerCfg testXFTPServerConfig {responseDelay = 500000} $ - \_ -> testXFTPAgentSendReceive - it "should send and receive with encrypted local files" testXFTPAgentSendReceiveEncrypted - it "should send and receive large file with a redirect" testXFTPAgentSendReceiveRedirect - it "should send and receive small file without a redirect" testXFTPAgentSendReceiveNoRedirect - describe "sending and receiving with version negotiation" testXFTPAgentSendReceiveMatrix - it "should resume receiving file after restart" testXFTPAgentReceiveRestore - it "should cleanup rcv tmp path after permanent error" testXFTPAgentReceiveCleanup - it "should resume sending file after restart" testXFTPAgentSendRestore - xit'' "should cleanup snd prefix path after permanent error" testXFTPAgentSendCleanup - it "should delete sent file on server" testXFTPAgentDelete - it "should resume deleting file after restart" testXFTPAgentDeleteRestore - -- TODO when server is fixed to correctly send AUTH error, this test has to be modified to expect AUTH error - it "if file is deleted on server, should limit retries and continue receiving next file" testXFTPAgentDeleteOnServer - it "if file is expired on server, should report error and continue receiving next file" testXFTPAgentExpiredOnServer - it "should request additional recipient IDs when number of recipients exceeds maximum per request" testXFTPAgentRequestAdditionalRecipientIDs - describe "XFTP server test via agent API" $ do - it "should pass without basic auth" $ testXFTPServerTest Nothing (noAuthSrv testXFTPServer2) `shouldReturn` Nothing - let srv1 = testXFTPServer2 {keyHash = "1234"} - it "should fail with incorrect fingerprint" $ do - testXFTPServerTest Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK) - describe "server with password" $ do - let auth = Just "abcd" - srv = ProtoServerWithAuth testXFTPServer2 - authErr = Just (ProtocolTestFailure TSCreateFile $ XFTP (B.unpack $ strEncode testXFTPServer2) AUTH) - it "should pass with correct password" $ testXFTPServerTest auth (srv auth) `shouldReturn` Nothing - it "should fail without password" $ testXFTPServerTest auth (srv Nothing) `shouldReturn` authErr - it "should fail with incorrect password" $ testXFTPServerTest auth (srv $ Just "wrong") `shouldReturn` authErr +xftpAgentTests = + around_ testBracket +#if defined(dbPostgres) + . after_ (dropAllSchemasExceptSystem testDBConnectInfo) +#endif + . describe "agent XFTP API" $ do + it "should send and receive file" $ withXFTPServer testXFTPAgentSendReceive + -- uncomment CPP option slow_servers and run hpack to run this test + xit "should send and receive file with slow server responses" $ + withXFTPServerCfg testXFTPServerConfig {responseDelay = 500000} $ + \_ -> testXFTPAgentSendReceive + it "should send and receive with encrypted local files" testXFTPAgentSendReceiveEncrypted + it "should send and receive large file with a redirect" testXFTPAgentSendReceiveRedirect + it "should send and receive small file without a redirect" testXFTPAgentSendReceiveNoRedirect + describe "sending and receiving with version negotiation" testXFTPAgentSendReceiveMatrix + it "should resume receiving file after restart" testXFTPAgentReceiveRestore + it "should cleanup rcv tmp path after permanent error" testXFTPAgentReceiveCleanup + it "should resume sending file after restart" testXFTPAgentSendRestore + xit'' "should cleanup snd prefix path after permanent error" testXFTPAgentSendCleanup + it "should delete sent file on server" testXFTPAgentDelete + it "should resume deleting file after restart" testXFTPAgentDeleteRestore + -- TODO when server is fixed to correctly send AUTH error, this test has to be modified to expect AUTH error + it "if file is deleted on server, should limit retries and continue receiving next file" testXFTPAgentDeleteOnServer + it "if file is expired on server, should report error and continue receiving next file" testXFTPAgentExpiredOnServer + it "should request additional recipient IDs when number of recipients exceeds maximum per request" testXFTPAgentRequestAdditionalRecipientIDs + describe "XFTP server test via agent API" $ do + it "should pass without basic auth" $ testXFTPServerTest Nothing (noAuthSrv testXFTPServer2) `shouldReturn` Nothing + let srv1 = testXFTPServer2 {keyHash = "1234"} + it "should fail with incorrect fingerprint" $ do + testXFTPServerTest Nothing (noAuthSrv srv1) `shouldReturn` Just (ProtocolTestFailure TSConnect $ BROKER (B.unpack $ strEncode srv1) NETWORK) + describe "server with password" $ do + let auth = Just "abcd" + srv = ProtoServerWithAuth testXFTPServer2 + authErr = Just (ProtocolTestFailure TSCreateFile $ XFTP (B.unpack $ strEncode testXFTPServer2) AUTH) + it "should pass with correct password" $ testXFTPServerTest auth (srv auth) `shouldReturn` Nothing + it "should fail without password" $ testXFTPServerTest auth (srv Nothing) `shouldReturn` authErr + it "should fail with incorrect password" $ testXFTPServerTest auth (srv $ Just "wrong") `shouldReturn` authErr rfProgress :: forall m. (HasCallStack, MonadIO m, MonadFail m) => AgentClient -> Int64 -> m () rfProgress c expected = loop 0