diff --git a/config/config.go b/config/config.go index e9ca311..c41e320 100644 --- a/config/config.go +++ b/config/config.go @@ -372,8 +372,12 @@ type Provider struct { Kaetzchen []*Kaetzchen // CBORPluginKaetzchen is the list of configured external CBOR Kaetzchen plugins - // for this provider. + // for this Provider. CBORPluginKaetzchen []*CBORPluginKaetzchen + + // PubsubPlugin is the list of configured external Publish Subscribe service plugins + // for this Provider. + PubsubPlugin []*PubsubPlugin } // SQLDB is the SQL database backend configuration. @@ -535,6 +539,56 @@ func (kCfg *CBORPluginKaetzchen) validate() error { return nil } +// PubsubPlugin is a Provider Publish Subscribe service agent. +type PubsubPlugin struct { + // Capability is the capability exposed by the agent. + Capability string + + // Endpoint is the provider side endpoint that the agent will accept + // requests at. While not required by the spec, this server only + // supports Endpoints that are lower-case local-parts of an e-mail + // address. + Endpoint string + + // Config is the extra per agent arguments to be passed to the agent's + // initialization routine. + Config map[string]interface{} + + // Command is the full file path to the external plugin program + // that implements this Kaetzchen service. + Command string + + // MaxConcurrency is the number of worker goroutines to start + // for this service. + MaxConcurrency int + + // Disable disabled a configured agent. + Disable bool +} + +func (kCfg *PubsubPlugin) validate() error { + if kCfg.Capability == "" { + return fmt.Errorf("config: Pubsub: Capability is invalid") + } + + // Ensure the endpoint is normalized. + epNorm, err := precis.UsernameCaseMapped.String(kCfg.Endpoint) + if err != nil { + return fmt.Errorf("config: Pubsub: '%v' has invalid endpoint: %v", kCfg.Capability, err) + } + if epNorm != kCfg.Endpoint { + return fmt.Errorf("config: Pubsub: '%v' has non-normalized endpoint %v", kCfg.Capability, kCfg.Endpoint) + } + if kCfg.Command == "" { + return fmt.Errorf("config: Pubsub: Command is invalid") + } + if _, err = mail.ParseAddress(kCfg.Endpoint + "@test.invalid"); err != nil { + return fmt.Errorf("config: Pubsub: '%v' has non local-part endpoint '%v': %v", kCfg.Capability, kCfg.Endpoint, err) + } + + return nil +} + func (pCfg *Provider) applyDefaults(sCfg *Server) { if pCfg.UserDB == nil { pCfg.UserDB = &UserDB{} @@ -656,6 +710,15 @@ func (pCfg *Provider) validate() error { } capaMap[v.Capability] = true } + for _, v := range pCfg.PubsubPlugin { + if err := v.validate(); err != nil { + return err + } + if capaMap[v.Capability] { + return fmt.Errorf("config: Kaetzchen: '%v' configured multiple times", v.Capability) + } + capaMap[v.Capability] = true + } return nil } diff --git a/internal/constants/constants.go b/internal/constants/constants.go index ac0594f..88166f0 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -34,6 +34,7 @@ const ( DecoySubsystem = "decoy" IncomingConnSubsystem = "incoming_conn" KaetzchenSubsystem = "kaetzchen" + PubsubPluginSubsystem = "pubsub_plugin" OutgoingConnSubsystem = "outgoing_conn" PKISubsystem = "pki" ProviderSubsystem = "provider" diff --git a/internal/glue/glue.go b/internal/glue/glue.go index dc47c45..e4c206c 100644 --- a/internal/glue/glue.go +++ b/internal/glue/glue.go @@ -66,6 +66,7 @@ type PKI interface { StartWorker() OutgoingDestinations() map[[constants.NodeIDLength]byte]*pki.MixDescriptor AuthenticateConnection(*wire.PeerCredentials, bool) (*pki.MixDescriptor, bool, bool) + GetCachedConsensusDoc(uint64) (*pki.Document, error) GetRawConsensus(uint64) ([]byte, error) } diff --git a/internal/packet/packet.go b/internal/packet/packet.go index 4a82a7d..679e2fd 100644 --- a/internal/packet/packet.go +++ b/internal/packet/packet.go @@ -19,11 +19,14 @@ package packet import ( "fmt" + mRand "math/rand" "sync" "sync/atomic" "time" "github.com/katzenpost/core/constants" + "github.com/katzenpost/core/crypto/rand" + "github.com/katzenpost/core/pki" "github.com/katzenpost/core/sphinx" "github.com/katzenpost/core/sphinx/commands" "github.com/katzenpost/core/utils" @@ -215,12 +218,10 @@ func newRedundantError(cmd commands.RoutingCommand) error { return fmt.Errorf("redundant command: %T", cmd) } -func ParseForwardPacket(pkt *Packet) ([]byte, []byte, error) { +func ParseForwardPacket(pkt *Packet) ([]byte, [][]byte, error) { const ( - hdrLength = constants.SphinxPlaintextHeaderLength + sphinx.SURBLength - flagsPadding = 0 - flagsSURB = 1 - reserved = 0 + hdrLength = constants.SphinxPlaintextHeaderLength + reserved = 0 ) // Sanity check the forward packet payload length. @@ -230,26 +231,21 @@ func ParseForwardPacket(pkt *Packet) ([]byte, []byte, error) { // Parse the payload, which should be a valid BlockSphinxPlaintext. b := pkt.Payload - if len(b) < hdrLength { - return nil, nil, fmt.Errorf("truncated message block") - } if b[1] != reserved { return nil, nil, fmt.Errorf("invalid message reserved: 0x%02x", b[1]) } - ct := b[hdrLength:] - var surb []byte - switch b[0] { - case flagsPadding: - case flagsSURB: - surb = b[constants.SphinxPlaintextHeaderLength:hdrLength] - default: - return nil, nil, fmt.Errorf("invalid message flags: 0x%02x", b[0]) + surbCount := int(b[0]) + if (surbCount * sphinx.SURBLength) >= (constants.ForwardPayloadLength - hdrLength) { + return nil, nil, fmt.Errorf("invalid message SURB count: %d", uint8(b[0])) } - if len(ct) != constants.UserForwardPayloadLength { - return nil, nil, fmt.Errorf("mis-sized user payload: %v", len(ct)) + surbs := make([][]byte, surbCount) + startOffset := 2 + for i := 0; i < surbCount; i++ { + surbs[i] = b[startOffset : startOffset+sphinx.SURBLength] + startOffset += sphinx.SURBLength } - - return ct, surb, nil + ct := b[hdrLength+(surbCount*sphinx.SURBLength):] + return ct, surbs, nil } func NewPacketFromSURB(pkt *Packet, surb, payload []byte) (*Packet, error) { @@ -303,3 +299,60 @@ func NewPacketFromSURB(pkt *Packet, surb, payload []byte) (*Packet, error) { return respPkt, nil } + +func NewProviderDelay(rng *mRand.Rand, doc *pki.Document) uint32 { + delay := uint64(rand.Exp(rng, doc.Mu)) + 1 + if doc.MuMaxDelay > 0 && delay > doc.MuMaxDelay { + delay = doc.MuMaxDelay + } + return uint32(delay) +} + +// NewDelayedPacketFromSURB creates a new Packet given a SURB, payload and, delay +// where the specified delay is for the first hop, the Provider. +func NewDelayedPacketFromSURB(delay uint32, surb, payload []byte) (*Packet, error) { + // Pad out payloads to the full packet size. + var respPayload [constants.ForwardPayloadLength]byte + switch { + case len(payload) == 0: + case len(payload) > constants.ForwardPayloadLength: + return nil, fmt.Errorf("oversized response payload: %v", len(payload)) + default: + copy(respPayload[:], payload) + } + + // Build a response packet using a SURB. + // + // TODO/perf: This is a crypto operation that is paralleizable, and + // could be handled by the crypto worker(s), since those are allocated + // based on hardware acceleration considerations. However the forward + // packet processing doesn't constantly utilize the AES-NI units due + // to the non-AEZ components of a Sphinx Unwrap operation. + rawRespPkt, firstHop, err := sphinx.NewPacketFromSURB(surb, respPayload[:]) + if err != nil { + return nil, err + } + + // Build the command vector for the SURB-ACK + cmds := make([]commands.RoutingCommand, 0, 2) + + nextHopCmd := new(commands.NextNodeHop) + copy(nextHopCmd.ID[:], firstHop[:]) + cmds = append(cmds, nextHopCmd) + + nodeDelayCmd := new(commands.NodeDelay) + nodeDelayCmd.Delay = delay + cmds = append(cmds, nodeDelayCmd) + + // Assemble the response packet. + respPkt, _ := New(rawRespPkt) + respPkt.Set(nil, cmds) + + respPkt.Delay = time.Duration(nodeDelayCmd.Delay) * time.Millisecond + respPkt.MustForward = true + + // XXX: This should probably fudge the delay to account for processing + // time. + + return respPkt, nil +} diff --git a/internal/packet/packet_test.go b/internal/packet/packet_test.go new file mode 100644 index 0000000..7f3f423 --- /dev/null +++ b/internal/packet/packet_test.go @@ -0,0 +1,104 @@ +// packet_test.go - Katzenpost server packet structure tests. +// Copyright (C) 2020 David Stainton. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package packet implements the Katzenpost server side packet structure. +package packet + +import ( + "testing" + + "github.com/katzenpost/core/constants" + "github.com/katzenpost/core/sphinx" + "github.com/stretchr/testify/require" +) + +func TestParseForwardPacket(t *testing.T) { + require := require.New(t) + + // test that wrong payload size is an error + wrongPayload := [constants.ForwardPayloadLength + 123]byte{} + pkt := &Packet{ + Payload: wrongPayload[:], + } + _, _, err := ParseForwardPacket(pkt) + require.Error(err) + + // test that the wrong reserved value is an error + payload := [constants.ForwardPayloadLength]byte{} + pkt = &Packet{ + Payload: payload[:], + } + pkt.Payload[1] = byte(1) + _, _, err = ParseForwardPacket(pkt) + require.Error(err) + + // test that an invalid SURB count is an error + payload = [constants.ForwardPayloadLength]byte{} + pkt = &Packet{ + Payload: payload[:], + } + pkt.Payload[0] = byte(255) + _, _, err = ParseForwardPacket(pkt) + require.Error(err) + + // test that an invalid SURB count is an error + payload = [constants.ForwardPayloadLength]byte{} + pkt = &Packet{ + Payload: payload[:], + } + pkt.Payload[0] = byte(93) + _, _, err = ParseForwardPacket(pkt) + require.Error(err) + + // test that the 1 SURB case is handled properly + payload = [constants.ForwardPayloadLength]byte{} + pkt = &Packet{ + Payload: payload[:], + } + pkt.Payload[0] = byte(1) + pkt.Payload[constants.SphinxPlaintextHeaderLength+sphinx.SURBLength] = 1 + ct, surbs, err := ParseForwardPacket(pkt) + require.NoError(err) + require.Equal(1, len(surbs)) + require.Equal(constants.UserForwardPayloadLength, len(ct)) + require.Equal(int(ct[0]), 1) + require.Equal(int(ct[1]), 0) + + // test that the 2 SURB case is handled properly + payload = [constants.ForwardPayloadLength]byte{} + pkt = &Packet{ + Payload: payload[:], + } + pkt.Payload[0] = byte(2) + pkt.Payload[constants.SphinxPlaintextHeaderLength+sphinx.SURBLength+sphinx.SURBLength] = 1 + ct, surbs, err = ParseForwardPacket(pkt) + require.NoError(err) + require.Equal(2, len(surbs)) + require.NotEqual(constants.UserForwardPayloadLength, len(ct)) + require.Equal(int(ct[0]), 1) + require.Equal(int(ct[1]), 0) + + // test that a large SURB count is OK + payload = [constants.ForwardPayloadLength]byte{} + pkt = &Packet{ + Payload: payload[:], + } + pkt.Payload[0] = byte(92) + ct, surbs, err = ParseForwardPacket(pkt) + require.NoError(err) + require.Equal(92, len(surbs)) + require.Equal((constants.ForwardPayloadLength-constants.SphinxPlaintextHeaderLength)-(92*sphinx.SURBLength), len(ct)) +} diff --git a/internal/pki/pki.go b/internal/pki/pki.go index 8cd9e22..67f0649 100644 --- a/internal/pki/pki.go +++ b/internal/pki/pki.go @@ -634,6 +634,17 @@ func (p *pki) OutgoingDestinations() map[[sConstants.NodeIDLength]byte]*cpki.Mix return descMap } +// GetCachedConsensusDoc returns a cache PKI doc for the given epoch. +func (p *pki) GetCachedConsensusDoc(epoch uint64) (*cpki.Document, error) { + p.RLock() + defer p.RUnlock() + entry, ok := p.docs[epoch] + if !ok { + return nil, errors.New("failed to retrieve cached pki doc") + } + return entry.Document(), nil +} + func (p *pki) GetRawConsensus(epoch uint64) ([]byte, error) { if ok, err := p.getFailedFetch(epoch); ok { p.log.Debugf("GetRawConsensus failure: no cached PKI document for epoch %v: %v", epoch, err) diff --git a/internal/provider/kaetzchen/cbor_plugins.go b/internal/provider/kaetzchen/cbor_plugins.go index 80cb8a4..fb26208 100644 --- a/internal/provider/kaetzchen/cbor_plugins.go +++ b/internal/provider/kaetzchen/cbor_plugins.go @@ -123,17 +123,25 @@ func (k *CBORPluginWorker) processKaetzchen(pkt *packet.Packet, pluginClient cbo defer kaetzchenRequestsTimer.ObserveDuration() defer pkt.Dispose() - ct, surb, err := packet.ParseForwardPacket(pkt) + ct, surbs, err := packet.ParseForwardPacket(pkt) if err != nil { k.log.Debugf("Dropping Kaetzchen request: %v (%v)", pkt.ID, err) kaetzchenRequestsDropped.Inc() return } - + if len(surbs) > 1 { + k.log.Debugf("Received multi-SURB payload, dropping Kaetzchen request: %v (%v)", pkt.ID, err) + kaetzchenRequestsDropped.Inc() + return + } + hasSURB := false + if len(surbs) == 1 { + hasSURB = true + } resp, err := pluginClient.OnRequest(&cborplugin.Request{ ID: pkt.ID, Payload: ct, - HasSURB: surb != nil, + HasSURB: hasSURB, }) switch err { case nil: @@ -151,10 +159,10 @@ func (k *CBORPluginWorker) processKaetzchen(pkt *packet.Packet, pluginClient cbo } // Iff there is a SURB, generate a SURB-Reply and schedule. - if surb != nil { + if hasSURB { // Prepend the response header. resp = append([]byte{0x01, 0x00}, resp...) - + surb := surbs[0] respPkt, err := packet.NewPacketFromSURB(pkt, surb, resp) if err != nil { k.log.Debugf("Failed to generate SURB-Reply: %v (%v)", pkt.ID, err) diff --git a/internal/provider/kaetzchen/kaetzchen.go b/internal/provider/kaetzchen/kaetzchen.go index 6da9208..12327e3 100644 --- a/internal/provider/kaetzchen/kaetzchen.go +++ b/internal/provider/kaetzchen/kaetzchen.go @@ -243,18 +243,27 @@ func (k *KaetzchenWorker) processKaetzchen(pkt *packet.Packet) { defer kaetzchenRequestsTimer.ObserveDuration() defer pkt.Dispose() - ct, surb, err := packet.ParseForwardPacket(pkt) + ct, surbs, err := packet.ParseForwardPacket(pkt) if err != nil { k.log.Debugf("Dropping Kaetzchen request: %v (%v)", pkt.ID, err) k.incrementDropCounter() kaetzchenRequestsDropped.Add(float64(k.getDropCounter())) return } - + if len(surbs) > 1 { + k.log.Debugf("Multi-SURB packet sent to Kaetzchen recipient, dropping Kaetzchen request: %v (%v)", pkt.ID, err) + k.incrementDropCounter() + kaetzchenRequestsDropped.Add(float64(k.getDropCounter())) + return + } var resp []byte dst, ok := k.kaetzchen[pkt.Recipient.ID] + hasSURB := false + if len(surbs) == 1 { + hasSURB = true + } if ok { - resp, err = dst.OnRequest(pkt.ID, ct, surb != nil) + resp, err = dst.OnRequest(pkt.ID, ct, hasSURB) } switch { case err == nil: @@ -269,7 +278,8 @@ func (k *KaetzchenWorker) processKaetzchen(pkt *packet.Packet) { } // Iff there is a SURB, generate a SURB-Reply and schedule. - if surb != nil { + if len(surbs) == 1 { + surb := surbs[0] // Prepend the response header. resp = append([]byte{0x01, 0x00}, resp...) diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 5f863d1..5ca87ea 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -43,6 +43,7 @@ import ( "github.com/katzenpost/server/internal/glue" "github.com/katzenpost/server/internal/packet" "github.com/katzenpost/server/internal/provider/kaetzchen" + "github.com/katzenpost/server/internal/provider/pubsub" "github.com/katzenpost/server/internal/sqldb" "github.com/katzenpost/server/registration" "github.com/katzenpost/server/spool" @@ -75,6 +76,7 @@ type provider struct { kaetzchenWorker *kaetzchen.KaetzchenWorker cborPluginKaetzchenWorker *kaetzchen.CBORPluginWorker + pubsubPluginWorker *pubsub.PluginWorker httpServers []*http.Server } @@ -101,6 +103,7 @@ func (p *provider) Halt() { p.ch.Close() p.kaetzchenWorker.Halt() p.cborPluginKaetzchenWorker.Halt() + p.pubsubPluginWorker.Halt() if p.userDB != nil { p.userDB.Close() p.userDB = nil @@ -254,13 +257,26 @@ func (p *provider) worker() { packetsDropped.Inc() pkt.Dispose() } else { - // Note that we pass ownership of pkt to p.kaetzchenWorker + // Note that we pass ownership of pkt to p.cborPluginKaetzchenWorker // which will take care to dispose of it. p.cborPluginKaetzchenWorker.OnKaetzchen(pkt) } continue } + if p.pubsubPluginWorker.HasRecipient(pkt.Recipient.ID) { + if pkt.IsSURBReply() { + p.log.Debugf("Dropping packet: %v (SURB-Reply for pubsub service)", pkt.ID) + packetsDropped.Inc() + pkt.Dispose() + } else { + // Note that we pass ownership of pkt to p.pubsubPluginWorker + // which will take care to dispose of it. + p.pubsubPluginWorker.OnSubscribeRequest(pkt) + } + continue + } + // Post-process the recipient. recipient, err := p.fixupRecipient(pkt.Recipient.ID[:]) if err != nil { @@ -306,21 +322,25 @@ func (p *provider) onSURBReply(pkt *packet.Packet, recipient []byte) { } func (p *provider) onToUser(pkt *packet.Packet, recipient []byte) { - ct, surb, err := packet.ParseForwardPacket(pkt) + ct, surbs, err := packet.ParseForwardPacket(pkt) if err != nil { p.log.Debugf("Dropping packet: %v (%v)", pkt.ID, err) packetsDropped.Inc() return } - // Store the ciphertext in the spool. if err := p.spool.StoreMessage(recipient, ct); err != nil { p.log.Debugf("Failed to store message payload: %v (%v)", pkt.ID, err) return } - + if len(surbs) > 1 { + p.log.Debugf("Multi-SURB packet sent to user recipient, dropping packet: %v (%v)", pkt.ID, err) + packetsDropped.Inc() + return + } // Iff there is a SURB, generate a SURB-ACK and schedule. - if surb != nil { + if len(surbs) == 1 { + surb := surbs[0] ackPkt, err := packet.NewPacketFromSURB(pkt, surb, nil) if err != nil { p.log.Debugf("Failed to generate SURB-ACK: %v (%v)", pkt.ID, err) @@ -746,12 +766,17 @@ func New(glue glue.Glue) (glue.Provider, error) { if err != nil { return nil, err } + pubsubPluginWorker, err := pubsub.New(glue) + if err != nil { + return nil, err + } p := &provider{ glue: glue, log: glue.LogBackend().GetLogger("provider"), ch: channels.NewInfiniteChannel(), kaetzchenWorker: kaetzchenWorker, cborPluginKaetzchenWorker: cborPluginWorker, + pubsubPluginWorker: pubsubPluginWorker, } cfg := glue.Config() diff --git a/internal/provider/pubsub/plugins.go b/internal/provider/pubsub/plugins.go new file mode 100644 index 0000000..f1cbf17 --- /dev/null +++ b/internal/provider/pubsub/plugins.go @@ -0,0 +1,467 @@ +// plugins.go - publish subscribe plugin system for mix network services +// Copyright (C) 2020 David Stainton. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package pubsub implements support for provider side SURB based publish subscribe agents. +package pubsub + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/katzenpost/core/crypto/rand" + "github.com/katzenpost/core/epochtime" + "github.com/katzenpost/core/monotime" + sConstants "github.com/katzenpost/core/sphinx/constants" + "github.com/katzenpost/core/worker" + "github.com/katzenpost/server/internal/constants" + "github.com/katzenpost/server/internal/glue" + "github.com/katzenpost/server/internal/packet" + "github.com/katzenpost/server/pubsubplugin/client" + "github.com/katzenpost/server/pubsubplugin/common" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/text/secure/precis" + "gopkg.in/eapache/channels.v1" + "gopkg.in/op/go-logging.v1" +) + +var ( + packetsDropped = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: constants.Namespace, + Name: "dropped_packets_total", + Subsystem: constants.PubsubPluginSubsystem, + Help: "Number of dropped packets", + }, + ) + pubsubRequests = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: constants.Namespace, + Name: "requests_total", + Subsystem: constants.PubsubPluginSubsystem, + Help: "Number of Pubsub requests", + }, + ) + pubsubRequestsDuration = prometheus.NewSummary( + prometheus.SummaryOpts{ + Namespace: constants.Namespace, + Name: "requests_duration_seconds", + Subsystem: constants.PubsubPluginSubsystem, + Help: "Duration of a pubsub request in seconds", + }, + ) + pubsubRequestsDropped = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: constants.Namespace, + Name: "dropped_requests_total", + Subsystem: constants.PubsubPluginSubsystem, + Help: "Number of total dropped pubsub requests", + }, + ) + pubsubRequestsFailed = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: constants.Namespace, + Name: "failed_requests_total", + Subsystem: constants.PubsubPluginSubsystem, + Help: "Number of total failed pubsub requests", + }, + ) + pubsubRequestsTimer *prometheus.Timer +) + +func init() { + prometheus.MustRegister(packetsDropped) + prometheus.MustRegister(pubsubRequests) + prometheus.MustRegister(pubsubRequestsDropped) + prometheus.MustRegister(pubsubRequestsFailed) + prometheus.MustRegister(pubsubRequestsDuration) +} + +const ( + // ParameterEndpoint is the mandatory Parameter key indicationg the + // Kaetzchen's endpoint. + ParameterEndpoint = "endpoint" +) + +// GarbageCollectionInterval is the time interval between running our +// subscription garbage collection routine. We shall attempt to garbage collect +// 5 times per epoch. +var GarbageCollectionInterval = epochtime.Period / 5 + +// PluginChans maps from Recipient ID to channel. +type PluginChans = map[[sConstants.RecipientIDLength]byte]*channels.InfiniteChannel + +// PluginName is the name of a plugin. +type PluginName = string + +// PluginParameters maps from parameter key to value. +type PluginParameters = map[PluginName]interface{} + +// ServiceMap maps from plugin name to plugin parameters +// and is used by Mix Descriptors which describe Providers +// with plugins. Each plugin can optionally set one or more +// parameters. +type ServiceMap = map[PluginName]PluginParameters + +// SURBBundle facilitates garbage collection of subscriptions +// by keeping track of the Epoch that the SURBs were received. +type SURBBundle struct { + // Epoch is the epoch whence the SURBs were received. + Epoch uint64 + + // SURBs is one or more SURBs. + SURBs [][]byte +} + +func ensureEndpointSanitized(endpointStr, capa string) (*[sConstants.RecipientIDLength]byte, error) { + if endpointStr == "" { + return nil, fmt.Errorf("provider: Pubsub: '%v' provided no endpoint", capa) + } else if epNorm, err := precis.UsernameCaseMapped.String(endpointStr); err != nil { + return nil, fmt.Errorf("provider: Pubsub: '%v' invalid endpoint: %v", capa, err) + } else if epNorm != endpointStr { + return nil, fmt.Errorf("provider: Pubsub: '%v' invalid endpoint, not normalized", capa) + } + rawEp := []byte(endpointStr) + if len(rawEp) == 0 || len(rawEp) > sConstants.RecipientIDLength { + return nil, fmt.Errorf("provider: Pubsub: '%v' invalid endpoint, length out of bounds", capa) + } + var endpoint [sConstants.RecipientIDLength]byte + copy(endpoint[:], rawEp) + return &endpoint, nil +} + +// PluginWorker implements the publish subscribe plugin worker. +type PluginWorker struct { + worker.Worker + + glue glue.Glue + log *logging.Logger + + haltOnce sync.Once + subscriptions *sync.Map // [SubscriptionIDLength]byte -> *SURBBundle + pluginChans PluginChans + clients []*client.Dialer + forPKI ServiceMap +} + +// OnSubscribeRequest enqueues the pkt for processing by our thread pool of plugins. +func (k *PluginWorker) OnSubscribeRequest(pkt *packet.Packet) { + handlerCh, ok := k.pluginChans[pkt.Recipient.ID] + if !ok { + k.log.Debugf("Failed to find handler. Dropping PubsubPlugin request: %v", pkt.ID) + return + } + handlerCh.In() <- pkt +} + +func (k *PluginWorker) sendReply(surb, payload []byte) { + // Prepend the response header. + payload = append([]byte{0x01, 0x00}, payload...) + + // generate random delay for first hop of SURB-Reply on Provider + epoch, _, _ := epochtime.Now() + doc, err := k.glue.PKI().GetCachedConsensusDoc(epoch) + if err != nil { + k.log.Debugf("Failed to get PKI doc for generating SURB-Reply: %v", err) + return + } + delay := packet.NewProviderDelay(rand.NewMath(), doc) + + respPkt, err := packet.NewDelayedPacketFromSURB(delay, surb, payload) + if err != nil { + k.log.Debugf("Failed to generate SURB-Reply: %v", err) + return + } + + k.log.Debugf("Handing off newly generated SURB-Reply: %v", respPkt.ID) + k.glue.Scheduler().OnPacket(respPkt) + return +} + +func (k *PluginWorker) garbageCollect() { + k.log.Debug("Running garbage collection process.") + // [SubscriptionIDLength]byte -> *SURBBundle + surbsMapRange := func(rawSubscriptionID, rawSurbBundle interface{}) bool { + subscriptionID := rawSubscriptionID.([common.SubscriptionIDLength]byte) + surbBundle := rawSurbBundle.(*SURBBundle) + + epoch, _, _ := epochtime.Now() + if epoch-surbBundle.Epoch >= 2 { + k.subscriptions.Delete(subscriptionID) + } + return true + } + k.subscriptions.Range(surbsMapRange) +} + +func (k *PluginWorker) garbageCollectionWorker() { + timer := time.NewTimer(GarbageCollectionInterval) + defer timer.Stop() + for { + select { + case <-k.HaltCh(): + k.log.Debugf("Garbage collection worker terminating gracefully.") + return + case <-timer.C: + k.garbageCollect() + timer.Reset(GarbageCollectionInterval) + } + } +} + +func (k *PluginWorker) appMessagesWorker(pluginClient *client.Dialer) { + for { + select { + case <-k.HaltCh(): + return + case appMessages := <-pluginClient.IncomingCh(): + rawSURBs, ok := k.subscriptions.Load(appMessages.SubscriptionID) + if !ok { + k.log.Error("Error, failed load a subscription ID from sync.Map") + continue + } + surbBundle, ok := rawSURBs.(*SURBBundle) + if !ok { + k.log.Error("Error, failed type assertion for type *SURBBundle") + continue + } + messagesBlob, err := common.MessagesToBytes(appMessages.Messages) + if err != nil { + k.log.Errorf("Error, failed to encode app messages as CBOR blob: %s", err) + continue + } + surb := surbBundle.SURBs[0] + if len(surbBundle.SURBs) == 1 { + k.log.Debug("Using last SURB in subscription.") + k.subscriptions.Delete(appMessages.SubscriptionID) + pluginClient.Unsubscribe(appMessages.SubscriptionID) + } else { + surbBundle.SURBs = surbBundle.SURBs[1:] + k.subscriptions.Store(appMessages.SubscriptionID, surbBundle) + } + k.sendReply(surb, messagesBlob) + } + } +} + +func (k *PluginWorker) subscriptionWorker(recipient [sConstants.RecipientIDLength]byte, pluginClient *client.Dialer) { + + // Kaetzchen delay is our max dwell time. + maxDwell := time.Duration(k.glue.Config().Debug.KaetzchenDelay) * time.Millisecond + + defer k.haltOnce.Do(k.haltAllClients) + + handlerCh, ok := k.pluginChans[recipient] + if !ok { + k.log.Debugf("Failed to find handler. Dropping PubsubPlugin request: %v", recipient) + pubsubRequestsDropped.Inc() + return + } + ch := handlerCh.Out() + + for { + var pkt *packet.Packet + select { + case <-k.HaltCh(): + k.log.Debugf("Terminating gracefully.") + return + case e := <-ch: + pkt = e.(*packet.Packet) + if dwellTime := monotime.Now() - pkt.DispatchAt; dwellTime > maxDwell { + k.log.Debugf("Dropping packet: %v (Spend %v in queue)", pkt.ID, dwellTime) + packetsDropped.Inc() + pkt.Dispose() + continue + } + k.processPacket(pkt, pluginClient) + pubsubRequests.Inc() + } + } +} + +func (k *PluginWorker) haltAllClients() { + k.log.Debug("Halting plugin clients.") + for _, client := range k.clients { + go client.Halt() + } +} + +func (k *PluginWorker) processPacket(pkt *packet.Packet, pluginClient *client.Dialer) { + pubsubRequestsTimer = prometheus.NewTimer(pubsubRequestsDuration) + defer pubsubRequestsTimer.ObserveDuration() + defer pkt.Dispose() + + payload, surbs, err := packet.ParseForwardPacket(pkt) + if err != nil { + k.log.Debugf("Failed to parse forward packet. Dropping Pubsub request: %v (%v)", pkt.ID, err) + pubsubRequestsDropped.Inc() + return + } + if len(surbs) == 0 { + k.log.Debugf("Zero SURBs supplied. Dropping Pubsub request: %v (%v)", pkt.ID, err) + pubsubRequestsDropped.Inc() + return + } + clientSubscribe, err := common.ClientSubscribeFromBytes(payload) + if err != nil { + k.log.Debugf("Failed to decode payload. Dropping Pubsub request: %v (%v)", pkt.ID, err) + pubsubRequestsDropped.Inc() + return + } + subscriptionID := common.GenerateSubscriptionID() + epoch, _, _ := epochtime.Now() + surbBundle := &SURBBundle{ + Epoch: epoch, + SURBs: surbs, + } + k.subscriptions.Store(subscriptionID, surbBundle) + subscription := &common.Subscribe{ + PacketID: pkt.ID, + SURBCount: uint8(len(surbs)), + SubscriptionID: subscriptionID, + SpoolID: clientSubscribe.SpoolID, + LastSpoolIndex: clientSubscribe.LastSpoolIndex, + } + pluginClient.Subscribe(subscription) + if err != nil { + k.log.Debugf("Failed to handle Pubsub request: %v (%v)", pkt.ID, err) + return + } + return +} + +// PubsubForPKI returns the plugins Parameters map for publication in the PKI doc. +func (k *PluginWorker) PubsubForPKI() ServiceMap { + return k.forPKI +} + +// HasRecipient returns true if the given recipient is one of our workers. +func (k *PluginWorker) HasRecipient(recipient [sConstants.RecipientIDLength]byte) bool { + _, ok := k.pluginChans[recipient] + return ok +} + +func (k *PluginWorker) launch(command string, args []string) (*client.Dialer, error) { + k.log.Debugf("Launching plugin: %s", command) + plugin := client.New(command, k.glue.LogBackend()) + err := plugin.Launch(command, args) + return plugin, err +} + +// New returns a new PluginWorker +func New(glue glue.Glue) (*PluginWorker, error) { + + pluginWorker := PluginWorker{ + glue: glue, + log: glue.LogBackend().GetLogger("pubsub plugin worker"), + pluginChans: make(PluginChans), + clients: make([]*client.Dialer, 0), + forPKI: make(ServiceMap), + subscriptions: new(sync.Map), + } + + pluginWorker.Go(pluginWorker.garbageCollectionWorker) + + capaMap := make(map[string]bool) + + for i, pluginConf := range glue.Config().Provider.PubsubPlugin { + pluginWorker.log.Noticef("Configuring plugin handler for %s", pluginConf.Capability) + + // Ensure no duplicates. + capa := pluginConf.Capability + if capa == "" { + return nil, errors.New("pubsub plugin capability cannot be empty string") + } + if pluginConf.Disable { + pluginWorker.log.Noticef("Skipping disabled Pubsub: '%v'.", capa) + continue + } + if capaMap[capa] { + return nil, fmt.Errorf("provider: Pubsub '%v' registered more than once", capa) + } + + // Sanitize the endpoint. + endpoint, err := ensureEndpointSanitized(pluginConf.Endpoint, capa) + if err != nil { + return nil, err + } + + // Add an infinite channel for this plugin. + + pluginWorker.pluginChans[*endpoint] = channels.NewInfiniteChannel() + + // Add entry from this plugin for the PKI. + params := make(map[string]interface{}) + gotParams := false + + pluginWorker.log.Noticef("Starting Pubsub plugin client: %s %d", capa, i) + + var args []string + if len(pluginConf.Config) > 0 { + args = []string{} + for key, val := range pluginConf.Config { + args = append(args, fmt.Sprintf("-%s", key), val.(string)) + } + } + + pluginClient, err := pluginWorker.launch(pluginConf.Command, args) + if err != nil { + pluginWorker.log.Error("Failed to start a plugin client: %s", err) + return nil, err + } + + for i := 0; i < pluginConf.MaxConcurrency; i++ { + err := pluginClient.Dial() + if err != nil { + pluginWorker.log.Error("Failed to dial a plugin client: %s", err) + return nil, err + } + + pluginWorker.Go(func() { + pluginWorker.appMessagesWorker(pluginClient) + }) + + if !gotParams { + // just once we call the Parameters method on the plugin + // and use that info to populate our forPKI map which + // ends up populating the PKI document + p := pluginClient.Parameters() + if p != nil { + for key, value := range *p { + params[key] = value + } + } + params[ParameterEndpoint] = pluginConf.Endpoint + gotParams = true + } + + // Accumulate a list of all clients to facilitate clean shutdown. + pluginWorker.clients = append(pluginWorker.clients, pluginClient) + + // Start the subscriptionWorker _after_ we have added all of the entries to pluginChans + // otherwise the subscriptionWorker() goroutines race this thread. + defer pluginWorker.Go(func() { + pluginWorker.subscriptionWorker(*endpoint, pluginClient) + }) + } + + pluginWorker.forPKI[capa] = params + capaMap[capa] = true + } + + return &pluginWorker, nil +} diff --git a/internal/provider/pubsub/plugins_test.go b/internal/provider/pubsub/plugins_test.go new file mode 100644 index 0000000..743e338 --- /dev/null +++ b/internal/provider/pubsub/plugins_test.go @@ -0,0 +1,245 @@ +// plugins_test.go - tests for plugin system +// Copyright (C) 2020 David Stainton +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package pubsub + +import ( + "testing" + + "github.com/katzenpost/core/crypto/ecdh" + "github.com/katzenpost/core/crypto/eddsa" + "github.com/katzenpost/core/crypto/rand" + "github.com/katzenpost/core/log" + "github.com/katzenpost/core/sphinx/constants" + "github.com/katzenpost/core/thwack" + "github.com/katzenpost/core/wire" + "github.com/katzenpost/server/config" + "github.com/katzenpost/server/internal/glue" + "github.com/katzenpost/server/internal/packet" + "github.com/katzenpost/server/internal/pkicache" + "github.com/katzenpost/server/spool" + "github.com/katzenpost/server/userdb" + "github.com/stretchr/testify/require" +) + +type mockUserDB struct { + provider *mockProvider +} + +func (u *mockUserDB) Exists([]byte) bool { + return true +} + +func (u *mockUserDB) IsValid([]byte, *ecdh.PublicKey) bool { return true } + +func (u *mockUserDB) Add([]byte, *ecdh.PublicKey, bool) error { return nil } + +func (u *mockUserDB) SetIdentity([]byte, *ecdh.PublicKey) error { return nil } + +func (u *mockUserDB) Link([]byte) (*ecdh.PublicKey, error) { + return nil, nil +} + +func (u *mockUserDB) Identity([]byte) (*ecdh.PublicKey, error) { + return u.provider.userKey, nil +} + +func (u *mockUserDB) Remove([]byte) error { return nil } + +func (u *mockUserDB) Close() {} + +type mockSpool struct{} + +func (s *mockSpool) StoreMessage(u, msg []byte) error { return nil } + +func (s *mockSpool) StoreSURBReply(u []byte, id *[constants.SURBIDLength]byte, msg []byte) error { + return nil +} + +func (s *mockSpool) Get(u []byte, advance bool) (msg, surbID []byte, remaining int, err error) { + return []byte{1, 2, 3}, nil, 1, nil +} + +func (s *mockSpool) Remove(u []byte) error { return nil } + +func (s *mockSpool) Vacuum(udb userdb.UserDB) error { return nil } + +func (s *mockSpool) Close() {} + +type mockProvider struct { + userName string + userKey *ecdh.PublicKey +} + +func (p *mockProvider) Halt() {} + +func (p *mockProvider) UserDB() userdb.UserDB { + return &mockUserDB{ + provider: p, + } +} + +func (p *mockProvider) Spool() spool.Spool { + return &mockSpool{} +} + +func (p *mockProvider) AuthenticateClient(*wire.PeerCredentials) bool { + return true +} + +func (p *mockProvider) OnPacket(*packet.Packet) {} + +func (p *mockProvider) KaetzchenForPKI() (map[string]map[string]interface{}, error) { + return nil, nil +} + +func (p *mockProvider) AdvertiseRegistrationHTTPAddresses() []string { + return nil +} + +type mockDecoy struct{} + +func (d *mockDecoy) Halt() {} + +func (d *mockDecoy) OnNewDocument(*pkicache.Entry) {} + +func (d *mockDecoy) OnPacket(*packet.Packet) {} + +type mockServer struct { + cfg *config.Config + logBackend *log.Backend + identityKey *eddsa.PrivateKey + linkKey *ecdh.PrivateKey + management *thwack.Server + mixKeys glue.MixKeys + pki glue.PKI + provider glue.Provider + scheduler glue.Scheduler + connector glue.Connector + listeners []glue.Listener +} + +type mockGlue struct { + s *mockServer +} + +func (g *mockGlue) Config() *config.Config { + return g.s.cfg +} + +func (g *mockGlue) LogBackend() *log.Backend { + return g.s.logBackend +} + +func (g *mockGlue) IdentityKey() *eddsa.PrivateKey { + return g.s.identityKey +} + +func (g *mockGlue) LinkKey() *ecdh.PrivateKey { + return g.s.linkKey +} + +func (g *mockGlue) Management() *thwack.Server { + return g.s.management +} + +func (g *mockGlue) MixKeys() glue.MixKeys { + return g.s.mixKeys +} + +func (g *mockGlue) PKI() glue.PKI { + return g.s.pki +} + +func (g *mockGlue) Provider() glue.Provider { + return g.s.provider +} + +func (g *mockGlue) Scheduler() glue.Scheduler { + return g.s.scheduler +} + +func (g *mockGlue) Connector() glue.Connector { + return g.s.connector +} + +func (g *mockGlue) Listeners() []glue.Listener { + return g.s.listeners +} + +func (g *mockGlue) ReshadowCryptoWorkers() {} + +func (g *mockGlue) Decoy() glue.Decoy { + return &mockDecoy{} +} + +func getGlue(logBackend *log.Backend, provider *mockProvider, linkKey *ecdh.PrivateKey, idKey *eddsa.PrivateKey) *mockGlue { + goo := &mockGlue{ + s: &mockServer{ + logBackend: logBackend, + provider: provider, + linkKey: linkKey, + cfg: &config.Config{ + Server: &config.Server{}, + Logging: &config.Logging{}, + Provider: &config.Provider{}, + PKI: &config.PKI{}, + Management: &config.Management{}, + Debug: &config.Debug{ + NumKaetzchenWorkers: 3, + IdentityKey: idKey, + KaetzchenDelay: 300, + }, + }, + }, + } + return goo +} + +func TestPubsubInvalidCommand(t *testing.T) { + require := require.New(t) + + idKey, err := eddsa.NewKeypair(rand.Reader) + require.NoError(err) + + logBackend, err := log.New("", "DEBUG", false) + require.NoError(err) + + userKey, err := ecdh.NewKeypair(rand.Reader) + require.NoError(err) + + linkKey, err := ecdh.NewKeypair(rand.Reader) + require.NoError(err) + + mockProvider := &mockProvider{ + userName: "alice", + userKey: userKey.PublicKey(), + } + + goo := getGlue(logBackend, mockProvider, linkKey, idKey) + goo.s.cfg.Provider.PubsubPlugin = []*config.PubsubPlugin{ + &config.PubsubPlugin{ + Capability: "loop", + Endpoint: "loop", + Config: map[string]interface{}{}, + Disable: false, + Command: "non-existent command", + MaxConcurrency: 1, + }, + } + _, err = New(goo) + require.Error(err) +} diff --git a/pubsubplugin/client/dialer.go b/pubsubplugin/client/dialer.go new file mode 100644 index 0000000..f5879c4 --- /dev/null +++ b/pubsubplugin/client/dialer.go @@ -0,0 +1,196 @@ +// dialer.go - client plugin dialer, tracks multiple connections to a plugin +// Copyright (C) 2020 David Stainton. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package pubsubplugin client is the client module a plugin system allowing mix network services +// to be added in any language. It implements a publish subscribe interface. +// +package client + +import ( + "bufio" + "fmt" + "io" + "net" + "os/exec" + "sync" + "syscall" + + "github.com/katzenpost/core/log" + "github.com/katzenpost/server/pubsubplugin/common" + "gopkg.in/op/go-logging.v1" +) + +const unixSocketNetwork = "unix" + +// Dialer handles the launching and subsequent multiple dialings of a given plugin. +type Dialer struct { + sync.RWMutex + + logBackend *log.Backend + log *logging.Logger + + // for sending subscribe/unsubscribe commands to the plugin + outgoingCh chan interface{} + + // for receiving AppMessages from the plugin + incomingCh chan *common.AppMessages + + conns []*outgoingConn + cmd *exec.Cmd + socketPath string + params *common.Parameters + + haltOnce sync.Once +} + +// New creates a new plugin client instance which represents the single execution +// of the external plugin program. +func New(command string, logBackend *log.Backend) *Dialer { + return &Dialer{ + logBackend: logBackend, + log: logBackend.GetLogger(command), + + outgoingCh: make(chan interface{}), + incomingCh: make(chan *common.AppMessages), + conns: make([]*outgoingConn, 0), + } +} + +func (d *Dialer) IncomingCh() chan *common.AppMessages { + return d.incomingCh +} + +// Halt halts all of the outgoing connections and halts +// the execution of the plugin application. +func (d *Dialer) Halt() { + d.haltOnce.Do(d.doHalt) +} + +func (d *Dialer) doHalt() { + d.RLock() + defer d.RUnlock() + for _, outgoingConn := range d.conns { + outgoingConn.Halt() + } + d.cmd.Process.Signal(syscall.SIGTERM) + err := d.cmd.Wait() + if err != nil { + d.log.Errorf("Publish-subscript plugin worker, command halt exec error: %s\n", err) + } +} + +func (d *Dialer) logPluginStderr(stderr io.ReadCloser) { + logWriter := d.logBackend.GetLogWriter(d.cmd.Path, "DEBUG") + _, err := io.Copy(logWriter, stderr) + if err != nil { + d.log.Errorf("Failed to proxy pubsubplugin stderr to DEBUG log: %s", err) + } + d.Halt() +} + +// Launch executes the given command and args, reading the unix socket path +// from STDOUT and saving it for later use when dialing the socket. +func (d *Dialer) Launch(command string, args []string) error { + // exec plugin + d.cmd = exec.Command(command, args...) + stdout, err := d.cmd.StdoutPipe() + if err != nil { + d.log.Debugf("pipe failure: %s", err) + return err + } + stderr, err := d.cmd.StderrPipe() + if err != nil { + d.log.Debugf("pipe failure: %s", err) + return err + } + err = d.cmd.Start() + if err != nil { + d.log.Debugf("failed to exec: %s", err) + return err + } + + // proxy stderr to our debug log + go d.logPluginStderr(stderr) + + // read and decode plugin stdout + stdoutScanner := bufio.NewScanner(stdout) + stdoutScanner.Scan() + d.socketPath = stdoutScanner.Text() + d.log.Debugf("plugin socket path:'%s'\n", d.socketPath) + return nil +} + +// Dial dials the unix socket that was recorded during Launch. +func (d *Dialer) Dial() error { + conn, err := net.Dial(unixSocketNetwork, d.socketPath) + if err != nil { + d.log.Debugf("unix socket connect failure: %s", err) + return err + } + d.onNewConn(conn) + return nil +} + +func (d *Dialer) ensureParameters(outgoingConn *outgoingConn) error { + if d.params == nil { + d.log.Debug("requesting plugin Parameters for Mix Descriptor publication") + var err error + d.params, err = outgoingConn.getParameters() + if err != nil { + return err + } + } + return nil +} + +func (d *Dialer) onNewConn(conn net.Conn) { + connLog := d.logBackend.GetLogger(fmt.Sprintf("%s connection %d", d.cmd, len(d.conns)+1)) + outConn := newOutgoingConn(d, conn, connLog) + + err := d.ensureParameters(outConn) + if err != nil { + d.log.Error("failed to acquire plugin parameters, giving up on plugin connection") + return + } + + d.Lock() + defer func() { + d.Unlock() + outConn.Go(outConn.worker) + }() + d.conns = append(d.conns, outConn) +} + +// Parameters returns the Parameters whcih are used in Mix Descriptor +// publication to give service clients more information about the service. +func (d *Dialer) Parameters() *common.Parameters { + return d.params +} + +// Unsubscribe sends a subscription request to the plugin over +// the Unix domain socket protocol. +func (d *Dialer) Unsubscribe(subscriptionID [common.SubscriptionIDLength]byte) { + u := &common.Unsubscribe{ + SubscriptionID: subscriptionID, + } + d.outgoingCh <- u +} + +// Subscribe sends a subscription request to the plugin over +// the Unix domain socket protocol. +func (d *Dialer) Subscribe(subscribe *common.Subscribe) { + d.outgoingCh <- subscribe +} diff --git a/pubsubplugin/client/outgoing_conn.go b/pubsubplugin/client/outgoing_conn.go new file mode 100644 index 0000000..419ccaa --- /dev/null +++ b/pubsubplugin/client/outgoing_conn.go @@ -0,0 +1,158 @@ +// outgoing_conn.go - out going client plugin connection +// Copyright (C) 2020 David Stainton. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package pubsubplugin client is the client module a plugin system allowing mix network services +// to be added in any language. It implements a publish subscribe interface. +// +package client + +import ( + "errors" + "io" + "net" + + "github.com/katzenpost/core/worker" + "github.com/katzenpost/server/pubsubplugin/common" + "gopkg.in/op/go-logging.v1" +) + +type outgoingConn struct { + worker.Worker + + log *logging.Logger + + dialer *Dialer + conn net.Conn +} + +func newOutgoingConn(dialer *Dialer, conn net.Conn, log *logging.Logger) *outgoingConn { + return &outgoingConn{ + dialer: dialer, + conn: conn, + log: log, + } +} + +func (c *outgoingConn) readAppMessages() (*common.AppMessages, error) { + lenPrefixBuf := make([]byte, 2) + _, err := io.ReadFull(c.conn, lenPrefixBuf) + if err != nil { + return nil, err + } + lenPrefix := common.PrefixLengthDecode(lenPrefixBuf) + responseBuf := make([]byte, lenPrefix) + _, err = io.ReadFull(c.conn, responseBuf) + if err != nil { + return nil, err + } + egressCmd, err := common.EgressUnixSocketCommandFromBytes(responseBuf) + if err != nil { + return nil, err + } + if egressCmd.AppMessages == nil { + return nil, errors.New("expected EgressUnixSocketCommand AppMessages to not be nil") + } + return egressCmd.AppMessages, nil +} + +func (c *outgoingConn) unsubscribe(u *common.Unsubscribe) { + serializedUnsubscribe, err := u.ToBytes() + if err != nil { + c.log.Errorf("unsubscribe error: %s", err) + } + serializedUnsubscribe = common.PrefixLengthEncode(serializedUnsubscribe) + _, err = c.conn.Write(serializedUnsubscribe) + if err != nil { + c.log.Errorf("unsubscribe error: %s", err) + } +} + +func (c *outgoingConn) subscribe(subscribe *common.Subscribe) error { + serializedSubscribe, err := subscribe.ToBytes() + if err != nil { + c.log.Errorf("subscribe error: %s", err) + } + serializedSubscribe = common.PrefixLengthEncode(serializedSubscribe) + _, err = c.conn.Write(serializedSubscribe) + if err != nil { + c.log.Errorf("subscribe error: %s", err) + } + return err +} + +func (c *outgoingConn) worker() { + defer func() { + // XXX TODO: stuff to shutdown + }() + for { + newMessages, err := c.readAppMessages() + if err != nil { + c.log.Errorf("failure to read new messages from plugin: %s", err) + return + } + select { + case <-c.HaltCh(): + return + case c.dialer.incomingCh <- newMessages: + case rawCmd := <-c.dialer.outgoingCh: + switch cmd := rawCmd.(type) { + case *common.Unsubscribe: + c.unsubscribe(cmd) + case *common.Subscribe: + c.subscribe(cmd) + default: + c.log.Errorf("outgoingConn received invalid command type %T from Dialer", rawCmd) + } + } + } +} + +func (c *outgoingConn) getParameters() (*common.Parameters, error) { + ingressCmd := &common.IngressUnixSocketCommand{ + GetParameters: &common.GetParameters{}, + } + rawGetParams, err := ingressCmd.ToBytes() + if err != nil { + return nil, err + } + rawGetParams = common.PrefixLengthEncode(rawGetParams) + _, err = c.conn.Write(rawGetParams) + if err != nil { + return nil, err + } + + // read response + lenPrefixBuf := make([]byte, 2) + _, err = io.ReadFull(c.conn, lenPrefixBuf) + if err != nil { + return nil, err + } + lenPrefix := common.PrefixLengthDecode(lenPrefixBuf) + responseBuf := make([]byte, lenPrefix) + _, err = io.ReadFull(c.conn, responseBuf) + if err != nil { + return nil, err + } + + egressCmd, err := common.EgressUnixSocketCommandFromBytes(responseBuf) + if err != nil { + return nil, err + } + if egressCmd.Parameters == nil { + return nil, errors.New("expected EgressUnixSocketCommand Parameters to not be nil") + } + return egressCmd.Parameters, nil +} diff --git a/pubsubplugin/common/common.go b/pubsubplugin/common/common.go new file mode 100644 index 0000000..f5acb58 --- /dev/null +++ b/pubsubplugin/common/common.go @@ -0,0 +1,300 @@ +// common.go - common types shared between the pubsub client and server. +// Copyright (C) 2020 David Stainton. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package common include the common types used in the publish subscribe +// client and server modules. +// +package common + +import ( + "encoding/binary" + "errors" + + "github.com/fxamacker/cbor/v2" + "github.com/katzenpost/core/crypto/rand" +) + +const ( + // SubscriptionIDLength is the length of the Subscription ID. + SubscriptionIDLength = 8 + + // SpoolIDLength is the length of the spool identity. + SpoolIDLength = 8 +) + +// PrefixLengthEncode encodes the given byte slice with +// two byte big endian length prefix encoding. +func PrefixLengthEncode(b []byte) []byte { + lenPrefix := make([]byte, 2) + binary.BigEndian.PutUint16(lenPrefix, uint16(len(b))) + b = append(lenPrefix, b...) + return b +} + +// PrefixLengthDecode decodes the first two bytes of the +// given byte slice as a uint16, big endian encoded. +func PrefixLengthDecode(b []byte) uint16 { + return binary.BigEndian.Uint16(b[:2]) +} + +// SubscriptionID is a subscription identity. +type SubscriptionID [SubscriptionIDLength]byte + +// SpoolID is a spool identity. +type SpoolID [SpoolIDLength]byte + +// Unsubscribe is used by the mix server to communicate an unsubscribe to +// the plugin. +type Unsubscribe struct { + // SubscriptionID is the server generated subscription identity. + SubscriptionID SubscriptionID +} + +// ToBytes returns a CBOR serialized Unsubscribe. +func (u *Unsubscribe) ToBytes() ([]byte, error) { + serializedunsubscribe, err := cbor.Marshal(u) + if err != nil { + return nil, err + } + return serializedunsubscribe, nil +} + +// Subscribe is the struct type used to establish a subscription with +// the plugin. +type Subscribe struct { + // PacketID is the packet identity. + PacketID uint64 + + // SURBCount is the number of SURBs available to this subscription. + SURBCount uint8 + + // SubscriptionID is the server generated subscription identity. + SubscriptionID SubscriptionID + + // SpoolID is the spool identity. + SpoolID SpoolID + + // LastSpoolIndex is the last spool index which was received by the client. + LastSpoolIndex uint64 +} + +// SubscribeToBytes encodes the given Subscribe as a CBOR byte blob. +func (s *Subscribe) ToBytes() ([]byte, error) { + serializedSubscribe, err := cbor.Marshal(s) + if err != nil { + return nil, err + } + return serializedSubscribe, nil +} + +// SubscribeFromBytes returns a Subscribe given a CBOR serialized Subscribe. +func SubscribeFromBytes(b []byte) (*Subscribe, error) { + subscribe := Subscribe{} + err := cbor.Unmarshal(b, &subscribe) + if err != nil { + return nil, err + } + return &subscribe, nil +} + +// GenerateSubscriptionID returns a random subscription ID. +func GenerateSubscriptionID() [SubscriptionIDLength]byte { + id := [SubscriptionIDLength]byte{} + rand.Reader.Read(id[:]) + return id +} + +// ClientSubscribe is used by the mixnet client to send a subscription +// request to the publish-subscribe application plugin. +type ClientSubscribe struct { + // SpoolID is the spool identity. + SpoolID [SpoolIDLength]byte + + // LastSpoolIndex is the last spool index which was received by the client. + LastSpoolIndex uint64 + + // Payload is the application specific payload which the client sends to + // the plugin. + Payload []byte +} + +// ClientSubscribeFromBytes decodes a ClientSubscribe from the +// given CBOR byte blob. +func ClientSubscribeFromBytes(b []byte) (*ClientSubscribe, error) { + clientSubscribe := ClientSubscribe{} + err := cbor.Unmarshal(b, &clientSubscribe) + if err != nil { + return nil, err + } + return &clientSubscribe, nil +} + +// AppMessages is the struct type used by the application plugin to +// send new messages to the server and eventually the subscribing client. +type AppMessages struct { + // SubscriptionID is the server generated subscription identity. + SubscriptionID SubscriptionID + + // Messages should contain one or more spool messages. + Messages []SpoolMessage +} + +// ToBytes serializes AppMessages into a CBOR byte blob +// or returns an error. +func (m *AppMessages) ToBytes() ([]byte, error) { + serialized, err := cbor.Marshal(m) + if err != nil { + return nil, err + } + return serialized, nil +} + +// SpoolMessage is a spool message from the application plugin. +type SpoolMessage struct { + // Index is the index value from whence the message came from. + Index uint64 + + // Payload contains the actual spool message contents which are + // application specific. + Payload []byte +} + +// MessagesToBytes returns a CBOR byte blob given a slice of type SpoolMessage. +func MessagesToBytes(messages []SpoolMessage) ([]byte, error) { + serialized, err := cbor.Marshal(messages) + if err != nil { + return nil, err + } + return serialized, nil +} + +// Parameters is an optional mapping that plugins can publish, these get +// advertised to clients in the MixDescriptor. +// The output of GetParameters() ends up being published in a map +// associating with the service names to service parameters map. +// This information is part of the Mix Descriptor which is defined here: +// https://github.com/katzenpost/core/blob/master/pki/pki.go +type Parameters map[string]string + +// GetParameters is used for querying the plugin over the unix socket +// to get the dynamic parameters after the plugin is started. +type GetParameters struct{} + +// IngressUnixSocketCommand wraps ingress unix socket wire protocol commands, +// that is commands used by the mix server, aka Provider to communicate with +// the application plugin. +type IngressUnixSocketCommand struct { + // GetParameters is used to retrieve the plugin parameters which + // can be dynamically selected by the plugin on startup. + GetParameters *GetParameters + + // Subscribe is used to establish a new SURB based subscription. + Subscribe *Subscribe + + // Unsubscribe is used to tear down an existing subscription. + Unsubscribe *Unsubscribe +} + +func (i *IngressUnixSocketCommand) validate() error { + notNilCount := 0 + if i.GetParameters != nil { + notNilCount += 1 + } + if i.Subscribe != nil { + notNilCount += 1 + } + if i.Unsubscribe != nil { + notNilCount += 1 + } + if notNilCount > 1 { + return errors.New("expected only one field to not be nil") + } + return nil +} + +// ToBytes serializes IngressUnixSocketCommand into a CBOR byte blob +// or returns an error. +func (i *IngressUnixSocketCommand) ToBytes() ([]byte, error) { + err := i.validate() + if err != nil { + return nil, err + } + serialized, err := cbor.Marshal(i) + if err != nil { + return nil, err + } + return serialized, nil +} + +func IngressUnixSocketCommandFromBytes(b []byte) (*IngressUnixSocketCommand, error) { + ingressCmds := &IngressUnixSocketCommand{} + err := cbor.Unmarshal(b, ingressCmds) + if err != nil { + return nil, err + } + err = ingressCmds.validate() + if err != nil { + return nil, err + } + return ingressCmds, nil +} + +// EgressUnixSocketCommand wraps egress unix socket wire protocol commands, +// that is commands used by the application plugin to communicate with the +// mix server, aka Provider. +type EgressUnixSocketCommand struct { + // Parameters is the plugin selected parameters which can be dynamically + // select at startup. + Parameters *Parameters + + // AppMessages contain the application messages from the plugin. + AppMessages *AppMessages +} + +func (e *EgressUnixSocketCommand) validate() error { + if e.Parameters != nil && e.AppMessages != nil { + return errors.New("expected only one field to not be nil") + } + return nil +} + +// ToBytes serializes EgressUnixSocketCommand into a CBOR byte blob +// or returns an error. +func (e *EgressUnixSocketCommand) ToBytes() ([]byte, error) { + err := e.validate() + if err != nil { + return nil, err + } + serialized, err := cbor.Marshal(e) + if err != nil { + return nil, err + } + return serialized, nil +} + +// EgressUnixSocketCommandFromBytes decodes a blob into a EgressUnixSocketCommand. +func EgressUnixSocketCommandFromBytes(b []byte) (*EgressUnixSocketCommand, error) { + egressCmds := &EgressUnixSocketCommand{} + err := cbor.Unmarshal(b, egressCmds) + if err != nil { + return nil, err + } + err = egressCmds.validate() + if err != nil { + return nil, err + } + return egressCmds, nil +} diff --git a/pubsubplugin/server/server.go b/pubsubplugin/server/server.go new file mode 100644 index 0000000..c045c01 --- /dev/null +++ b/pubsubplugin/server/server.go @@ -0,0 +1,278 @@ +// server.go - Publish subscribe server module for writing plugins. +// Copyright (C) 2020 David Stainton. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package server + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "path" + "path/filepath" + + "github.com/katzenpost/core/log" + "github.com/katzenpost/core/worker" + "github.com/katzenpost/server/pubsubplugin/common" + "gopkg.in/op/go-logging.v1" +) + +// Spool is an interface for spool implementations which +// will handle publishing new spool content to spool subscribers. +type Spool interface { + // Subscribe creates a new spool subscription. + Subscribe(subscriptionID *common.SubscriptionID, spoolID *common.SpoolID, lastSpoolIndex uint64) error + + // Unsubscribe removes an existing spool subscription. + Unsubscribe(subscriptionID *common.SubscriptionID) error +} + +// Config is used to configure a new Server instance. +type Config struct { + // Name is the name of the application service. + Name string + + // Parameters are the application service specified parameters which are advertized in the + // Katzenpost PKI document. + Parameters *common.Parameters + + // LogDir is the logging directory. + LogDir string + + // LogLevel is the log level and is set to one of: ERROR, WARNING, NOTICE, INFO, DEBUG, CRITICAL. + LogLevel string + + // Spool is the implementation of our Spool interface. + Spool Spool + + // AppMessagesCh is the application messages channel which is used by the application to send + // messages to the mix server for transit over the mix network to the destination client. + AppMessagesCh chan *common.AppMessages +} + +func validateConfig(config *Config) error { + if config == nil { + return errors.New("config must not be nil") + } + if config.Name == "" { + return errors.New("config.Name must not be empty") + } + if config.Parameters == nil { + return errors.New("config.Parameters must not be nil") + } + if config.LogDir == "" { + return errors.New("config.LogDir must not be empty") + } + if config.LogLevel == "" { + return errors.New("config.LogLevel must not be empty") + } + if config.Spool == nil { + return errors.New("config.Spool must not be nil") + } + if config.AppMessagesCh == nil { + return errors.New("config.AppMessagesCh must not be nil") + } + return nil +} + +// Server is used by applications to implement the application plugin which listens +// for connections from the mix server over a unix domain socket. Server handles the +// management of this unix domain socket as well as the wire protocol used. +type Server struct { + worker.Worker + + logBackend *log.Backend + log *logging.Logger + + listener net.Listener + socketFile string + + params *common.Parameters + appMessagesCh chan *common.AppMessages + spool Spool +} + +func (s *Server) sendParameters(conn net.Conn) error { + e := &common.EgressUnixSocketCommand{ + Parameters: s.params, + } + paramsBlob, err := e.ToBytes() + if err != nil { + return err + } + paramsBlob = common.PrefixLengthEncode(paramsBlob) + _, err = conn.Write(paramsBlob) + return err +} + +func (s *Server) readIngressCommands(conn net.Conn) (*common.IngressUnixSocketCommand, error) { + lenPrefixBuf := make([]byte, 2) + _, err := io.ReadFull(conn, lenPrefixBuf) + if err != nil { + return nil, err + } + lenPrefix := common.PrefixLengthDecode(lenPrefixBuf) + cmdBuf := make([]byte, lenPrefix) + _, err = io.ReadFull(conn, cmdBuf) + if err != nil { + return nil, err + } + ingressCmd, err := common.IngressUnixSocketCommandFromBytes(cmdBuf) + return ingressCmd, err +} + +func (s *Server) perpetualCommandReader(conn net.Conn) <-chan *common.IngressUnixSocketCommand { + readCh := make(chan *common.IngressUnixSocketCommand) + + s.Go(func() { + for { + cmd, err := s.readIngressCommands(conn) + if err != nil { + s.log.Errorf("failure to read new messages from plugin: %s", err) + return + } + select { + case <-s.HaltCh(): + return + case readCh <- cmd: + } + } + }) + + return readCh +} + +func (s *Server) connectionWorker(conn net.Conn) { + readCmdCh := s.perpetualCommandReader(conn) + + for { + select { + case <-s.HaltCh(): + s.log.Debugf("Worker terminating gracefully.") + return + case cmd := <-readCmdCh: + if cmd.GetParameters != nil { + s.sendParameters(conn) + continue + } + if cmd.Subscribe != nil { + s.spool.Subscribe(&cmd.Subscribe.SubscriptionID, &cmd.Subscribe.SpoolID, cmd.Subscribe.LastSpoolIndex) + continue + } + if cmd.Unsubscribe != nil { + s.spool.Unsubscribe(&cmd.Subscribe.SubscriptionID) + continue + } + case messages := <-s.appMessagesCh: + e := &common.EgressUnixSocketCommand{ + AppMessages: messages, + } + messagesBlob, err := e.ToBytes() + if err != nil { + s.log.Errorf("failed to serialize app messages: %s", err) + continue + } + messagesBlob = common.PrefixLengthEncode(messagesBlob) + _, err = conn.Write(messagesBlob) + if err != nil { + s.log.Errorf("failed to write AppMessages to socket: %s", err) + continue + } + } + } +} + +func (s *Server) worker() { + conn, err := s.listener.Accept() + if err != nil { + s.log.Errorf("error accepting connection: %s", err) + return + } + s.Go(func() { + s.connectionWorker(conn) + }) +} + +func (s *Server) setupListener(name string) error { + tmpDir, err := ioutil.TempDir("", name) + if err != nil { + return err + } + s.socketFile = filepath.Join(tmpDir, fmt.Sprintf("%s.socket", name)) + s.listener, err = net.Listen("unix", s.socketFile) + if err != nil { + return err + } + return nil +} + +func (s *Server) initLogging(name, logFile, logLevel string) error { + var err error + logDisable := false + s.logBackend, err = log.New(logFile, logLevel, logDisable) + if err != nil { + return err + } + s.log = s.logBackend.GetLogger(name) + return nil +} + +func (s *Server) ensureLogDir(logDir string) error { + stat, err := os.Stat(logDir) + if os.IsNotExist(err) { + return fmt.Errorf("Log directory '%s' doesn't exist.", logDir) + } + if !stat.IsDir() { + return fmt.Errorf("Log directory '%s' must be a directory.", logDir) + } + return nil +} + +// New creates a new Server instance and starts immediately listening for new connections. +func New(config *Config) (*Server, error) { + err := validateConfig(config) + if err != nil { + return nil, err + } + s := &Server{ + params: config.Parameters, + appMessagesCh: config.AppMessagesCh, + spool: config.Spool, + } + err = s.ensureLogDir(config.LogDir) + if err != nil { + return nil, err + } + logFile := path.Join(config.LogDir, fmt.Sprintf("%s.%d.log", config.Name, os.Getpid())) + err = s.initLogging(config.Name, logFile, config.LogLevel) + if err != nil { + return nil, err + } + s.log.Debug("starting listener") + err = s.setupListener(config.Name) + if err != nil { + return nil, err + } + fmt.Printf("%s\n", s.socketFile) + s.Go(s.worker) + go func() { + <-s.HaltCh() + os.Remove(s.socketFile) + }() + return s, nil +}