diff --git a/build.sc b/build.sc index 0391de934..21ac0279b 100644 --- a/build.sc +++ b/build.sc @@ -159,12 +159,14 @@ class ScalaInterpreter(val crossScalaVersion: String) extends AlmondModule with if (crossScalaVersion.startsWith("3.")) Seq( shared.interpreter(ScalaVersions.scala3Compat), - scala.`scala-kernel-api`() + scala.`scala-kernel-api`(), + scala.`toree-hooks`(ScalaVersions.binary(crossScalaVersion)) ) else Seq( shared.interpreter(), - scala.`scala-kernel-api`() + scala.`scala-kernel-api`(), + scala.`toree-hooks`(ScalaVersions.binary(crossScalaVersion)) ) def ivyDeps = T { val metabrowse = @@ -302,6 +304,12 @@ class Echo(val crossScalaVersion: String) extends AlmondModule { } } +class ToreeHooks(val crossScalaVersion: String) extends AlmondModule { + def compileModuleDeps = super.compileModuleDeps ++ Seq( + scala.`scala-kernel-api`(ScalaVersions.binary(crossScalaVersion)) + ) +} + object shared extends Module { object `logger-scala2-macros` extends Cross[LoggerScala2Macros](ScalaVersions.binaries: _*) object logger extends Cross[Logger](ScalaVersions.binaries: _*) @@ -325,6 +333,8 @@ object scala extends Module { extends Cross[ScalaKernelHelper](ScalaVersions.all.filter(_.startsWith("3.")): _*) object `almond-scalapy` extends Cross[AlmondScalaPy](ScalaVersions.binaries: _*) object `almond-rx` extends Cross[AlmondRx](ScalaVersions.scala212, ScalaVersions.scala213) + + object `toree-hooks` extends Cross[ToreeHooks](ScalaVersions.binaries: _*) } object echo extends Cross[Echo](ScalaVersions.binaries: _*) diff --git a/docs/pages/api-jupyter.md b/docs/pages/api-jupyter.md index 676a356f3..640c4a3ba 100644 --- a/docs/pages/api-jupyter.md +++ b/docs/pages/api-jupyter.md @@ -153,3 +153,50 @@ kernel.publish.updateHtml("Got all items", id) ``` ![](/demo/updatable.gif) + +### Hooks + +Hooks allow to pre-process code right before it's executed. Use like +```scala +private def runSql(sql: String): String = { + println("Running query...") + val fakeResult = + """ + | + | + | + | + | + | + | + | + | + | + | + | + |
IdName
1Tree
2Apple
+ |""".stripMargin + fakeResult +} + +kernel.addExecuteHook { code => + import almond.api.JupyterApi + import almond.interpreter.api.DisplayData + + if (code.linesIterator.take(1).toList == List("%sql")) { + val sql = code.linesWithSeparators.drop(1).mkString // drop first line with "%sql" + val result = runSql(sql) + Left(JupyterApi.ExecuteHookResult.Success(DisplayData.html(result))) + } + else + Right(code) // just pass on code +} +``` + +Such code can be run either in a cell or in a predef file. + +Later on, users can run things like +```text +%sql +SELECT id, name FROM my_table +``` diff --git a/docs/pages/install-options.md b/docs/pages/install-options.md index 6a72a6cfc..d26058bac 100644 --- a/docs/pages/install-options.md +++ b/docs/pages/install-options.md @@ -157,3 +157,12 @@ Add entries to the Help menu (Jupyter classic). Use like #### `--connection-file` +## Other + +#### `--toree-magics` + +Enable experimental support for [Toree](https://toree.apache.org) magics. + +Simple line magics such as `%AddDeps` (always assumed to be transitive as of writing this, `--transitive` is just ignored), +`%AddJar`, and cell magics such as `%%html` or `%%javascript` are supported. Note that `%%javascript` only works from Jupyter +classic, as JupyterLab doesn't allow for random javascript code execution. diff --git a/modules/scala/jupyter-api/src/main/scala/almond/api/JupyterApi.scala b/modules/scala/jupyter-api/src/main/scala/almond/api/JupyterApi.scala index 5cd72741b..7409cd11d 100644 --- a/modules/scala/jupyter-api/src/main/scala/almond/api/JupyterApi.scala +++ b/modules/scala/jupyter-api/src/main/scala/almond/api/JupyterApi.scala @@ -2,7 +2,7 @@ package almond.api import java.util.UUID -import almond.interpreter.api.{CommHandler, OutputHandler} +import almond.interpreter.api.{CommHandler, DisplayData, OutputHandler} import jupyter.{Displayer, Displayers} import scala.reflect.{ClassTag, classTag} @@ -47,10 +47,62 @@ abstract class JupyterApi { api => def display(t: T) = f(t).asJava } ) + + def addExecuteHook(hook: JupyterApi.ExecuteHook): Boolean + def removeExecuteHook(hook: JupyterApi.ExecuteHook): Boolean } object JupyterApi { + /** A hook, that can pre-process code right before it's executed + */ + @FunctionalInterface + abstract class ExecuteHook { + + /** Pre-processes code right before it's executed. + * + * Like when actual code is executed, `println` / `{Console,System}.{out,err}` get sent to the + * cell output, stdin can be requested from users, CommHandler and OutputHandler can be used, + * etc. + * + * When several hooks were added, they are called in the order they were added. The output of + * the previous hook gets passed to the next one, as long as hooks return code to be executed + * rather than an `ExecuteHookResult`. + * + * @param code + * Code to be pre-processed + * @return + * Either code to be executed (right), or an `ExecuteHookResult` (left) + */ + def hook(code: String): Either[ExecuteHookResult, String] + } + + /** Can be returned by `ExecuteHook.hook` to stop code execution. + */ + sealed abstract class ExecuteHookResult extends Product with Serializable + object ExecuteHookResult { + + /** Returns data to be displayed */ + final case class Success(data: DisplayData = DisplayData.empty) extends ExecuteHookResult + + /** Exception-like error + * + * If you'd like to build one out of an actual `Throwable`, just throw it. It will then be + * caught while the hook is running, and sent to users. + */ + final case class Error( + name: String, + message: String, + stackTrace: List[String] + ) extends ExecuteHookResult + + /** Tells the front-end that execution was aborted */ + case object Abort extends ExecuteHookResult + + /** Should instruct the front-end to prompt the user for exit */ + case object Exit extends ExecuteHookResult + } + abstract class UpdatableResults { @deprecated("Use updatable instead", "0.4.1") diff --git a/modules/scala/scala-interpreter/src/main/scala-2/almond/amm/AlmondParsers.scala b/modules/scala/scala-interpreter/src/main/scala-2/almond/amm/AlmondParsers.scala index d3a076b23..415581050 100644 --- a/modules/scala/scala-interpreter/src/main/scala-2/almond/amm/AlmondParsers.scala +++ b/modules/scala/scala-interpreter/src/main/scala-2/almond/amm/AlmondParsers.scala @@ -6,11 +6,11 @@ import scalaparse.Scala._ object AlmondParsers { - private def Prelude[_: P] = P((Annot ~ OneNLMax).rep ~ (Mod ~/ Pass).rep) + private def Prelude[X: P] = P((Annot ~ OneNLMax).rep ~ (Mod ~/ Pass).rep) // same as the methods with the same name in ammonite.interp.Parsers, but keeping the type aside in LHS - def PatVarSplitter[_: P] = { + def PatVarSplitter[X: P] = { def Prefixes = P(Prelude ~ (`var` | `val`)) def Lhs = P(Prefixes ~/ VarId) def TypeAnnotation = P((`:` ~/ Type.!).?) diff --git a/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterInspections.scala b/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterInspections.scala index 1a86514a2..43d48bb54 100644 --- a/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterInspections.scala +++ b/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterInspections.scala @@ -14,6 +14,7 @@ import ammonite.util.Util.newLine import metabrowse.server.{MetabrowseServer, Sourcepath} import scala.meta.dialects +import scala.collection.compat._ import scala.tools.nsc.Global import scala.tools.nsc.interactive.{Global => Interactive} import scala.util.Random @@ -231,13 +232,13 @@ object ScalaInterpreterInspections { javaDirs.exists(path.startsWith) } - def classpath(cl: ClassLoader): Stream[java.net.URL] = + def classpath(cl: ClassLoader): immutable.LazyList[java.net.URL] = if (cl == null) - Stream() + immutable.LazyList() else { val cp = cl match { - case u: java.net.URLClassLoader => u.getURLs.toStream - case _ => Stream() + case u: java.net.URLClassLoader => u.getURLs.to(immutable.LazyList) + case _ => immutable.LazyList() } cp #::: classpath(cl.getParent) diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/Execute.scala b/modules/scala/scala-interpreter/src/main/scala/almond/Execute.scala index a9b7a8b77..072ec88e6 100644 --- a/modules/scala/scala-interpreter/src/main/scala/almond/Execute.scala +++ b/modules/scala/scala-interpreter/src/main/scala/almond/Execute.scala @@ -26,7 +26,7 @@ import fastparse.Parsed import scala.collection.mutable import scala.concurrent.{Await, ExecutionContext} import scala.concurrent.duration.Duration -import scala.util.{Failure, Success} +import scala.util.{Failure, Success, Try} /** Wraps contextual things around when executing code (capturing output, stdin via front-ends, * interruption, etc.) @@ -65,7 +65,7 @@ final class Execute( Duration.Inf ) } - log.info("Received input") + log.info(s"Received input ${res.map { case "" => "[empty]"; case _ => "[non empty]" }}") res match { case Success(s) => Some(s) @@ -136,7 +136,7 @@ final class Execute( } } - private def withInputManager[T](m: Option[InputManager])(f: => T): T = { + private def withInputManager[T](m: Option[InputManager], done: Boolean = true)(f: => T): T = { val previous = currentInputManagerOpt0 try { currentInputManagerOpt0 = m @@ -144,7 +144,8 @@ final class Execute( } finally { currentInputManagerOpt0 = previous - m.foreach(_.done()) + if (done) + m.foreach(_.done()) } } @@ -337,7 +338,8 @@ final class Execute( inputManager: Option[InputManager], outputHandler: Option[OutputHandler], colors0: Ref[Colors], - storeHistory: Boolean + storeHistory: Boolean, + executeHooks: Seq[JupyterApi.ExecuteHook] ): ExecuteResult = { if (storeHistory) { @@ -345,36 +347,84 @@ final class Execute( history0 = history0 :+ code } - ammResult(ammInterp, code, inputManager, outputHandler, storeHistory) match { - case Res.Success((_, data)) => - ExecuteResult.Success(data) - case Res.Failure(msg) => - interruptedStackTraceOpt0 match { - case None => - val err = Execute.error(colors0(), None, msg) - outputHandler.foreach(_.stderr(err.message)) // necessary? - err - case Some(st) => - val cutoff = Set("$main", "evaluatorRunPrinter") - - ExecuteResult.Error( - ( - "Interrupted!" +: st - .takeWhile(x => !cutoff(x.getMethodName)) - .map(Execute.highlightFrame(_, fansi.Attr.Reset, colors0().literal())) - ).mkString(System.lineSeparator()) - ) + val finalCodeOrResult = + withOutputHandler(outputHandler) { + interruptible { + withInputManager(inputManager, done = false) { + withClientStdin { + capturingOutput { + executeHooks.foldLeft[Try[Either[JupyterApi.ExecuteHookResult, String]]]( + Success(Right(code)) + ) { + (codeOrDisplayDataAttempt, hook) => + codeOrDisplayDataAttempt.flatMap { codeOrDisplayData => + try Success(codeOrDisplayData.flatMap { value => + hook.hook(value) + }) + catch { + case e: Throwable => // kind of meh, but Ammonite does the same it seemsā€¦ + Failure(e) + } + } + } + } + } + } } + } - case Res.Exception(ex, msg) => - log.error(s"exception in user code (${ex.getMessage})", ex) - Execute.error(colors0(), Some(ex), msg) + finalCodeOrResult match { + case Failure(ex) => + log.error(s"exception when running hooks (${ex.getMessage})", ex) + Execute.error(colors0(), Some(ex), "") + + case Success(Left(res)) => + res match { + case s: JupyterApi.ExecuteHookResult.Success => + ExecuteResult.Success(s.data) + case e: JupyterApi.ExecuteHookResult.Error => + ExecuteResult.Error(e.name, e.message, e.stackTrace) + case JupyterApi.ExecuteHookResult.Abort => + ExecuteResult.Abort + case JupyterApi.ExecuteHookResult.Exit => + ExecuteResult.Exit + } - case Res.Skip => + case Success(Right(emptyCode)) if emptyCode.trim.isEmpty => ExecuteResult.Success() - case Res.Exit(_) => - ExecuteResult.Exit + case Success(Right(finalCode)) => + ammResult(ammInterp, finalCode, inputManager, outputHandler, storeHistory) match { + case Res.Success((_, data)) => + ExecuteResult.Success(data) + case Res.Failure(msg) => + interruptedStackTraceOpt0 match { + case None => + val err = Execute.error(colors0(), None, msg) + outputHandler.foreach(_.stderr(err.message)) // necessary? + err + case Some(st) => + val cutoff = Set("$main", "evaluatorRunPrinter") + + ExecuteResult.Error( + ( + "Interrupted!" +: st + .takeWhile(x => !cutoff(x.getMethodName)) + .map(Execute.highlightFrame(_, fansi.Attr.Reset, colors0().literal())) + ).mkString(System.lineSeparator()) + ) + } + + case Res.Exception(ex, msg) => + log.error(s"exception in user code (${ex.getMessage})", ex) + Execute.error(colors0(), Some(ex), msg) + + case Res.Skip => + ExecuteResult.Success() + + case Res.Exit(_) => + ExecuteResult.Exit + } } } } diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/JupyterApiImpl.scala b/modules/scala/scala-interpreter/src/main/scala/almond/JupyterApiImpl.scala index ab739f5f5..137f00223 100644 --- a/modules/scala/scala-interpreter/src/main/scala/almond/JupyterApiImpl.scala +++ b/modules/scala/scala-interpreter/src/main/scala/almond/JupyterApiImpl.scala @@ -9,6 +9,7 @@ import almond.interpreter.api.CommHandler import ammonite.util.Ref import pprint.{TPrint, TPrintColors} +import scala.collection.mutable import scala.concurrent.Await import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -69,4 +70,20 @@ final class JupyterApiImpl( protected def updatableResults0: JupyterApi.UpdatableResults = execute.updatableResults + + private val executeHooks0 = new mutable.ListBuffer[JupyterApi.ExecuteHook] + def executeHooks: Seq[JupyterApi.ExecuteHook] = + executeHooks0.toList + def addExecuteHook(hook: JupyterApi.ExecuteHook): Boolean = + !executeHooks0.contains(hook) && { + executeHooks0.append(hook) + true + } + def removeExecuteHook(hook: JupyterApi.ExecuteHook): Boolean = { + val idx = executeHooks0.indexOf(hook) + idx >= 0 && { + executeHooks0.remove(idx) + true + } + } } diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala index a8c776d08..2d38dcf18 100644 --- a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala +++ b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala @@ -8,6 +8,7 @@ import almond.interpreter.input.InputManager import almond.interpreter.util.AsyncInterpreterOps import almond.logger.LoggerContext import almond.protocol.KernelInfo +import almond.toree.{CellMagicHook, LineMagicHook} import ammonite.compiler.Parsers import ammonite.repl.{ReplApiImpl => _, _} import ammonite.runtime._ @@ -62,27 +63,32 @@ final class ScalaInterpreter( params.useThreadInterrupt ) - lazy val ammInterp: ammonite.interp.Interpreter = { + val sessApi = new SessionApiImpl(frames0) - val sessApi = new SessionApiImpl(frames0) + val replApi = + new ReplApiImpl( + execute0, + storage, + colors0, + ammInterp, + sessApi + ) - val replApi = - new ReplApiImpl( - execute0, - storage, - colors0, - ammInterp, - sessApi - ) + val jupyterApi = + new JupyterApiImpl( + execute0, + commHandlerOpt, + replApi, + silent0, + params.allowVariableInspector + ) - val jupyterApi = - new JupyterApiImpl( - execute0, - commHandlerOpt, - replApi, - silent0, - params.allowVariableInspector - ) + if (params.toreeMagics) { + jupyterApi.addExecuteHook(LineMagicHook.hook(replApi.pprinter)) + jupyterApi.addExecuteHook(CellMagicHook.hook(jupyterApi.publish)) + } + + lazy val ammInterp: ammonite.interp.Interpreter = { for (ec <- params.updateBackgroundVariablesEcOpt) UpdatableFuture.setup(replApi, jupyterApi, ec) @@ -129,7 +135,15 @@ final class ScalaInterpreter( inputManager: Option[InputManager], outputHandler: Option[OutputHandler] ): ExecuteResult = - execute0(ammInterp, code, inputManager, outputHandler, colors0, storeHistory) + execute0( + ammInterp, + code, + inputManager, + outputHandler, + colors0, + storeHistory, + jupyterApi.executeHooks + ) def currentLine(): Int = execute0.currentLine diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreterParams.scala b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreterParams.scala index fdfe45c63..f8d0be7db 100644 --- a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreterParams.scala +++ b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreterParams.scala @@ -34,5 +34,6 @@ final case class ScalaInterpreterParams( autoUpdateVars: Boolean = true, allowVariableInspector: Option[Boolean] = None, useThreadInterrupt: Boolean = false, - outputDir: Either[os.Path, Boolean] = Right(true) + outputDir: Either[os.Path, Boolean] = Right(true), + toreeMagics: Boolean = false ) diff --git a/modules/scala/scala-interpreter/src/test/scala/almond/ScalaKernelTests.scala b/modules/scala/scala-interpreter/src/test/scala/almond/ScalaKernelTests.scala index b9f26fc44..82103854b 100644 --- a/modules/scala/scala-interpreter/src/test/scala/almond/ScalaKernelTests.scala +++ b/modules/scala/scala-interpreter/src/test/scala/almond/ScalaKernelTests.scala @@ -20,6 +20,7 @@ import fs2.Stream import utest._ import scala.collection.compat._ +import scala.util.Properties object ScalaKernelTests extends TestSuite { @@ -30,6 +31,8 @@ object ScalaKernelTests extends TestSuite { val threads = KernelThreads.create("test") + val maybePostImportNewLine = if (TestUtil.isScala2) "" else System.lineSeparator() + override def utestAfterAll() = { threads.attemptShutdown() if (!attemptShutdownExecutionContext(interpreterEc)) @@ -40,6 +43,16 @@ object ScalaKernelTests extends TestSuite { test("stdin") { + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + initialColors = Colors.BlackWhite + ), + logCtx = logCtx + ) + + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + // How the pseudo-client behaves val inputHandler = MessageHandler(Channel.Input, Input.requestType) { msg => @@ -64,30 +77,21 @@ object ScalaKernelTests extends TestSuite { (_, m) => IO.pure(m.header.msg_type == "execute_reply" && m.content.toString().contains("exit")) - val sessionId = UUID.randomUUID().toString + implicit val sessionId: SessionId = SessionId() // Initial messages from client val input = Stream( - execute(sessionId, "val n = scala.io.StdIn.readInt()"), - execute(sessionId, "val m = new java.util.Scanner(System.in).nextInt()"), - execute(sessionId, """val s = "exit"""") + execute("val n = scala.io.StdIn.readInt()"), + execute("val m = new java.util.Scanner(System.in).nextInt()"), + execute("""val s = "exit"""") ) val streams = ClientStreams.create(input, stopWhen, inputHandler.orElse(ignoreExpectedReplies)) - val interpreter = new ScalaInterpreter( - params = ScalaInterpreterParams( - initialColors = Colors.BlackWhite - ), - logCtx = logCtx - ) - - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val replies = streams.executeReplies @@ -102,12 +106,29 @@ object ScalaKernelTests extends TestSuite { test("stop on error") { - // How the pseudo-client behaves + // There's something non-deterministic in this test. + // It requires the 3 cells to execute to be sent at once at the beginning. + // The exception running the first cell must make the kernel discard (not compile + // nor run) the other two, that are queued. + // That means, the other 2 cells must have been queued when the first cell's thrown + // exception is caught by the kernel. + // Because of that, we can't rely on individual calls to 'kernel.execute' like + // the other tests do, as these send messages one after the other, sending the + // next one when the previous one is done running (so no messages would be queued.) - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + initialColors = Colors.BlackWhite + ), + logCtx = logCtx + ) - // When the pseudo-client exits + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + + implicit val sessionId: SessionId = SessionId() + + val lastMsgId = UUID.randomUUID().toString val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = (_, m) => @@ -115,27 +136,16 @@ object ScalaKernelTests extends TestSuite { m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId) ) - // Initial messages from client - val input = Stream( - execute(sessionId, """sys.error("foo")"""), - execute(sessionId, "val n = 2"), - execute(sessionId, """val s = "other"""", lastMsgId) + execute("""sys.error("foo")"""), + execute("val n = 2"), + execute("""val s = "other"""", lastMsgId) ) val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( - params = ScalaInterpreterParams( - initialColors = Colors.BlackWhite - ), - logCtx = logCtx - ) - - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val messageTypes = streams.generatedMessageTypes() @@ -161,32 +171,6 @@ object ScalaKernelTests extends TestSuite { test("jvm-repr") { - // How the pseudo-client behaves - - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString - - // When the pseudo-client exits - - val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = - (_, m) => - IO.pure( - m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId) - ) - - // Initial messages from client - - val input = Stream( - execute(sessionId, """class Bar(val value: String)"""), - execute( - sessionId, - """kernel.register[Bar](bar => Map("text/plain" -> s"Bar(${bar.value})"))""" - ), - execute(sessionId, """val b = new Bar("other")""", lastMsgId) - ) - - val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( params = ScalaInterpreterParams( initialColors = Colors.BlackWhite @@ -194,63 +178,25 @@ object ScalaKernelTests extends TestSuite { logCtx = logCtx ) - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() - val messageTypes = streams.generatedMessageTypes() + implicit val sessionId: SessionId = SessionId() - val expectedMessageTypes = Seq( - "execute_input", - "execute_result", - "execute_reply", - "execute_input", - "execute_reply", - "execute_input", - "display_data", - "execute_reply" + kernel.execute("""class Bar(val value: String)""", "defined class Bar") + kernel.execute( + """kernel.register[Bar](bar => Map("text/plain" -> s"Bar(${bar.value})"))""", + "" ) - - assert(messageTypes == expectedMessageTypes) - - val displayData = streams.displayData - - val expectedDisplayData = Seq( - ProtocolExecute.DisplayData( - Map("text/plain" -> RawJson("\"Bar(other)\"".bytes)), - Map() - ) -> false + kernel.execute( + """val b = new Bar("other")""", + "", + displaysText = Seq("Bar(other)") ) - - assert(displayData == expectedDisplayData) } test("updatable display") { - // How the pseudo-client behaves - - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString - - // When the pseudo-client exits - - val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = - (_, m) => - IO.pure( - m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId) - ) - - // Initial messages from client - - val input = Stream( - execute(sessionId, """val handle = Html("foo")"""), - execute(sessionId, """handle.withContent("bzz").update()"""), - execute(sessionId, """val s = "other"""", lastMsgId) - ) - - val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( params = ScalaInterpreterParams( initialColors = Colors.BlackWhite @@ -258,83 +204,26 @@ object ScalaKernelTests extends TestSuite { logCtx = logCtx ) - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() - t.unsafeRunTimedOrThrow() + implicit val sessionId: SessionId = SessionId() - val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector - val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector - - val expectedRequestsMessageTypes = Seq( - "execute_reply", - "execute_reply", - "execute_reply" + kernel.execute( + """val handle = Html("foo")""", + "", + displaysHtml = Seq("foo") ) - val expectedPublishMessageTypes = Seq( - "execute_input", - "display_data", - "execute_input", - "update_display_data", - "execute_input", - "execute_result" + kernel.execute( + """handle.withContent("bzz").update()""", + "", + displaysHtmlUpdates = Seq("bzz") ) - - assert(requestsMessageTypes == expectedRequestsMessageTypes) - assert(publishMessageTypes == expectedPublishMessageTypes) - - val displayData = streams.displayData - val id = { - val ids = displayData.flatMap(_._1.transient.display_id).toSet - assert(ids.size == 1) - ids.head - } - - val expectedDisplayData = Seq( - ProtocolExecute.DisplayData( - Map("text/html" -> RawJson("\"foo\"".bytes)), - Map(), - ProtocolExecute.DisplayData.Transient(Some(id)) - ) -> false, - ProtocolExecute.DisplayData( - Map("text/html" -> RawJson("\"bzz\"".bytes)), - Map(), - ProtocolExecute.DisplayData.Transient(Some(id)) - ) -> true - ) - - assert(displayData == expectedDisplayData) } test("auto-update Future results upon completion") { - // How the pseudo-client behaves - - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString - - // When the pseudo-client exits - - val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = - (_, m) => - IO.pure( - m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId) - ) - - // Initial messages from client - - val input = Stream( - execute( - sessionId, - "import scala.concurrent.Future; import scala.concurrent.ExecutionContext.Implicits.global" - ), - execute(sessionId, "val f = Future { Thread.sleep(3000L); 2 }"), - execute(sessionId, "Thread.sleep(6000L)", lastMsgId) - ) - - val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( params = ScalaInterpreterParams( updateBackgroundVariablesEcOpt = Some(bgVarEc), @@ -343,37 +232,57 @@ object ScalaKernelTests extends TestSuite { logCtx = logCtx ) - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() - t.unsafeRunTimedOrThrow() + implicit val sessionId: SessionId = SessionId() - val messageTypes = streams.generatedMessageTypes() + val sp = " " + val ls = System.lineSeparator() - val expectedMessageTypes = Seq( - "execute_input", - "execute_result", - "execute_reply", - "execute_input", - "display_data", - "execute_reply", - "execute_input", - // that one originates from the second line, but arrives while the third one is running - "update_display_data", - "execute_reply" + kernel.execute( + "import scala.concurrent.Future; import scala.concurrent.ExecutionContext.Implicits.global", + // Multi-line with stripMargin seems to be a problem on our Windows CI for this test, + // but not for the other ones using stripMarginā€¦ + s"import scala.concurrent.Future;$sp$ls" + + s"import scala.concurrent.ExecutionContext.Implicits.global$maybePostImportNewLine" ) - assert(messageTypes == expectedMessageTypes) + kernel.execute( + "val f = Future { Thread.sleep(3000L); 2 }", + "", + displaysText = Seq("f: Future[Int] = [running]") + ) + kernel.execute( + "Thread.sleep(6000L)", + "", + // the update originates from the previous cell, but arrives while the third one is running + displaysTextUpdates = Seq( + if (TestUtil.isScala212) "f: Future[Int] = Success(2)" + else "f: Future[Int] = Success(value = 2)" + ) + ) } test("auto-update Future results in background upon completion") { + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + updateBackgroundVariablesEcOpt = Some(bgVarEc), + initialColors = Colors.BlackWhite + ), + logCtx = logCtx + ) + + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + // same as above, except no cell is running when the future completes // How the pseudo-client behaves - val sessionId = UUID.randomUUID().toString + implicit val sessionId: SessionId = SessionId() // When the pseudo-client exits @@ -385,26 +294,15 @@ object ScalaKernelTests extends TestSuite { val input = Stream( execute( - sessionId, "import scala.concurrent.Future; import scala.concurrent.ExecutionContext.Implicits.global" ), - execute(sessionId, "val f = Future { Thread.sleep(3000L); 2 }") + execute("val f = Future { Thread.sleep(3000L); 2 }") ) val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( - params = ScalaInterpreterParams( - updateBackgroundVariablesEcOpt = Some(bgVarEc), - initialColors = Colors.BlackWhite - ), - logCtx = logCtx - ) - - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val messageTypes = streams.generatedMessageTypes() @@ -424,13 +322,23 @@ object ScalaKernelTests extends TestSuite { test("auto-update Rx stuff upon change") { if (isScala212) { - // How the pseudo-client behaves - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + updateBackgroundVariablesEcOpt = Some(bgVarEc), + initialColors = Colors.BlackWhite + ), + logCtx = logCtx + ) + + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + + implicit val sessionId: SessionId = SessionId() // When the pseudo-client exits + val lastMsgId = UUID.randomUUID().toString val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = (_, m) => IO.pure( @@ -440,26 +348,16 @@ object ScalaKernelTests extends TestSuite { // Initial messages from client val input = Stream( - execute(sessionId, "almondrx.setup()"), - execute(sessionId, "val a = rx.Var(1)"), - execute(sessionId, "a() = 2"), - execute(sessionId, "a() = 3", lastMsgId) + execute("almondrx.setup()"), + execute("val a = rx.Var(1)"), + execute("a() = 2"), + execute("a() = 3", lastMsgId) ) val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( - params = ScalaInterpreterParams( - updateBackgroundVariablesEcOpt = Some(bgVarEc), - initialColors = Colors.BlackWhite - ), - logCtx = logCtx - ) - - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector @@ -522,17 +420,24 @@ object ScalaKernelTests extends TestSuite { test("handle interrupt messages") { - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + initialColors = Colors.BlackWhite + ), + logCtx = logCtx + ) - // How the pseudo-client behaves + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + + implicit val sessionId: SessionId = SessionId() val interruptOnInput = MessageHandler(Channel.Input, Input.requestType) { msg => Message( Header( UUID.randomUUID().toString, "test", - sessionId, + sessionId.sessionId, Interrupt.requestType.messageType, Some(Protocol.versionStr) ), @@ -548,6 +453,7 @@ object ScalaKernelTests extends TestSuite { // When the pseudo-client exits + val lastMsgId = UUID.randomUUID().toString val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = (_, m) => IO.pure( @@ -559,24 +465,15 @@ object ScalaKernelTests extends TestSuite { // Initial messages from client val input = Stream( - execute(sessionId, "val n = scala.io.StdIn.readInt()"), - execute(sessionId, """val s = "ok done"""", msgId = lastMsgId) + execute("val n = scala.io.StdIn.readInt()"), + execute("""val s = "ok done"""", msgId = lastMsgId) ) val streams = ClientStreams.create(input, stopWhen, interruptOnInput.orElse(ignoreExpectedReplies)) - val interpreter = new ScalaInterpreter( - params = ScalaInterpreterParams( - initialColors = Colors.BlackWhite - ), - logCtx = logCtx - ) - - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val messageTypes = streams.generatedMessageTypes() val controlMessageTypes = streams.generatedMessageTypes(Set(Channel.Control)) @@ -600,31 +497,6 @@ object ScalaKernelTests extends TestSuite { test("start from custom class loader") { - // How the pseudo-client behaves - - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString - - // When the pseudo-client exits - - val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = - (_, m) => - IO.pure( - m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId) - ) - - // Initial messages from client - - val input = Stream( - execute( - sessionId, - """val url = Thread.currentThread().getContextClassLoader.getResource("foo")""" - ), - execute(sessionId, """assert(url.toString == "https://google.fr")""", lastMsgId) - ) - - val streams = ClientStreams.create(input, stopWhen) - val loader = new URLClassLoader(Array(), Thread.currentThread().getContextClassLoader) { override def getResource(name: String) = if (name == "foo") @@ -641,10 +513,31 @@ object ScalaKernelTests extends TestSuite { logCtx = logCtx ) - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + + implicit val sessionId: SessionId = SessionId() + val lastMsgId = UUID.randomUUID().toString - t.unsafeRunTimedOrThrow() + // When the pseudo-client exits + + val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = + (_, m) => + IO.pure( + m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId) + ) + + // Initial messages from client + + val input = Stream( + execute("""val url = Thread.currentThread().getContextClassLoader.getResource("foo")"""), + execute("""assert(url.toString == "https://google.fr")""", lastMsgId) + ) + + val streams = ClientStreams.create(input, stopWhen) + + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val messageTypes = streams.generatedMessageTypes() @@ -661,17 +554,6 @@ object ScalaKernelTests extends TestSuite { test("exit") { - val sessionId = UUID.randomUUID().toString - - // Initial messages from client - - val input = Stream( - execute(sessionId, "val n = 2"), - execute(sessionId, "exit") - ) - - val streams = ClientStreams.create(input) - val interpreter = new ScalaInterpreter( params = ScalaInterpreterParams( initialColors = Colors.BlackWhite @@ -679,10 +561,22 @@ object ScalaKernelTests extends TestSuite { logCtx = logCtx ) - val t = Kernel.create(interpreter, interpreterEc, threads) - .flatMap(_.run(streams.source, streams.sink)) + val kernel = Kernel.create(interpreter, interpreterEc, threads) + .unsafeRunTimedOrThrow() - t.unsafeRunTimedOrThrow() + implicit val sessionId: SessionId = SessionId() + + // Initial messages from client + + val input = Stream( + execute("val n = 2"), + execute("exit") + ) + + val streams = ClientStreams.create(input) + + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val messageTypes = streams.generatedMessageTypes() @@ -709,19 +603,6 @@ object ScalaKernelTests extends TestSuite { test("trap output") { - val sessionId = UUID.randomUUID().toString - - // Initial messages from client - - val input = Stream( - execute(sessionId, "val n = 2"), - execute(sessionId, """println("Hello")"""), - execute(sessionId, """System.err.println("Bbbb")"""), - execute(sessionId, "exit") - ) - - val streams = ClientStreams.create(input) - val interpreter = new ScalaInterpreter( params = ScalaInterpreterParams( initialColors = Colors.BlackWhite, @@ -730,10 +611,24 @@ object ScalaKernelTests extends TestSuite { logCtx = logCtx ) - val t = Kernel.create(interpreter, interpreterEc, threads) - .flatMap(_.run(streams.source, streams.sink)) + val kernel = Kernel.create(interpreter, interpreterEc, threads) + .unsafeRunTimedOrThrow() + + implicit val sessionId: SessionId = SessionId() + + // Initial messages from client + + val input = Stream( + execute("val n = 2"), + execute("""println("Hello")"""), + execute("""System.err.println("Bbbb")"""), + execute("exit") + ) + + val streams = ClientStreams.create(input) - t.unsafeRunTimedOrThrow() + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val messageTypes = streams.generatedMessageTypes() @@ -755,29 +650,6 @@ object ScalaKernelTests extends TestSuite { test("last exception") { - // How the pseudo-client behaves - - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString - - // When the pseudo-client exits - - val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = - (_, m) => - IO.pure( - m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId) - ) - - // Initial messages from client - - val input = Stream( - execute(sessionId, """val nullBefore = repl.lastException == null"""), - execute(sessionId, """sys.error("foo")""", stopOnError = false), - execute(sessionId, """val nullAfter = repl.lastException == null""", lastMsgId) - ) - - val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( params = ScalaInterpreterParams( initialColors = Colors.BlackWhite @@ -785,70 +657,21 @@ object ScalaKernelTests extends TestSuite { logCtx = logCtx ) - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() - val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector - val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector + implicit val sessionId: SessionId = SessionId() - val expectedRequestsMessageTypes = Seq( - "execute_reply", - "execute_reply", - "execute_reply" + kernel.execute( + """val nullBefore = repl.lastException == null""", + "nullBefore: Boolean = true" ) - - val expectedPublishMessageTypes = Seq( - "execute_input", - "execute_result", - "execute_input", - "error", - "execute_input", - "execute_result" - ) - - assert(requestsMessageTypes == expectedRequestsMessageTypes) - assert(publishMessageTypes == expectedPublishMessageTypes) - - val replies = streams.executeReplies - val expectedReplies = Map( - 1 -> "nullBefore: Boolean = true", - 3 -> "nullAfter: Boolean = false" - ) - assert(replies == expectedReplies) + kernel.execute("""sys.error("foo")""", expectError = true) + kernel.execute("""val nullAfter = repl.lastException == null""", "nullAfter: Boolean = false") } test("history") { - // How the pseudo-client behaves - - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString - - // When the pseudo-client exits - - val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = - (_, m) => - IO.pure( - m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId) - ) - - // Initial messages from client - - val input = Stream( - execute(sessionId, """val before = repl.history.toVector"""), - execute(sessionId, """val a = 2"""), - execute(sessionId, """val b = a + 1"""), - execute( - sessionId, - """val after = repl.history.toVector.mkString(",").toString""", - lastMsgId - ) - ) - - val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( params = ScalaInterpreterParams( initialColors = Colors.BlackWhite @@ -856,55 +679,40 @@ object ScalaKernelTests extends TestSuite { logCtx = logCtx ) - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() - val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector - val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector + implicit val sessionId: SessionId = SessionId() - val expectedRequestsMessageTypes = Seq( - "execute_reply", - "execute_reply", - "execute_reply", - "execute_reply" - ) - - val expectedPublishMessageTypes = Seq( - "execute_input", - "execute_result", - "execute_input", - "execute_result", - "execute_input", - "execute_result", - "execute_input", - "execute_result" + kernel.execute( + """val before = repl.history.toVector""", + """before: Vector[String] = Vector("val before = repl.history.toVector")""" ) - - assert(requestsMessageTypes == expectedRequestsMessageTypes) - assert(publishMessageTypes == expectedPublishMessageTypes) - - val replies = streams.executeReplies - val expectedReplies = Map( - 1 -> """before: Vector[String] = Vector("val before = repl.history.toVector")""", - 2 -> """a: Int = 2""", - 3 -> """b: Int = 3""", - 4 -> """after: String = "val before = repl.history.toVector,val a = 2,val b = a + 1,val after = repl.history.toVector.mkString(\",\").toString"""" + kernel.execute("val a = 2", "a: Int = 2") + kernel.execute("val b = a + 1", "b: Int = 3") + kernel.execute( + """val after = repl.history.toVector.mkString(",").toString""", + """after: String = "val before = repl.history.toVector,val a = 2,val b = a + 1,val after = repl.history.toVector.mkString(\",\").toString"""" ) - assert(replies == expectedReplies) } test("update vars") { if (AlmondCompilerLifecycleManager.isAtLeast_2_12_7 && TestUtil.isScala2) { - // How the pseudo-client behaves + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + updateBackgroundVariablesEcOpt = Some(bgVarEc), + initialColors = Colors.BlackWhite + ), + logCtx = logCtx + ) - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() - // When the pseudo-client exits + implicit val sessionId: SessionId = SessionId() + val lastMsgId = UUID.randomUUID().toString val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = (_, m) => IO.pure( @@ -914,25 +722,15 @@ object ScalaKernelTests extends TestSuite { // Initial messages from client val input = Stream( - execute(sessionId, """var n = 2"""), - execute(sessionId, """n = n + 1"""), - execute(sessionId, """n += 2""", lastMsgId) + execute("""var n = 2"""), + execute("""n = n + 1"""), + execute("""n += 2""", lastMsgId) ) val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( - params = ScalaInterpreterParams( - updateBackgroundVariablesEcOpt = Some(bgVarEc), - initialColors = Colors.BlackWhite - ), - logCtx = logCtx - ) - - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector @@ -992,13 +790,20 @@ object ScalaKernelTests extends TestSuite { def updateLazyValsTest(): Unit = { - // How the pseudo-client behaves + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + updateBackgroundVariablesEcOpt = Some(bgVarEc), + initialColors = Colors.BlackWhite + ), + logCtx = logCtx + ) - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() - // When the pseudo-client exits + implicit val sessionId: SessionId = SessionId() + val lastMsgId = UUID.randomUUID().toString val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = (_, m) => IO.pure( @@ -1008,25 +813,15 @@ object ScalaKernelTests extends TestSuite { // Initial messages from client val input = Stream( - execute(sessionId, """lazy val n = 2"""), - execute(sessionId, """val a = { n; () }"""), - execute(sessionId, """val b = { n; () }""", lastMsgId) + execute("""lazy val n = 2"""), + execute("""val a = { n; () }"""), + execute("""val b = { n; () }""", lastMsgId) ) val streams = ClientStreams.create(input, stopWhen) - val interpreter = new ScalaInterpreter( - params = ScalaInterpreterParams( - updateBackgroundVariablesEcOpt = Some(bgVarEc), - initialColors = Colors.BlackWhite - ), - logCtx = logCtx - ) - - val t = Kernel.create(interpreter, interpreterEc, threads, logCtx) - .flatMap(_.run(streams.source, streams.sink)) - - t.unsafeRunTimedOrThrow() + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector @@ -1081,6 +876,238 @@ object ScalaKernelTests extends TestSuite { if (TestUtil.isScala2) updateLazyValsTest() else "disabled" } + + test("hooks") { + + val predef = + """private val foos0 = new scala.collection.mutable.ListBuffer[String] + | + |def foos(): List[String] = + | foos0.result() + | + |kernel.addExecuteHook { code => + | import almond.api.JupyterApi + | + | var errorOpt = Option.empty[JupyterApi.ExecuteHookResult] + | val remainingCode = code.linesWithSeparators.zip(code.linesIterator) + | .map { + | case (originalLine, line) => + | if (line == "%AddFoo") "" + | else if (line.startsWith("%AddFoo ")) { + | foos0 ++= line.split("\\s+").drop(1).toSeq + | "" + | } + | else if (line == "%Error") { + | errorOpt = Some(JupyterApi.ExecuteHookResult.Error("thing", "erroring", List("aa", "bb"))) + | "" + | } + | else if (line == "%Abort") { + | errorOpt = Some(JupyterApi.ExecuteHookResult.Abort) + | "" + | } + | else if (line == "%Exit") { + | errorOpt = Some(JupyterApi.ExecuteHookResult.Exit) + | "" + | } + | else + | originalLine + | } + | .mkString + | + | errorOpt.toLeft(remainingCode) + |} + |""".stripMargin + + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + initialColors = Colors.BlackWhite, + predefCode = predef + ), + logCtx = logCtx + ) + + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + + implicit val sessionId: SessionId = SessionId() + + kernel.execute( + "val before = foos()", + "before: List[String] = List()" + ) + kernel.execute("""%AddFoo a""", "") + kernel.execute("""%AddFoo b""", "") + kernel.execute( + "val after = foos()", + """after: List[String] = List("a", "b")""" + ) + kernel.execute( + "%Error", + errors = Seq( + ("thing", "erroring", List("thing: erroring", " aa", " bb")) + ) + ) + kernel.execute("%Abort") + kernel.execute( + """val a = "a" + |""".stripMargin, + """a: String = "a"""" + ) + kernel.execute( + """%Exit + |val b = "b" + |""".stripMargin, + "" + ) + } + + test("toree AddDeps") { + toreeAddDepsTest() + } + // unsupported yet, needs tweaking in the Ammonite dependency parser + // test("toree AddDeps intransitive") { + // toreeAddDepsTest(transitive = false) + // } + + def toreeAddDepsTest(transitive: Boolean = true): Unit = { + + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + initialColors = Colors.BlackWhite, + toreeMagics = true + ), + logCtx = logCtx + ) + + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + + implicit val sessionId: SessionId = SessionId() + + val sbv = { + val sv = Properties.versionNumberString + if (sv.startsWith("2.")) sv.split('.').take(2).mkString(".") + else sv.takeWhile(_ != '.') + } + + kernel.execute( + "import caseapp.CaseApp", + errors = Seq( + ("", "Compilation Failed", List("Compilation Failed")) + ), + ignoreStreams = true + ) + kernel.execute( + "import caseapp.util._", + errors = Seq( + ("", "Compilation Failed", List("Compilation Failed")) + ), + ignoreStreams = true + ) + val suffix = if (transitive) " --transitive" else "" + kernel.execute( + s"%AddDeps com.github.alexarchambault case-app_$sbv 2.1.0-M24" + suffix, + "import $ivy.$ " + (" " * sbv.length) + maybePostImportNewLine, + ignoreStreams = true // ignoring coursier messages (that it prints when downloading things) + ) + kernel.execute( + "import caseapp.CaseApp", + "import caseapp.CaseApp" + maybePostImportNewLine + ) + if (transitive) + kernel.execute( + "import caseapp.util._", + "import caseapp.util._" + maybePostImportNewLine + ) + else + kernel.execute( + "import caseapp.util._", + errors = Seq( + ("", "Compilation Failed", List("Compilation Failed")) + ), + ignoreStreams = true + ) + } + + test("toree Html") { + + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + initialColors = Colors.BlackWhite, + toreeMagics = true + ), + logCtx = logCtx + ) + + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + + implicit val sessionId: SessionId = SessionId() + + kernel.execute( + """%%html + |

+ |Hello + |

+ |""".stripMargin, + "", + displaysHtml = Seq( + """

+ |Hello + |

+ |""".stripMargin + ) + ) + } + + test("toree Truncation") { + + val interpreter = new ScalaInterpreter( + params = ScalaInterpreterParams( + initialColors = Colors.BlackWhite, + toreeMagics = true + ), + logCtx = logCtx + ) + + val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx) + .unsafeRunTimedOrThrow() + + implicit val sessionId: SessionId = SessionId() + + val nl = System.lineSeparator() + + kernel.execute( + "%truncation", + "", + stdout = + "Truncation is currently on" + nl + ) + kernel.execute( + "%truncation off", + "", + stdout = + "Output will NOT be truncated" + nl + ) + kernel.execute( + "(1 to 200).toVector", + "res0: Vector[Int] = " + (1 to 200).toVector.toString + ) + kernel.execute( + "%truncation on", + "", + stdout = + "Output WILL be truncated." + nl + ) + kernel.execute( + "(1 to 200).toVector", + "res1: Vector[Int] = " + + (1 to 38) + .toVector + .map(" " + _ + "," + "\n") + .mkString("Vector(" + "\n", "", "...") + ) + } } } diff --git a/modules/scala/scala-interpreter/src/test/scala/almond/TestUtil.scala b/modules/scala/scala-interpreter/src/test/scala/almond/TestUtil.scala index 9e620fc24..4a0b769b8 100644 --- a/modules/scala/scala-interpreter/src/test/scala/almond/TestUtil.scala +++ b/modules/scala/scala-interpreter/src/test/scala/almond/TestUtil.scala @@ -44,17 +44,144 @@ object TestUtil { } } + final case class SessionId(sessionId: String = UUID.randomUUID().toString) + + implicit class KernelOps(private val kernel: Kernel) extends AnyVal { + def execute( + code: String, + reply: String = null, + expectError: Boolean = false, + errors: Seq[(String, String, List[String])] = null, + displaysText: Seq[String] = null, + displaysHtml: Seq[String] = null, + displaysTextUpdates: Seq[String] = null, + displaysHtmlUpdates: Seq[String] = null, + ignoreStreams: Boolean = false, + stdout: String = null + )(implicit sessionId: SessionId): Unit = { + + val expectError0 = expectError || Option(errors).nonEmpty + val ignoreStreams0 = ignoreStreams || Option(stdout).nonEmpty + + val input = Stream( + TestUtil.execute(code, stopOnError = !expectError0) + ) + + val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = + (_, m) => IO.pure(m.header.msg_type == "execute_reply") + + val streams = ClientStreams.create(input, stopWhen) + + kernel.run(streams.source, streams.sink) + .unsafeRunTimedOrThrow() + + val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector + val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector + .filter(if (ignoreStreams0) _ != "stream" else _ => true) + + val expectedRequestsMessageTypes = + if (reply == null && !expectError0) + Nil + else + Seq("execute_reply") + assert(requestsMessageTypes == Seq("execute_reply")) + + val expectedPublishMessageTypes = { + val displayDataCount = Seq( + Option(displaysText).fold(0)(_.length), + Option(displaysHtml).fold(0)(_.length) + ).max + val updateDisplayDataCount = Seq( + Option(displaysTextUpdates).fold(0)(_.length), + Option(displaysHtmlUpdates).fold(0)(_.length) + ).max + val prefix = Seq("execute_input") ++ + Seq.fill(displayDataCount)("display_data") ++ + Seq.fill(updateDisplayDataCount)("update_display_data") + if (expectError0) + prefix :+ "error" + else if (reply == null || reply.isEmpty) + prefix + else + prefix :+ "execute_result" + } + assert(publishMessageTypes == expectedPublishMessageTypes) + + if (stdout != null) { + val stdoutMessages = streams.output.mkString + assert(stdout == stdoutMessages) + } + + val replies = streams.executeReplies.toVector.sortBy(_._1).map(_._2) + assert(replies == Option(reply).toVector) + + for (expectedTextDisplay <- Option(displaysText)) { + import ClientStreams.RawJsonOps + + val textDisplay = streams.displayData.collect { + case (data, false) => + data.data.get("text/plain") + .map(_.stringOrEmpty) + .getOrElse("") + } + + assert(textDisplay == expectedTextDisplay) + } + + val receivedErrors = streams.executeErrors.toVector.sortBy(_._1).map(_._2) + assert(errors == null || receivedErrors == errors) + + for (expectedHtmlDisplay <- Option(displaysHtml)) { + import ClientStreams.RawJsonOps + + val htmlDisplay = streams.displayData.collect { + case (data, false) => + data.data.get("text/html") + .map(_.stringOrEmpty) + .getOrElse("") + } + + assert(htmlDisplay == expectedHtmlDisplay) + } + + for (expectedTextDisplayUpdates <- Option(displaysTextUpdates)) { + import ClientStreams.RawJsonOps + + val textDisplayUpdates = streams.displayData.collect { + case (data, true) => + data.data.get("text/plain") + .map(_.stringOrEmpty) + .getOrElse("") + } + + assert(textDisplayUpdates == expectedTextDisplayUpdates) + } + + for (expectedHtmlDisplayUpdates <- Option(displaysHtmlUpdates)) { + import ClientStreams.RawJsonOps + + val htmlDisplayUpdates = streams.displayData.collect { + case (data, true) => + data.data.get("text/html") + .map(_.stringOrEmpty) + .getOrElse("") + } + + assert(htmlDisplayUpdates == expectedHtmlDisplayUpdates) + } + } + } + def execute( - sessionId: String, code: String, msgId: String = UUID.randomUUID().toString, stopOnError: Boolean = true - ) = + )(implicit sessionId: SessionId) = Message( Header( msgId, "test", - sessionId, + sessionId.sessionId, ProtocolExecute.requestType.messageType, Some(Protocol.versionStr) ), @@ -71,8 +198,8 @@ object TestUtil { val (input, replies) = inputs.unzip - val sessionId = UUID.randomUUID().toString - val lastMsgId = UUID.randomUUID().toString + implicit val sessionId: SessionId = SessionId() + val lastMsgId = UUID.randomUUID().toString val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] = (_, m) => @@ -83,8 +210,8 @@ object TestUtil { assert(input.nonEmpty) val input0 = Stream( - input.init.map(s => execute(sessionId, s)) :+ - execute(sessionId, input.last, lastMsgId): _* + input.init.map(s => execute(s)) :+ + execute(input.last, lastMsgId): _* ) val streams = ClientStreams.create(input0, stopWhen) diff --git a/modules/scala/scala-kernel/src/main/scala/almond/Options.scala b/modules/scala/scala-kernel/src/main/scala/almond/Options.scala index 7dc437fcc..741950b2a 100644 --- a/modules/scala/scala-kernel/src/main/scala/almond/Options.scala +++ b/modules/scala/scala-kernel/src/main/scala/almond/Options.scala @@ -72,7 +72,10 @@ final case class Options( @ExtraName("outputDir") outputDirectory: Option[String] = None, @ExtraName("tmpOutputDir") - tmpOutputDirectory: Option[Boolean] = None + tmpOutputDirectory: Option[Boolean] = None, + + @HelpMessage("Add experimental support for Toree magics") + toreeMagics: Boolean = false ) { // format: on diff --git a/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala b/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala index b6562b733..f5ea9bea9 100644 --- a/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala +++ b/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala @@ -131,7 +131,8 @@ object ScalaKernel extends CaseApp[Options] { .toLeft { options.tmpOutputDirectory .getOrElse(true) // Create tmp output dir by default - } + }, + toreeMagics = options.toreeMagics ), logCtx = logCtx ) diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandler.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandler.scala new file mode 100644 index 000000000..e92497289 --- /dev/null +++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandler.scala @@ -0,0 +1,7 @@ +package almond.toree + +import almond.api.JupyterApi + +trait CellMagicHandler { + def handle(name: String, content: String): Either[JupyterApi.ExecuteHookResult, String] +} diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandlers.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandlers.scala new file mode 100644 index 000000000..8142c11c2 --- /dev/null +++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandlers.scala @@ -0,0 +1,24 @@ +package almond.toree + +import almond.api.JupyterApi +import almond.interpreter.api.DisplayData.ContentType +import almond.interpreter.api.{DisplayData, OutputHandler} + +object CellMagicHandlers { + + class DisplayDataHandler(publish: OutputHandler, contentType: String) extends CellMagicHandler { + def handle(name: String, content: String): Either[JupyterApi.ExecuteHookResult, String] = { + publish.display(DisplayData(Map(contentType -> content))) + Right("") + } + } + + def handlers(publish: OutputHandler) = Map( + "html" -> new DisplayDataHandler(publish, ContentType.html), + "javascript" -> new DisplayDataHandler(publish, ContentType.js) + ) + + def handlerKeys: Iterable[String] = + handlers(OutputHandler.NopOutputHandler).keys + +} diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHook.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHook.scala new file mode 100644 index 000000000..038074a86 --- /dev/null +++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHook.scala @@ -0,0 +1,32 @@ +package almond.toree + +import almond.api.JupyterApi +import almond.interpreter.api.OutputHandler + +import java.util.Locale + +object CellMagicHook { + + def hook(publish: OutputHandler): JupyterApi.ExecuteHook = { + val handlers = CellMagicHandlers.handlers(publish) + code => + val nameOpt = code.linesIterator.take(1).toList.collectFirst { + case name if name.startsWith("%%") => + name.stripPrefix("%%") + } + nameOpt match { + case Some(name) => + handlers.get(name.toLowerCase(Locale.ROOT)) match { + case Some(handler) => + val content = code.linesWithSeparators.drop(1).mkString + handler.handle(name, content) + case None => + System.err.println(s"Warning: ignoring unrecognized Toree cell magic $name") + Right("") + } + case None => + Right(code) + } + } + +} diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandler.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandler.scala new file mode 100644 index 000000000..e10e1832f --- /dev/null +++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandler.scala @@ -0,0 +1,7 @@ +package almond.toree + +import almond.api.JupyterApi + +trait LineMagicHandler { + def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String] +} diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandlers.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandlers.scala new file mode 100644 index 000000000..328b341c6 --- /dev/null +++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandlers.scala @@ -0,0 +1,185 @@ +package almond.toree + +import ammonite.util.Ref +import almond.api.JupyterApi + +import java.net.URI +import java.nio.file.Paths + +object LineMagicHandlers { + + class AddDepHandler extends LineMagicHandler { + import AddDepHandler._ + def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String] = + values match { + case Seq(org, name, ver, other @ _*) => + val params = parseParams(other.toList, Params()) + val depSuffix = if (params.transitive) "" else ",intransitive" + val repoSuffix = params.repositories.map(repo => s"; import $$repo.`$repo`").mkString + + if (params.trace) + System.err.println(s"Warning: ignoring unsupported %AddDeps argument --trace") + if (params.verbose) + System.err.println(s"Warning: ignoring unsupported %AddDeps argument --verbose") + if (params.abortOnResolutionErrors) + System.err.println( + s"Warning: ignoring unsupported %AddDeps argument --abort-on-resolution-errors" + ) + + Right(s"import $$ivy.`$org:$name:$ver$depSuffix`$repoSuffix") + case _ => + System.err.println( + s"Warning: ignoring malformed %AddDeps Toree magic (expected '%AddDeps org name version [optional-arguments*]')" + ) + Right("") + } + } + + object AddDepHandler { + // not 100% sure about the default values, but https://github.com/apache/incubator-toree/blob/5b19aac2e56a56d35c888acc4ed5e549b1f4ed7c/etc/examples/notebooks/magic-tutorial.ipynb + // seems to imply these are correct + private case class Params( + transitive: Boolean = false, + trace: Boolean = false, + verbose: Boolean = false, + abortOnResolutionErrors: Boolean = false, + repositories: Seq[String] = Nil + ) + + private def parseParams(args: List[String], params: Params): Params = + args match { + case Nil => params + case "--trace" :: t => parseParams(t, params.copy(trace = true)) + case "--verbose" :: t => parseParams(t, params.copy(verbose = true)) + case "--transitive" :: t => parseParams(t, params.copy(transitive = true)) + case "--abort-on-resolution-errors" :: t => + parseParams(t, params.copy(abortOnResolutionErrors = true)) + case "--repository" :: repo :: t => + parseParams(t, params.copy(repositories = params.repositories :+ repo)) + case other :: t => + System.err.println(s"Warning: ignoring unrecognized %AddDeps argument '$other'") + parseParams(t, params) + } + } + + class AddJarHandler extends LineMagicHandler { + import AddJarHandler._ + def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String] = + values match { + case Seq(url, other @ _*) => + val uri = new URI(url) + val params = parseParams(other.toList, Params()) + if (params.force) + System.err.println(s"Warning: ignoring unsupported %AddJar argument -f") + if (params.magic) + System.err.println(s"Warning: ignoring unsupported %AddJar argument --magic") + if (uri.getScheme == "file") { + val path = Paths.get(uri) + Right(s"import $$cp.`$path`") + } + else { + System.err.println( + s"Warning: ignoring %AddJar URL $url (unsupported protocol ${uri.getScheme})" + ) + Right("") + } + case _ => + System.err.println( + s"Warning: ignoring malformed %AddJar Toree magic (expected '%AddJar url [optional-arguments*]')" + ) + Right("") + } + } + + object AddJarHandler { + private case class Params( + force: Boolean = false, + magic: Boolean = false + ) + + private def parseParams(args: List[String], params: Params): Params = + args match { + case Nil => params + case "-f" :: t => parseParams(t, params.copy(force = true)) + case "--magic" :: t => parseParams(t, params.copy(magic = true)) + case other :: t => + System.err.println(s"Warning: ignoring unrecognized %AddJar argument '$other'") + parseParams(t, params) + } + } + + class LsMagicHandler extends LineMagicHandler { + def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String] = { + if (values.nonEmpty) + System.err.println( + s"Warning: ignoring unrecognized values passed to %LsMagic: ${values.mkString(" ")}" + ) + + println("Available line magics:") + println(handlerKeys.toVector.sorted.map("%" + _).mkString(" ")) + println() + + println("Available cell magics:") + println(CellMagicHandlers.handlerKeys.toVector.sorted.map("%%" + _).mkString(" ")) + println() + + Right("") + } + } + + class TruncationHandler(pprinter: Ref[pprint.PPrinter]) extends LineMagicHandler { + private def enabled() = { + val current = pprinter() + current.defaultWidth != Int.MaxValue && current.defaultHeight != Int.MaxValue + } + private var formerWidth = pprinter().defaultWidth + private var formerHeight = pprinter().defaultHeight + private def disable(): Unit = + if (enabled()) { + formerWidth = pprinter().defaultWidth + formerHeight = pprinter().defaultHeight + + pprinter.update { + pprinter().copy( + defaultWidth = Int.MaxValue, + defaultHeight = Int.MaxValue + ) + } + } + private def enable(): Unit = + if (!enabled()) + pprinter.update { + pprinter().copy( + defaultWidth = formerWidth, + defaultHeight = formerHeight + ) + } + def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String] = { + values match { + case Seq() => + val state = if (enabled()) "on" else "off" + println(s"Truncation is currently $state") + case Seq("on") => + enable() + println("Output WILL be truncated.") + case Seq("off") => + disable() + println("Output will NOT be truncated") + case _ => + System.err.println( + s"Warning: ignoring %truncation magic with unrecognized parameters ${values.mkString(" ")}" + ) + } + Right("") + } + } + + def handlers(pprinter: Ref[pprint.PPrinter]) = Map( + "adddeps" -> new AddDepHandler, + "addjar" -> new AddJarHandler, + "lsmagic" -> new LsMagicHandler, + "truncation" -> new TruncationHandler(pprinter) + ) + def handlerKeys: Iterable[String] = + handlers(Ref(pprint.PPrinter())).keys +} diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHook.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHook.scala new file mode 100644 index 000000000..0a3c1d7bd --- /dev/null +++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHook.scala @@ -0,0 +1,69 @@ +package almond.toree + +import almond.api.JupyterApi +import ammonite.util.Ref + +import java.util.regex.Pattern +import java.util.Locale + +import scala.collection.mutable + +object LineMagicHook { + + private val sep = Pattern.compile("\\s+") + + def inspect(code: String): Iterator[Either[(Seq[String], String, String), String]] = { + var parsingMagics = true + code.linesWithSeparators.zip(code.linesIterator).map { + case (rawLine, line) => + if (parsingMagics && line.startsWith("%") && !line.drop(1).startsWith("%")) + Left((sep.split(line).toSeq, rawLine, line)) + else { + if (parsingMagics) { + val trimmed = line.trim() + parsingMagics = trimmed.isEmpty || trimmed.startsWith("//") + } + Right(rawLine) + } + } + } + + def hook(pprinter: Ref[pprint.PPrinter]): JupyterApi.ExecuteHook = { + + val handlers = LineMagicHandlers.handlers(pprinter) + + code => + + val magicsIt = inspect(code) + var errorOpt = Option.empty[JupyterApi.ExecuteHookResult] + val remainingCode = new StringBuilder + while (magicsIt.hasNext && errorOpt.isEmpty) + magicsIt.next() match { + case Left((elems, rawLine, line)) => + assert(elems.nonEmpty) + + val name = elems.head + val values = elems.tail + + assert(name.startsWith("%")) + + handlers.get(name.toLowerCase(Locale.ROOT).stripPrefix("%")) match { + case None => + System.err.println(s"Warning: ignoring unrecognized Toree line magic $name") + case Some(handler) => + handler.handle(name, values) match { + case Left(res) => errorOpt = Some(res) + case Right(substituteCode) => + remainingCode ++= substituteCode + remainingCode ++= rawLine.substring(line.length) + } + } + + case Right(code) => + remainingCode ++= code + } + + errorOpt.toLeft(remainingCode.result()) + } + +} diff --git a/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/OutputHandler.scala b/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/OutputHandler.scala index ce0453463..04df0a615 100644 --- a/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/OutputHandler.scala +++ b/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/OutputHandler.scala @@ -86,4 +86,11 @@ object OutputHandler { underlying.updateDisplay(displayData) } + object NopOutputHandler extends OutputHandler { + def stdout(s: String): Unit = () + def stderr(s: String): Unit = () + def display(displayData: DisplayData): Unit = () + def updateDisplay(displayData: DisplayData): Unit = () + } + } diff --git a/modules/shared/interpreter/src/main/scala/almond/interpreter/InterpreterToIOInterpreter.scala b/modules/shared/interpreter/src/main/scala/almond/interpreter/InterpreterToIOInterpreter.scala index 248788ef6..0f6ccc598 100644 --- a/modules/shared/interpreter/src/main/scala/almond/interpreter/InterpreterToIOInterpreter.scala +++ b/modules/shared/interpreter/src/main/scala/almond/interpreter/InterpreterToIOInterpreter.scala @@ -29,9 +29,8 @@ final class InterpreterToIOInterpreter( private val cancelledSignal0 = { implicit val shift = - IO.contextShift( - interpreterEc - ) // maybe not the right ec, but that one shouldn't be used yet at that point + // maybe not the right ec, but that one shouldn't be used yet at that point + IO.contextShift(interpreterEc) SignallingRef[IO, Boolean](false).unsafeRunSync() } def cancelledSignal: SignallingRef[IO, Boolean] = diff --git a/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala b/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala index 130e37957..cdbd5d702 100644 --- a/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala +++ b/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala @@ -120,7 +120,7 @@ final case class Kernel( immediateHandlers.handleOrLogError(channel, rawMessage, log) match { case None => - log.warn(s"Ignoring unhandled message:\n$rawMessage") + log.warn(s"Ignoring unhandled message on $channel:\n$rawMessage") Stream.empty case Some(output) => diff --git a/modules/shared/kernel/src/test/scala/almond/kernel/ClientStreams.scala b/modules/shared/kernel/src/test/scala/almond/kernel/ClientStreams.scala index 4cfca75e0..cf0b213eb 100644 --- a/modules/shared/kernel/src/test/scala/almond/kernel/ClientStreams.scala +++ b/modules/shared/kernel/src/test/scala/almond/kernel/ClientStreams.scala @@ -115,6 +115,23 @@ final case class ClientStreams( } .toMap + def executeErrors: Map[Int, (String, String, List[String])] = + generatedMessages + .iterator + .collect { + case Left((Channel.Requests, m)) if m.header.msg_type == Execute.replyType.messageType => + m.decodeAs[Execute.Reply] match { + case Left(_) => Nil + case Right(m) => Seq(m.content) + } + } + .flatten + .collect { + case e: Execute.Reply.Error => + (e.execution_count, (e.ename, e.evalue, e.traceback)) + } + .toMap + def executeReplyPayloads: Map[Int, Seq[RawJson]] = generatedMessages .iterator @@ -175,6 +192,16 @@ final case class ClientStreams( case _ => Nil } } + case Left((Channel.Publish, m)) + if m.header.msg_type == "stream" => + m.decodeAs[Execute.Stream] match { + case Left(_) => Nil + case Right(m) => + if (m.content.name == "stdout") + Seq(m.content.text) + else + Nil + } case Left((Channel.Publish, m)) if m.header.msg_type == "display_data" || m.header.msg_type == "update_display_data" => m.decodeAs[Execute.DisplayData] match { @@ -191,7 +218,7 @@ object ClientStreams { import com.github.plokhotnyuk.jsoniter_scala.core._ - private implicit class RawJsonOps(private val rawJson: RawJson) extends AnyVal { + implicit class RawJsonOps(private val rawJson: RawJson) extends AnyVal { def stringOrEmpty: String = Try(readFromArray[String](rawJson.value)).toOption.getOrElse("") } diff --git a/project/deps.sc b/project/deps.sc index 9cc186289..41cca24f6 100644 --- a/project/deps.sc +++ b/project/deps.sc @@ -57,6 +57,7 @@ object Deps { def jvmRepr = ivy"com.github.jupyter:jvm-repr:0.4.0" def mdoc = ivy"org.scalameta::mdoc:2.3.7" def metabrowseServer = ivy"org.scalameta:::metabrowse-server:0.2.9" + def pprint = ivy"com.lihaoyi::pprint:0.8.1" def scalafmtDynamic = ivy"org.scalameta::scalafmt-dynamic:${Versions.scalafmt}" def scalapy = ivy"me.shadaj::scalapy-core:0.5.2" def scalaReflect(sv: String) = ivy"org.scala-lang:scala-reflect:$sv" @@ -103,4 +104,9 @@ object ScalaVersions { "2.12.10", "2.12.9" ).distinct + + def binary(sv: String) = + if (sv.startsWith("2.12.")) scala212 + else if (sv.startsWith("2.13.")) scala213 + else scala3Compat } diff --git a/project/jupyterserver.sc b/project/jupyterserver.sc index c52cacc0a..5304f0e05 100644 --- a/project/jupyterserver.sc +++ b/project/jupyterserver.sc @@ -13,7 +13,8 @@ def writeKernelJson(launcher: Path, jupyterDir: Path): Unit = { "$launcherPath", "--log", "info", "--connection-file", "{connection_file}", - "--variable-inspector" + "--variable-inspector", + "--toree-magics" ] }""" Files.write(dir.resolve("kernel.json"), kernelJson.getBytes("UTF-8"))