diff --git a/config/config.go b/config/config.go
index e9ca311..c41e320 100644
--- a/config/config.go
+++ b/config/config.go
@@ -372,8 +372,12 @@ type Provider struct {
Kaetzchen []*Kaetzchen
// CBORPluginKaetzchen is the list of configured external CBOR Kaetzchen plugins
- // for this provider.
+ // for this Provider.
CBORPluginKaetzchen []*CBORPluginKaetzchen
+
+ // PubsubPlugin is the list of configured external Publish Subscribe service plugins
+ // for this Provider.
+ PubsubPlugin []*PubsubPlugin
}
// SQLDB is the SQL database backend configuration.
@@ -535,6 +539,56 @@ func (kCfg *CBORPluginKaetzchen) validate() error {
return nil
}
+// PubsubPlugin is a Provider Publish Subscribe service agent.
+type PubsubPlugin struct {
+ // Capability is the capability exposed by the agent.
+ Capability string
+
+ // Endpoint is the provider side endpoint that the agent will accept
+ // requests at. While not required by the spec, this server only
+ // supports Endpoints that are lower-case local-parts of an e-mail
+ // address.
+ Endpoint string
+
+ // Config is the extra per agent arguments to be passed to the agent's
+ // initialization routine.
+ Config map[string]interface{}
+
+ // Command is the full file path to the external plugin program
+ // that implements this Kaetzchen service.
+ Command string
+
+ // MaxConcurrency is the number of worker goroutines to start
+ // for this service.
+ MaxConcurrency int
+
+ // Disable disabled a configured agent.
+ Disable bool
+}
+
+func (kCfg *PubsubPlugin) validate() error {
+ if kCfg.Capability == "" {
+ return fmt.Errorf("config: Pubsub: Capability is invalid")
+ }
+
+ // Ensure the endpoint is normalized.
+ epNorm, err := precis.UsernameCaseMapped.String(kCfg.Endpoint)
+ if err != nil {
+ return fmt.Errorf("config: Pubsub: '%v' has invalid endpoint: %v", kCfg.Capability, err)
+ }
+ if epNorm != kCfg.Endpoint {
+ return fmt.Errorf("config: Pubsub: '%v' has non-normalized endpoint %v", kCfg.Capability, kCfg.Endpoint)
+ }
+ if kCfg.Command == "" {
+ return fmt.Errorf("config: Pubsub: Command is invalid")
+ }
+ if _, err = mail.ParseAddress(kCfg.Endpoint + "@test.invalid"); err != nil {
+ return fmt.Errorf("config: Pubsub: '%v' has non local-part endpoint '%v': %v", kCfg.Capability, kCfg.Endpoint, err)
+ }
+
+ return nil
+}
+
func (pCfg *Provider) applyDefaults(sCfg *Server) {
if pCfg.UserDB == nil {
pCfg.UserDB = &UserDB{}
@@ -656,6 +710,15 @@ func (pCfg *Provider) validate() error {
}
capaMap[v.Capability] = true
}
+ for _, v := range pCfg.PubsubPlugin {
+ if err := v.validate(); err != nil {
+ return err
+ }
+ if capaMap[v.Capability] {
+ return fmt.Errorf("config: Kaetzchen: '%v' configured multiple times", v.Capability)
+ }
+ capaMap[v.Capability] = true
+ }
return nil
}
diff --git a/internal/constants/constants.go b/internal/constants/constants.go
index ac0594f..88166f0 100644
--- a/internal/constants/constants.go
+++ b/internal/constants/constants.go
@@ -34,6 +34,7 @@ const (
DecoySubsystem = "decoy"
IncomingConnSubsystem = "incoming_conn"
KaetzchenSubsystem = "kaetzchen"
+ PubsubPluginSubsystem = "pubsub_plugin"
OutgoingConnSubsystem = "outgoing_conn"
PKISubsystem = "pki"
ProviderSubsystem = "provider"
diff --git a/internal/glue/glue.go b/internal/glue/glue.go
index dc47c45..e4c206c 100644
--- a/internal/glue/glue.go
+++ b/internal/glue/glue.go
@@ -66,6 +66,7 @@ type PKI interface {
StartWorker()
OutgoingDestinations() map[[constants.NodeIDLength]byte]*pki.MixDescriptor
AuthenticateConnection(*wire.PeerCredentials, bool) (*pki.MixDescriptor, bool, bool)
+ GetCachedConsensusDoc(uint64) (*pki.Document, error)
GetRawConsensus(uint64) ([]byte, error)
}
diff --git a/internal/packet/packet.go b/internal/packet/packet.go
index 4a82a7d..679e2fd 100644
--- a/internal/packet/packet.go
+++ b/internal/packet/packet.go
@@ -19,11 +19,14 @@ package packet
import (
"fmt"
+ mRand "math/rand"
"sync"
"sync/atomic"
"time"
"github.com/katzenpost/core/constants"
+ "github.com/katzenpost/core/crypto/rand"
+ "github.com/katzenpost/core/pki"
"github.com/katzenpost/core/sphinx"
"github.com/katzenpost/core/sphinx/commands"
"github.com/katzenpost/core/utils"
@@ -215,12 +218,10 @@ func newRedundantError(cmd commands.RoutingCommand) error {
return fmt.Errorf("redundant command: %T", cmd)
}
-func ParseForwardPacket(pkt *Packet) ([]byte, []byte, error) {
+func ParseForwardPacket(pkt *Packet) ([]byte, [][]byte, error) {
const (
- hdrLength = constants.SphinxPlaintextHeaderLength + sphinx.SURBLength
- flagsPadding = 0
- flagsSURB = 1
- reserved = 0
+ hdrLength = constants.SphinxPlaintextHeaderLength
+ reserved = 0
)
// Sanity check the forward packet payload length.
@@ -230,26 +231,21 @@ func ParseForwardPacket(pkt *Packet) ([]byte, []byte, error) {
// Parse the payload, which should be a valid BlockSphinxPlaintext.
b := pkt.Payload
- if len(b) < hdrLength {
- return nil, nil, fmt.Errorf("truncated message block")
- }
if b[1] != reserved {
return nil, nil, fmt.Errorf("invalid message reserved: 0x%02x", b[1])
}
- ct := b[hdrLength:]
- var surb []byte
- switch b[0] {
- case flagsPadding:
- case flagsSURB:
- surb = b[constants.SphinxPlaintextHeaderLength:hdrLength]
- default:
- return nil, nil, fmt.Errorf("invalid message flags: 0x%02x", b[0])
+ surbCount := int(b[0])
+ if (surbCount * sphinx.SURBLength) >= (constants.ForwardPayloadLength - hdrLength) {
+ return nil, nil, fmt.Errorf("invalid message SURB count: %d", uint8(b[0]))
}
- if len(ct) != constants.UserForwardPayloadLength {
- return nil, nil, fmt.Errorf("mis-sized user payload: %v", len(ct))
+ surbs := make([][]byte, surbCount)
+ startOffset := 2
+ for i := 0; i < surbCount; i++ {
+ surbs[i] = b[startOffset : startOffset+sphinx.SURBLength]
+ startOffset += sphinx.SURBLength
}
-
- return ct, surb, nil
+ ct := b[hdrLength+(surbCount*sphinx.SURBLength):]
+ return ct, surbs, nil
}
func NewPacketFromSURB(pkt *Packet, surb, payload []byte) (*Packet, error) {
@@ -303,3 +299,60 @@ func NewPacketFromSURB(pkt *Packet, surb, payload []byte) (*Packet, error) {
return respPkt, nil
}
+
+func NewProviderDelay(rng *mRand.Rand, doc *pki.Document) uint32 {
+ delay := uint64(rand.Exp(rng, doc.Mu)) + 1
+ if doc.MuMaxDelay > 0 && delay > doc.MuMaxDelay {
+ delay = doc.MuMaxDelay
+ }
+ return uint32(delay)
+}
+
+// NewDelayedPacketFromSURB creates a new Packet given a SURB, payload and, delay
+// where the specified delay is for the first hop, the Provider.
+func NewDelayedPacketFromSURB(delay uint32, surb, payload []byte) (*Packet, error) {
+ // Pad out payloads to the full packet size.
+ var respPayload [constants.ForwardPayloadLength]byte
+ switch {
+ case len(payload) == 0:
+ case len(payload) > constants.ForwardPayloadLength:
+ return nil, fmt.Errorf("oversized response payload: %v", len(payload))
+ default:
+ copy(respPayload[:], payload)
+ }
+
+ // Build a response packet using a SURB.
+ //
+ // TODO/perf: This is a crypto operation that is paralleizable, and
+ // could be handled by the crypto worker(s), since those are allocated
+ // based on hardware acceleration considerations. However the forward
+ // packet processing doesn't constantly utilize the AES-NI units due
+ // to the non-AEZ components of a Sphinx Unwrap operation.
+ rawRespPkt, firstHop, err := sphinx.NewPacketFromSURB(surb, respPayload[:])
+ if err != nil {
+ return nil, err
+ }
+
+ // Build the command vector for the SURB-ACK
+ cmds := make([]commands.RoutingCommand, 0, 2)
+
+ nextHopCmd := new(commands.NextNodeHop)
+ copy(nextHopCmd.ID[:], firstHop[:])
+ cmds = append(cmds, nextHopCmd)
+
+ nodeDelayCmd := new(commands.NodeDelay)
+ nodeDelayCmd.Delay = delay
+ cmds = append(cmds, nodeDelayCmd)
+
+ // Assemble the response packet.
+ respPkt, _ := New(rawRespPkt)
+ respPkt.Set(nil, cmds)
+
+ respPkt.Delay = time.Duration(nodeDelayCmd.Delay) * time.Millisecond
+ respPkt.MustForward = true
+
+ // XXX: This should probably fudge the delay to account for processing
+ // time.
+
+ return respPkt, nil
+}
diff --git a/internal/packet/packet_test.go b/internal/packet/packet_test.go
new file mode 100644
index 0000000..7f3f423
--- /dev/null
+++ b/internal/packet/packet_test.go
@@ -0,0 +1,104 @@
+// packet_test.go - Katzenpost server packet structure tests.
+// Copyright (C) 2020 David Stainton.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+// Package packet implements the Katzenpost server side packet structure.
+package packet
+
+import (
+ "testing"
+
+ "github.com/katzenpost/core/constants"
+ "github.com/katzenpost/core/sphinx"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseForwardPacket(t *testing.T) {
+ require := require.New(t)
+
+ // test that wrong payload size is an error
+ wrongPayload := [constants.ForwardPayloadLength + 123]byte{}
+ pkt := &Packet{
+ Payload: wrongPayload[:],
+ }
+ _, _, err := ParseForwardPacket(pkt)
+ require.Error(err)
+
+ // test that the wrong reserved value is an error
+ payload := [constants.ForwardPayloadLength]byte{}
+ pkt = &Packet{
+ Payload: payload[:],
+ }
+ pkt.Payload[1] = byte(1)
+ _, _, err = ParseForwardPacket(pkt)
+ require.Error(err)
+
+ // test that an invalid SURB count is an error
+ payload = [constants.ForwardPayloadLength]byte{}
+ pkt = &Packet{
+ Payload: payload[:],
+ }
+ pkt.Payload[0] = byte(255)
+ _, _, err = ParseForwardPacket(pkt)
+ require.Error(err)
+
+ // test that an invalid SURB count is an error
+ payload = [constants.ForwardPayloadLength]byte{}
+ pkt = &Packet{
+ Payload: payload[:],
+ }
+ pkt.Payload[0] = byte(93)
+ _, _, err = ParseForwardPacket(pkt)
+ require.Error(err)
+
+ // test that the 1 SURB case is handled properly
+ payload = [constants.ForwardPayloadLength]byte{}
+ pkt = &Packet{
+ Payload: payload[:],
+ }
+ pkt.Payload[0] = byte(1)
+ pkt.Payload[constants.SphinxPlaintextHeaderLength+sphinx.SURBLength] = 1
+ ct, surbs, err := ParseForwardPacket(pkt)
+ require.NoError(err)
+ require.Equal(1, len(surbs))
+ require.Equal(constants.UserForwardPayloadLength, len(ct))
+ require.Equal(int(ct[0]), 1)
+ require.Equal(int(ct[1]), 0)
+
+ // test that the 2 SURB case is handled properly
+ payload = [constants.ForwardPayloadLength]byte{}
+ pkt = &Packet{
+ Payload: payload[:],
+ }
+ pkt.Payload[0] = byte(2)
+ pkt.Payload[constants.SphinxPlaintextHeaderLength+sphinx.SURBLength+sphinx.SURBLength] = 1
+ ct, surbs, err = ParseForwardPacket(pkt)
+ require.NoError(err)
+ require.Equal(2, len(surbs))
+ require.NotEqual(constants.UserForwardPayloadLength, len(ct))
+ require.Equal(int(ct[0]), 1)
+ require.Equal(int(ct[1]), 0)
+
+ // test that a large SURB count is OK
+ payload = [constants.ForwardPayloadLength]byte{}
+ pkt = &Packet{
+ Payload: payload[:],
+ }
+ pkt.Payload[0] = byte(92)
+ ct, surbs, err = ParseForwardPacket(pkt)
+ require.NoError(err)
+ require.Equal(92, len(surbs))
+ require.Equal((constants.ForwardPayloadLength-constants.SphinxPlaintextHeaderLength)-(92*sphinx.SURBLength), len(ct))
+}
diff --git a/internal/pki/pki.go b/internal/pki/pki.go
index 8cd9e22..67f0649 100644
--- a/internal/pki/pki.go
+++ b/internal/pki/pki.go
@@ -634,6 +634,17 @@ func (p *pki) OutgoingDestinations() map[[sConstants.NodeIDLength]byte]*cpki.Mix
return descMap
}
+// GetCachedConsensusDoc returns a cache PKI doc for the given epoch.
+func (p *pki) GetCachedConsensusDoc(epoch uint64) (*cpki.Document, error) {
+ p.RLock()
+ defer p.RUnlock()
+ entry, ok := p.docs[epoch]
+ if !ok {
+ return nil, errors.New("failed to retrieve cached pki doc")
+ }
+ return entry.Document(), nil
+}
+
func (p *pki) GetRawConsensus(epoch uint64) ([]byte, error) {
if ok, err := p.getFailedFetch(epoch); ok {
p.log.Debugf("GetRawConsensus failure: no cached PKI document for epoch %v: %v", epoch, err)
diff --git a/internal/provider/kaetzchen/cbor_plugins.go b/internal/provider/kaetzchen/cbor_plugins.go
index 80cb8a4..fb26208 100644
--- a/internal/provider/kaetzchen/cbor_plugins.go
+++ b/internal/provider/kaetzchen/cbor_plugins.go
@@ -123,17 +123,25 @@ func (k *CBORPluginWorker) processKaetzchen(pkt *packet.Packet, pluginClient cbo
defer kaetzchenRequestsTimer.ObserveDuration()
defer pkt.Dispose()
- ct, surb, err := packet.ParseForwardPacket(pkt)
+ ct, surbs, err := packet.ParseForwardPacket(pkt)
if err != nil {
k.log.Debugf("Dropping Kaetzchen request: %v (%v)", pkt.ID, err)
kaetzchenRequestsDropped.Inc()
return
}
-
+ if len(surbs) > 1 {
+ k.log.Debugf("Received multi-SURB payload, dropping Kaetzchen request: %v (%v)", pkt.ID, err)
+ kaetzchenRequestsDropped.Inc()
+ return
+ }
+ hasSURB := false
+ if len(surbs) == 1 {
+ hasSURB = true
+ }
resp, err := pluginClient.OnRequest(&cborplugin.Request{
ID: pkt.ID,
Payload: ct,
- HasSURB: surb != nil,
+ HasSURB: hasSURB,
})
switch err {
case nil:
@@ -151,10 +159,10 @@ func (k *CBORPluginWorker) processKaetzchen(pkt *packet.Packet, pluginClient cbo
}
// Iff there is a SURB, generate a SURB-Reply and schedule.
- if surb != nil {
+ if hasSURB {
// Prepend the response header.
resp = append([]byte{0x01, 0x00}, resp...)
-
+ surb := surbs[0]
respPkt, err := packet.NewPacketFromSURB(pkt, surb, resp)
if err != nil {
k.log.Debugf("Failed to generate SURB-Reply: %v (%v)", pkt.ID, err)
diff --git a/internal/provider/kaetzchen/kaetzchen.go b/internal/provider/kaetzchen/kaetzchen.go
index 6da9208..12327e3 100644
--- a/internal/provider/kaetzchen/kaetzchen.go
+++ b/internal/provider/kaetzchen/kaetzchen.go
@@ -243,18 +243,27 @@ func (k *KaetzchenWorker) processKaetzchen(pkt *packet.Packet) {
defer kaetzchenRequestsTimer.ObserveDuration()
defer pkt.Dispose()
- ct, surb, err := packet.ParseForwardPacket(pkt)
+ ct, surbs, err := packet.ParseForwardPacket(pkt)
if err != nil {
k.log.Debugf("Dropping Kaetzchen request: %v (%v)", pkt.ID, err)
k.incrementDropCounter()
kaetzchenRequestsDropped.Add(float64(k.getDropCounter()))
return
}
-
+ if len(surbs) > 1 {
+ k.log.Debugf("Multi-SURB packet sent to Kaetzchen recipient, dropping Kaetzchen request: %v (%v)", pkt.ID, err)
+ k.incrementDropCounter()
+ kaetzchenRequestsDropped.Add(float64(k.getDropCounter()))
+ return
+ }
var resp []byte
dst, ok := k.kaetzchen[pkt.Recipient.ID]
+ hasSURB := false
+ if len(surbs) == 1 {
+ hasSURB = true
+ }
if ok {
- resp, err = dst.OnRequest(pkt.ID, ct, surb != nil)
+ resp, err = dst.OnRequest(pkt.ID, ct, hasSURB)
}
switch {
case err == nil:
@@ -269,7 +278,8 @@ func (k *KaetzchenWorker) processKaetzchen(pkt *packet.Packet) {
}
// Iff there is a SURB, generate a SURB-Reply and schedule.
- if surb != nil {
+ if len(surbs) == 1 {
+ surb := surbs[0]
// Prepend the response header.
resp = append([]byte{0x01, 0x00}, resp...)
diff --git a/internal/provider/provider.go b/internal/provider/provider.go
index 5f863d1..5ca87ea 100644
--- a/internal/provider/provider.go
+++ b/internal/provider/provider.go
@@ -43,6 +43,7 @@ import (
"github.com/katzenpost/server/internal/glue"
"github.com/katzenpost/server/internal/packet"
"github.com/katzenpost/server/internal/provider/kaetzchen"
+ "github.com/katzenpost/server/internal/provider/pubsub"
"github.com/katzenpost/server/internal/sqldb"
"github.com/katzenpost/server/registration"
"github.com/katzenpost/server/spool"
@@ -75,6 +76,7 @@ type provider struct {
kaetzchenWorker *kaetzchen.KaetzchenWorker
cborPluginKaetzchenWorker *kaetzchen.CBORPluginWorker
+ pubsubPluginWorker *pubsub.PluginWorker
httpServers []*http.Server
}
@@ -101,6 +103,7 @@ func (p *provider) Halt() {
p.ch.Close()
p.kaetzchenWorker.Halt()
p.cborPluginKaetzchenWorker.Halt()
+ p.pubsubPluginWorker.Halt()
if p.userDB != nil {
p.userDB.Close()
p.userDB = nil
@@ -254,13 +257,26 @@ func (p *provider) worker() {
packetsDropped.Inc()
pkt.Dispose()
} else {
- // Note that we pass ownership of pkt to p.kaetzchenWorker
+ // Note that we pass ownership of pkt to p.cborPluginKaetzchenWorker
// which will take care to dispose of it.
p.cborPluginKaetzchenWorker.OnKaetzchen(pkt)
}
continue
}
+ if p.pubsubPluginWorker.HasRecipient(pkt.Recipient.ID) {
+ if pkt.IsSURBReply() {
+ p.log.Debugf("Dropping packet: %v (SURB-Reply for pubsub service)", pkt.ID)
+ packetsDropped.Inc()
+ pkt.Dispose()
+ } else {
+ // Note that we pass ownership of pkt to p.pubsubPluginWorker
+ // which will take care to dispose of it.
+ p.pubsubPluginWorker.OnSubscribeRequest(pkt)
+ }
+ continue
+ }
+
// Post-process the recipient.
recipient, err := p.fixupRecipient(pkt.Recipient.ID[:])
if err != nil {
@@ -306,21 +322,25 @@ func (p *provider) onSURBReply(pkt *packet.Packet, recipient []byte) {
}
func (p *provider) onToUser(pkt *packet.Packet, recipient []byte) {
- ct, surb, err := packet.ParseForwardPacket(pkt)
+ ct, surbs, err := packet.ParseForwardPacket(pkt)
if err != nil {
p.log.Debugf("Dropping packet: %v (%v)", pkt.ID, err)
packetsDropped.Inc()
return
}
-
// Store the ciphertext in the spool.
if err := p.spool.StoreMessage(recipient, ct); err != nil {
p.log.Debugf("Failed to store message payload: %v (%v)", pkt.ID, err)
return
}
-
+ if len(surbs) > 1 {
+ p.log.Debugf("Multi-SURB packet sent to user recipient, dropping packet: %v (%v)", pkt.ID, err)
+ packetsDropped.Inc()
+ return
+ }
// Iff there is a SURB, generate a SURB-ACK and schedule.
- if surb != nil {
+ if len(surbs) == 1 {
+ surb := surbs[0]
ackPkt, err := packet.NewPacketFromSURB(pkt, surb, nil)
if err != nil {
p.log.Debugf("Failed to generate SURB-ACK: %v (%v)", pkt.ID, err)
@@ -746,12 +766,17 @@ func New(glue glue.Glue) (glue.Provider, error) {
if err != nil {
return nil, err
}
+ pubsubPluginWorker, err := pubsub.New(glue)
+ if err != nil {
+ return nil, err
+ }
p := &provider{
glue: glue,
log: glue.LogBackend().GetLogger("provider"),
ch: channels.NewInfiniteChannel(),
kaetzchenWorker: kaetzchenWorker,
cborPluginKaetzchenWorker: cborPluginWorker,
+ pubsubPluginWorker: pubsubPluginWorker,
}
cfg := glue.Config()
diff --git a/internal/provider/pubsub/plugins.go b/internal/provider/pubsub/plugins.go
new file mode 100644
index 0000000..f1cbf17
--- /dev/null
+++ b/internal/provider/pubsub/plugins.go
@@ -0,0 +1,467 @@
+// plugins.go - publish subscribe plugin system for mix network services
+// Copyright (C) 2020 David Stainton.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+// Package pubsub implements support for provider side SURB based publish subscribe agents.
+package pubsub
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/katzenpost/core/crypto/rand"
+ "github.com/katzenpost/core/epochtime"
+ "github.com/katzenpost/core/monotime"
+ sConstants "github.com/katzenpost/core/sphinx/constants"
+ "github.com/katzenpost/core/worker"
+ "github.com/katzenpost/server/internal/constants"
+ "github.com/katzenpost/server/internal/glue"
+ "github.com/katzenpost/server/internal/packet"
+ "github.com/katzenpost/server/pubsubplugin/client"
+ "github.com/katzenpost/server/pubsubplugin/common"
+ "github.com/prometheus/client_golang/prometheus"
+ "golang.org/x/text/secure/precis"
+ "gopkg.in/eapache/channels.v1"
+ "gopkg.in/op/go-logging.v1"
+)
+
+var (
+ packetsDropped = prometheus.NewCounter(
+ prometheus.CounterOpts{
+ Namespace: constants.Namespace,
+ Name: "dropped_packets_total",
+ Subsystem: constants.PubsubPluginSubsystem,
+ Help: "Number of dropped packets",
+ },
+ )
+ pubsubRequests = prometheus.NewCounter(
+ prometheus.CounterOpts{
+ Namespace: constants.Namespace,
+ Name: "requests_total",
+ Subsystem: constants.PubsubPluginSubsystem,
+ Help: "Number of Pubsub requests",
+ },
+ )
+ pubsubRequestsDuration = prometheus.NewSummary(
+ prometheus.SummaryOpts{
+ Namespace: constants.Namespace,
+ Name: "requests_duration_seconds",
+ Subsystem: constants.PubsubPluginSubsystem,
+ Help: "Duration of a pubsub request in seconds",
+ },
+ )
+ pubsubRequestsDropped = prometheus.NewCounter(
+ prometheus.CounterOpts{
+ Namespace: constants.Namespace,
+ Name: "dropped_requests_total",
+ Subsystem: constants.PubsubPluginSubsystem,
+ Help: "Number of total dropped pubsub requests",
+ },
+ )
+ pubsubRequestsFailed = prometheus.NewCounter(
+ prometheus.CounterOpts{
+ Namespace: constants.Namespace,
+ Name: "failed_requests_total",
+ Subsystem: constants.PubsubPluginSubsystem,
+ Help: "Number of total failed pubsub requests",
+ },
+ )
+ pubsubRequestsTimer *prometheus.Timer
+)
+
+func init() {
+ prometheus.MustRegister(packetsDropped)
+ prometheus.MustRegister(pubsubRequests)
+ prometheus.MustRegister(pubsubRequestsDropped)
+ prometheus.MustRegister(pubsubRequestsFailed)
+ prometheus.MustRegister(pubsubRequestsDuration)
+}
+
+const (
+ // ParameterEndpoint is the mandatory Parameter key indicationg the
+ // Kaetzchen's endpoint.
+ ParameterEndpoint = "endpoint"
+)
+
+// GarbageCollectionInterval is the time interval between running our
+// subscription garbage collection routine. We shall attempt to garbage collect
+// 5 times per epoch.
+var GarbageCollectionInterval = epochtime.Period / 5
+
+// PluginChans maps from Recipient ID to channel.
+type PluginChans = map[[sConstants.RecipientIDLength]byte]*channels.InfiniteChannel
+
+// PluginName is the name of a plugin.
+type PluginName = string
+
+// PluginParameters maps from parameter key to value.
+type PluginParameters = map[PluginName]interface{}
+
+// ServiceMap maps from plugin name to plugin parameters
+// and is used by Mix Descriptors which describe Providers
+// with plugins. Each plugin can optionally set one or more
+// parameters.
+type ServiceMap = map[PluginName]PluginParameters
+
+// SURBBundle facilitates garbage collection of subscriptions
+// by keeping track of the Epoch that the SURBs were received.
+type SURBBundle struct {
+ // Epoch is the epoch whence the SURBs were received.
+ Epoch uint64
+
+ // SURBs is one or more SURBs.
+ SURBs [][]byte
+}
+
+func ensureEndpointSanitized(endpointStr, capa string) (*[sConstants.RecipientIDLength]byte, error) {
+ if endpointStr == "" {
+ return nil, fmt.Errorf("provider: Pubsub: '%v' provided no endpoint", capa)
+ } else if epNorm, err := precis.UsernameCaseMapped.String(endpointStr); err != nil {
+ return nil, fmt.Errorf("provider: Pubsub: '%v' invalid endpoint: %v", capa, err)
+ } else if epNorm != endpointStr {
+ return nil, fmt.Errorf("provider: Pubsub: '%v' invalid endpoint, not normalized", capa)
+ }
+ rawEp := []byte(endpointStr)
+ if len(rawEp) == 0 || len(rawEp) > sConstants.RecipientIDLength {
+ return nil, fmt.Errorf("provider: Pubsub: '%v' invalid endpoint, length out of bounds", capa)
+ }
+ var endpoint [sConstants.RecipientIDLength]byte
+ copy(endpoint[:], rawEp)
+ return &endpoint, nil
+}
+
+// PluginWorker implements the publish subscribe plugin worker.
+type PluginWorker struct {
+ worker.Worker
+
+ glue glue.Glue
+ log *logging.Logger
+
+ haltOnce sync.Once
+ subscriptions *sync.Map // [SubscriptionIDLength]byte -> *SURBBundle
+ pluginChans PluginChans
+ clients []*client.Dialer
+ forPKI ServiceMap
+}
+
+// OnSubscribeRequest enqueues the pkt for processing by our thread pool of plugins.
+func (k *PluginWorker) OnSubscribeRequest(pkt *packet.Packet) {
+ handlerCh, ok := k.pluginChans[pkt.Recipient.ID]
+ if !ok {
+ k.log.Debugf("Failed to find handler. Dropping PubsubPlugin request: %v", pkt.ID)
+ return
+ }
+ handlerCh.In() <- pkt
+}
+
+func (k *PluginWorker) sendReply(surb, payload []byte) {
+ // Prepend the response header.
+ payload = append([]byte{0x01, 0x00}, payload...)
+
+ // generate random delay for first hop of SURB-Reply on Provider
+ epoch, _, _ := epochtime.Now()
+ doc, err := k.glue.PKI().GetCachedConsensusDoc(epoch)
+ if err != nil {
+ k.log.Debugf("Failed to get PKI doc for generating SURB-Reply: %v", err)
+ return
+ }
+ delay := packet.NewProviderDelay(rand.NewMath(), doc)
+
+ respPkt, err := packet.NewDelayedPacketFromSURB(delay, surb, payload)
+ if err != nil {
+ k.log.Debugf("Failed to generate SURB-Reply: %v", err)
+ return
+ }
+
+ k.log.Debugf("Handing off newly generated SURB-Reply: %v", respPkt.ID)
+ k.glue.Scheduler().OnPacket(respPkt)
+ return
+}
+
+func (k *PluginWorker) garbageCollect() {
+ k.log.Debug("Running garbage collection process.")
+ // [SubscriptionIDLength]byte -> *SURBBundle
+ surbsMapRange := func(rawSubscriptionID, rawSurbBundle interface{}) bool {
+ subscriptionID := rawSubscriptionID.([common.SubscriptionIDLength]byte)
+ surbBundle := rawSurbBundle.(*SURBBundle)
+
+ epoch, _, _ := epochtime.Now()
+ if epoch-surbBundle.Epoch >= 2 {
+ k.subscriptions.Delete(subscriptionID)
+ }
+ return true
+ }
+ k.subscriptions.Range(surbsMapRange)
+}
+
+func (k *PluginWorker) garbageCollectionWorker() {
+ timer := time.NewTimer(GarbageCollectionInterval)
+ defer timer.Stop()
+ for {
+ select {
+ case <-k.HaltCh():
+ k.log.Debugf("Garbage collection worker terminating gracefully.")
+ return
+ case <-timer.C:
+ k.garbageCollect()
+ timer.Reset(GarbageCollectionInterval)
+ }
+ }
+}
+
+func (k *PluginWorker) appMessagesWorker(pluginClient *client.Dialer) {
+ for {
+ select {
+ case <-k.HaltCh():
+ return
+ case appMessages := <-pluginClient.IncomingCh():
+ rawSURBs, ok := k.subscriptions.Load(appMessages.SubscriptionID)
+ if !ok {
+ k.log.Error("Error, failed load a subscription ID from sync.Map")
+ continue
+ }
+ surbBundle, ok := rawSURBs.(*SURBBundle)
+ if !ok {
+ k.log.Error("Error, failed type assertion for type *SURBBundle")
+ continue
+ }
+ messagesBlob, err := common.MessagesToBytes(appMessages.Messages)
+ if err != nil {
+ k.log.Errorf("Error, failed to encode app messages as CBOR blob: %s", err)
+ continue
+ }
+ surb := surbBundle.SURBs[0]
+ if len(surbBundle.SURBs) == 1 {
+ k.log.Debug("Using last SURB in subscription.")
+ k.subscriptions.Delete(appMessages.SubscriptionID)
+ pluginClient.Unsubscribe(appMessages.SubscriptionID)
+ } else {
+ surbBundle.SURBs = surbBundle.SURBs[1:]
+ k.subscriptions.Store(appMessages.SubscriptionID, surbBundle)
+ }
+ k.sendReply(surb, messagesBlob)
+ }
+ }
+}
+
+func (k *PluginWorker) subscriptionWorker(recipient [sConstants.RecipientIDLength]byte, pluginClient *client.Dialer) {
+
+ // Kaetzchen delay is our max dwell time.
+ maxDwell := time.Duration(k.glue.Config().Debug.KaetzchenDelay) * time.Millisecond
+
+ defer k.haltOnce.Do(k.haltAllClients)
+
+ handlerCh, ok := k.pluginChans[recipient]
+ if !ok {
+ k.log.Debugf("Failed to find handler. Dropping PubsubPlugin request: %v", recipient)
+ pubsubRequestsDropped.Inc()
+ return
+ }
+ ch := handlerCh.Out()
+
+ for {
+ var pkt *packet.Packet
+ select {
+ case <-k.HaltCh():
+ k.log.Debugf("Terminating gracefully.")
+ return
+ case e := <-ch:
+ pkt = e.(*packet.Packet)
+ if dwellTime := monotime.Now() - pkt.DispatchAt; dwellTime > maxDwell {
+ k.log.Debugf("Dropping packet: %v (Spend %v in queue)", pkt.ID, dwellTime)
+ packetsDropped.Inc()
+ pkt.Dispose()
+ continue
+ }
+ k.processPacket(pkt, pluginClient)
+ pubsubRequests.Inc()
+ }
+ }
+}
+
+func (k *PluginWorker) haltAllClients() {
+ k.log.Debug("Halting plugin clients.")
+ for _, client := range k.clients {
+ go client.Halt()
+ }
+}
+
+func (k *PluginWorker) processPacket(pkt *packet.Packet, pluginClient *client.Dialer) {
+ pubsubRequestsTimer = prometheus.NewTimer(pubsubRequestsDuration)
+ defer pubsubRequestsTimer.ObserveDuration()
+ defer pkt.Dispose()
+
+ payload, surbs, err := packet.ParseForwardPacket(pkt)
+ if err != nil {
+ k.log.Debugf("Failed to parse forward packet. Dropping Pubsub request: %v (%v)", pkt.ID, err)
+ pubsubRequestsDropped.Inc()
+ return
+ }
+ if len(surbs) == 0 {
+ k.log.Debugf("Zero SURBs supplied. Dropping Pubsub request: %v (%v)", pkt.ID, err)
+ pubsubRequestsDropped.Inc()
+ return
+ }
+ clientSubscribe, err := common.ClientSubscribeFromBytes(payload)
+ if err != nil {
+ k.log.Debugf("Failed to decode payload. Dropping Pubsub request: %v (%v)", pkt.ID, err)
+ pubsubRequestsDropped.Inc()
+ return
+ }
+ subscriptionID := common.GenerateSubscriptionID()
+ epoch, _, _ := epochtime.Now()
+ surbBundle := &SURBBundle{
+ Epoch: epoch,
+ SURBs: surbs,
+ }
+ k.subscriptions.Store(subscriptionID, surbBundle)
+ subscription := &common.Subscribe{
+ PacketID: pkt.ID,
+ SURBCount: uint8(len(surbs)),
+ SubscriptionID: subscriptionID,
+ SpoolID: clientSubscribe.SpoolID,
+ LastSpoolIndex: clientSubscribe.LastSpoolIndex,
+ }
+ pluginClient.Subscribe(subscription)
+ if err != nil {
+ k.log.Debugf("Failed to handle Pubsub request: %v (%v)", pkt.ID, err)
+ return
+ }
+ return
+}
+
+// PubsubForPKI returns the plugins Parameters map for publication in the PKI doc.
+func (k *PluginWorker) PubsubForPKI() ServiceMap {
+ return k.forPKI
+}
+
+// HasRecipient returns true if the given recipient is one of our workers.
+func (k *PluginWorker) HasRecipient(recipient [sConstants.RecipientIDLength]byte) bool {
+ _, ok := k.pluginChans[recipient]
+ return ok
+}
+
+func (k *PluginWorker) launch(command string, args []string) (*client.Dialer, error) {
+ k.log.Debugf("Launching plugin: %s", command)
+ plugin := client.New(command, k.glue.LogBackend())
+ err := plugin.Launch(command, args)
+ return plugin, err
+}
+
+// New returns a new PluginWorker
+func New(glue glue.Glue) (*PluginWorker, error) {
+
+ pluginWorker := PluginWorker{
+ glue: glue,
+ log: glue.LogBackend().GetLogger("pubsub plugin worker"),
+ pluginChans: make(PluginChans),
+ clients: make([]*client.Dialer, 0),
+ forPKI: make(ServiceMap),
+ subscriptions: new(sync.Map),
+ }
+
+ pluginWorker.Go(pluginWorker.garbageCollectionWorker)
+
+ capaMap := make(map[string]bool)
+
+ for i, pluginConf := range glue.Config().Provider.PubsubPlugin {
+ pluginWorker.log.Noticef("Configuring plugin handler for %s", pluginConf.Capability)
+
+ // Ensure no duplicates.
+ capa := pluginConf.Capability
+ if capa == "" {
+ return nil, errors.New("pubsub plugin capability cannot be empty string")
+ }
+ if pluginConf.Disable {
+ pluginWorker.log.Noticef("Skipping disabled Pubsub: '%v'.", capa)
+ continue
+ }
+ if capaMap[capa] {
+ return nil, fmt.Errorf("provider: Pubsub '%v' registered more than once", capa)
+ }
+
+ // Sanitize the endpoint.
+ endpoint, err := ensureEndpointSanitized(pluginConf.Endpoint, capa)
+ if err != nil {
+ return nil, err
+ }
+
+ // Add an infinite channel for this plugin.
+
+ pluginWorker.pluginChans[*endpoint] = channels.NewInfiniteChannel()
+
+ // Add entry from this plugin for the PKI.
+ params := make(map[string]interface{})
+ gotParams := false
+
+ pluginWorker.log.Noticef("Starting Pubsub plugin client: %s %d", capa, i)
+
+ var args []string
+ if len(pluginConf.Config) > 0 {
+ args = []string{}
+ for key, val := range pluginConf.Config {
+ args = append(args, fmt.Sprintf("-%s", key), val.(string))
+ }
+ }
+
+ pluginClient, err := pluginWorker.launch(pluginConf.Command, args)
+ if err != nil {
+ pluginWorker.log.Error("Failed to start a plugin client: %s", err)
+ return nil, err
+ }
+
+ for i := 0; i < pluginConf.MaxConcurrency; i++ {
+ err := pluginClient.Dial()
+ if err != nil {
+ pluginWorker.log.Error("Failed to dial a plugin client: %s", err)
+ return nil, err
+ }
+
+ pluginWorker.Go(func() {
+ pluginWorker.appMessagesWorker(pluginClient)
+ })
+
+ if !gotParams {
+ // just once we call the Parameters method on the plugin
+ // and use that info to populate our forPKI map which
+ // ends up populating the PKI document
+ p := pluginClient.Parameters()
+ if p != nil {
+ for key, value := range *p {
+ params[key] = value
+ }
+ }
+ params[ParameterEndpoint] = pluginConf.Endpoint
+ gotParams = true
+ }
+
+ // Accumulate a list of all clients to facilitate clean shutdown.
+ pluginWorker.clients = append(pluginWorker.clients, pluginClient)
+
+ // Start the subscriptionWorker _after_ we have added all of the entries to pluginChans
+ // otherwise the subscriptionWorker() goroutines race this thread.
+ defer pluginWorker.Go(func() {
+ pluginWorker.subscriptionWorker(*endpoint, pluginClient)
+ })
+ }
+
+ pluginWorker.forPKI[capa] = params
+ capaMap[capa] = true
+ }
+
+ return &pluginWorker, nil
+}
diff --git a/internal/provider/pubsub/plugins_test.go b/internal/provider/pubsub/plugins_test.go
new file mode 100644
index 0000000..743e338
--- /dev/null
+++ b/internal/provider/pubsub/plugins_test.go
@@ -0,0 +1,245 @@
+// plugins_test.go - tests for plugin system
+// Copyright (C) 2020 David Stainton
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package pubsub
+
+import (
+ "testing"
+
+ "github.com/katzenpost/core/crypto/ecdh"
+ "github.com/katzenpost/core/crypto/eddsa"
+ "github.com/katzenpost/core/crypto/rand"
+ "github.com/katzenpost/core/log"
+ "github.com/katzenpost/core/sphinx/constants"
+ "github.com/katzenpost/core/thwack"
+ "github.com/katzenpost/core/wire"
+ "github.com/katzenpost/server/config"
+ "github.com/katzenpost/server/internal/glue"
+ "github.com/katzenpost/server/internal/packet"
+ "github.com/katzenpost/server/internal/pkicache"
+ "github.com/katzenpost/server/spool"
+ "github.com/katzenpost/server/userdb"
+ "github.com/stretchr/testify/require"
+)
+
+type mockUserDB struct {
+ provider *mockProvider
+}
+
+func (u *mockUserDB) Exists([]byte) bool {
+ return true
+}
+
+func (u *mockUserDB) IsValid([]byte, *ecdh.PublicKey) bool { return true }
+
+func (u *mockUserDB) Add([]byte, *ecdh.PublicKey, bool) error { return nil }
+
+func (u *mockUserDB) SetIdentity([]byte, *ecdh.PublicKey) error { return nil }
+
+func (u *mockUserDB) Link([]byte) (*ecdh.PublicKey, error) {
+ return nil, nil
+}
+
+func (u *mockUserDB) Identity([]byte) (*ecdh.PublicKey, error) {
+ return u.provider.userKey, nil
+}
+
+func (u *mockUserDB) Remove([]byte) error { return nil }
+
+func (u *mockUserDB) Close() {}
+
+type mockSpool struct{}
+
+func (s *mockSpool) StoreMessage(u, msg []byte) error { return nil }
+
+func (s *mockSpool) StoreSURBReply(u []byte, id *[constants.SURBIDLength]byte, msg []byte) error {
+ return nil
+}
+
+func (s *mockSpool) Get(u []byte, advance bool) (msg, surbID []byte, remaining int, err error) {
+ return []byte{1, 2, 3}, nil, 1, nil
+}
+
+func (s *mockSpool) Remove(u []byte) error { return nil }
+
+func (s *mockSpool) Vacuum(udb userdb.UserDB) error { return nil }
+
+func (s *mockSpool) Close() {}
+
+type mockProvider struct {
+ userName string
+ userKey *ecdh.PublicKey
+}
+
+func (p *mockProvider) Halt() {}
+
+func (p *mockProvider) UserDB() userdb.UserDB {
+ return &mockUserDB{
+ provider: p,
+ }
+}
+
+func (p *mockProvider) Spool() spool.Spool {
+ return &mockSpool{}
+}
+
+func (p *mockProvider) AuthenticateClient(*wire.PeerCredentials) bool {
+ return true
+}
+
+func (p *mockProvider) OnPacket(*packet.Packet) {}
+
+func (p *mockProvider) KaetzchenForPKI() (map[string]map[string]interface{}, error) {
+ return nil, nil
+}
+
+func (p *mockProvider) AdvertiseRegistrationHTTPAddresses() []string {
+ return nil
+}
+
+type mockDecoy struct{}
+
+func (d *mockDecoy) Halt() {}
+
+func (d *mockDecoy) OnNewDocument(*pkicache.Entry) {}
+
+func (d *mockDecoy) OnPacket(*packet.Packet) {}
+
+type mockServer struct {
+ cfg *config.Config
+ logBackend *log.Backend
+ identityKey *eddsa.PrivateKey
+ linkKey *ecdh.PrivateKey
+ management *thwack.Server
+ mixKeys glue.MixKeys
+ pki glue.PKI
+ provider glue.Provider
+ scheduler glue.Scheduler
+ connector glue.Connector
+ listeners []glue.Listener
+}
+
+type mockGlue struct {
+ s *mockServer
+}
+
+func (g *mockGlue) Config() *config.Config {
+ return g.s.cfg
+}
+
+func (g *mockGlue) LogBackend() *log.Backend {
+ return g.s.logBackend
+}
+
+func (g *mockGlue) IdentityKey() *eddsa.PrivateKey {
+ return g.s.identityKey
+}
+
+func (g *mockGlue) LinkKey() *ecdh.PrivateKey {
+ return g.s.linkKey
+}
+
+func (g *mockGlue) Management() *thwack.Server {
+ return g.s.management
+}
+
+func (g *mockGlue) MixKeys() glue.MixKeys {
+ return g.s.mixKeys
+}
+
+func (g *mockGlue) PKI() glue.PKI {
+ return g.s.pki
+}
+
+func (g *mockGlue) Provider() glue.Provider {
+ return g.s.provider
+}
+
+func (g *mockGlue) Scheduler() glue.Scheduler {
+ return g.s.scheduler
+}
+
+func (g *mockGlue) Connector() glue.Connector {
+ return g.s.connector
+}
+
+func (g *mockGlue) Listeners() []glue.Listener {
+ return g.s.listeners
+}
+
+func (g *mockGlue) ReshadowCryptoWorkers() {}
+
+func (g *mockGlue) Decoy() glue.Decoy {
+ return &mockDecoy{}
+}
+
+func getGlue(logBackend *log.Backend, provider *mockProvider, linkKey *ecdh.PrivateKey, idKey *eddsa.PrivateKey) *mockGlue {
+ goo := &mockGlue{
+ s: &mockServer{
+ logBackend: logBackend,
+ provider: provider,
+ linkKey: linkKey,
+ cfg: &config.Config{
+ Server: &config.Server{},
+ Logging: &config.Logging{},
+ Provider: &config.Provider{},
+ PKI: &config.PKI{},
+ Management: &config.Management{},
+ Debug: &config.Debug{
+ NumKaetzchenWorkers: 3,
+ IdentityKey: idKey,
+ KaetzchenDelay: 300,
+ },
+ },
+ },
+ }
+ return goo
+}
+
+func TestPubsubInvalidCommand(t *testing.T) {
+ require := require.New(t)
+
+ idKey, err := eddsa.NewKeypair(rand.Reader)
+ require.NoError(err)
+
+ logBackend, err := log.New("", "DEBUG", false)
+ require.NoError(err)
+
+ userKey, err := ecdh.NewKeypair(rand.Reader)
+ require.NoError(err)
+
+ linkKey, err := ecdh.NewKeypair(rand.Reader)
+ require.NoError(err)
+
+ mockProvider := &mockProvider{
+ userName: "alice",
+ userKey: userKey.PublicKey(),
+ }
+
+ goo := getGlue(logBackend, mockProvider, linkKey, idKey)
+ goo.s.cfg.Provider.PubsubPlugin = []*config.PubsubPlugin{
+ &config.PubsubPlugin{
+ Capability: "loop",
+ Endpoint: "loop",
+ Config: map[string]interface{}{},
+ Disable: false,
+ Command: "non-existent command",
+ MaxConcurrency: 1,
+ },
+ }
+ _, err = New(goo)
+ require.Error(err)
+}
diff --git a/pubsubplugin/client/dialer.go b/pubsubplugin/client/dialer.go
new file mode 100644
index 0000000..f5879c4
--- /dev/null
+++ b/pubsubplugin/client/dialer.go
@@ -0,0 +1,196 @@
+// dialer.go - client plugin dialer, tracks multiple connections to a plugin
+// Copyright (C) 2020 David Stainton.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+// Package pubsubplugin client is the client module a plugin system allowing mix network services
+// to be added in any language. It implements a publish subscribe interface.
+//
+package client
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "net"
+ "os/exec"
+ "sync"
+ "syscall"
+
+ "github.com/katzenpost/core/log"
+ "github.com/katzenpost/server/pubsubplugin/common"
+ "gopkg.in/op/go-logging.v1"
+)
+
+const unixSocketNetwork = "unix"
+
+// Dialer handles the launching and subsequent multiple dialings of a given plugin.
+type Dialer struct {
+ sync.RWMutex
+
+ logBackend *log.Backend
+ log *logging.Logger
+
+ // for sending subscribe/unsubscribe commands to the plugin
+ outgoingCh chan interface{}
+
+ // for receiving AppMessages from the plugin
+ incomingCh chan *common.AppMessages
+
+ conns []*outgoingConn
+ cmd *exec.Cmd
+ socketPath string
+ params *common.Parameters
+
+ haltOnce sync.Once
+}
+
+// New creates a new plugin client instance which represents the single execution
+// of the external plugin program.
+func New(command string, logBackend *log.Backend) *Dialer {
+ return &Dialer{
+ logBackend: logBackend,
+ log: logBackend.GetLogger(command),
+
+ outgoingCh: make(chan interface{}),
+ incomingCh: make(chan *common.AppMessages),
+ conns: make([]*outgoingConn, 0),
+ }
+}
+
+func (d *Dialer) IncomingCh() chan *common.AppMessages {
+ return d.incomingCh
+}
+
+// Halt halts all of the outgoing connections and halts
+// the execution of the plugin application.
+func (d *Dialer) Halt() {
+ d.haltOnce.Do(d.doHalt)
+}
+
+func (d *Dialer) doHalt() {
+ d.RLock()
+ defer d.RUnlock()
+ for _, outgoingConn := range d.conns {
+ outgoingConn.Halt()
+ }
+ d.cmd.Process.Signal(syscall.SIGTERM)
+ err := d.cmd.Wait()
+ if err != nil {
+ d.log.Errorf("Publish-subscript plugin worker, command halt exec error: %s\n", err)
+ }
+}
+
+func (d *Dialer) logPluginStderr(stderr io.ReadCloser) {
+ logWriter := d.logBackend.GetLogWriter(d.cmd.Path, "DEBUG")
+ _, err := io.Copy(logWriter, stderr)
+ if err != nil {
+ d.log.Errorf("Failed to proxy pubsubplugin stderr to DEBUG log: %s", err)
+ }
+ d.Halt()
+}
+
+// Launch executes the given command and args, reading the unix socket path
+// from STDOUT and saving it for later use when dialing the socket.
+func (d *Dialer) Launch(command string, args []string) error {
+ // exec plugin
+ d.cmd = exec.Command(command, args...)
+ stdout, err := d.cmd.StdoutPipe()
+ if err != nil {
+ d.log.Debugf("pipe failure: %s", err)
+ return err
+ }
+ stderr, err := d.cmd.StderrPipe()
+ if err != nil {
+ d.log.Debugf("pipe failure: %s", err)
+ return err
+ }
+ err = d.cmd.Start()
+ if err != nil {
+ d.log.Debugf("failed to exec: %s", err)
+ return err
+ }
+
+ // proxy stderr to our debug log
+ go d.logPluginStderr(stderr)
+
+ // read and decode plugin stdout
+ stdoutScanner := bufio.NewScanner(stdout)
+ stdoutScanner.Scan()
+ d.socketPath = stdoutScanner.Text()
+ d.log.Debugf("plugin socket path:'%s'\n", d.socketPath)
+ return nil
+}
+
+// Dial dials the unix socket that was recorded during Launch.
+func (d *Dialer) Dial() error {
+ conn, err := net.Dial(unixSocketNetwork, d.socketPath)
+ if err != nil {
+ d.log.Debugf("unix socket connect failure: %s", err)
+ return err
+ }
+ d.onNewConn(conn)
+ return nil
+}
+
+func (d *Dialer) ensureParameters(outgoingConn *outgoingConn) error {
+ if d.params == nil {
+ d.log.Debug("requesting plugin Parameters for Mix Descriptor publication")
+ var err error
+ d.params, err = outgoingConn.getParameters()
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (d *Dialer) onNewConn(conn net.Conn) {
+ connLog := d.logBackend.GetLogger(fmt.Sprintf("%s connection %d", d.cmd, len(d.conns)+1))
+ outConn := newOutgoingConn(d, conn, connLog)
+
+ err := d.ensureParameters(outConn)
+ if err != nil {
+ d.log.Error("failed to acquire plugin parameters, giving up on plugin connection")
+ return
+ }
+
+ d.Lock()
+ defer func() {
+ d.Unlock()
+ outConn.Go(outConn.worker)
+ }()
+ d.conns = append(d.conns, outConn)
+}
+
+// Parameters returns the Parameters whcih are used in Mix Descriptor
+// publication to give service clients more information about the service.
+func (d *Dialer) Parameters() *common.Parameters {
+ return d.params
+}
+
+// Unsubscribe sends a subscription request to the plugin over
+// the Unix domain socket protocol.
+func (d *Dialer) Unsubscribe(subscriptionID [common.SubscriptionIDLength]byte) {
+ u := &common.Unsubscribe{
+ SubscriptionID: subscriptionID,
+ }
+ d.outgoingCh <- u
+}
+
+// Subscribe sends a subscription request to the plugin over
+// the Unix domain socket protocol.
+func (d *Dialer) Subscribe(subscribe *common.Subscribe) {
+ d.outgoingCh <- subscribe
+}
diff --git a/pubsubplugin/client/outgoing_conn.go b/pubsubplugin/client/outgoing_conn.go
new file mode 100644
index 0000000..419ccaa
--- /dev/null
+++ b/pubsubplugin/client/outgoing_conn.go
@@ -0,0 +1,158 @@
+// outgoing_conn.go - out going client plugin connection
+// Copyright (C) 2020 David Stainton.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+// Package pubsubplugin client is the client module a plugin system allowing mix network services
+// to be added in any language. It implements a publish subscribe interface.
+//
+package client
+
+import (
+ "errors"
+ "io"
+ "net"
+
+ "github.com/katzenpost/core/worker"
+ "github.com/katzenpost/server/pubsubplugin/common"
+ "gopkg.in/op/go-logging.v1"
+)
+
+type outgoingConn struct {
+ worker.Worker
+
+ log *logging.Logger
+
+ dialer *Dialer
+ conn net.Conn
+}
+
+func newOutgoingConn(dialer *Dialer, conn net.Conn, log *logging.Logger) *outgoingConn {
+ return &outgoingConn{
+ dialer: dialer,
+ conn: conn,
+ log: log,
+ }
+}
+
+func (c *outgoingConn) readAppMessages() (*common.AppMessages, error) {
+ lenPrefixBuf := make([]byte, 2)
+ _, err := io.ReadFull(c.conn, lenPrefixBuf)
+ if err != nil {
+ return nil, err
+ }
+ lenPrefix := common.PrefixLengthDecode(lenPrefixBuf)
+ responseBuf := make([]byte, lenPrefix)
+ _, err = io.ReadFull(c.conn, responseBuf)
+ if err != nil {
+ return nil, err
+ }
+ egressCmd, err := common.EgressUnixSocketCommandFromBytes(responseBuf)
+ if err != nil {
+ return nil, err
+ }
+ if egressCmd.AppMessages == nil {
+ return nil, errors.New("expected EgressUnixSocketCommand AppMessages to not be nil")
+ }
+ return egressCmd.AppMessages, nil
+}
+
+func (c *outgoingConn) unsubscribe(u *common.Unsubscribe) {
+ serializedUnsubscribe, err := u.ToBytes()
+ if err != nil {
+ c.log.Errorf("unsubscribe error: %s", err)
+ }
+ serializedUnsubscribe = common.PrefixLengthEncode(serializedUnsubscribe)
+ _, err = c.conn.Write(serializedUnsubscribe)
+ if err != nil {
+ c.log.Errorf("unsubscribe error: %s", err)
+ }
+}
+
+func (c *outgoingConn) subscribe(subscribe *common.Subscribe) error {
+ serializedSubscribe, err := subscribe.ToBytes()
+ if err != nil {
+ c.log.Errorf("subscribe error: %s", err)
+ }
+ serializedSubscribe = common.PrefixLengthEncode(serializedSubscribe)
+ _, err = c.conn.Write(serializedSubscribe)
+ if err != nil {
+ c.log.Errorf("subscribe error: %s", err)
+ }
+ return err
+}
+
+func (c *outgoingConn) worker() {
+ defer func() {
+ // XXX TODO: stuff to shutdown
+ }()
+ for {
+ newMessages, err := c.readAppMessages()
+ if err != nil {
+ c.log.Errorf("failure to read new messages from plugin: %s", err)
+ return
+ }
+ select {
+ case <-c.HaltCh():
+ return
+ case c.dialer.incomingCh <- newMessages:
+ case rawCmd := <-c.dialer.outgoingCh:
+ switch cmd := rawCmd.(type) {
+ case *common.Unsubscribe:
+ c.unsubscribe(cmd)
+ case *common.Subscribe:
+ c.subscribe(cmd)
+ default:
+ c.log.Errorf("outgoingConn received invalid command type %T from Dialer", rawCmd)
+ }
+ }
+ }
+}
+
+func (c *outgoingConn) getParameters() (*common.Parameters, error) {
+ ingressCmd := &common.IngressUnixSocketCommand{
+ GetParameters: &common.GetParameters{},
+ }
+ rawGetParams, err := ingressCmd.ToBytes()
+ if err != nil {
+ return nil, err
+ }
+ rawGetParams = common.PrefixLengthEncode(rawGetParams)
+ _, err = c.conn.Write(rawGetParams)
+ if err != nil {
+ return nil, err
+ }
+
+ // read response
+ lenPrefixBuf := make([]byte, 2)
+ _, err = io.ReadFull(c.conn, lenPrefixBuf)
+ if err != nil {
+ return nil, err
+ }
+ lenPrefix := common.PrefixLengthDecode(lenPrefixBuf)
+ responseBuf := make([]byte, lenPrefix)
+ _, err = io.ReadFull(c.conn, responseBuf)
+ if err != nil {
+ return nil, err
+ }
+
+ egressCmd, err := common.EgressUnixSocketCommandFromBytes(responseBuf)
+ if err != nil {
+ return nil, err
+ }
+ if egressCmd.Parameters == nil {
+ return nil, errors.New("expected EgressUnixSocketCommand Parameters to not be nil")
+ }
+ return egressCmd.Parameters, nil
+}
diff --git a/pubsubplugin/common/common.go b/pubsubplugin/common/common.go
new file mode 100644
index 0000000..f5acb58
--- /dev/null
+++ b/pubsubplugin/common/common.go
@@ -0,0 +1,300 @@
+// common.go - common types shared between the pubsub client and server.
+// Copyright (C) 2020 David Stainton.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+// Package common include the common types used in the publish subscribe
+// client and server modules.
+//
+package common
+
+import (
+ "encoding/binary"
+ "errors"
+
+ "github.com/fxamacker/cbor/v2"
+ "github.com/katzenpost/core/crypto/rand"
+)
+
+const (
+ // SubscriptionIDLength is the length of the Subscription ID.
+ SubscriptionIDLength = 8
+
+ // SpoolIDLength is the length of the spool identity.
+ SpoolIDLength = 8
+)
+
+// PrefixLengthEncode encodes the given byte slice with
+// two byte big endian length prefix encoding.
+func PrefixLengthEncode(b []byte) []byte {
+ lenPrefix := make([]byte, 2)
+ binary.BigEndian.PutUint16(lenPrefix, uint16(len(b)))
+ b = append(lenPrefix, b...)
+ return b
+}
+
+// PrefixLengthDecode decodes the first two bytes of the
+// given byte slice as a uint16, big endian encoded.
+func PrefixLengthDecode(b []byte) uint16 {
+ return binary.BigEndian.Uint16(b[:2])
+}
+
+// SubscriptionID is a subscription identity.
+type SubscriptionID [SubscriptionIDLength]byte
+
+// SpoolID is a spool identity.
+type SpoolID [SpoolIDLength]byte
+
+// Unsubscribe is used by the mix server to communicate an unsubscribe to
+// the plugin.
+type Unsubscribe struct {
+ // SubscriptionID is the server generated subscription identity.
+ SubscriptionID SubscriptionID
+}
+
+// ToBytes returns a CBOR serialized Unsubscribe.
+func (u *Unsubscribe) ToBytes() ([]byte, error) {
+ serializedunsubscribe, err := cbor.Marshal(u)
+ if err != nil {
+ return nil, err
+ }
+ return serializedunsubscribe, nil
+}
+
+// Subscribe is the struct type used to establish a subscription with
+// the plugin.
+type Subscribe struct {
+ // PacketID is the packet identity.
+ PacketID uint64
+
+ // SURBCount is the number of SURBs available to this subscription.
+ SURBCount uint8
+
+ // SubscriptionID is the server generated subscription identity.
+ SubscriptionID SubscriptionID
+
+ // SpoolID is the spool identity.
+ SpoolID SpoolID
+
+ // LastSpoolIndex is the last spool index which was received by the client.
+ LastSpoolIndex uint64
+}
+
+// SubscribeToBytes encodes the given Subscribe as a CBOR byte blob.
+func (s *Subscribe) ToBytes() ([]byte, error) {
+ serializedSubscribe, err := cbor.Marshal(s)
+ if err != nil {
+ return nil, err
+ }
+ return serializedSubscribe, nil
+}
+
+// SubscribeFromBytes returns a Subscribe given a CBOR serialized Subscribe.
+func SubscribeFromBytes(b []byte) (*Subscribe, error) {
+ subscribe := Subscribe{}
+ err := cbor.Unmarshal(b, &subscribe)
+ if err != nil {
+ return nil, err
+ }
+ return &subscribe, nil
+}
+
+// GenerateSubscriptionID returns a random subscription ID.
+func GenerateSubscriptionID() [SubscriptionIDLength]byte {
+ id := [SubscriptionIDLength]byte{}
+ rand.Reader.Read(id[:])
+ return id
+}
+
+// ClientSubscribe is used by the mixnet client to send a subscription
+// request to the publish-subscribe application plugin.
+type ClientSubscribe struct {
+ // SpoolID is the spool identity.
+ SpoolID [SpoolIDLength]byte
+
+ // LastSpoolIndex is the last spool index which was received by the client.
+ LastSpoolIndex uint64
+
+ // Payload is the application specific payload which the client sends to
+ // the plugin.
+ Payload []byte
+}
+
+// ClientSubscribeFromBytes decodes a ClientSubscribe from the
+// given CBOR byte blob.
+func ClientSubscribeFromBytes(b []byte) (*ClientSubscribe, error) {
+ clientSubscribe := ClientSubscribe{}
+ err := cbor.Unmarshal(b, &clientSubscribe)
+ if err != nil {
+ return nil, err
+ }
+ return &clientSubscribe, nil
+}
+
+// AppMessages is the struct type used by the application plugin to
+// send new messages to the server and eventually the subscribing client.
+type AppMessages struct {
+ // SubscriptionID is the server generated subscription identity.
+ SubscriptionID SubscriptionID
+
+ // Messages should contain one or more spool messages.
+ Messages []SpoolMessage
+}
+
+// ToBytes serializes AppMessages into a CBOR byte blob
+// or returns an error.
+func (m *AppMessages) ToBytes() ([]byte, error) {
+ serialized, err := cbor.Marshal(m)
+ if err != nil {
+ return nil, err
+ }
+ return serialized, nil
+}
+
+// SpoolMessage is a spool message from the application plugin.
+type SpoolMessage struct {
+ // Index is the index value from whence the message came from.
+ Index uint64
+
+ // Payload contains the actual spool message contents which are
+ // application specific.
+ Payload []byte
+}
+
+// MessagesToBytes returns a CBOR byte blob given a slice of type SpoolMessage.
+func MessagesToBytes(messages []SpoolMessage) ([]byte, error) {
+ serialized, err := cbor.Marshal(messages)
+ if err != nil {
+ return nil, err
+ }
+ return serialized, nil
+}
+
+// Parameters is an optional mapping that plugins can publish, these get
+// advertised to clients in the MixDescriptor.
+// The output of GetParameters() ends up being published in a map
+// associating with the service names to service parameters map.
+// This information is part of the Mix Descriptor which is defined here:
+// https://github.com/katzenpost/core/blob/master/pki/pki.go
+type Parameters map[string]string
+
+// GetParameters is used for querying the plugin over the unix socket
+// to get the dynamic parameters after the plugin is started.
+type GetParameters struct{}
+
+// IngressUnixSocketCommand wraps ingress unix socket wire protocol commands,
+// that is commands used by the mix server, aka Provider to communicate with
+// the application plugin.
+type IngressUnixSocketCommand struct {
+ // GetParameters is used to retrieve the plugin parameters which
+ // can be dynamically selected by the plugin on startup.
+ GetParameters *GetParameters
+
+ // Subscribe is used to establish a new SURB based subscription.
+ Subscribe *Subscribe
+
+ // Unsubscribe is used to tear down an existing subscription.
+ Unsubscribe *Unsubscribe
+}
+
+func (i *IngressUnixSocketCommand) validate() error {
+ notNilCount := 0
+ if i.GetParameters != nil {
+ notNilCount += 1
+ }
+ if i.Subscribe != nil {
+ notNilCount += 1
+ }
+ if i.Unsubscribe != nil {
+ notNilCount += 1
+ }
+ if notNilCount > 1 {
+ return errors.New("expected only one field to not be nil")
+ }
+ return nil
+}
+
+// ToBytes serializes IngressUnixSocketCommand into a CBOR byte blob
+// or returns an error.
+func (i *IngressUnixSocketCommand) ToBytes() ([]byte, error) {
+ err := i.validate()
+ if err != nil {
+ return nil, err
+ }
+ serialized, err := cbor.Marshal(i)
+ if err != nil {
+ return nil, err
+ }
+ return serialized, nil
+}
+
+func IngressUnixSocketCommandFromBytes(b []byte) (*IngressUnixSocketCommand, error) {
+ ingressCmds := &IngressUnixSocketCommand{}
+ err := cbor.Unmarshal(b, ingressCmds)
+ if err != nil {
+ return nil, err
+ }
+ err = ingressCmds.validate()
+ if err != nil {
+ return nil, err
+ }
+ return ingressCmds, nil
+}
+
+// EgressUnixSocketCommand wraps egress unix socket wire protocol commands,
+// that is commands used by the application plugin to communicate with the
+// mix server, aka Provider.
+type EgressUnixSocketCommand struct {
+ // Parameters is the plugin selected parameters which can be dynamically
+ // select at startup.
+ Parameters *Parameters
+
+ // AppMessages contain the application messages from the plugin.
+ AppMessages *AppMessages
+}
+
+func (e *EgressUnixSocketCommand) validate() error {
+ if e.Parameters != nil && e.AppMessages != nil {
+ return errors.New("expected only one field to not be nil")
+ }
+ return nil
+}
+
+// ToBytes serializes EgressUnixSocketCommand into a CBOR byte blob
+// or returns an error.
+func (e *EgressUnixSocketCommand) ToBytes() ([]byte, error) {
+ err := e.validate()
+ if err != nil {
+ return nil, err
+ }
+ serialized, err := cbor.Marshal(e)
+ if err != nil {
+ return nil, err
+ }
+ return serialized, nil
+}
+
+// EgressUnixSocketCommandFromBytes decodes a blob into a EgressUnixSocketCommand.
+func EgressUnixSocketCommandFromBytes(b []byte) (*EgressUnixSocketCommand, error) {
+ egressCmds := &EgressUnixSocketCommand{}
+ err := cbor.Unmarshal(b, egressCmds)
+ if err != nil {
+ return nil, err
+ }
+ err = egressCmds.validate()
+ if err != nil {
+ return nil, err
+ }
+ return egressCmds, nil
+}
diff --git a/pubsubplugin/server/server.go b/pubsubplugin/server/server.go
new file mode 100644
index 0000000..c045c01
--- /dev/null
+++ b/pubsubplugin/server/server.go
@@ -0,0 +1,278 @@
+// server.go - Publish subscribe server module for writing plugins.
+// Copyright (C) 2020 David Stainton.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package server
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "path/filepath"
+
+ "github.com/katzenpost/core/log"
+ "github.com/katzenpost/core/worker"
+ "github.com/katzenpost/server/pubsubplugin/common"
+ "gopkg.in/op/go-logging.v1"
+)
+
+// Spool is an interface for spool implementations which
+// will handle publishing new spool content to spool subscribers.
+type Spool interface {
+ // Subscribe creates a new spool subscription.
+ Subscribe(subscriptionID *common.SubscriptionID, spoolID *common.SpoolID, lastSpoolIndex uint64) error
+
+ // Unsubscribe removes an existing spool subscription.
+ Unsubscribe(subscriptionID *common.SubscriptionID) error
+}
+
+// Config is used to configure a new Server instance.
+type Config struct {
+ // Name is the name of the application service.
+ Name string
+
+ // Parameters are the application service specified parameters which are advertized in the
+ // Katzenpost PKI document.
+ Parameters *common.Parameters
+
+ // LogDir is the logging directory.
+ LogDir string
+
+ // LogLevel is the log level and is set to one of: ERROR, WARNING, NOTICE, INFO, DEBUG, CRITICAL.
+ LogLevel string
+
+ // Spool is the implementation of our Spool interface.
+ Spool Spool
+
+ // AppMessagesCh is the application messages channel which is used by the application to send
+ // messages to the mix server for transit over the mix network to the destination client.
+ AppMessagesCh chan *common.AppMessages
+}
+
+func validateConfig(config *Config) error {
+ if config == nil {
+ return errors.New("config must not be nil")
+ }
+ if config.Name == "" {
+ return errors.New("config.Name must not be empty")
+ }
+ if config.Parameters == nil {
+ return errors.New("config.Parameters must not be nil")
+ }
+ if config.LogDir == "" {
+ return errors.New("config.LogDir must not be empty")
+ }
+ if config.LogLevel == "" {
+ return errors.New("config.LogLevel must not be empty")
+ }
+ if config.Spool == nil {
+ return errors.New("config.Spool must not be nil")
+ }
+ if config.AppMessagesCh == nil {
+ return errors.New("config.AppMessagesCh must not be nil")
+ }
+ return nil
+}
+
+// Server is used by applications to implement the application plugin which listens
+// for connections from the mix server over a unix domain socket. Server handles the
+// management of this unix domain socket as well as the wire protocol used.
+type Server struct {
+ worker.Worker
+
+ logBackend *log.Backend
+ log *logging.Logger
+
+ listener net.Listener
+ socketFile string
+
+ params *common.Parameters
+ appMessagesCh chan *common.AppMessages
+ spool Spool
+}
+
+func (s *Server) sendParameters(conn net.Conn) error {
+ e := &common.EgressUnixSocketCommand{
+ Parameters: s.params,
+ }
+ paramsBlob, err := e.ToBytes()
+ if err != nil {
+ return err
+ }
+ paramsBlob = common.PrefixLengthEncode(paramsBlob)
+ _, err = conn.Write(paramsBlob)
+ return err
+}
+
+func (s *Server) readIngressCommands(conn net.Conn) (*common.IngressUnixSocketCommand, error) {
+ lenPrefixBuf := make([]byte, 2)
+ _, err := io.ReadFull(conn, lenPrefixBuf)
+ if err != nil {
+ return nil, err
+ }
+ lenPrefix := common.PrefixLengthDecode(lenPrefixBuf)
+ cmdBuf := make([]byte, lenPrefix)
+ _, err = io.ReadFull(conn, cmdBuf)
+ if err != nil {
+ return nil, err
+ }
+ ingressCmd, err := common.IngressUnixSocketCommandFromBytes(cmdBuf)
+ return ingressCmd, err
+}
+
+func (s *Server) perpetualCommandReader(conn net.Conn) <-chan *common.IngressUnixSocketCommand {
+ readCh := make(chan *common.IngressUnixSocketCommand)
+
+ s.Go(func() {
+ for {
+ cmd, err := s.readIngressCommands(conn)
+ if err != nil {
+ s.log.Errorf("failure to read new messages from plugin: %s", err)
+ return
+ }
+ select {
+ case <-s.HaltCh():
+ return
+ case readCh <- cmd:
+ }
+ }
+ })
+
+ return readCh
+}
+
+func (s *Server) connectionWorker(conn net.Conn) {
+ readCmdCh := s.perpetualCommandReader(conn)
+
+ for {
+ select {
+ case <-s.HaltCh():
+ s.log.Debugf("Worker terminating gracefully.")
+ return
+ case cmd := <-readCmdCh:
+ if cmd.GetParameters != nil {
+ s.sendParameters(conn)
+ continue
+ }
+ if cmd.Subscribe != nil {
+ s.spool.Subscribe(&cmd.Subscribe.SubscriptionID, &cmd.Subscribe.SpoolID, cmd.Subscribe.LastSpoolIndex)
+ continue
+ }
+ if cmd.Unsubscribe != nil {
+ s.spool.Unsubscribe(&cmd.Subscribe.SubscriptionID)
+ continue
+ }
+ case messages := <-s.appMessagesCh:
+ e := &common.EgressUnixSocketCommand{
+ AppMessages: messages,
+ }
+ messagesBlob, err := e.ToBytes()
+ if err != nil {
+ s.log.Errorf("failed to serialize app messages: %s", err)
+ continue
+ }
+ messagesBlob = common.PrefixLengthEncode(messagesBlob)
+ _, err = conn.Write(messagesBlob)
+ if err != nil {
+ s.log.Errorf("failed to write AppMessages to socket: %s", err)
+ continue
+ }
+ }
+ }
+}
+
+func (s *Server) worker() {
+ conn, err := s.listener.Accept()
+ if err != nil {
+ s.log.Errorf("error accepting connection: %s", err)
+ return
+ }
+ s.Go(func() {
+ s.connectionWorker(conn)
+ })
+}
+
+func (s *Server) setupListener(name string) error {
+ tmpDir, err := ioutil.TempDir("", name)
+ if err != nil {
+ return err
+ }
+ s.socketFile = filepath.Join(tmpDir, fmt.Sprintf("%s.socket", name))
+ s.listener, err = net.Listen("unix", s.socketFile)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *Server) initLogging(name, logFile, logLevel string) error {
+ var err error
+ logDisable := false
+ s.logBackend, err = log.New(logFile, logLevel, logDisable)
+ if err != nil {
+ return err
+ }
+ s.log = s.logBackend.GetLogger(name)
+ return nil
+}
+
+func (s *Server) ensureLogDir(logDir string) error {
+ stat, err := os.Stat(logDir)
+ if os.IsNotExist(err) {
+ return fmt.Errorf("Log directory '%s' doesn't exist.", logDir)
+ }
+ if !stat.IsDir() {
+ return fmt.Errorf("Log directory '%s' must be a directory.", logDir)
+ }
+ return nil
+}
+
+// New creates a new Server instance and starts immediately listening for new connections.
+func New(config *Config) (*Server, error) {
+ err := validateConfig(config)
+ if err != nil {
+ return nil, err
+ }
+ s := &Server{
+ params: config.Parameters,
+ appMessagesCh: config.AppMessagesCh,
+ spool: config.Spool,
+ }
+ err = s.ensureLogDir(config.LogDir)
+ if err != nil {
+ return nil, err
+ }
+ logFile := path.Join(config.LogDir, fmt.Sprintf("%s.%d.log", config.Name, os.Getpid()))
+ err = s.initLogging(config.Name, logFile, config.LogLevel)
+ if err != nil {
+ return nil, err
+ }
+ s.log.Debug("starting listener")
+ err = s.setupListener(config.Name)
+ if err != nil {
+ return nil, err
+ }
+ fmt.Printf("%s\n", s.socketFile)
+ s.Go(s.worker)
+ go func() {
+ <-s.HaltCh()
+ os.Remove(s.socketFile)
+ }()
+ return s, nil
+}