From 62a68b4785a9d6effc3b2a2368bca9e4b38768f1 Mon Sep 17 00:00:00 2001 From: Mateusz Charytoniuk Date: Mon, 13 May 2024 22:52:01 +0200 Subject: [PATCH] chore: not all requests use slots --- loadbalancer/LoadBalancer.go | 12 ++++++++++-- loadbalancer/LoadBalancerRequest.go | 11 +++++++++++ loadbalancer/LoadBalancerTargetCollection.go | 13 +++++++++++-- loadbalancer/ReverseProxyServer.go | 4 +++- 4 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 loadbalancer/LoadBalancerRequest.go diff --git a/loadbalancer/LoadBalancer.go b/loadbalancer/LoadBalancer.go index 9e4a55e..e44eab7 100644 --- a/loadbalancer/LoadBalancer.go +++ b/loadbalancer/LoadBalancer.go @@ -20,8 +20,8 @@ type LoadBalancer struct { Logger hclog.Logger } -func (self *LoadBalancer) Balance(request *http.Request) (*url.URL, error) { - headTarget := self.LoadBalancerTargetCollection.GetForBalancing() +func (self *LoadBalancer) Balance(request *LoadBalancerRequest) (*url.URL, error) { + headTarget := self.GetLlamaCppTargetForRequest(request) if headTarget == nil { return nil, ErrorNoTargetsAvailable @@ -42,6 +42,14 @@ func (self *LoadBalancer) Balance(request *http.Request) (*url.URL, error) { return targetUrl, nil } +func (self *LoadBalancer) GetLlamaCppTargetForRequest(request *LoadBalancerRequest) *LlamaCppTarget { + if request.IsSlottable() { + return self.LoadBalancerTargetCollection.GetForBalancingSlot() + } + + return self.LoadBalancerTargetCollection.GetHead() +} + func (self *LoadBalancer) GetStatus() *LoadBalancerStatus { return &LoadBalancerStatus{ RegisteredTargets: self.LoadBalancerTargetCollection.Len(), diff --git a/loadbalancer/LoadBalancerRequest.go b/loadbalancer/LoadBalancerRequest.go new file mode 100644 index 0000000..3c7f72f --- /dev/null +++ b/loadbalancer/LoadBalancerRequest.go @@ -0,0 +1,11 @@ +package loadbalancer + +import "net/http" + +type LoadBalancerRequest struct { + HttpRequest *http.Request +} + +func (self *LoadBalancerRequest) IsSlottable() bool { + return self.HttpRequest.Method == "POST" && self.HttpRequest.URL.Path == "/completion" +} diff --git a/loadbalancer/LoadBalancerTargetCollection.go b/loadbalancer/LoadBalancerTargetCollection.go index ed1d7b7..aa1fd46 100644 --- a/loadbalancer/LoadBalancerTargetCollection.go +++ b/loadbalancer/LoadBalancerTargetCollection.go @@ -19,15 +19,24 @@ func (self *LoadBalancerTargetCollection) HasTargetConfiguration( return ok } -func (self *LoadBalancerTargetCollection) GetForBalancing() *LlamaCppTarget { +func (self *LoadBalancerTargetCollection) GetHead() *LlamaCppTarget { if self.targetHeap.Len() < 1 { return nil } + return (*self.targetHeap)[0] +} + +func (self *LoadBalancerTargetCollection) GetForBalancingSlot() *LlamaCppTarget { + headTarget := self.GetHead() + + if headTarget == nil { + return nil + } + self.mutex.Lock() defer self.mutex.Unlock() - headTarget := (*self.targetHeap)[0] headTarget.LlamaCppHealthStatus.SlotsIdle -= 1 heap.Fix(self.targetHeap, 0) diff --git a/loadbalancer/ReverseProxyServer.go b/loadbalancer/ReverseProxyServer.go index e504c90..d1896d6 100644 --- a/loadbalancer/ReverseProxyServer.go +++ b/loadbalancer/ReverseProxyServer.go @@ -26,7 +26,9 @@ func (self *ReverseProxyServer) Serve(serverEventsChannel chan goroutine.ResultM InferLevels: true, }), Rewrite: func(request *httputil.ProxyRequest) { - targetUrl, err := self.LoadBalancer.Balance(request.In) + targetUrl, err := self.LoadBalancer.Balance(&LoadBalancerRequest{ + HttpRequest: request.In, + }) if err != nil { serverEventsChannel <- goroutine.ResultMessage{