Skip to content

Commit

Permalink
feat: Add support for switching scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Jan 17, 2025
1 parent b20ec82 commit 5776a32
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* license agreements; and to You under the Apache License, version 2.0:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* This file is part of the Apache Pekko project, which was derived from Akka.
*/

/*
* Copyright (C) 2018-2022 Lightbend Inc. <https://www.lightbend.com>
*/

package org.apache.pekko.dispatch

import com.typesafe.config.ConfigFactory

import org.apache.pekko
import pekko.actor.{ Actor, Props }
import pekko.testkit.{ ImplicitSender, PekkoSpec }
import pekko.util.JavaVersion

object ForkJoinPoolVirtualThreadSpec {
val config = ConfigFactory.parseString("""
|virtual {
| task-dispatcher {
| mailbox-type = "org.apache.pekko.dispatch.SingleConsumerOnlyUnboundedMailbox"
| throughput = 5
| fork-join-executor {
| parallelism-factor = 2
| parallelism-max = 2
| parallelism-min = 2
| virtualize = on
| }
| }
|}
""".stripMargin)

class ThreadNameActor extends Actor {

override def receive = {
case "ping" =>
sender() ! Thread.currentThread().getName
}
}

}

class ForkJoinPoolVirtualThreadSpec extends PekkoSpec(ForkJoinPoolVirtualThreadSpec.config) with ImplicitSender {
import ForkJoinPoolVirtualThreadSpec._

"PekkoForkJoinPool" must {

"support virtualization with Virtual Thread" in {
val actor = system.actorOf(Props(new ThreadNameActor).withDispatcher("virtual.task-dispatcher"))
for (_ <- 1 to 1000) {
// External task submission via the default dispatcher
actor ! "ping"
expectMsgPF() { case name: String => name should include("virtual-thread-") }
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ object VirtualThreadPoolDispatcherSpec {
class VirtualThreadPoolDispatcherSpec extends PekkoSpec(VirtualThreadPoolDispatcherSpec.config) with ImplicitSender {
import VirtualThreadPoolDispatcherSpec._

val Iterations = 1000

"VirtualThreadPool support" must {

"handle simple dispatch" in {
Expand Down
5 changes: 5 additions & 0 deletions actor/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,11 @@ pekko {
# This config is new in Pekko v1.1.0 and only has an effect if you are running with JDK 9 and above.
# Read the documentation on `java.util.concurrent.ForkJoinPool` to find out more. Default in hex is 0x7fff.
maximum-pool-size = 32767

# This config is new in Pekko v1.2.0 and only has an effect if you are running with JDK 21 and above.
# Virtualize this dispatcher as a virtual-thread-executor
# Valid values are: `on`, `off`
virtualize = off
}

# This will be used if you have set "executor = "thread-pool-executor""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,18 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer
val threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
val parallelism: Int,
val asyncMode: Boolean,
val maxPoolSize: Int)
val maxPoolSize: Int,
val virtualize: Boolean)
extends ExecutorServiceFactory {

def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
parallelism: Int,
asyncMode: Boolean) = this(threadFactory, parallelism, asyncMode, ForkJoinPoolConstants.MaxCap)
asyncMode: Boolean) = this(threadFactory, parallelism, asyncMode, ForkJoinPoolConstants.MaxCap, false)

def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
parallelism: Int,
asyncMode: Boolean,
maxPoolSize: Int) = this(threadFactory, parallelism, asyncMode, maxPoolSize, false)

private def pekkoJdk9ForkJoinPoolClassOpt: Option[Class[_]] =
Try(Class.forName("org.apache.pekko.dispatch.PekkoJdk9ForkJoinPool")).toOption
Expand All @@ -116,12 +122,19 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer
def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory, parallelism: Int) =
this(threadFactory, parallelism, asyncMode = true)

def createExecutorService: ExecutorService = pekkoJdk9ForkJoinPoolHandleOpt match {
case Some(handle) =>
handle.invoke(parallelism, threadFactory, maxPoolSize,
MonitorableThreadFactory.doNothing, asyncMode).asInstanceOf[ExecutorService]
case _ =>
new PekkoForkJoinPool(parallelism, threadFactory, MonitorableThreadFactory.doNothing, asyncMode)
def createExecutorService: ExecutorService = {
val forkJoinPool = pekkoJdk9ForkJoinPoolHandleOpt match {
case Some(handle) =>
handle.invoke(parallelism, threadFactory, maxPoolSize,
MonitorableThreadFactory.doNothing, asyncMode).asInstanceOf[ExecutorService]
case _ =>
new PekkoForkJoinPool(parallelism, threadFactory, MonitorableThreadFactory.doNothing, asyncMode)
}
if (virtualize) {
new VirtualizedExecutorService("pekko", forkJoinPool)
} else {
forkJoinPool
}
}
}

Expand Down Expand Up @@ -149,6 +162,7 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer
config.getDouble("parallelism-factor"),
config.getInt("parallelism-max")),
asyncMode,
config.getInt("maximum-pool-size"))
config.getInt("maximum-pool-size"),
config.getBoolean("virtualize"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ object ThreadPoolConfig {
* Function0 without the fun stuff (mostly for the sake of the Java API side of things)
*/
trait ExecutorServiceFactory {

/**
* Create a new ExecutorService
*/
def createExecutorService: ExecutorService
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

package org.apache.pekko.dispatch

import org.apache.pekko.annotation.InternalApi
import org.apache.pekko.util.JavaVersion
import org.apache.pekko
import pekko.annotation.InternalApi
import pekko.util.JavaVersion

import java.lang.invoke.{ MethodHandles, MethodType }
import java.util.concurrent.{ ExecutorService, ThreadFactory }
Expand All @@ -34,8 +35,7 @@ private[dispatch] object VirtualThreadSupport {
val isSupported: Boolean = JavaVersion.majorVersion >= 21

/**
* Create a virtual thread factory with a executor, the executor will be used as the scheduler of
* virtual thread.
* Create a virtual thread factory with the default Virtual Thread executor.
*/
def newVirtualThreadFactory(prefix: String): ThreadFactory = {
require(isSupported, "Virtual thread is not supported.")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.dispatch

import org.apache.pekko.annotation.InternalApi

import java.util.concurrent.{ ExecutorService, ThreadFactory }
import scala.util.control.NonFatal

/**
* TODO remove this class once we drop Java 8 support
*/
@InternalApi
private[dispatch] object VirtualThreadSupportReflect {

/**
* Create a virtual thread factory with given executor, the executor will be used as the scheduler of
* virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
def newThreadPerTaskExecutor(prefix: String, executor: ExecutorService): ExecutorService = {
val factory = virtualThreadFactory(prefix, executor)
VirtualThreadSupport.newThreadPerTaskExecutor(factory)
}

private def virtualThreadFactory(prefix: String, executor: ExecutorService): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = classOf[Thread].getDeclaredMethod("ofVirtual")
var builder = ofVirtualMethod.invoke(null)
if (executor != null) {
val clazz = builder.getClass
val field = clazz.getDeclaredField("scheduler")
field.setAccessible(true)
field.set(builder, executor)
}
val nameMethod = ofVirtualClass.getDeclaredMethod("name", classOf[String], classOf[Long])
val factoryMethod = builderClass.getDeclaredMethod("factory")
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
throw new UnsupportedOperationException("Failed to create virtual thread factory", e)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.dispatch

import org.apache.pekko.annotation.InternalApi

import java.util
import java.util.concurrent.{ Callable, ExecutorService, Future, TimeUnit }

/**
* A virtualized executor service that creates a new virtual thread for each task.
* Will shut down the underlying executor service when this executor is being shutdown.
*
* INTERNAL API
*/
@InternalApi
final class VirtualizedExecutorService(prefix: String, underlying: ExecutorService) extends ExecutorService {
require(prefix ne null, "Parameter prefix must not be null or empty")
require(underlying ne null, "Parameter underlying must not be null")

private val executor = VirtualThreadSupportReflect.newThreadPerTaskExecutor(prefix, underlying)

override def shutdown(): Unit = {
executor.shutdown()
underlying.shutdown()
}

override def shutdownNow(): util.List[Runnable] = {
executor.shutdownNow()
underlying.shutdownNow()
}

override def isShutdown: Boolean = {
executor.isShutdown || underlying.isShutdown
}

override def isTerminated: Boolean = {
executor.isTerminated && underlying.isTerminated
}

override def awaitTermination(timeout: Long, unit: TimeUnit): Boolean = {
executor.awaitTermination(timeout, unit) && underlying.awaitTermination(timeout, unit)
}

override def submit[T](task: Callable[T]): Future[T] = {
executor.submit(task)
}

override def submit[T](task: Runnable, result: T): Future[T] = {
executor.submit(task, result)
}

override def submit(task: Runnable): Future[_] = {
executor.submit(task)
}

override def invokeAll[T](tasks: util.Collection[_ <: Callable[T]]): util.List[Future[T]] = {
executor.invokeAll(tasks)
}

override def invokeAll[T](
tasks: util.Collection[_ <: Callable[T]], timeout: Long, unit: TimeUnit): util.List[Future[T]] = {
executor.invokeAll(tasks, timeout, unit)
}

override def invokeAny[T](tasks: util.Collection[_ <: Callable[T]]): T = {
executor.invokeAny(tasks)
}

override def invokeAny[T](tasks: util.Collection[_ <: Callable[T]], timeout: Long, unit: TimeUnit): T = {
executor.invokeAny(tasks, timeout, unit)
}

override def execute(command: Runnable): Unit = {
executor.execute(command)
}
}

0 comments on commit 5776a32

Please sign in to comment.