From 39d1d5dc271f85a3e6cb86ea1629786f99fa1b09 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 16 Jun 2021 12:41:44 -0700 Subject: [PATCH] don't cache more than 1024 entries, to avoid DoS attacks --- client.go | 9 +++++++ service_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 270394ab..4fe1ddfc 100644 --- a/client.go +++ b/client.go @@ -26,6 +26,9 @@ const ( IPv4AndIPv6 = (IPv4 | IPv6) //< Default option. ) +// DoS protection: we won't cache more than 1024 entries when receiving entries. +var maxSentEntries = 1024 + type clientOpts struct { listenOn IPType ifaces []net.Interface @@ -293,6 +296,12 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { // This is also a point to possibly stop probing actively for a // service entry. params.Entries <- e + // DoS protection: don't cache more than maxSentEntries entries + if len(sentEntries) >= maxSentEntries { + for key := range sentEntries { + delete(sentEntries, key) + } + } sentEntries[k] = e if !params.isBrowsing { params.disableProbing() diff --git a/service_test.go b/service_test.go index 2c5a23ed..69918e6f 100644 --- a/service_test.go +++ b/service_test.go @@ -2,6 +2,7 @@ package zeroconf import ( "context" + "fmt" "log" "testing" "time" @@ -9,7 +10,7 @@ import ( "github.com/pkg/errors" ) -var ( +const ( mdnsName = "test--xxxxxxxxxxxx" mdnsService = "_test--xxxx._tcp" mdnsSubtype = "_test--xxxx._tcp,_fancy" @@ -163,4 +164,67 @@ func TestSubtype(t *testing.T) { t.Fatalf("Expected port is %d, but got %d", mdnsPort, result.Port) } }) + + t.Run("DoS protection", func(t *testing.T) { + origMaxSentEntries := maxSentEntries + maxSentEntries = 10 + defer func() { maxSentEntries = origMaxSentEntries }() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + const firstName = mdnsName + + go startMDNS(ctx, mdnsPort, firstName, mdnsSubtype, mdnsDomain) + time.Sleep(time.Second) + + resolver, err := NewResolver(nil) + if err != nil { + t.Fatalf("Expected create resolver success, but got %v", err) + } + entries := make(chan *ServiceEntry, maxSentEntries+1) + received := make(chan *ServiceEntry, 10) + go func() { + for { + select { + case entry := <-entries: + if entry.Instance == firstName { + received <- entry + } + case <-ctx.Done(): + return + } + } + }() + if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil { + t.Fatalf("Expected browse success, but got %v", err) + } + select { + case <-received: + case <-time.NewTimer(time.Second).C: + t.Fatal("expected to discover service") + } + + for i := 1; i < maxSentEntries; i++ { + go startMDNS(ctx, mdnsPort, fmt.Sprintf("%s-%d", mdnsName, i), mdnsSubtype, mdnsDomain) + } + time.Sleep(time.Second) + + select { + case entry := <-entries: + t.Fatalf("didn't expect to receive an entry, got %v", entry) + default: + } + + // Announcing this service will cause the map to overflow. + go startMDNS(ctx, mdnsPort, fmt.Sprintf("%s-%d", mdnsName, maxSentEntries), mdnsSubtype, mdnsDomain) + + // wait for a re-announcement of the firstName service + select { + case <-received: + cancel() + case <-ctx.Done(): + t.Fatal("expected to discover service") + } + }) }