diff --git a/services/shhext/chat/encryption.go b/services/shhext/chat/encryption.go index 7df9ba0da11..ab20abf615c 100644 --- a/services/shhext/chat/encryption.go +++ b/services/shhext/chat/encryption.go @@ -40,6 +40,8 @@ type EncryptionServiceConfig struct { MaxKeep int // How many keys do we store in total per session. MaxMessageKeysPerSession int + // How long before we refresh the interval in milliseconds + BundleRefreshInterval int64 } type IdentityAndIDPair [2]string @@ -51,6 +53,7 @@ func DefaultEncryptionServiceConfig(installationID string) EncryptionServiceConf MaxSkip: 1000, MaxKeep: 3000, MaxMessageKeysPerSession: 2000, + BundleRefreshInterval: 14 * 24 * 60 * 60 * 1000, InstallationID: installationID, } } @@ -107,7 +110,7 @@ func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle, } // If the bundle has expired we create a new one - if bundleContainer != nil && bundleContainer.GetBundle().Timestamp < time.Now().AddDate(0, 0, -14).UnixNano() { + if bundleContainer != nil && bundleContainer.GetBundle().Timestamp < time.Now().Add(-1*time.Duration(s.config.BundleRefreshInterval)*time.Millisecond).UnixNano() { // Mark sessions has expired if err := s.persistence.MarkBundleExpired(bundleContainer.GetBundle().GetIdentity()); err != nil { return nil, err diff --git a/services/shhext/chat/encryption_test.go b/services/shhext/chat/encryption_test.go index 29ae7662658..7c5ee825a50 100644 --- a/services/shhext/chat/encryption_test.go +++ b/services/shhext/chat/encryption_test.go @@ -4,6 +4,7 @@ import ( "crypto/ecdsa" "errors" "fmt" + "io/ioutil" "math/rand" "os" "reflect" @@ -25,16 +26,33 @@ func TestEncryptionServiceTestSuite(t *testing.T) { type EncryptionServiceTestSuite struct { suite.Suite - alice *EncryptionService - bob *EncryptionService + alice *EncryptionService + bob *EncryptionService + aliceDBPath string + bobDBPath string } -func (s *EncryptionServiceTestSuite) initDatabases() { +func (s *EncryptionServiceTestSuite) initDatabases(baseConfig *EncryptionServiceConfig) { + + aliceDBFile, err := ioutil.TempFile(os.TempDir(), "alice") + s.Require().NoError(err) + aliceDBPath := aliceDBFile.Name() + + bobDBFile, err := ioutil.TempFile(os.TempDir(), "bob") + s.Require().NoError(err) + bobDBPath := bobDBFile.Name() + + s.aliceDBPath = aliceDBPath + s.bobDBPath = bobDBPath + + if baseConfig == nil { + config := DefaultEncryptionServiceConfig(aliceInstallationID) + baseConfig = &config + } + const ( - aliceDBPath = "/tmp/alice.db" - aliceDBKey = "alice" - bobDBPath = "/tmp/bob.db" - bobDBKey = "bob" + aliceDBKey = "alice" + bobDBKey = "bob" ) alicePersistence, err := NewSQLLitePersistence(aliceDBPath, aliceDBKey) @@ -47,14 +65,20 @@ func (s *EncryptionServiceTestSuite) initDatabases() { panic(err) } - s.alice = NewEncryptionService(alicePersistence, DefaultEncryptionServiceConfig(aliceInstallationID)) - s.bob = NewEncryptionService(bobPersistence, DefaultEncryptionServiceConfig(bobInstallationID)) + baseConfig.InstallationID = aliceInstallationID + s.alice = NewEncryptionService(alicePersistence, *baseConfig) + + baseConfig.InstallationID = bobInstallationID + s.bob = NewEncryptionService(bobPersistence, *baseConfig) } func (s *EncryptionServiceTestSuite) SetupTest() { - os.Remove("/tmp/alice.db") - os.Remove("/tmp/bob.db") - s.initDatabases() + s.initDatabases(nil) +} + +func (s *EncryptionServiceTestSuite) TearDownTest() { + os.Remove(s.aliceDBPath) + os.Remove(s.bobDBPath) } func (s *EncryptionServiceTestSuite) TestCreateBundle() { @@ -749,6 +773,12 @@ func (s *EncryptionServiceTestSuite) TestBundleNotExisting() { // A new bundle has been received func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { + config := DefaultEncryptionServiceConfig("none") + // Set up refresh interval to "always" + config.BundleRefreshInterval = 1000 + + s.initDatabases(&config) + bobKey, err := crypto.GenerateKey() s.Require().NoError(err) @@ -756,23 +786,20 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { s.Require().NoError(err) // Create bundles - bobBundle1, err := NewBundleContainer(bobKey, bobInstallationID) + bobBundle1, err := s.bob.CreateBundle(bobKey) s.Require().NoError(err) + s.Require().Equal(uint32(1), bobBundle1.GetSignedPreKeys()[bobInstallationID].GetVersion()) - err = SignBundle(bobKey, bobBundle1) - s.Require().NoError(err) + // Sleep the required time so that bundle is refreshed + time.Sleep(time.Duration(config.BundleRefreshInterval) * time.Millisecond) - bobBundle2, err := NewBundleContainer(bobKey, bobInstallationID) - s.Require().NoError(err) - // We set the version - - bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].Version = 1 - - err = SignBundle(bobKey, bobBundle2) + // Create bundles + bobBundle2, err := s.bob.CreateBundle(bobKey) s.Require().NoError(err) + s.Require().Equal(uint32(2), bobBundle2.GetSignedPreKeys()[bobInstallationID].GetVersion()) // We add the first bob bundle - _, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle1.GetBundle()) + _, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle1) s.Require().NoError(err) // Alice sends a message @@ -786,10 +813,10 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { x3dhHeader1 := installationResponse1.GetX3DHHeader() s.NotNil(x3dhHeader1) - s.Equal(bobBundle1.GetBundle().GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader1.GetId()) + s.Equal(bobBundle1.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader1.GetId()) // We add the second bob bundle - _, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle2.GetBundle()) + _, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle2) s.Require().NoError(err) // Alice sends a message @@ -803,6 +830,6 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() { x3dhHeader2 := installationResponse2.GetX3DHHeader() s.NotNil(x3dhHeader2) - s.Equal(bobBundle2.GetBundle().GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId()) + s.Equal(bobBundle2.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId()) } diff --git a/services/shhext/chat/sql_lite_persistence.go b/services/shhext/chat/sql_lite_persistence.go index 16d20fab458..5f4c3969001 100644 --- a/services/shhext/chat/sql_lite_persistence.go +++ b/services/shhext/chat/sql_lite_persistence.go @@ -3,6 +3,8 @@ package chat import ( "crypto/ecdsa" "database/sql" + "fmt" + "os" "strings" "github.com/ethereum/go-ethereum/crypto" @@ -51,6 +53,64 @@ func NewSQLLitePersistence(path string, key string) (*SQLLitePersistence, error) return s, nil } +func MigrateDBFile(oldPath string, newPath string, key string) error { + _, err := os.Stat(oldPath) + + // No files, nothing to do + if os.IsNotExist(err) { + return nil + } + + // Any other error, throws + if err != nil { + return err + } + + if err := os.Rename(oldPath, newPath); err != nil { + return err + } + + // Migrate dev/nightly builds which used ON as a key for debugging + db, err := openDB(newPath, "ON") + if err != nil { + return err + } + + keyString := fmt.Sprintf("PRAGMA rekey=%s", key) + + if _, err = db.Exec(keyString); err != nil { + return err + } + + return nil + +} + +func openDB(path string, key string) (*sql.DB, error) { + db, err := sql.Open("sqlite3", path) + if err != nil { + return nil, err + } + + keyString := fmt.Sprintf("PRAGMA key=%s", key) + + // Disable concurrent access as not supported by the driver + db.SetMaxOpenConns(1) + + if _, err = db.Exec("PRAGMA foreign_keys=ON"); err != nil { + return nil, err + } + + if _, err = db.Exec(keyString); err != nil { + return nil, err + } + + if _, err = db.Exec("PRAGMA cypher_page_size=4096"); err != nil { + return nil, err + } + return db, nil +} + // NewSQLLiteKeysStorage creates a new SQLLiteKeysStorage instance associated with the specified database func NewSQLLiteKeysStorage(db *sql.DB) *SQLLiteKeysStorage { return &SQLLiteKeysStorage{ @@ -77,26 +137,11 @@ func (s *SQLLitePersistence) GetSessionStorage() dr.SessionStorage { // Open opens a file at the specified path func (s *SQLLitePersistence) Open(path string, key string) error { - db, err := sql.Open("sqlite3", path) + db, err := openDB(path, key) if err != nil { return err } - // Disable concurrent access as not supported by the driver - db.SetMaxOpenConns(1) - - if _, err = db.Exec("PRAGMA foreign_keys=ON"); err != nil { - return err - } - - if _, err = db.Exec("PRAGMA key=ON"); err != nil { - return err - } - - if _, err = db.Exec("PRAGMA cypher_page_size=4096"); err != nil { - return err - } - s.db = db return s.setup() @@ -111,7 +156,11 @@ func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error { for installationID, signedPreKey := range bc.GetBundle().GetSignedPreKeys() { var version uint32 - stmt, err := tx.Prepare("SELECT version FROM bundles WHERE installation_id = ? AND identity = ? ORDER BY version DESC LIMIT 1") + stmt, err := tx.Prepare(`SELECT version + FROM bundles + WHERE installation_id = ? AND identity = ? + ORDER BY version DESC + LIMIT 1`) if err != nil { return err } @@ -123,7 +172,8 @@ func (s *SQLLitePersistence) AddPrivateBundle(bc *BundleContainer) error { return err } - stmt, err = tx.Prepare("INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) VALUES(?, ?, ?, ?, ?, ?)") + stmt, err = tx.Prepare(`INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp) + VALUES(?, ?, ?, ?, ?, ?)`) if err != nil { return err } @@ -162,7 +212,8 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error { for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() { signedPreKey := signedPreKeyContainer.GetSignedPreKey() version := signedPreKeyContainer.GetVersion() - insertStmt, err := tx.Prepare("INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) VALUES( ?, ?, ?, ?, ?)") + insertStmt, err := tx.Prepare(`INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp) + VALUES( ?, ?, ?, ?, ?)`) if err != nil { return err } @@ -180,7 +231,9 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error { return err } // Mark old bundles as expired - updateStmt, err := tx.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND installation_id = ? AND version < ?") + updateStmt, err := tx.Prepare(`UPDATE bundles + SET expired = 1 + WHERE identity = ? AND installation_id = ? AND version < ?`) if err != nil { return err } @@ -205,7 +258,9 @@ func (s *SQLLitePersistence) AddPublicBundle(b *Bundle) error { func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installationIDs []string) (*BundleContainer, error) { /* #nosec */ - statement := "SELECT identity, private_key, signed_pre_key, installation_id, timestamp FROM bundles WHERE expired = 0 AND identity = ? AND installation_id IN (?" + strings.Repeat(",?", len(installationIDs)-1) + ")" + statement := `SELECT identity, private_key, signed_pre_key, installation_id, timestamp, version + FROM bundles + WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + ")" stmt, err := s.db.Prepare(statement) if err != nil { return nil, err @@ -215,6 +270,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat var timestamp int64 var identity []byte var privateKey []byte + var version uint32 args := make([]interface{}, len(installationIDs)+1) args[0] = myIdentityKey @@ -249,6 +305,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat &signedPreKey, &installationID, ×tamp, + &version, ) if err != nil { return nil, err @@ -258,7 +315,7 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat bundle.Timestamp = timestamp } - bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey} + bundle.SignedPreKeys[installationID] = &SignedPreKey{SignedPreKey: signedPreKey, Version: version} bundle.Identity = identity } @@ -273,7 +330,9 @@ func (s *SQLLitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installat // GetPrivateKeyBundle retrieves a private key for a bundle from the database func (s *SQLLitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error) { - stmt, err := s.db.Prepare("SELECT private_key FROM bundles WHERE expired = 0 AND signed_pre_key = ? LIMIT 1") + stmt, err := s.db.Prepare(`SELECT private_key + FROM bundles + WHERE expired = 0 AND signed_pre_key = ? LIMIT 1`) if err != nil { return nil, err } @@ -294,7 +353,9 @@ func (s *SQLLitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error // MarkBundleExpired expires any private bundle for a given identity func (s *SQLLitePersistence) MarkBundleExpired(identity []byte) error { - stmt, err := s.db.Prepare("UPDATE bundles SET expired = 1 WHERE identity = ? AND private_key IS NOT NULL") + stmt, err := s.db.Prepare(`UPDATE bundles + SET expired = 1 + WHERE identity = ? AND private_key IS NOT NULL`) if err != nil { return err } @@ -315,7 +376,10 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install identity := crypto.CompressPubkey(publicKey) /* #nosec */ - statement := "SELECT signed_pre_key,installation_id, version FROM bundles WHERE expired = 0 AND identity = ? AND installation_id IN (?" + strings.Repeat(",?", len(installationIDs)-1) + ") ORDER BY version DESC" + statement := `SELECT signed_pre_key,installation_id, version + FROM bundles + WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + `) + ORDER BY version DESC` stmt, err := s.db.Prepare(statement) if err != nil { return nil, err @@ -373,7 +437,8 @@ func (s *SQLLitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, install // AddRatchetInfo persists the specified ratchet info into the database func (s *SQLLitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleID []byte, ephemeralKey []byte, installationID string) error { - stmt, err := s.db.Prepare("INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id) VALUES(?, ?, ?, ?, ?)") + stmt, err := s.db.Prepare(`INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id) + VALUES(?, ?, ?, ?, ?)`) if err != nil { return err } @@ -392,7 +457,10 @@ func (s *SQLLitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleI // GetRatchetInfo retrieves the existing RatchetInfo for a specified bundle ID and interlocutor public key from the database func (s *SQLLitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byte, installationID string) (*RatchetInfo, error) { - stmt, err := s.db.Prepare("SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ? LIMIT 1") + stmt, err := s.db.Prepare(`SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id + FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key + WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ? + LIMIT 1`) if err != nil { return nil, err } @@ -423,7 +491,10 @@ func (s *SQLLitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byt // GetAnyRatchetInfo retrieves any existing RatchetInfo for a specified interlocutor public key from the database func (s *SQLLitePersistence) GetAnyRatchetInfo(identity []byte, installationID string) (*RatchetInfo, error) { - stmt, err := s.db.Prepare("SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? LIMIT 1") + stmt, err := s.db.Prepare(`SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key + FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key + WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? + LIMIT 1`) if err != nil { return nil, err } @@ -455,7 +526,9 @@ func (s *SQLLitePersistence) GetAnyRatchetInfo(identity []byte, installationID s // RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo // associated with the specified bundle ID and interlocutor identity public key func (s *SQLLitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity []byte, installationID string) error { - stmt, err := s.db.Prepare("UPDATE ratchet_info_v2 SET ephemeral_key = NULL WHERE identity = ? AND bundle_id = ? AND installation_id = ?") + stmt, err := s.db.Prepare(`UPDATE ratchet_info_v2 + SET ephemeral_key = NULL + WHERE identity = ? AND bundle_id = ? AND installation_id = ?`) if err != nil { return err } @@ -474,7 +547,10 @@ func (s *SQLLitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, error) { var keyBytes []byte var key [32]byte - stmt, err := s.db.Prepare("SELECT message_key FROM keys WHERE public_key = ? AND msg_num = ? LIMIT 1") + stmt, err := s.db.Prepare(`SELECT message_key + FROM keys + WHERE public_key = ? AND msg_num = ? + LIMIT 1`) if err != nil { return key, false, err @@ -495,7 +571,8 @@ func (s *SQLLiteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, erro // Put stores a key with the specified public key, message number and message key func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error { - stmt, err := s.db.Prepare("insert into keys(session_id, public_key, msg_num, message_key, seq_num) values(?, ?, ?, ?, ?)") + stmt, err := s.db.Prepare(`INSERT INTO keys(session_id, public_key, msg_num, message_key, seq_num) + VALUES(?, ?, ?, ?, ?)`) if err != nil { return err } @@ -514,7 +591,8 @@ func (s *SQLLiteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, m // DeleteOldMks caps remove any key < seq_num, included func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error { - stmt, err := s.db.Prepare("DELETE FROM keys WHERE session_id = ? AND seq_num <= ?") + stmt, err := s.db.Prepare(`DELETE FROM keys + WHERE session_id = ? AND seq_num <= ?`) if err != nil { return err } @@ -530,7 +608,8 @@ func (s *SQLLiteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) er // TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error { - stmt, err := s.db.Prepare("DELETE FROM keys WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)") + stmt, err := s.db.Prepare(`DELETE FROM keys + WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)`) if err != nil { return err } @@ -548,7 +627,8 @@ func (s *SQLLiteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int // DeleteMk deletes the key with the specified public key and message key func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error { - stmt, err := s.db.Prepare("DELETE FROM keys WHERE public_key = ? AND msg_num = ?") + stmt, err := s.db.Prepare(`DELETE FROM keys + WHERE public_key = ? AND msg_num = ?`) if err != nil { return err } @@ -564,7 +644,9 @@ func (s *SQLLiteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error { // Count returns the count of keys with the specified public key func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) { - stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys WHERE public_key = ?") + stmt, err := s.db.Prepare(`SELECT COUNT(1) + FROM keys + WHERE public_key = ?`) if err != nil { return 0, err } @@ -581,7 +663,8 @@ func (s *SQLLiteKeysStorage) Count(pubKey dr.Key) (uint, error) { // CountAll returns the count of keys with the specified public key func (s *SQLLiteKeysStorage) CountAll() (uint, error) { - stmt, err := s.db.Prepare("SELECT COUNT(1) FROM keys") + stmt, err := s.db.Prepare(`SELECT COUNT(1) + FROM keys`) if err != nil { return 0, err } @@ -619,7 +702,8 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error { recvChainKey := state.RecvCh.CK[:] recvChainN := state.RecvCh.N - stmt, err := s.db.Prepare("insert into sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") + stmt, err := s.db.Prepare(`INSERT INTO sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) if err != nil { return err } @@ -645,7 +729,9 @@ func (s *SQLLiteSessionStorage) Save(id []byte, state *dr.State) error { // Load retrieves the double ratchet state for a given ID func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) { - stmt, err := s.db.Prepare("SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count FROM sessions WHERE id = ?") + stmt, err := s.db.Prepare(`SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count + FROM sessions + WHERE id = ?`) if err != nil { return nil, err } @@ -710,7 +796,11 @@ func (s *SQLLiteSessionStorage) Load(id []byte) (*dr.State, error) { // GetActiveInstallations returns the active installations for a given identity func (s *SQLLitePersistence) GetActiveInstallations(maxInstallations int, identity []byte) ([]string, error) { - stmt, err := s.db.Prepare("SELECT installation_id FROM installations WHERE enabled = 1 AND identity = ? ORDER BY timestamp DESC LIMIT ?") + stmt, err := s.db.Prepare(`SELECT installation_id + FROM installations + WHERE enabled = 1 AND identity = ? + ORDER BY timestamp DESC + LIMIT ?`) if err != nil { return nil, err } @@ -744,7 +834,10 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64, } for _, installationID := range installationIDs { - stmt, err := tx.Prepare("SELECT enabled FROM installations WHERE identity = ? AND installation_id = ? LIMIT 1") + stmt, err := tx.Prepare(`SELECT enabled + FROM installations + WHERE identity = ? AND installation_id = ? + LIMIT 1`) if err != nil { return err } @@ -759,7 +852,9 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64, // We update timestamp if present without changing enabled if err != sql.ErrNoRows { - stmt, err = tx.Prepare("UPDATE installations SET timestamp = ?, enabled = ? WHERE identity = ? AND installation_id = ?") + stmt, err = tx.Prepare(`UPDATE installations + SET timestamp = ?, enabled = ? + WHERE identity = ? AND installation_id = ?`) if err != nil { return err } @@ -776,7 +871,8 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64, defer stmt.Close() } else { - stmt, err = tx.Prepare("INSERT INTO installations(identity, installation_id, timestamp, enabled) VALUES (?, ?, ?, ?)") + stmt, err = tx.Prepare(`INSERT INTO installations(identity, installation_id, timestamp, enabled) + VALUES (?, ?, ?, ?)`) if err != nil { return err } @@ -806,7 +902,9 @@ func (s *SQLLitePersistence) AddInstallations(identity []byte, timestamp int64, // EnableInstallation enables the installation func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID string) error { - stmt, err := s.db.Prepare("UPDATE installations SET enabled = 1 WHERE identity = ? AND installation_id = ?") + stmt, err := s.db.Prepare(`UPDATE installations + SET enabled = 1 + WHERE identity = ? AND installation_id = ?`) if err != nil { return err } @@ -819,7 +917,9 @@ func (s *SQLLitePersistence) EnableInstallation(identity []byte, installationID // DisableInstallation disable the installation func (s *SQLLitePersistence) DisableInstallation(identity []byte, installationID string) error { - stmt, err := s.db.Prepare("UPDATE installations SET enabled = 0 WHERE identity = ? AND installation_id = ?") + stmt, err := s.db.Prepare(`UPDATE installations + SET enabled = 0 + WHERE identity = ? AND installation_id = ?`) if err != nil { return err } diff --git a/services/shhext/chat/sql_lite_persistence_test.go b/services/shhext/chat/sql_lite_persistence_test.go index 16ac690d119..31e79e9a4b3 100644 --- a/services/shhext/chat/sql_lite_persistence_test.go +++ b/services/shhext/chat/sql_lite_persistence_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/ethereum/go-ethereum/crypto" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/suite" ) @@ -74,7 +73,7 @@ func (s *SQLLitePersistenceTestSuite) TestPrivateBundle() { anyPrivateBundle, err = s.service.GetAnyPrivateBundle(identity, []string{installationID}) s.Require().NoError(err) s.NotNil(anyPrivateBundle) - s.True(proto.Equal(bundle.GetBundle(), anyPrivateBundle.GetBundle()), "It returns the same bundle") + s.Equal(bundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, anyPrivateBundle.GetBundle().GetSignedPreKeys()[installationID].SignedPreKey, "It returns the same bundle") } func (s *SQLLitePersistenceTestSuite) TestPublicBundle() { diff --git a/services/shhext/service.go b/services/shhext/service.go index 91412924d37..00d51a0efa5 100644 --- a/services/shhext/service.go +++ b/services/shhext/service.go @@ -107,7 +107,14 @@ func (s *Service) InitProtocol(address string, password string) error { if err := os.MkdirAll(filepath.Clean(s.dataDir), os.ModePerm); err != nil { return err } - persistence, err := chat.NewSQLLitePersistence(filepath.Join(s.dataDir, fmt.Sprintf("%x.db", address)), password) + oldPath := filepath.Join(s.dataDir, fmt.Sprintf("%x.db", address)) + newPath := filepath.Join(s.dataDir, fmt.Sprintf("%s.db", s.installationID)) + + if err := chat.MigrateDBFile(oldPath, newPath, password); err != nil { + return err + } + + persistence, err := chat.NewSQLLitePersistence(newPath, password) if err != nil { return err }