diff --git a/services/shhext/chat/protocol.go b/services/shhext/chat/protocol.go index 8e74c6229e8..e234faeb813 100644 --- a/services/shhext/chat/protocol.go +++ b/services/shhext/chat/protocol.go @@ -24,7 +24,7 @@ func NewProtocolService(encryption *EncryptionService, addedBundlesHandler func( } } -func (p *ProtocolService) addBundleAndMarshal(myIdentityKey *ecdsa.PrivateKey, msg *ProtocolMessage) ([]byte, error) { +func (p *ProtocolService) addBundleAndMarshal(myIdentityKey *ecdsa.PrivateKey, msg *ProtocolMessage, sendSingle bool) ([]byte, error) { // Get a bundle bundle, err := p.encryption.CreateBundle(myIdentityKey) if err != nil { @@ -32,7 +32,13 @@ func (p *ProtocolService) addBundleAndMarshal(myIdentityKey *ecdsa.PrivateKey, m return nil, err } - msg.Bundles = []*Bundle{bundle} + if sendSingle { + // DEPRECATED: This is only for backward compatibility, remove once not + // an issue anymore + msg.Bundle = bundle + } else { + msg.Bundles = []*Bundle{bundle} + } // marshal for sending to wire marshaledMessage, err := proto.Marshal(msg) @@ -52,7 +58,7 @@ func (p *ProtocolService) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, pa PublicMessage: payload, } - return p.addBundleAndMarshal(myIdentityKey, protocolMessage) + return p.addBundleAndMarshal(myIdentityKey, protocolMessage, false) } // BuildDirectMessage marshals a 1:1 chat message given the user identity private key, the recipient's public key, and a payload @@ -72,7 +78,7 @@ func (p *ProtocolService) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, pa DirectMessage: encryptionResponse, } - payload, err := p.addBundleAndMarshal(myIdentityKey, protocolMessage) + payload, err := p.addBundleAndMarshal(myIdentityKey, protocolMessage, true) if err != nil { return nil, err } @@ -99,7 +105,7 @@ func (p *ProtocolService) BuildPairingMessage(myIdentityKey *ecdsa.PrivateKey, p DirectMessage: encryptionResponse, } - return p.addBundleAndMarshal(myIdentityKey, protocolMessage) + return p.addBundleAndMarshal(myIdentityKey, protocolMessage, true) } // ProcessPublicBundle processes a received X3DH bundle. diff --git a/services/shhext/chat/protocol_test.go b/services/shhext/chat/protocol_test.go index 0a3aff2577b..b12fcb7fb8c 100644 --- a/services/shhext/chat/protocol_test.go +++ b/services/shhext/chat/protocol_test.go @@ -44,6 +44,28 @@ func (s *ProtocolServiceTestSuite) SetupTest() { s.bob = NewProtocolService(NewEncryptionService(bobPersistence, DefaultEncryptionServiceConfig("2")), addedBundlesHandler) } +func (s *ProtocolServiceTestSuite) TestBuildPublicMessage() { + aliceKey, err := crypto.GenerateKey() + s.NoError(err) + + payload, err := proto.Marshal(&ChatMessagePayload{ + Content: "Test content", + ClockValue: 1, + ContentType: "a", + MessageType: "some type", + }) + s.NoError(err) + + marshaledMsg, err := s.alice.BuildPublicMessage(aliceKey, payload) + s.NoError(err) + s.NotNil(marshaledMsg, "It creates a message") + + unmarshaledMsg := &ProtocolMessage{} + err = proto.Unmarshal(marshaledMsg, unmarshaledMsg) + s.NoError(err) + s.NotNilf(unmarshaledMsg.GetBundles(), "It adds a bundle to the message") +} + func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() { bobKey, err := crypto.GenerateKey() s.NoError(err) @@ -59,24 +81,19 @@ func (s *ProtocolServiceTestSuite) TestBuildDirectMessage() { s.NoError(err) marshaledMsg, err := s.alice.BuildDirectMessage(aliceKey, payload, &bobKey.PublicKey, &aliceKey.PublicKey) - s.NoError(err) s.NotNil(marshaledMsg, "It creates a message") s.NotNil(marshaledMsg[&aliceKey.PublicKey], "It creates a single message") unmarshaledMsg := &ProtocolMessage{} err = proto.Unmarshal(marshaledMsg[&bobKey.PublicKey], unmarshaledMsg) - s.NoError(err) - - s.NotNilf(unmarshaledMsg.GetBundles(), "It adds a bundle to the message") + s.NotNilf(unmarshaledMsg.GetBundle(), "It adds a bundle to the message") directMessage := unmarshaledMsg.GetDirectMessage() - s.NotNilf(directMessage, "It sets the direct message") encryptedPayload := directMessage["none"].GetPayload() - s.NotNilf(encryptedPayload, "It sets the payload of the message") s.NotEqualf(payload, encryptedPayload, "It encrypts the payload")