From b6d07a2fefedb7a312c21643d8fada19fc66321d Mon Sep 17 00:00:00 2001 From: Aliaksandr Stsiapanay Date: Fri, 31 Jan 2025 18:22:55 +0300 Subject: [PATCH] fix: RandomizedWeightedBalancer chooses the first upstream with higher probability than others #666 (#667) --- .../server/upstream/RandomizedWeightedBalancer.java | 8 +++++--- .../upstream/RandomizedWeightedBalancerTest.java | 10 +++++----- .../core/server/upstream/TieredBalancerTest.java | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancer.java b/server/src/main/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancer.java index f4738ead..5c453fe9 100644 --- a/server/src/main/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancer.java +++ b/server/src/main/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancer.java @@ -50,16 +50,18 @@ public Upstream next() { if (availableUpstreams.isEmpty()) { return null; } + if (availableUpstreams.size() == 1) { + return availableUpstreams.get(0); + } int total = availableUpstreams.stream().map(Upstream::getWeight).reduce(0, Integer::sum); - // make sure the upper bound `total` is inclusive - int random = generator.nextInt(total + 1); + int random = generator.nextInt(total); int current = 0; Upstream result = null; for (Upstream upstream : availableUpstreams) { current += upstream.getWeight(); - if (current >= random) { + if (current > random) { result = upstream; break; } diff --git a/server/src/test/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancerTest.java b/server/src/test/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancerTest.java index c2d41559..e00fb914 100644 --- a/server/src/test/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancerTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancerTest.java @@ -19,7 +19,7 @@ public class RandomizedWeightedBalancerTest { @Mock private Random generator; - + @Test void testWeightedLoadBalancer() { List upstreams = List.of( @@ -31,25 +31,25 @@ void testWeightedLoadBalancer() { RandomizedWeightedBalancer balancer = new RandomizedWeightedBalancer("model1", upstreams, generator); - when(generator.nextInt(11)).thenReturn(0); + when(generator.nextInt(10)).thenReturn(0); Upstream upstream = balancer.next(); assertNotNull(upstream); assertEquals(upstreams.get(0), upstream); - when(generator.nextInt(11)).thenReturn(2); + when(generator.nextInt(10)).thenReturn(2); upstream = balancer.next(); assertNotNull(upstream); assertEquals(upstreams.get(1), upstream); - when(generator.nextInt(11)).thenReturn(6); + when(generator.nextInt(10)).thenReturn(5); upstream = balancer.next(); assertNotNull(upstream); assertEquals(upstreams.get(2), upstream); - when(generator.nextInt(11)).thenReturn(10); + when(generator.nextInt(10)).thenReturn(9); upstream = balancer.next(); assertNotNull(upstream); diff --git a/server/src/test/java/com/epam/aidial/core/server/upstream/TieredBalancerTest.java b/server/src/test/java/com/epam/aidial/core/server/upstream/TieredBalancerTest.java index 5dc7b0b4..84bb26fe 100644 --- a/server/src/test/java/com/epam/aidial/core/server/upstream/TieredBalancerTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/upstream/TieredBalancerTest.java @@ -110,8 +110,8 @@ void testUpstreamFallback() { .map(index -> new Upstream("endpoint" + index, null, null, 1, 1)) .toList(); model.setUpstreams(upstreams); - AtomicInteger counter = new AtomicInteger(); - when(generator.nextInt(5)).thenAnswer(cb -> counter.incrementAndGet()); + AtomicInteger counter = new AtomicInteger(-1); + when(generator.nextInt(4)).thenAnswer(cb -> counter.incrementAndGet()); Supplier factory = () -> generator; UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx, factory);