diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index c1f35df5e1..3261b0eefc 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: java: [ '11', '17', '21' ] - name: Tests for Java ${{ matrix.Java }} + name: Tests local for Java ${{ matrix.Java }} steps: - uses: actions/checkout@v3 - name: Setup java @@ -27,14 +27,18 @@ jobs: - name: Run tests run: | set -eux - ./mill -ikj1 --disable-ticker __.testLocal + if [ "${{ matrix.java }}" == "21" ]; then + JAVA_OPTS='--add-opens java.base/java.lang=ALL-UNNAMED -Dcask.virtual-thread.enabled=true' ./mill -ikj1 --disable-ticker __.testLocal + else + ./mill -ikj1 --disable-ticker __.testLocal + fi test-examples: runs-on: ubuntu-latest strategy: matrix: java: [ '11', '17', '21' ] - name: Tests for Java ${{ matrix.Java }} + name: Tests examples for Java ${{ matrix.Java }} steps: - uses: actions/checkout@v3 - name: Setup java @@ -45,8 +49,13 @@ jobs: - name: Run tests run: | set -eux - ./mill __.publishLocal - ./mill -ikj1 --disable-ticker testExamples + if [ "${{ matrix.java }}" == "21" ]; then + ./mill __.publishLocal + JAVA_OPTS='--add-opens java.base/java.lang=ALL-UNNAMED -Dcask.virtual-thread.enabled=true' ./mill -ikj1 --disable-ticker testExamples + else + ./mill __.publishLocal + ./mill -ikj1 --disable-ticker testExamples + fi publish-sonatype: if: github.repository == 'com-lihaoyi/cask' && contains(github.ref, 'refs/tags/') diff --git a/build.mill b/build.mill index a02f4b502f..4b375b73dc 100644 --- a/build.mill +++ b/build.mill @@ -85,12 +85,116 @@ object cask extends Cross[CaskMainModule](scalaVersions) { } } -trait LocalModule extends CrossScalaModule{ - override def millSourcePath = super.millSourcePath / "app" +trait BenchmarkModule extends CrossScalaModule { def moduleDeps = Seq(cask(crossScalaVersion)) + def ivyDeps = Agg[Dep]( + ) } +object benchmark extends Cross[BenchmarkModule](build.scalaVersions) with RunModule { + + def waitForServer(url: String, maxAttempts: Int = 120): Boolean = { + (1 to maxAttempts).exists { attempt => + try { + Thread.sleep(3000) + println("Checking server... Attempt " + attempt) + os.proc("curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", url) + .call(check = false) + .exitCode == 0 + } catch { + case _: Throwable => + Thread.sleep(3000) + false + } + } + } + + def runBenchmark() = T.command { + if (os.proc("which", "wrk").call(check = false).exitCode != 0) { + println("Error: wrk is not installed. Please install wrk first.") + sys.exit(1) + } + + val duration = "30s" + val threads = "4" + val connections = "100" + val url = "http://localhost:8080/" + + println("Testing with regular threads...") + + val projectRoot = T.workspace + println("projectRoot: " + projectRoot) + + def runMillBackground(example: String, vt:Boolean) = { + println(s"Running $example with vt: $vt") + + os.proc( + "mill", + s"example.$example.app[$scala213].run") + .spawn( + cwd = projectRoot, + env = Map("CASK_VIRTUAL_THREAD" -> vt.toString), + stdout = os.Inherit, + stderr = os.Inherit) + } + + for (example <- Seq( + "staticFilesWithLoom", + "todoDbWithLoom", + "minimalApplicationWithLoom")) { + + val regularApp = runMillBackground(example, vt = false) + println("Waiting for regular server to start...") + if (!waitForServer(url)) { + regularApp.destroy() + println("Failed to start regular server") + sys.exit(1) + } + + println("target server started, starting run benchmark with wrk") + val regularResults = os.proc("wrk", + "-t", threads, + "-c", connections, + "-d", duration, + url + ).call(stderr = os.Pipe) + regularApp.destroy() + + println("\nRegular Threads Results:") + println(regularResults.out.text()) + + Thread.sleep(1000) + println("\nTesting with virtual threads, please use Java 21+...") + val virtualApp = runMillBackground(example, vt = true) + + println("Waiting for virtual server to start...") + if (!waitForServer(url)) { + virtualApp.destroy() + println("Failed to start virtual server") + sys.exit(1) + } + + println("target server started, starting run benchmark with wrk") + val virtualResults = os.proc("wrk", + "-t", threads, + "-c", connections, + "-d", duration, + url + ).call(stderr = os.Pipe) + virtualApp.destroy() + + println("\nVirtual Threads Results:") + println(virtualResults.out.text()) + } + + } +} + +trait LocalModule extends CrossScalaModule{ + override def millSourcePath = super.millSourcePath / "app" + def moduleDeps = Seq(cask(crossScalaVersion)) +} def zippedExamples = T { val vcsState = VcsVersion.vcsState() @@ -111,13 +215,16 @@ def zippedExamples = T { build.example.httpMethods.millSourcePath, build.example.minimalApplication.millSourcePath, build.example.minimalApplication2.millSourcePath, + build.example.minimalApplicationWithLoom.millSourcePath, build.example.redirectAbort.millSourcePath, build.example.scalatags.millSourcePath, build.example.staticFiles.millSourcePath, + build.example.staticFilesWithLoom.millSourcePath, build.example.staticFiles2.millSourcePath, build.example.todo.millSourcePath, build.example.todoApi.millSourcePath, build.example.todoDb.millSourcePath, + build.example.todoDbWithLoom.millSourcePath, build.example.twirl.millSourcePath, build.example.variableRoutes.millSourcePath, build.example.queryParams.millSourcePath, diff --git a/cask/src/cask/internal/ThreadBlockingHandler.scala b/cask/src/cask/internal/ThreadBlockingHandler.scala new file mode 100644 index 0000000000..bf4ff0b0e1 --- /dev/null +++ b/cask/src/cask/internal/ThreadBlockingHandler.scala @@ -0,0 +1,17 @@ +package cask.internal + +import io.undertow.server.{HttpHandler, HttpServerExchange} + +import java.util.concurrent.Executor + +/** + * A handler that dispatches the request to the given handler using the given executor. + * */ +final class ThreadBlockingHandler(executor: Executor, handler: HttpHandler) extends HttpHandler { + require(executor ne null, "Executor should not be null") + + def handleRequest(exchange: HttpServerExchange): Unit = { + exchange.startBlocking() + exchange.dispatch(executor, handler) + } +} diff --git a/cask/src/cask/internal/Util.scala b/cask/src/cask/internal/Util.scala index 87e2a15623..4b829f93f6 100644 --- a/cask/src/cask/internal/Util.scala +++ b/cask/src/cask/internal/Util.scala @@ -1,24 +1,121 @@ package cask.internal import java.io.{InputStream, PrintWriter, StringWriter} - import scala.collection.generic.CanBuildFrom import scala.collection.mutable import java.io.OutputStream - +import java.lang.invoke.{MethodHandles, MethodType} +import java.util.concurrent.{Executor, ExecutorService, ForkJoinPool, ThreadFactory} import scala.annotation.switch import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.util.Try +import scala.util.control.NonFatal object Util { + private val lookup = MethodHandles.lookup() + + import cask.util.Logger.Console.globalLogger + + /** + * Create a virtual thread executor with the given executor as the scheduler. + * */ + def createVirtualThreadExecutor(executor: Executor): Option[Executor] = { + (for { + factory <- Try(createVirtualThreadFactory("cask-handler-executor", executor)) + executor <- Try(createNewThreadPerTaskExecutor(factory)) + } yield executor).toOption + } + + /** + * Create a default cask virtual thread executor if possible. + * */ + def createDefaultCaskVirtualThreadExecutor: Option[Executor] = { + for { + scheduler <- getDefaultVirtualThreadScheduler + executor <- createVirtualThreadExecutor(scheduler) + } yield executor + } + + /** + * Try to get the default virtual thread scheduler, or null if not supported. + * */ + def getDefaultVirtualThreadScheduler: Option[ForkJoinPool] = { + try { + val virtualThreadClass = Class.forName("java.lang.VirtualThread") + val privateLookup = MethodHandles.privateLookupIn(virtualThreadClass, lookup) + val defaultSchedulerField = privateLookup.findStaticVarHandle(virtualThreadClass, "DEFAULT_SCHEDULER", classOf[ForkJoinPool]) + Option(defaultSchedulerField.get().asInstanceOf[ForkJoinPool]) + } catch { + case NonFatal(e) => + //--add-opens java.base/java.lang=ALL-UNNAMED + globalLogger.exception(e) + None + } + } + + def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = { + try { + val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors") + val newThreadPerTaskExecutorMethod = lookup.findStatic( + executorsClazz, + "newThreadPerTaskExecutor", + MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory])) + newThreadPerTaskExecutorMethod.invoke(threadFactory) + .asInstanceOf[ExecutorService] + } catch { + case NonFatal(e) => + globalLogger.exception(e) + throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e) + } + } + + /** + * Create a virtual thread factory with a executor, the executor will be used as the scheduler of + * virtual thread. + * + * The executor should run task on platform threads. + * + * returns null if not supported. + */ + def createVirtualThreadFactory(prefix: String, + executor: Executor): ThreadFactory = + try { + val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder") + val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual") + val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass)) + var builder = ofVirtualMethod.invoke() + if (executor != null) { + val clazz = builder.getClass + val privateLookup = MethodHandles.privateLookupIn( + clazz, + lookup + ) + val schedulerFieldSetter = privateLookup + .findSetter(clazz, "scheduler", classOf[Executor]) + schedulerFieldSetter.invoke(builder, executor) + } + val nameMethod = lookup.findVirtual(ofVirtualClass, "name", + MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long])) + val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory])) + builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L) + factoryMethod.invoke(builder).asInstanceOf[ThreadFactory] + } catch { + case NonFatal(e) => + globalLogger.exception(e) + //--add-opens java.base/java.lang=ALL-UNNAMED + throw new UnsupportedOperationException("Failed to create virtual thread factory.", e) + } + def firstFutureOf[T](futures: Seq[Future[T]])(implicit ec: ExecutionContext) = { val p = Promise[T] futures.foreach(_.foreach(p.trySuccess)) p.future } + /** - * Convert a string to a C&P-able literal. Basically - * copied verbatim from the uPickle source code. - */ + * Convert a string to a C&P-able literal. Basically + * copied verbatim from the uPickle source code. + */ def literalize(s: IndexedSeq[Char], unicode: Boolean = true) = { val sb = new StringBuilder sb.append('"') @@ -47,8 +144,8 @@ object Util { def transferTo(in: InputStream, out: OutputStream) = { val buffer = new Array[Byte](8192) - while ({ - in.read(buffer) match{ + while ( { + in.read(buffer) match { case -1 => false case n => out.write(buffer, 0, n) @@ -56,20 +153,21 @@ object Util { } }) () } + def pluralize(s: String, n: Int) = { if (n == 1) s else s + "s" } /** - * Splits a string into path segments; automatically removes all - * leading/trailing slashes, and ignores empty path segments. - * - * Written imperatively for performance since it's used all over the place. - */ + * Splits a string into path segments; automatically removes all + * leading/trailing slashes, and ignores empty path segments. + * + * Written imperatively for performance since it's used all over the place. + */ def splitPath(p: String): collection.IndexedSeq[String] = { val pLength = p.length var i = 0 - while(i < pLength && p(i) == '/') i += 1 + while (i < pLength && p(i) == '/') i += 1 var segmentStart = i val out = mutable.ArrayBuffer.empty[String] @@ -81,7 +179,7 @@ object Util { segmentStart = i + 1 } - while(i < pLength){ + while (i < pLength) { if (p(i) == '/') complete() i += 1 } @@ -96,6 +194,7 @@ object Util { pw.flush() trace.toString } + def softWrap(s: String, leftOffset: Int, maxWidth: Int) = { val oneLine = s.linesIterator.mkString(" ").split(' ') @@ -103,13 +202,13 @@ object Util { val output = new StringBuilder(oneLine.head) var currentLineWidth = oneLine.head.length - for(chunk <- oneLine.tail){ + for (chunk <- oneLine.tail) { val addedWidth = currentLineWidth + chunk.length + 1 - if (addedWidth > maxWidth){ + if (addedWidth > maxWidth) { output.append("\n" + indent) output.append(chunk) currentLineWidth = chunk.length - } else{ + } else { currentLineWidth = addedWidth output.append(' ') output.append(chunk) @@ -117,12 +216,13 @@ object Util { } output.mkString } + def sequenceEither[A, B, M[X] <: TraversableOnce[X]](in: M[Either[A, B]])( implicit cbf: CanBuildFrom[M[Either[A, B]], B, M[B]]): Either[A, M[B]] = { in.foldLeft[Either[A, mutable.Builder[B, M[B]]]](Right(cbf(in))) { - case (acc, el) => - for (a <- acc; e <- el) yield a += e - } + case (acc, el) => + for (a <- acc; e <- el) yield a += e + } .map(_.result()) } } diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index 67b41e5c1c..ee440345f3 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -2,7 +2,7 @@ package cask.main import cask.endpoints.{WebsocketResult, WsHandler} import cask.model._ -import cask.internal.{DispatchTrie, Util} +import cask.internal.{DispatchTrie, ThreadBlockingHandler, Util} import cask.main import cask.router.{Decorator, EndpointMetadata, EntryPoint, RawDecorator, Result} import cask.util.Logger @@ -11,6 +11,7 @@ import io.undertow.server.{HttpHandler, HttpServerExchange} import io.undertow.server.handlers.BlockingHandler import io.undertow.util.HttpString +import java.util.concurrent.Executor import scala.concurrent.ExecutionContext /** @@ -46,9 +47,33 @@ abstract class Main{ def dispatchTrie = Main.prepareDispatchTrie(allRoutes) - def defaultHandler = new BlockingHandler( - new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError) - ) + /** + * The handler that will be used to handle incoming requests. By default, + * when a `null` handler is provided, a default handler will be used, + * otherwise the provided executor will be used to handle requests. + * + * When `cask.virtual-thread.enabled` is set to `true` and running with a JDK + * where virtual threads are supported, then a virtual thread executor will be used. + * */ + protected def handlerExecutor(): Executor = { + if (enableVirtualThread) { + Util.createDefaultCaskVirtualThreadExecutor.orNull + } else null + } + + protected def enableVirtualThread: Boolean = { + val enableVirtualThread = System.getProperty(Main.VIRTUAL_THREAD_ENABLED) + enableVirtualThread != null && enableVirtualThread.toBoolean + } + + def defaultHandler: HttpHandler = { + val executor = handlerExecutor() + val mainHandler = new Main.DefaultHandler( + dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError) + if (executor ne null) { + new ThreadBlockingHandler(executor, mainHandler) + } else new BlockingHandler(mainHandler) + } def handleNotFound(req: Request): Response.Raw = Main.defaultHandleNotFound(req) @@ -72,7 +97,12 @@ abstract class Main{ } -object Main{ +object Main { + /** + * property key to enable virtual thread support. + * */ + val VIRTUAL_THREAD_ENABLED = "cask.virtual-thread.enabled" + class DefaultHandler(dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]], mainDecorators: Seq[Decorator[_, _, _, _]], debugMode: Boolean, @@ -86,7 +116,7 @@ object Main{ Tuple2( "websocket", (r: Any) => - r.asInstanceOf[WebsocketResult] match{ + r.asInstanceOf[WebsocketResult] match { case l: WsHandler => io.undertow.Handlers.websocket(l).handleRequest(exchange) case l: WebsocketResult.Listener => @@ -131,8 +161,9 @@ object Main{ } } } - }catch{case e: Throwable => - e.printStackTrace() + } catch { + case e: Throwable => + e.printStackTrace() } } @@ -160,7 +191,7 @@ object Main{ val methodMap = methods.toMap[String, (Routes, EndpointMetadata[_])] val subpath = metadata.endpoint.subpath || - metadata.entryPoint.argSignatures.exists(_.exists(_.reads.remainingPathSegments)) + metadata.entryPoint.argSignatures.exists(_.exists(_.reads.remainingPathSegments)) (segments, methodMap, subpath) } @@ -175,10 +206,10 @@ object Main{ } def writeResponse(exchange: HttpServerExchange, response: Response.Raw) = { - response.data.headers.foreach{case (k, v) => + response.data.headers.foreach { case (k, v) => exchange.getResponseHeaders.put(new HttpString(k), v) } - response.headers.foreach{case (k, v) => + response.headers.foreach { case (k, v) => exchange.getResponseHeaders.put(new HttpString(k), v) } response.cookies.foreach(c => exchange.setResponseCookie(Cookie.toUndertow(c))) diff --git a/docs/pages/1 - Cask - a Scala HTTP micro-framework.md b/docs/pages/1 - Cask - a Scala HTTP micro-framework.md index b9b1697823..1f77a28237 100644 --- a/docs/pages/1 - Cask - a Scala HTTP micro-framework.md +++ b/docs/pages/1 - Cask - a Scala HTTP micro-framework.md @@ -130,6 +130,27 @@ $$$minimalApplication2 You can split up your routes into separate `cask.Routes` objects as makes sense and pass them all into `cask.Main`. +$$minimalApplicationWithLoom + +Cask can support using Virtual Threads to handle the request out of the box, you can enable it with the next steps: + +1. Running cask with Java 21 or later +2. add `--add-opens java.base/java.lang=ALL-UNNAMED` to your JVM options, which is needed to name the virtual threads. +3. add `-Dcask.virtual-thread.enabled=true` to your JVM options, which is needed to enable the virtual threads. +4. tweak the underlying carrier threads with `-Djdk.virtualThreadScheduler.parallelism`, `jdk.virtualThreadScheduler.maxPoolSize` and `jdk.unparker.maxPoolSize`. + +**Advanced Features**: +1. You can change the default scheduler of the carrier threads with `cask.internal.Util.createVirtualThreadExecutor` method, but keep in mind, that's not officially supported by JDK for now. +2. You can supply your own `Executor` by override the `handlerExecutor()` method in your `cask.Main` object, which will be called only once when the server starts. +3. You can use `jdk.internal.misc.Blocker`'s `begin` and `end` methods to help the `ForkJoinPool` when needed. + +**NOTE**: +1. If your code is CPU-bound, you should not use virtual threads, because it will not improve the performance, but will increase the overhead. +2. OOM is a common issue when you have many virtual threads, you should limit the max in-flight requests to avoid it. +3. There are some known issues which can leads to a deadlock, you should be careful when using it in production, at least after long time stress test. +3. [JEP 491: Synchronize Virtual Threads without Pinning](https://openjdk.org/jeps/491) will be shipped in Java 24. +4. Some info from early adaptor [faire](https://craft.faire.com/java-virtual-threads-increasing-search-performance-while-avoiding-deadlocks-f12fa296d521) + ## Variable Routes $$$variableRoutes diff --git a/example/minimalApplicationWithLoom/app/src/MinimalApplicationWithLoom.scala b/example/minimalApplicationWithLoom/app/src/MinimalApplicationWithLoom.scala new file mode 100644 index 0000000000..3dc93fb4ba --- /dev/null +++ b/example/minimalApplicationWithLoom/app/src/MinimalApplicationWithLoom.scala @@ -0,0 +1,60 @@ +package app + +import cask.internal.Util +import cask.main.Main + +import java.lang.management.{ManagementFactory, RuntimeMXBean} +import java.util.concurrent.{Executor, Executors} + +// run benchmark with : ./mill benchmark.runBenchmark +object MinimalApplicationWithLoom extends cask.MainRoutes { + // Print Java version + private val javaVersion: String = System.getProperty("java.version") + println("Java Version: " + javaVersion) + + // Print JVM arguments// Print JVM arguments + private val runtimeMxBean: RuntimeMXBean = ManagementFactory.getRuntimeMXBean + private val jvmArguments = runtimeMxBean.getInputArguments + println("JVM Arguments:") + + jvmArguments.forEach((arg: String) => println(arg)) + + println(Main.VIRTUAL_THREAD_ENABLED + " :" + System.getProperty(Main.VIRTUAL_THREAD_ENABLED)) + + //Use the same underlying executor for both virtual and non-virtual threads + private val executor = Executors.newFixedThreadPool(4) + + //TO USE LOOM: + //1. JDK 21 or later is needed. + //2. add VM option: --add-opens java.base/java.lang=ALL-UNNAMED + //3. set system property: cask.virtual-thread.enabled=true + //4. NOTE: `java.util.concurrent.Executors.newVirtualThreadPerTaskExecutor` is using the shared + // ForkJoinPool in VirtualThread. If you want to use a separate ForkJoinPool, you can create + // a new ForkJoinPool instance and pass it to `createVirtualThreadExecutor` method. + + override protected def handlerExecutor(): Executor = { + if (enableVirtualThread) { + Util.createVirtualThreadExecutor(executor).get + } else { + executor + } + } + + /** + * With curl: curl -X GET http://localhost:8080/ + * you wil see something like: + * Hello World! from thread:VirtualThread[#63,cask-handler-executor-virtual-thread-10]/runnable@ForkJoinPool-1-worker-1% + * */ + @cask.get("/") + def hello() = { + Thread.sleep(100) // simulate some blocking work + "Hello World!" + } + + @cask.post("/do-thing") + def doThing(request: cask.Request) = { + request.text().reverse + } + + initialize() +} diff --git a/example/minimalApplicationWithLoom/app/test/src/ExampleTests.scala b/example/minimalApplicationWithLoom/app/test/src/ExampleTests.scala new file mode 100644 index 0000000000..28f4247020 --- /dev/null +++ b/example/minimalApplicationWithLoom/app/test/src/ExampleTests.scala @@ -0,0 +1,34 @@ +package app +import io.undertow.Undertow +import org.xnio.Options +import utest._ + +object ExampleTests extends TestSuite{ + def withServer[T](example: cask.main.Main)(f: String => T): T = { + val server = Undertow.builder + .addHttpListener(8081, "localhost") + .setSocketOption(Options.REUSE_ADDRESSES, java.lang.Boolean.TRUE) + .setHandler(example.defaultHandler) + .build + server.start() + val res = + try f("http://localhost:8081") + finally server.stop() + res + } + + val tests = Tests { + test("MinimalApplicationWithLoom") - withServer(MinimalApplicationWithLoom) { host => + val success = requests.get(host) + + success.text() ==> "Hello World!" + success.statusCode ==> 200 + + requests.get(s"$host/doesnt-exist", check = false).statusCode ==> 404 + + requests.post(s"$host/do-thing", data = "hello").text() ==> "olleh" + + requests.delete(s"$host/do-thing", check = false).statusCode ==> 405 + } + } +} diff --git a/example/minimalApplicationWithLoom/package.mill b/example/minimalApplicationWithLoom/package.mill new file mode 100644 index 0000000000..791d8f494b --- /dev/null +++ b/example/minimalApplicationWithLoom/package.mill @@ -0,0 +1,49 @@ +package build.example.minimalApplicationWithLoom + +import mill._, scalalib._ +import mill.define.ModuleRef + +object app extends Cross[AppModule](build.scalaVersions) +trait AppModule extends CrossScalaModule{ + + private def parseJvmArgs(argsStr: String) = { + argsStr.split(" ").filter(_.nonEmpty).toSeq + } + + def forkArgs = T { + //TODO not sure why the env passing is not working + val envVirtualThread: String = T.env.getOrElse("CASK_VIRTUAL_THREAD", "null") + println("envVirtualThread: " + envVirtualThread) + + val systemProps = if (envVirtualThread == "true") { + Seq("-Dcask.virtual-thread.enabled=true") + } else Nil + + val baseArgs = Seq( + "--add-opens", "java.base/java.lang=ALL-UNNAMED" + ) + + val seq = baseArgs ++ systemProps + println("final forkArgs: " + seq) + seq + } + + def zincWorker = ModuleRef(ZincWorkerJava11Latest) + + def moduleDeps = Seq(build.cask(crossScalaVersion)) + + def ivyDeps = Agg[Dep]( + ) + + object test extends ScalaTests with TestModule.Utest { + def ivyDeps = Agg( + ivy"com.lihaoyi::utest::0.8.4", + ivy"com.lihaoyi::requests::0.9.0", + ) + } +} + +object ZincWorkerJava11Latest extends ZincWorkerModule with CoursierModule { + def jvmId = "temurin:23.0.1" + def jvmIndexVersion = "latest.release" +} diff --git a/example/staticFiles/package.mill b/example/staticFiles/package.mill index 003685dc15..6d0e492776 100644 --- a/example/staticFiles/package.mill +++ b/example/staticFiles/package.mill @@ -22,7 +22,7 @@ trait AppModule extends CrossScalaModule{ app => // redirect this to the forked `test` to make sure static file serving works def testLocal(args: String*) = T.command{ - test(args:_*) + this.test(args:_*) } } } diff --git a/example/staticFiles2/package.mill b/example/staticFiles2/package.mill index 27881395e7..9fc9b512ab 100644 --- a/example/staticFiles2/package.mill +++ b/example/staticFiles2/package.mill @@ -22,7 +22,7 @@ trait AppModule extends CrossScalaModule{ app => // redirect this to the forked `test` to make sure static file serving works def testLocal(args: String*) = T.command{ - test(args:_*) + this.test(args:_*) } } } diff --git a/example/staticFilesWithLoom/app/resources/cask/example.txt b/example/staticFilesWithLoom/app/resources/cask/example.txt new file mode 100644 index 0000000000..5184576553 --- /dev/null +++ b/example/staticFilesWithLoom/app/resources/cask/example.txt @@ -0,0 +1 @@ +the quick brown fox jumps over the lazy dog \ No newline at end of file diff --git a/example/staticFilesWithLoom/app/src/StaticFilesWithLoom.scala b/example/staticFilesWithLoom/app/src/StaticFilesWithLoom.scala new file mode 100644 index 0000000000..ec98bf10b7 --- /dev/null +++ b/example/staticFilesWithLoom/app/src/StaticFilesWithLoom.scala @@ -0,0 +1,33 @@ +package app + +import cask.internal.Util +import java.util.concurrent.{Executor, Executors} + +object StaticFilesWithLoom extends cask.MainRoutes{ + private val executor = Executors.newFixedThreadPool(4) + + override protected def handlerExecutor(): Executor = { + println("use virtual thread : " + enableVirtualThread) + if (enableVirtualThread) { + Util.createVirtualThreadExecutor(executor).get + } else { + executor + } + } + + @cask.get("/") + def index() = { + "Hello!" + } + + @cask.staticFiles("/static/file") + def staticFileRoutes() = "resources/cask" + + @cask.staticResources("/static/resource") + def staticResourceRoutes() = "cask" + + @cask.staticResources("/static/resource2") + def staticResourceRoutes2() = "." + + initialize() +} diff --git a/example/staticFilesWithLoom/app/test/src/ExampleTests.scala b/example/staticFilesWithLoom/app/test/src/ExampleTests.scala new file mode 100644 index 0000000000..d8d723e558 --- /dev/null +++ b/example/staticFilesWithLoom/app/test/src/ExampleTests.scala @@ -0,0 +1,35 @@ +package app +import io.undertow.Undertow + +import utest._ + +object ExampleTests extends TestSuite{ + def withServer[T](example: cask.main.Main)(f: String => T): T = { + val server = Undertow.builder + .addHttpListener(8081, "localhost") + .setHandler(example.defaultHandler) + .build + server.start() + val res = + try f("http://localhost:8081") + finally server.stop() + res + } + + val tests = Tests{ + + test("StaticFiles") - withServer(StaticFilesWithLoom){ host => + requests.get(s"$host/static/file/example.txt").text() ==> + "the quick brown fox jumps over the lazy dog" + + requests.get(s"$host/static/resource/example.txt").text() ==> + "the quick brown fox jumps over the lazy dog" + + requests.get(s"$host/static/resource2/cask/example.txt").text() ==> + "the quick brown fox jumps over the lazy dog" + + requests.get(s"$host/static/file/../../../build.sc", check = false).statusCode ==> 404 + } + + } +} diff --git a/example/staticFilesWithLoom/package.mill b/example/staticFilesWithLoom/package.mill new file mode 100644 index 0000000000..b7292d584e --- /dev/null +++ b/example/staticFilesWithLoom/package.mill @@ -0,0 +1,60 @@ +package build.example.staticFilesWithLoom +import mill._, scalalib._ +import mill.define.ModuleRef + +object app extends Cross[AppModule](build.scalaVersions) +trait AppModule extends CrossScalaModule{ app => + + def moduleDeps = Seq(build.cask(crossScalaVersion)) + + def forkWorkingDir = app.millSourcePath + def ivyDeps = Agg[Dep]( + ) + + private def parseJvmArgs(argsStr: String) = { + argsStr.split(" ").filter(_.nonEmpty).toSeq + } + + def forkArgs = T { + //TODO not sure why the env passing is not working + val envVirtualThread: String = T.env.getOrElse("CASK_VIRTUAL_THREAD", "null") + println("envVirtualThread: " + envVirtualThread) + + val systemProps = if (envVirtualThread == "true") { + Seq("-Dcask.virtual-thread.enabled=true") + } else Nil + + val baseArgs = Seq( + "--add-opens", "java.base/java.lang=ALL-UNNAMED" + ) + + val seq = baseArgs ++ systemProps + println("final forkArgs: " + seq) + seq + } + + def zincWorker = ModuleRef(ZincWorkerJava11Latest) + + object test extends ScalaTests with TestModule.Utest{ + + def ivyDeps = Agg( + ivy"com.lihaoyi::utest::0.8.4", + ivy"com.lihaoyi::requests::0.9.0", + ) + + def forkWorkingDir = app.millSourcePath + + def testSandboxWorkingDir = false + + // redirect this to the forked `test` to make sure static file serving works + def testLocal(args: String*) = T.command{ + this.test(args:_*) + } + } +} + + +object ZincWorkerJava11Latest extends ZincWorkerModule with CoursierModule { + def jvmId = "temurin:23.0.1" + def jvmIndexVersion = "latest.release" +} \ No newline at end of file diff --git a/example/todoDbWithLoom/app/src/TodoMvcDbWithLoom.scala b/example/todoDbWithLoom/app/src/TodoMvcDbWithLoom.scala new file mode 100644 index 0000000000..3d85e34ed0 --- /dev/null +++ b/example/todoDbWithLoom/app/src/TodoMvcDbWithLoom.scala @@ -0,0 +1,94 @@ +package app +import cask.internal.Util +import scalasql.DbApi.Txn +import scalasql.Sc +import scalasql.SqliteDialect._ + +import java.util.concurrent.{Executor, Executors} + +object TodoMvcDbWithLoom extends cask.MainRoutes { + val tmpDb = java.nio.file.Files.createTempDirectory("todo-cask-sqlite") + val sqliteDataSource = new org.sqlite.SQLiteDataSource() + sqliteDataSource.setUrl(s"jdbc:sqlite:$tmpDb/file.db") + lazy val sqliteClient = new scalasql.DbClient.DataSource( + sqliteDataSource, + config = new scalasql.Config {} + ) + + private val executor = Executors.newFixedThreadPool(4) + override protected def handlerExecutor(): Executor = { + if (enableVirtualThread) { + Util.createVirtualThreadExecutor(executor).get + } else { + executor + } + } + + class transactional extends cask.RawDecorator{ + def wrapFunction(pctx: cask.Request, delegate: Delegate) = { + sqliteClient.transaction { txn => + val res = delegate(pctx, Map("txn" -> txn)) + if (res.isInstanceOf[cask.router.Result.Error]) txn.rollback() + res + } + } + } + + case class Todo[T[_]](id: T[Int], checked: T[Boolean], text: T[String]) + object Todo extends scalasql.Table[Todo]{ + implicit def todoRW = upickle.default.macroRW[Todo[Sc]] + } + + sqliteClient.getAutoCommitClientConnection.updateRaw( + """CREATE TABLE todo ( + | id INTEGER PRIMARY KEY AUTOINCREMENT, + | checked BOOLEAN, + | text TEXT + |); + | + |INSERT INTO todo (checked, text) VALUES + |(1, 'Get started with Cask'), + |(0, 'Profit!'); + |""".stripMargin + ) + + @transactional + @cask.get("/list/:state") + def list(state: String)(txn: Txn) = { + val filteredTodos = state match{ + case "all" => txn.run(Todo.select) + case "active" => txn.run(Todo.select.filter(!_.checked)) + case "completed" => txn.run(Todo.select.filter(_.checked)) + } + upickle.default.write(filteredTodos) + } + + @transactional + @cask.post("/add") + def add(request: cask.Request)(txn: Txn) = { + val body = request.text() + txn.run( + Todo + .insert + .columns(_.checked := false, _.text := body) + .returning(_.id) + .single + ) + + if (body == "FORCE FAILURE") throw new Exception("FORCE FAILURE BODY") + } + + @transactional + @cask.post("/toggle/:index") + def toggle(index: Int)(txn: Txn) = { + txn.run(Todo.update(_.id === index).set(p => p.checked := !p.checked)) + } + + @transactional + @cask.post("/delete/:index") + def delete(index: Int)(txn: Txn) = { + txn.run(Todo.delete(_.id === index)) + } + + initialize() +} diff --git a/example/todoDbWithLoom/app/test/src/ExampleTests.scala b/example/todoDbWithLoom/app/test/src/ExampleTests.scala new file mode 100644 index 0000000000..3882a97ab2 --- /dev/null +++ b/example/todoDbWithLoom/app/test/src/ExampleTests.scala @@ -0,0 +1,50 @@ +package app +import io.undertow.Undertow + +import utest._ + +object ExampleTests extends TestSuite{ + def withServer[T](example: cask.main.Main)(f: String => T): T = { + val server = Undertow.builder + .addHttpListener(8081, "localhost") + .setHandler(example.defaultHandler) + .build + server.start() + val res = + try f("http://localhost:8081") + finally server.stop() + res + } + + val tests = Tests{ + test("TodoMvcDb") - withServer(TodoMvcDbWithLoom){ host => + requests.get(s"$host/list/all").text() ==> + """[{"id":1,"checked":true,"text":"Get started with Cask"},{"id":2,"checked":false,"text":"Profit!"}]""" + requests.get(s"$host/list/active").text() ==> + """[{"id":2,"checked":false,"text":"Profit!"}]""" + requests.get(s"$host/list/completed").text() ==> + """[{"id":1,"checked":true,"text":"Get started with Cask"}]""" + + requests.post(s"$host/toggle/2") + + requests.get(s"$host/list/all").text() ==> + """[{"id":1,"checked":true,"text":"Get started with Cask"},{"id":2,"checked":true,"text":"Profit!"}]""" + + requests.get(s"$host/list/active").text() ==> + """[]""" + + requests.post(s"$host/add", data = "new Task") + + // Make sure endpoint failures do not commit their transaction + requests.post(s"$host/add", data = "FORCE FAILURE", check = false).statusCode ==> 500 + + requests.get(s"$host/list/active").text() ==> + """[{"id":3,"checked":false,"text":"new Task"}]""" + + requests.post(s"$host/delete/3") + + requests.get(s"$host/list/active").text() ==> + """[]""" + } + } +} diff --git a/example/todoDbWithLoom/package.mill b/example/todoDbWithLoom/package.mill new file mode 100644 index 0000000000..bb829eaa19 --- /dev/null +++ b/example/todoDbWithLoom/package.mill @@ -0,0 +1,50 @@ +package build.example.todoDbWithLoom +import mill._, scalalib._ +import mill.define.ModuleRef + +object app extends Cross[AppModule](build.scala213) +trait AppModule extends CrossScalaModule{ + + private def parseJvmArgs(argsStr: String) = { + argsStr.split(" ").filter(_.nonEmpty).toSeq + } + + def forkArgs = T { + //TODO not sure why the env passing is not working + val envVirtualThread: String = T.env.getOrElse("CASK_VIRTUAL_THREAD", "null") + println("envVirtualThread: " + envVirtualThread) + + val systemProps = if (envVirtualThread == "true") { + Seq("-Dcask.virtual-thread.enabled=true") + } else Nil + + val baseArgs = Seq( + "--add-opens", "java.base/java.lang=ALL-UNNAMED" + ) + + val seq = baseArgs ++ systemProps + println("final forkArgs: " + seq) + seq + } + + def zincWorker = ModuleRef(ZincWorkerJava11Latest) + + def moduleDeps = Seq(build.cask(crossScalaVersion)) + + def ivyDeps = Agg[Dep]( + ivy"org.xerial:sqlite-jdbc:3.42.0.0", + ivy"com.lihaoyi::scalasql:0.1.0", + ) + + object test extends ScalaTests with TestModule.Utest { + def ivyDeps = Agg( + ivy"com.lihaoyi::utest::0.8.4", + ivy"com.lihaoyi::requests::0.9.0", + ) + } +} + +object ZincWorkerJava11Latest extends ZincWorkerModule with CoursierModule { + def jvmId = "temurin:23.0.1" + def jvmIndexVersion = "latest.release" +} \ No newline at end of file