Skip to content

Commit

Permalink
feat: Add LoadMetrics support for virtual thread executor.
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Jan 19, 2025
1 parent e5d766b commit 4c3edf4
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -426,11 +426,18 @@ final class VirtualThreadExecutorConfigurator(config: Config, prerequisites: Dis
vt
}
}
case _ => VirtualThreadSupport.newVirtualThreadFactory(prerequisites.settings.name + "-" + id);
case _ => newVirtualThreadFactory(prerequisites.settings.name + "-" + id);
}
new ExecutorServiceFactory {
import VirtualThreadSupport._
override def createExecutorService: ExecutorService = newThreadPerTaskExecutor(tf)
override def createExecutorService: ExecutorService with LoadMetrics = {
val pool = getVirtualThreadDefaultScheduler // the default scheduler of virtual thread
new VirtualizedExecutorService(
tf,
pool, // the default scheduler of virtual thread
loadMetricsProvider = (_: Executor) => pool.getActiveThreadCount >= pool.getParallelism,
cascadeShutdown = false // we don't want to cascade shutdown the default virtual thread scheduler
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.pekko.annotation.InternalApi
import org.apache.pekko.util.JavaVersion

import java.lang.invoke.{ MethodHandles, MethodType }
import java.util.concurrent.{ ExecutorService, ThreadFactory }
import java.util.concurrent.{ ExecutorService, ForkJoinPool, ThreadFactory }
import scala.util.control.NonFatal

@InternalApi
Expand Down Expand Up @@ -73,4 +73,21 @@ private[dispatch] object VirtualThreadSupport {
}
}

/**
* Try to get the default scheduler of virtual thread.
*/
def getVirtualThreadDefaultScheduler: ForkJoinPool =
try {
require(isSupported, "Virtual thread is not supported.")
val clazz = Class.forName("java.lang.VirtualThread")
val fieldName = "DEFAULT_SCHEDULER"
val field = clazz.getDeclaredField(fieldName)
field.setAccessible(true)
field.get(null).asInstanceOf[ForkJoinPool]
} catch {
case NonFatal(e) =>
// --add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.dispatch

import org.apache.pekko.annotation.InternalApi

import java.util
import java.util.concurrent.{ Callable, Executor, ExecutorService, Future, ThreadFactory, TimeUnit }

/**
* A virtualized executor service that creates a new virtual thread for each task.
* Will shut down the underlying executor service when this executor is being shutdown.
*
* INTERNAL API
*/
@InternalApi
final class VirtualizedExecutorService(
vtFactory: ThreadFactory,
underlying: ExecutorService,
loadMetricsProvider: Executor => Boolean,
cascadeShutdown: Boolean)
extends ExecutorService with LoadMetrics {
require(VirtualThreadSupport.isSupported, "Virtual thread is not supported.")
require(vtFactory != null, "Virtual thread factory must not be null")
require(underlying != null, "Underlying executor service must not be null")
require(loadMetricsProvider != null, "Load metrics provider must not be null")

def this(prefix: String,
underlying: ExecutorService,
loadMetricsProvider: Executor => Boolean,
cascadeShutdown: Boolean) = {
this(VirtualThreadSupport.newVirtualThreadFactory(prefix), underlying, loadMetricsProvider, cascadeShutdown)
}

private val executor = VirtualThreadSupport.newThreadPerTaskExecutor(vtFactory)

override def atFullThrottle(): Boolean = loadMetricsProvider(this)

override def shutdown(): Unit = {
executor.shutdown()
if (cascadeShutdown) {
underlying.shutdown()
}
}

override def shutdownNow(): util.List[Runnable] = {
val r = executor.shutdownNow()
if (cascadeShutdown) {
underlying.shutdownNow()
}
r
}

override def isShutdown: Boolean = {
if (cascadeShutdown) {
executor.isShutdown && underlying.isShutdown
} else {
executor.isShutdown
}
}

override def isTerminated: Boolean = {
if (cascadeShutdown) {
executor.isTerminated && underlying.isTerminated
} else {
executor.isTerminated
}
}

override def awaitTermination(timeout: Long, unit: TimeUnit): Boolean = {
if (cascadeShutdown) {
executor.awaitTermination(timeout, unit) && underlying.awaitTermination(timeout, unit)
} else {
executor.awaitTermination(timeout, unit)
}
}

override def submit[T](task: Callable[T]): Future[T] = {
executor.submit(task)
}

override def submit[T](task: Runnable, result: T): Future[T] = {
executor.submit(task, result)
}

override def submit(task: Runnable): Future[_] = {
executor.submit(task)
}

override def invokeAll[T](tasks: util.Collection[_ <: Callable[T]]): util.List[Future[T]] = {
executor.invokeAll(tasks)
}

override def invokeAll[T](
tasks: util.Collection[_ <: Callable[T]], timeout: Long, unit: TimeUnit): util.List[Future[T]] = {
executor.invokeAll(tasks, timeout, unit)
}

override def invokeAny[T](tasks: util.Collection[_ <: Callable[T]]): T = {
executor.invokeAny(tasks)
}

override def invokeAny[T](tasks: util.Collection[_ <: Callable[T]], timeout: Long, unit: TimeUnit): T = {
executor.invokeAny(tasks, timeout, unit)
}

override def execute(command: Runnable): Unit = {
executor.execute(command)
}
}
2 changes: 2 additions & 0 deletions project/JdkOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ object JdkOptions extends AutoPlugin {

lazy val versionSpecificJavaOptions =
if (isJdk17orHigher) {
// for virtual threads
"--add-opens=java.base/java.lang=ALL-UNNAMED" ::
// for aeron
"--add-opens=java.base/sun.nio.ch=ALL-UNNAMED" ::
// for LevelDB
Expand Down

0 comments on commit 4c3edf4

Please sign in to comment.