From 12e7b86d7624da9dc44b5b1cc95888f007257822 Mon Sep 17 00:00:00 2001 From: Alex Leong Date: Wed, 18 Oct 2023 12:34:38 -0700 Subject: [PATCH] Add update queue to endpoint translator (#11491) When a grpc client of the destination.Get API initiates a request but then doesn't read off of that stream, the HTTP2 stream flow control window will fill up and eventually exert backpressure on the destination controller. This manifests as calls to `Send` on the stream blocking. Since `Send` is called synchronously from the client-go informer callback (by way of the endpoint translator), this blocks the informer callback and prevents all further informer calllbacks from firing. This causes the destination controller to stop sending updates to any of its clients. We add a queue in the endpoint translator so that when it gets an update from the informer callback, that update is queued and we avoid potentially blocking the informer callback. Each endpoint translator spawns a goroutine to process this queue and call `Send`. If there is not capacity in this queue (e.g. because a client has stopped reading and we are experiencing backpressure) then we terminate the stream. Signed-off-by: Alex Leong --- .../api/destination/destination_fuzzer.go | 4 +- .../api/destination/endpoint_translator.go | 186 ++++++++++---- .../destination/endpoint_translator_test.go | 85 +++++-- controller/api/destination/server.go | 11 + controller/api/destination/server_test.go | 231 +++++++++--------- controller/api/destination/test_util.go | 11 +- 6 files changed, 341 insertions(+), 187 deletions(-) diff --git a/controller/api/destination/destination_fuzzer.go b/controller/api/destination/destination_fuzzer.go index ff774763f8066..beba1adc6b766 100644 --- a/controller/api/destination/destination_fuzzer.go +++ b/controller/api/destination/destination_fuzzer.go @@ -25,6 +25,8 @@ func FuzzAdd(data []byte) int { } t := &testing.T{} _, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(set) translator.Remove(set) return 1 @@ -52,7 +54,7 @@ func FuzzGet(data []byte) int { server := makeServer(t) stream := &bufferingGetStream{ - updates: []*pb.Update{}, + updates: make(chan *pb.Update, 50), MockServerStream: util.NewMockServerStream(), } _ = server.Get(dest1, stream) diff --git a/controller/api/destination/endpoint_translator.go b/controller/api/destination/endpoint_translator.go index cde93eadff9be..1ba0908a0975e 100644 --- a/controller/api/destination/endpoint_translator.go +++ b/controller/api/destination/endpoint_translator.go @@ -5,7 +5,6 @@ import ( "reflect" "strconv" "strings" - "sync" pb "github.com/linkerd/linkerd2-proxy-api/go/destination" "github.com/linkerd/linkerd2-proxy-api/go/net" @@ -13,6 +12,8 @@ import ( "github.com/linkerd/linkerd2/controller/k8s" "github.com/linkerd/linkerd2/pkg/addr" pkgK8s "github.com/linkerd/linkerd2/pkg/k8s" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" logging "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" ) @@ -22,26 +23,55 @@ const ( // inboundListenAddr is the environment variable holding the inbound // listening address for the proxy container. envInboundListenAddr = "LINKERD2_PROXY_INBOUND_LISTEN_ADDR" + + updateQueueCapacity = 100 ) // endpointTranslator satisfies EndpointUpdateListener and translates updates // into Destination.Get messages. -type endpointTranslator struct { - controllerNS string - identityTrustDomain string - enableH2Upgrade bool - nodeTopologyZone string - nodeName string - defaultOpaquePorts map[uint32]struct{} - enableEndpointFiltering bool - - availableEndpoints watcher.AddressSet - filteredSnapshot watcher.AddressSet - stream pb.Destination_GetServer - log *logging.Entry - - mu sync.Mutex -} +type ( + endpointTranslator struct { + controllerNS string + identityTrustDomain string + enableH2Upgrade bool + nodeTopologyZone string + nodeName string + defaultOpaquePorts map[uint32]struct{} + enableEndpointFiltering bool + + availableEndpoints watcher.AddressSet + filteredSnapshot watcher.AddressSet + stream pb.Destination_GetServer + endStream chan struct{} + log *logging.Entry + overflowCounter prometheus.Counter + + updates chan interface{} + stop chan struct{} + } + + addUpdate struct { + set watcher.AddressSet + } + + removeUpdate struct { + set watcher.AddressSet + } + + noEndpointsUpdate struct { + exists bool + } +) + +var updatesQueueOverflowCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "endpoint_updates_queue_overflow", + Help: "A counter incremented whenever the endpoint updates queue overflows", + }, + []string{ + "service", + }, +) func newEndpointTranslator( controllerNS string, @@ -53,6 +83,7 @@ func newEndpointTranslator( enableEndpointFiltering bool, k8sAPI *k8s.MetadataAPI, stream pb.Destination_GetServer, + endStream chan struct{}, log *logging.Entry, ) *endpointTranslator { log = log.WithFields(logging.Fields{ @@ -79,15 +110,85 @@ func newEndpointTranslator( availableEndpoints, filteredSnapshot, stream, + endStream, log, - sync.Mutex{}, + updatesQueueOverflowCounter.With(prometheus.Labels{"service": service}), + make(chan interface{}, updateQueueCapacity), + make(chan struct{}), } } func (et *endpointTranslator) Add(set watcher.AddressSet) { - et.mu.Lock() - defer et.mu.Unlock() + et.enqueueUpdate(&addUpdate{set}) +} + +func (et *endpointTranslator) Remove(set watcher.AddressSet) { + et.enqueueUpdate(&removeUpdate{set}) +} + +func (et *endpointTranslator) NoEndpoints(exists bool) { + et.enqueueUpdate(&noEndpointsUpdate{exists}) +} + +// Add, Remove, and NoEndpoints are called from a client-go informer callback +// and therefore must not block. For each of these, we enqueue an update in +// a channel so that it can be processed asyncronously. To ensure that enqueuing +// does not block, we first check to see if there is capacity in the buffered +// channel. If there is not, we drop the update and signal to the stream that +// it has fallen too far behind and should be closed. +func (et *endpointTranslator) enqueueUpdate(update interface{}) { + select { + case et.updates <- update: + // Update has been successfully enqueued. + default: + // We are unable to enqueue because the channel does not have capacity. + // The stream has fallen too far behind and should be closed. + et.overflowCounter.Inc() + select { + case <-et.endStream: + // The endStream channel has already been closed so no action is + // necessary. + default: + et.log.Error("endpoint update queue full; aborting stream") + close(et.endStream) + } + } +} + +// Start initiates a goroutine which processes update events off of the +// endpointTranslator's internal queue and sends to the grpc stream as +// appropriate. The goroutine calls several non-thread-safe functions (including +// Send) and therefore, Start must not be called more than once. +func (et *endpointTranslator) Start() { + go func() { + for { + select { + case update := <-et.updates: + et.processUpdate(update) + case <-et.stop: + return + } + } + }() +} +// Stop terminates the goroutine started by Start. +func (et *endpointTranslator) Stop() { + close(et.stop) +} + +func (et *endpointTranslator) processUpdate(update interface{}) { + switch update := update.(type) { + case *addUpdate: + et.add(update.set) + case *removeUpdate: + et.remove(update.set) + case *noEndpointsUpdate: + et.noEndpoints(update.exists) + } +} + +func (et *endpointTranslator) add(set watcher.AddressSet) { for id, address := range set.Addresses { et.availableEndpoints.Addresses[id] = address } @@ -95,10 +196,7 @@ func (et *endpointTranslator) Add(set watcher.AddressSet) { et.sendFilteredUpdate(set) } -func (et *endpointTranslator) Remove(set watcher.AddressSet) { - et.mu.Lock() - defer et.mu.Unlock() - +func (et *endpointTranslator) remove(set watcher.AddressSet) { for id := range set.Addresses { delete(et.availableEndpoints.Addresses, id) } @@ -106,6 +204,26 @@ func (et *endpointTranslator) Remove(set watcher.AddressSet) { et.sendFilteredUpdate(set) } +func (et *endpointTranslator) noEndpoints(exists bool) { + et.log.Debugf("NoEndpoints(%+v)", exists) + + et.availableEndpoints.Addresses = map[watcher.ID]watcher.Address{} + et.filteredSnapshot.Addresses = map[watcher.ID]watcher.Address{} + + u := &pb.Update{ + Update: &pb.Update_NoEndpoints{ + NoEndpoints: &pb.NoEndpoints{ + Exists: exists, + }, + }, + } + + et.log.Debugf("Sending destination no endpoints: %+v", u) + if err := et.stream.Send(u); err != nil { + et.log.Debugf("Failed to send address update: %s", err) + } +} + func (et *endpointTranslator) sendFilteredUpdate(set watcher.AddressSet) { et.availableEndpoints = watcher.AddressSet{ Addresses: et.availableEndpoints.Addresses, @@ -244,26 +362,6 @@ func (et *endpointTranslator) diffEndpoints(filtered watcher.AddressSet) (watche } } -func (et *endpointTranslator) NoEndpoints(exists bool) { - et.log.Debugf("NoEndpoints(%+v)", exists) - - et.availableEndpoints.Addresses = map[watcher.ID]watcher.Address{} - et.filteredSnapshot.Addresses = map[watcher.ID]watcher.Address{} - - u := &pb.Update{ - Update: &pb.Update_NoEndpoints{ - NoEndpoints: &pb.NoEndpoints{ - Exists: exists, - }, - }, - } - - et.log.Debugf("Sending destination no endpoints: %+v", u) - if err := et.stream.Send(u); err != nil { - et.log.Debugf("Failed to send address update: %s", err) - } -} - func (et *endpointTranslator) sendClientAdd(set watcher.AddressSet) { addrs := []*pb.WeightedAddr{} for _, address := range set.Addresses { diff --git a/controller/api/destination/endpoint_translator_test.go b/controller/api/destination/endpoint_translator_test.go index 71f094bf96607..f2bdfb6d3f0f5 100644 --- a/controller/api/destination/endpoint_translator_test.go +++ b/controller/api/destination/endpoint_translator_test.go @@ -172,28 +172,37 @@ var ( func TestEndpointTranslatorForRemoteGateways(t *testing.T) { t.Run("Sends one update for add and another for remove", func(t *testing.T) { mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForServices(remoteGateway1, remoteGateway2)) translator.Remove(mkAddressSetForServices(remoteGateway2)) expectedNumUpdates := 2 - actualNumUpdates := len(mockGetServer.updatesReceived) - if actualNumUpdates != expectedNumUpdates { - t.Fatalf("Expecting [%d] updates, got [%d]. Updates: %v", expectedNumUpdates, actualNumUpdates, mockGetServer.updatesReceived) + <-mockGetServer.updatesReceived // Add + <-mockGetServer.updatesReceived // Remove + + if len(mockGetServer.updatesReceived) != 0 { + t.Fatalf("Expecting [%d] updates, got [%d].", expectedNumUpdates, expectedNumUpdates+len(mockGetServer.updatesReceived)) } }) t.Run("Recovers after emptying address et", func(t *testing.T) { mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForServices(remoteGateway1)) translator.Remove(mkAddressSetForServices(remoteGateway1)) translator.Add(mkAddressSetForServices(remoteGateway1)) expectedNumUpdates := 3 - actualNumUpdates := len(mockGetServer.updatesReceived) - if actualNumUpdates != expectedNumUpdates { - t.Fatalf("Expecting [%d] updates, got [%d]. Updates: %v", expectedNumUpdates, actualNumUpdates, mockGetServer.updatesReceived) + <-mockGetServer.updatesReceived // Add + <-mockGetServer.updatesReceived // Remove + <-mockGetServer.updatesReceived // Add + + if len(mockGetServer.updatesReceived) != 0 { + t.Fatalf("Expecting [%d] updates, got [%d].", expectedNumUpdates, expectedNumUpdates+len(mockGetServer.updatesReceived)) } }) @@ -209,10 +218,12 @@ func TestEndpointTranslatorForRemoteGateways(t *testing.T) { } mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForServices(remoteGateway2)) - addrs := mockGetServer.updatesReceived[0].GetAdd().GetAddrs() + addrs := (<-mockGetServer.updatesReceived).GetAdd().GetAddrs() if len(addrs) != 1 { t.Fatalf("Expected [1] address returned, got %v", addrs) } @@ -244,10 +255,12 @@ func TestEndpointTranslatorForRemoteGateways(t *testing.T) { } mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForServices(remoteGatewayAuthOverride)) - addrs := mockGetServer.updatesReceived[0].GetAdd().GetAddrs() + addrs := (<-mockGetServer.updatesReceived).GetAdd().GetAddrs() if len(addrs) != 1 { t.Fatalf("Expected [1] address returned, got %v", addrs) } @@ -270,10 +283,12 @@ func TestEndpointTranslatorForRemoteGateways(t *testing.T) { t.Run("Does not send TlsIdentity when not present", func(t *testing.T) { mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForServices(remoteGateway1)) - addrs := mockGetServer.updatesReceived[0].GetAdd().GetAddrs() + addrs := (<-mockGetServer.updatesReceived).GetAdd().GetAddrs() if len(addrs) != 1 { t.Fatalf("Expected [1] address returned, got %v", addrs) } @@ -291,31 +306,37 @@ func TestEndpointTranslatorForRemoteGateways(t *testing.T) { func TestEndpointTranslatorForPods(t *testing.T) { t.Run("Sends one update for add and another for remove", func(t *testing.T) { mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForPods(pod1, pod2)) translator.Remove(mkAddressSetForPods(pod2)) expectedNumUpdates := 2 - actualNumUpdates := len(mockGetServer.updatesReceived) - if actualNumUpdates != expectedNumUpdates { - t.Fatalf("Expecting [%d] updates, got [%d]. Updates: %v", expectedNumUpdates, actualNumUpdates, mockGetServer.updatesReceived) + <-mockGetServer.updatesReceived // Add + <-mockGetServer.updatesReceived // Remove + + if len(mockGetServer.updatesReceived) != 0 { + t.Fatalf("Expecting [%d] updates, got [%d].", expectedNumUpdates, expectedNumUpdates+len(mockGetServer.updatesReceived)) } }) t.Run("Sends addresses as removed or added", func(t *testing.T) { mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForPods(pod1, pod2, pod3)) translator.Remove(mkAddressSetForPods(pod3)) - addressesAdded := mockGetServer.updatesReceived[0].GetAdd().Addrs + addressesAdded := (<-mockGetServer.updatesReceived).GetAdd().Addrs actualNumberOfAdded := len(addressesAdded) expectedNumberOfAdded := 3 if actualNumberOfAdded != expectedNumberOfAdded { t.Fatalf("Expecting [%d] addresses to be added, got [%d]: %v", expectedNumberOfAdded, actualNumberOfAdded, addressesAdded) } - addressesRemoved := mockGetServer.updatesReceived[1].GetRemove().Addrs + addressesRemoved := (<-mockGetServer.updatesReceived).GetRemove().Addrs actualNumberOfRemoved := len(addressesRemoved) expectedNumberOfRemoved := 1 if actualNumberOfRemoved != expectedNumberOfRemoved { @@ -332,16 +353,20 @@ func TestEndpointTranslatorForPods(t *testing.T) { t.Run("Sends metric labels with added addresses", func(t *testing.T) { mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForPods(pod1)) - actualGlobalMetricLabels := mockGetServer.updatesReceived[0].GetAdd().MetricLabels + update := <-mockGetServer.updatesReceived + + actualGlobalMetricLabels := update.GetAdd().MetricLabels expectedGlobalMetricLabels := map[string]string{"namespace": "service-ns", "service": "service-name"} if diff := deep.Equal(actualGlobalMetricLabels, expectedGlobalMetricLabels); diff != nil { t.Fatalf("Expected global metric labels sent to be [%v] but was [%v]", expectedGlobalMetricLabels, actualGlobalMetricLabels) } - actualAddedAddress1MetricLabels := mockGetServer.updatesReceived[0].GetAdd().Addrs[0].MetricLabels + actualAddedAddress1MetricLabels := update.GetAdd().Addrs[0].MetricLabels expectedAddedAddress1MetricLabels := map[string]string{ "pod": "pod1", "replicationcontroller": "rc-name", @@ -359,10 +384,12 @@ func TestEndpointTranslatorForPods(t *testing.T) { } mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForPods(pod1)) - addrs := mockGetServer.updatesReceived[0].GetAdd().GetAddrs() + addrs := (<-mockGetServer.updatesReceived).GetAdd().GetAddrs() if len(addrs) != 1 { t.Fatalf("Expected [1] address returned, got %v", addrs) } @@ -384,10 +411,12 @@ func TestEndpointTranslatorForPods(t *testing.T) { } mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForServices(podOpaque)) - addrs := mockGetServer.updatesReceived[0].GetAdd().GetAddrs() + addrs := (<-mockGetServer.updatesReceived).GetAdd().GetAddrs() if len(addrs) != 1 { t.Fatalf("Expected [1] address returned, got %v", addrs) } @@ -402,6 +431,8 @@ func TestEndpointTranslatorForPods(t *testing.T) { func TestEndpointTranslatorForZonedAddresses(t *testing.T) { t.Run("Sends one update for add and none for remove", func(t *testing.T) { mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() translator.Add(mkAddressSetForServices(west1aAddress, west1bAddress)) translator.Remove(mkAddressSetForServices(west1bAddress)) @@ -410,9 +441,10 @@ func TestEndpointTranslatorForZonedAddresses(t *testing.T) { // that when we try to remove the address meant for west-1b there // should be no remove update. expectedNumUpdates := 1 - actualNumUpdates := len(mockGetServer.updatesReceived) - if actualNumUpdates != expectedNumUpdates { - t.Fatalf("Expecting [%d] updates, got [%d]. Updates: %v", expectedNumUpdates, actualNumUpdates, mockGetServer.updatesReceived) + <-mockGetServer.updatesReceived // Add + + if len(mockGetServer.updatesReceived) != 0 { + t.Fatalf("Expecting [%d] updates, got [%d].", expectedNumUpdates, expectedNumUpdates+len(mockGetServer.updatesReceived)) } }) } @@ -420,6 +452,8 @@ func TestEndpointTranslatorForZonedAddresses(t *testing.T) { func TestEndpointTranslatorForLocalTrafficPolicy(t *testing.T) { t.Run("Sends one update for add and none for remove", func(t *testing.T) { mockGetServer, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() addressSet := mkAddressSetForServices(AddressOnTest123Node, AddressNotOnTest123Node) addressSet.LocalTrafficPolicy = true translator.Add(addressSet) @@ -429,9 +463,10 @@ func TestEndpointTranslatorForLocalTrafficPolicy(t *testing.T) { // that when we try to remove the address meant for AddressNotOnTest123Node there // should be no remove update. expectedNumUpdates := 1 - actualNumUpdates := len(mockGetServer.updatesReceived) - if actualNumUpdates != expectedNumUpdates { - t.Fatalf("Expecting [%d] updates, got [%d]. Updates: %v", expectedNumUpdates, actualNumUpdates, mockGetServer.updatesReceived) + <-mockGetServer.updatesReceived // Add + + if len(mockGetServer.updatesReceived) != 0 { + t.Fatalf("Expecting [%d] updates, got [%d].", expectedNumUpdates, expectedNumUpdates+len(mockGetServer.updatesReceived)) } }) } @@ -439,6 +474,8 @@ func TestEndpointTranslatorForLocalTrafficPolicy(t *testing.T) { // TestConcurrency, to be triggered with `go test -race`, shouldn't report a race condition func TestConcurrency(t *testing.T) { _, translator := makeEndpointTranslator(t) + translator.Start() + defer translator.Stop() var wg sync.WaitGroup for i := 0; i < 10; i++ { diff --git a/controller/api/destination/server.go b/controller/api/destination/server.go index af92cb5f3a6bc..2fb637b416567 100644 --- a/controller/api/destination/server.go +++ b/controller/api/destination/server.go @@ -133,6 +133,8 @@ func (s *server) Get(dest *pb.GetDestination, stream pb.Destination_GetServer) e } log.Debugf("Get %s", dest.GetPath()) + streamEnd := make(chan struct{}) + var token contextToken if dest.GetContextToken() != "" { token = s.parseContextToken(dest.GetContextToken()) @@ -189,8 +191,12 @@ func (s *server) Get(dest *pb.GetDestination, stream pb.Destination_GetServer) e false, // Disable endpoint filtering for remote discovery. s.metadataAPI, stream, + streamEnd, log, ) + translator.Start() + defer translator.Stop() + err = remoteWatcher.Subscribe(watcher.ServiceID{Namespace: service.Namespace, Name: remoteSvc}, port, instanceID, translator) if err != nil { var ise watcher.InvalidService @@ -215,8 +221,11 @@ func (s *server) Get(dest *pb.GetDestination, stream pb.Destination_GetServer) e true, s.metadataAPI, stream, + streamEnd, log, ) + translator.Start() + defer translator.Stop() err = s.endpoints.Subscribe(service, port, instanceID, translator) if err != nil { @@ -235,6 +244,8 @@ func (s *server) Get(dest *pb.GetDestination, stream pb.Destination_GetServer) e case <-s.shutdown: case <-stream.Context().Done(): log.Debugf("Get %s cancelled", dest.GetPath()) + case <-streamEnd: + log.Errorf("Get %s stream aborted", dest.GetPath()) } return nil diff --git a/controller/api/destination/server_test.go b/controller/api/destination/server_test.go index 20307c6392d57..dc4741c241c23 100644 --- a/controller/api/destination/server_test.go +++ b/controller/api/destination/server_test.go @@ -42,9 +42,10 @@ const skippedPort uint32 = 24224 func TestGet(t *testing.T) { t.Run("Returns error if not valid service name", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() stream := &bufferingGetStream{ - updates: []*pb.Update{}, + updates: make(chan *pb.Update, 50), MockServerStream: util.NewMockServerStream(), } @@ -52,113 +53,130 @@ func TestGet(t *testing.T) { if err == nil { t.Fatalf("Expecting error, got nothing") } - - server.clusterStore.UnregisterGauges() }) t.Run("Returns endpoints", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() stream := &bufferingGetStream{ - updates: []*pb.Update{}, + updates: make(chan *pb.Update, 50), MockServerStream: util.NewMockServerStream(), } + defer stream.Cancel() + errs := make(chan error) + + // server.Get blocks until the grpc stream is complete so we call it + // in a goroutine and watch stream.updates for updates. + go func() { + err := server.Get(&pb.GetDestination{Scheme: "k8s", Path: fmt.Sprintf("%s:%d", fullyQualifiedName, port)}, stream) + if err != nil { + errs <- err + } + }() - // We cancel the stream before even sending the request so that we don't - // need to call server.Get in a separate goroutine. By preemptively - // cancelling, the behavior of Get becomes effectively synchronous and - // we will get only the initial update, which is what we want for this - // test. - stream.Cancel() + select { + case update := <-stream.updates: + if updateAddAddress(t, update)[0] != fmt.Sprintf("%s:%d", podIP1, port) { + t.Fatalf("Expected %s but got %s", fmt.Sprintf("%s:%d", podIP1, port), updateAddAddress(t, update)[0]) + } - err := server.Get(&pb.GetDestination{Scheme: "k8s", Path: fmt.Sprintf("%s:%d", fullyQualifiedName, port)}, stream) - if err != nil { + if len(stream.updates) != 0 { + t.Fatalf("Expected 1 update but got %d: %v", 1+len(stream.updates), stream.updates) + } + case err := <-errs: t.Fatalf("Got error: %s", err) } - - if len(stream.updates) != 1 { - t.Fatalf("Expected 1 update but got %d: %v", len(stream.updates), stream.updates) - } - - if updateAddAddress(t, stream.updates[0])[0] != fmt.Sprintf("%s:%d", podIP1, port) { - t.Fatalf("Expected %s but got %s", fmt.Sprintf("%s:%d", podIP1, port), updateAddAddress(t, stream.updates[0])[0]) - } - - server.clusterStore.UnregisterGauges() }) t.Run("Return endpoint with unknown protocol hint and identity when service name contains skipped inbound port", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := &bufferingGetStream{ - updates: []*pb.Update{}, + updates: make(chan *pb.Update, 50), MockServerStream: util.NewMockServerStream(), } - stream.Cancel() + defer stream.Cancel() + errs := make(chan error) path := fmt.Sprintf("%s:%d", fullyQualifiedNameSkipped, skippedPort) - err := server.Get(&pb.GetDestination{ - Scheme: "k8s", - Path: path, - }, stream) - if err != nil { - t.Fatalf("Got error: %s", err) - } - update := assertSingleUpdate(t, stream.updates) - addrs := update.GetAdd().Addrs - if len(addrs) == 0 { - t.Fatalf("Expected len(addrs) to be > 0") - } + // server.Get blocks until the grpc stream is complete so we call it + // in a goroutine and watch stream.updates for updates. + go func() { + err := server.Get(&pb.GetDestination{ + Scheme: "k8s", + Path: path, + }, stream) + if err != nil { + errs <- err + } + }() - if addrs[0].GetProtocolHint().GetProtocol() != nil || addrs[0].GetProtocolHint().GetOpaqueTransport() != nil { - t.Fatalf("Expected protocol hint for %s to be nil but got %+v", path, addrs[0].ProtocolHint) - } + select { + case update := <-stream.updates: + addrs := update.GetAdd().Addrs + if len(addrs) == 0 { + t.Fatalf("Expected len(addrs) to be > 0") + } - if addrs[0].TlsIdentity != nil { - t.Fatalf("Expected TLS identity for %s to be nil but got %+v", path, addrs[0].TlsIdentity) - } + if addrs[0].GetProtocolHint().GetProtocol() != nil || addrs[0].GetProtocolHint().GetOpaqueTransport() != nil { + t.Fatalf("Expected protocol hint for %s to be nil but got %+v", path, addrs[0].ProtocolHint) + } - server.clusterStore.UnregisterGauges() + if addrs[0].TlsIdentity != nil { + t.Fatalf("Expected TLS identity for %s to be nil but got %+v", path, addrs[0].TlsIdentity) + } + case err := <-errs: + t.Fatalf("Got error: %s", err) + } }) t.Run("Remote discovery", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() // Wait for cluster store to be synced. time.Sleep(50 * time.Millisecond) stream := &bufferingGetStream{ - updates: []*pb.Update{}, + updates: make(chan *pb.Update, 50), MockServerStream: util.NewMockServerStream(), } + defer stream.Cancel() + errs := make(chan error) + + // server.Get blocks until the grpc stream is complete so we call it + // in a goroutine and watch stream.updates for updates. + go func() { + err := server.Get(&pb.GetDestination{Scheme: "k8s", Path: fmt.Sprintf("%s:%d", "foo-target.ns.svc.mycluster.local", 80)}, stream) + if err != nil { + errs <- err + } + }() - // We cancel the stream before even sending the request so that we don't - // need to call server.Get in a separate goroutine. By preemptively - // cancelling, the behavior of Get becomes effectively synchronous and - // we will get only the initial update, which is what we want for this - // test. - stream.Cancel() - - err := server.Get(&pb.GetDestination{Scheme: "k8s", Path: fmt.Sprintf("%s:%d", "foo-target.ns.svc.mycluster.local", 80)}, stream) - if err != nil { - t.Fatalf("Got error: %s", err) - } + select { + case update := <-stream.updates: + if updateAddAddress(t, update)[0] != fmt.Sprintf("%s:%d", "172.17.55.1", 80) { + t.Fatalf("Expected %s but got %s", fmt.Sprintf("%s:%d", podIP1, port), updateAddAddress(t, update)[0]) + } - if len(stream.updates) != 1 { - t.Fatalf("Expected 1 update but got %d: %v", len(stream.updates), stream.updates) - } + if len(stream.updates) != 0 { + t.Fatalf("Expected 1 update but got %d: %v", 1+len(stream.updates), stream.updates) + } - if updateAddAddress(t, stream.updates[0])[0] != fmt.Sprintf("%s:%d", "172.17.55.1", 80) { - t.Fatalf("Expected %s but got %s", fmt.Sprintf("%s:%d", podIP1, port), updateAddAddress(t, stream.updates[0])[0]) + case err := <-errs: + t.Fatalf("Got error: %s", err) } - - server.clusterStore.UnregisterGauges() }) } func TestGetProfiles(t *testing.T) { t.Run("Returns error if not valid service name", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := &bufferingGetProfileStream{ updates: []*pb.DestinationProfile{}, MockServerStream: util.NewMockServerStream(), @@ -168,12 +186,12 @@ func TestGetProfiles(t *testing.T) { if err == nil { t.Fatalf("Expecting error, got nothing") } - - server.clusterStore.UnregisterGauges() }) t.Run("Returns server profile", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, fullyQualifiedName, port, "ns:other") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -188,12 +206,12 @@ func TestGetProfiles(t *testing.T) { if len(routes) != 1 { t.Fatalf("Expected 0 routes but got %d: %v", len(routes), routes) } - - server.clusterStore.UnregisterGauges() }) t.Run("Return service profile when using json token", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, fullyQualifiedName, port, `{"ns":"other"}`) defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -204,12 +222,12 @@ func TestGetProfiles(t *testing.T) { if len(routes) != 1 { t.Fatalf("Expected 1 route got %d: %v", len(routes), routes) } - - server.clusterStore.UnregisterGauges() }) t.Run("Returns client profile", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, fullyQualifiedName, port, `{"ns":"client-ns"}`) defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -220,12 +238,12 @@ func TestGetProfiles(t *testing.T) { if !routes[0].GetIsRetryable() { t.Fatalf("Expected route to be retryable, but it was not") } - - server.clusterStore.UnregisterGauges() }) t.Run("Return profile when using cluster IP", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, clusterIP, port, "") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -239,12 +257,12 @@ func TestGetProfiles(t *testing.T) { if len(routes) != 1 { t.Fatalf("Expected 1 route but got %d: %v", len(routes), routes) } - - server.clusterStore.UnregisterGauges() }) t.Run("Return profile with endpoint when using pod DNS", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, fullyQualifiedPodDNS, port, "ns:ns") defer stream.Cancel() @@ -280,12 +298,12 @@ func TestGetProfiles(t *testing.T) { if first.Endpoint.Addr.String() != epAddr.String() { t.Fatalf("Expected endpoint IP to be %s, but it was %s", epAddr.Ip, first.Endpoint.Addr.Ip) } - - server.clusterStore.UnregisterGauges() }) t.Run("Return profile with endpoint when using pod IP", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, podIP1, port, "ns:ns") defer stream.Cancel() @@ -321,24 +339,24 @@ func TestGetProfiles(t *testing.T) { if first.Endpoint.Addr.String() != epAddr.String() { t.Fatalf("Expected endpoint IP to be %s, but it was %s", epAddr.Ip, first.Endpoint.Addr.Ip) } - - server.clusterStore.UnregisterGauges() }) t.Run("Return default profile when IP does not map to service or pod", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, "172.0.0.0", 1234, "") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) if profile.RetryBudget == nil { t.Fatalf("Expected default profile to have a retry budget") } - - server.clusterStore.UnregisterGauges() }) t.Run("Return profile with no protocol hint when pod does not have label", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, podIP2, port, "") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -348,12 +366,12 @@ func TestGetProfiles(t *testing.T) { if profile.Endpoint.GetProtocolHint().GetProtocol() != nil || profile.Endpoint.GetProtocolHint().GetOpaqueTransport() != nil { t.Fatalf("Expected no protocol hint but found one") } - - server.clusterStore.UnregisterGauges() }) t.Run("Return non-opaque protocol profile when using cluster IP and opaque protocol port", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, clusterIPOpaque, opaquePort, "") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -363,12 +381,12 @@ func TestGetProfiles(t *testing.T) { if profile.OpaqueProtocol { t.Fatalf("Expected port %d to not be an opaque protocol, but it was", opaquePort) } - - server.clusterStore.UnregisterGauges() }) t.Run("Return opaque protocol profile with endpoint when using pod IP and opaque protocol port", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, podIPOpaque, opaquePort, "") defer stream.Cancel() @@ -404,12 +422,12 @@ func TestGetProfiles(t *testing.T) { if profile.Endpoint.Addr.String() != epAddr.String() { t.Fatalf("Expected endpoint IP port to be %d, but it was %d", epAddr.Port, profile.Endpoint.Addr.Port) } - - server.clusterStore.UnregisterGauges() }) t.Run("Return opaque protocol profile when using service name with opaque port annotation", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, fullyQualifiedNameOpaqueService, opaquePort, "") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -419,12 +437,12 @@ func TestGetProfiles(t *testing.T) { if !profile.OpaqueProtocol { t.Fatalf("Expected port %d to be an opaque protocol, but it was not", opaquePort) } - - server.clusterStore.UnregisterGauges() }) t.Run("Return profile with unknown protocol hint and identity when pod contains skipped inbound port", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, podIPSkipped, skippedPort, "") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -438,12 +456,12 @@ func TestGetProfiles(t *testing.T) { if addr.TlsIdentity != nil { t.Fatalf("Expected TLS identity for %s to be nil but got %+v", podIPSkipped, addr.TlsIdentity) } - - server.clusterStore.UnregisterGauges() }) t.Run("Return profile with opaque protocol when using Pod IP selected by a Server", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, podIPPolicy, 80, "") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -459,12 +477,12 @@ func TestGetProfiles(t *testing.T) { if profile.Endpoint.ProtocolHint.GetOpaqueTransport().GetInboundPort() != 4143 { t.Fatalf("Expected pod to support opaque traffic on port 4143") } - - server.clusterStore.UnregisterGauges() }) t.Run("Return profile with opaque protocol when using an opaque port with an external IP", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, externalIP, 3306, "") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) @@ -472,25 +490,26 @@ func TestGetProfiles(t *testing.T) { t.Fatalf("Expected port %d to be an opaque protocol, but it was not", 3306) } - server.clusterStore.UnregisterGauges() }) t.Run("Return profile with non-opaque protocol when using an arbitrary port with an external IP", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, externalIP, 80, "") defer stream.Cancel() profile := assertSingleProfile(t, stream.Updates()) if profile.OpaqueProtocol { t.Fatalf("Expected port %d to be a non-opaque protocol, but it was opaque", 80) } - - server.clusterStore.UnregisterGauges() }) t.Run("Return profile for host port pods", func(t *testing.T) { hostPort := uint32(7777) containerPort := uint32(80) server, l5dClient := getServerWithClient(t) + defer server.clusterStore.UnregisterGauges() + stream := profileStream(t, server, externalIP, hostPort, "") defer stream.Cancel() @@ -637,14 +656,14 @@ func TestGetProfiles(t *testing.T) { if !profile.OpaqueProtocol { t.Fatal("Expected OpaqueProtocol=true") } - - server.clusterStore.UnregisterGauges() }) } func TestTokenStructure(t *testing.T) { t.Run("when JSON is valid", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + dest := &pb.GetDestination{ContextToken: "{\"ns\":\"ns-1\",\"nodeName\":\"node-1\"}\n"} token := server.parseContextToken(dest.ContextToken) @@ -655,30 +674,28 @@ func TestTokenStructure(t *testing.T) { if token.NodeName != "node-1" { t.Fatalf("Expected token nodeName to be %s got %s", "node-1", token.NodeName) } - - server.clusterStore.UnregisterGauges() }) t.Run("when JSON is invalid and old token format used", func(t *testing.T) { server := makeServer(t) + defer server.clusterStore.UnregisterGauges() + dest := &pb.GetDestination{ContextToken: "ns:ns-2"} token := server.parseContextToken(dest.ContextToken) if token.Ns != "ns-2" { t.Fatalf("Expected %s got %s", "ns-2", token.Ns) } - - server.clusterStore.UnregisterGauges() }) t.Run("when invalid JSON and invalid old format", func(t *testing.T) { server := makeServer(t) + server.clusterStore.UnregisterGauges() + dest := &pb.GetDestination{ContextToken: "123fa-test"} token := server.parseContextToken(dest.ContextToken) if token.Ns != "" || token.NodeName != "" { t.Fatalf("Expected context token to be empty, got %v", token) } - - server.clusterStore.UnregisterGauges() }) } @@ -772,18 +789,6 @@ func assertSingleProfile(t *testing.T, updates []*pb.DestinationProfile) *pb.Des return updates[0] } -func assertSingleUpdate(t *testing.T, updates []*pb.Update) *pb.Update { - t.Helper() - // Under normal conditions the creation of resources by the fake API will - // generate notifications that are discarded after the stream.Cancel() call, - // but very rarely those notifications might come after, in which case we'll - // get a second update. - if len(updates) == 0 || len(updates) > 2 { - t.Fatalf("Expected 1 or 2 updates but got %d: %v", len(updates), updates) - } - return updates[0] -} - func profileStream(t *testing.T, server *server, host string, port uint32, token string) *bufferingGetProfileStream { t.Helper() diff --git a/controller/api/destination/test_util.go b/controller/api/destination/test_util.go index b8a9b5b86ea65..1ab71eb1ad5ff 100644 --- a/controller/api/destination/test_util.go +++ b/controller/api/destination/test_util.go @@ -523,12 +523,12 @@ spec: } type bufferingGetStream struct { - updates []*pb.Update + updates chan *pb.Update util.MockServerStream } func (bgs *bufferingGetStream) Send(update *pb.Update) error { - bgs.updates = append(bgs.updates, update) + bgs.updates <- update return nil } @@ -553,11 +553,11 @@ func (bgps *bufferingGetProfileStream) Updates() []*pb.DestinationProfile { type mockDestinationGetServer struct { util.MockServerStream - updatesReceived []*pb.Update + updatesReceived chan *pb.Update } func (m *mockDestinationGetServer) Send(update *pb.Update) error { - m.updatesReceived = append(m.updatesReceived, update) + m.updatesReceived <- update return nil } @@ -600,7 +600,7 @@ metadata: } metadataAPI.Sync(nil) - mockGetServer := &mockDestinationGetServer{updatesReceived: []*pb.Update{}} + mockGetServer := &mockDestinationGetServer{updatesReceived: make(chan *pb.Update, 50)} translator := newEndpointTranslator( "linkerd", "trust.domain", @@ -611,6 +611,7 @@ metadata: true, metadataAPI, mockGetServer, + nil, logging.WithField("test", t.Name()), ) return mockGetServer, translator