From 03d81122b855fba8dd4658fab35259926feaa828 Mon Sep 17 00:00:00 2001 From: mazhen Date: Fri, 22 Dec 2023 11:12:25 +0800 Subject: [PATCH] Improve Thread Pool Management in VirtualThreadExecutorService --- .../VirtualThreadExecutorService.java | 35 +++++++++++++------ .../VirtualThreadExecutorServiceTest.java | 29 +++++++++++---- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/modules/grizzly/src/main/java/org/glassfish/grizzly/threadpool/VirtualThreadExecutorService.java b/modules/grizzly/src/main/java/org/glassfish/grizzly/threadpool/VirtualThreadExecutorService.java index a751531cd..f98b91703 100644 --- a/modules/grizzly/src/main/java/org/glassfish/grizzly/threadpool/VirtualThreadExecutorService.java +++ b/modules/grizzly/src/main/java/org/glassfish/grizzly/threadpool/VirtualThreadExecutorService.java @@ -23,7 +23,8 @@ public class VirtualThreadExecutorService extends AbstractExecutorService implem private static final Logger logger = Grizzly.logger(VirtualThreadExecutorService.class); private final ExecutorService internalExecutorService; - private Semaphore poolSemaphore; + private final Semaphore poolSemaphore; + private final Semaphore queueSemaphore; public static VirtualThreadExecutorService createInstance() { return createInstance(ThreadPoolConfig.defaultConfig().setMaxPoolSize(-1).setPoolName("Grizzly-virt-")); @@ -36,11 +37,18 @@ public static VirtualThreadExecutorService createInstance(ThreadPoolConfig cfg) protected VirtualThreadExecutorService(ThreadPoolConfig cfg) { internalExecutorService = Executors.newThreadPerTaskExecutor(getThreadFactory(cfg)); - if (cfg.getMaxPoolSize() > 0) { - poolSemaphore = new Semaphore(cfg.getMaxPoolSize()); + + int poolSizeLimit = cfg.getMaxPoolSize() > 0 ? cfg.getMaxPoolSize() : Integer.MAX_VALUE; + int queueLimit = cfg.getQueueLimit() >= 0 ? cfg.getQueueLimit() : Integer.MAX_VALUE; + // Check for integer overflow + long totalLimit = (long) poolSizeLimit + (long) queueLimit; + if (totalLimit > Integer.MAX_VALUE) { + // Handle the overflow case + queueSemaphore = new Semaphore(Integer.MAX_VALUE, true); } else { - poolSemaphore = new Semaphore(Integer.MAX_VALUE); + queueSemaphore = new Semaphore((int) totalLimit, true); } + poolSemaphore = new Semaphore(poolSizeLimit, true); } private ThreadFactory getThreadFactory(ThreadPoolConfig threadPoolConfig) { @@ -90,17 +98,24 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE @Override public void execute(Runnable command) { - if (poolSemaphore.tryAcquire()) { - internalExecutorService.execute(() -> { + if (!queueSemaphore.tryAcquire()) { + throw new RejectedExecutionException("Too Many Concurrent Requests"); + } + + internalExecutorService.execute(() -> { + try { + poolSemaphore.acquire(); try { command.run(); } finally { poolSemaphore.release(); } - }); - } else { - throw new RejectedExecutionException("Too Many Concurrent Requests"); - } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + queueSemaphore.release(); + } + }); } @Override diff --git a/modules/grizzly/src/test/java/org/glassfish/grizzly/VirtualThreadExecutorServiceTest.java b/modules/grizzly/src/test/java/org/glassfish/grizzly/VirtualThreadExecutorServiceTest.java index f707a0924..e45fa9c8c 100644 --- a/modules/grizzly/src/test/java/org/glassfish/grizzly/VirtualThreadExecutorServiceTest.java +++ b/modules/grizzly/src/test/java/org/glassfish/grizzly/VirtualThreadExecutorServiceTest.java @@ -29,14 +29,18 @@ public void testAwaitTermination() throws Exception { } public void testQueueLimit() throws Exception { - int poolSize = 10; - ThreadPoolConfig config = ThreadPoolConfig.defaultConfig().setMaxPoolSize(poolSize); + int maxPoolSize = 20; + int queueLimit = 10; + int queue = maxPoolSize + queueLimit; + ThreadPoolConfig config = ThreadPoolConfig.defaultConfig() + .setMaxPoolSize(maxPoolSize) + .setQueueLimit(queueLimit); VirtualThreadExecutorService r = VirtualThreadExecutorService.createInstance(config); - CyclicBarrier start = new CyclicBarrier(poolSize + 1); - CyclicBarrier hold = new CyclicBarrier(poolSize + 1); + CyclicBarrier start = new CyclicBarrier(maxPoolSize + 1); + CyclicBarrier hold = new CyclicBarrier(maxPoolSize + 1); AtomicInteger result = new AtomicInteger(); - for (int i = 0; i < poolSize; i++) { + for (int i = 0; i < maxPoolSize; i++) { int taskId = i; r.execute(() -> { try { @@ -44,22 +48,33 @@ public void testQueueLimit() throws Exception { start.await(); hold.await(); result.getAndIncrement(); + System.out.println("task " + taskId + " is completed"); } catch (Exception e) { } }); } start.await(); + for (int i = maxPoolSize; i < queue; i++) { + int taskId = i; + r.execute(() -> { + try { + result.getAndIncrement(); + System.out.println("task " + taskId + " is completed"); + } catch (Exception e) { + } + }); + } // Too Many Concurrent Requests Assert.assertThrows(RejectedExecutionException.class, () -> r.execute(() -> System.out.println("cannot be executed"))); hold.await(); while (true) { - if (result.intValue() == poolSize) { + if (result.intValue() == queue) { System.out.println("All tasks have been completed."); break; } } // The executor can accept new tasks - doTest(r, poolSize); + doTest(r, queue); } private void doTest(VirtualThreadExecutorService r, int tasks) throws Exception {