diff --git a/build.sc b/build.sc
index 0391de934..21ac0279b 100644
--- a/build.sc
+++ b/build.sc
@@ -159,12 +159,14 @@ class ScalaInterpreter(val crossScalaVersion: String) extends AlmondModule with
if (crossScalaVersion.startsWith("3."))
Seq(
shared.interpreter(ScalaVersions.scala3Compat),
- scala.`scala-kernel-api`()
+ scala.`scala-kernel-api`(),
+ scala.`toree-hooks`(ScalaVersions.binary(crossScalaVersion))
)
else
Seq(
shared.interpreter(),
- scala.`scala-kernel-api`()
+ scala.`scala-kernel-api`(),
+ scala.`toree-hooks`(ScalaVersions.binary(crossScalaVersion))
)
def ivyDeps = T {
val metabrowse =
@@ -302,6 +304,12 @@ class Echo(val crossScalaVersion: String) extends AlmondModule {
}
}
+class ToreeHooks(val crossScalaVersion: String) extends AlmondModule {
+ def compileModuleDeps = super.compileModuleDeps ++ Seq(
+ scala.`scala-kernel-api`(ScalaVersions.binary(crossScalaVersion))
+ )
+}
+
object shared extends Module {
object `logger-scala2-macros` extends Cross[LoggerScala2Macros](ScalaVersions.binaries: _*)
object logger extends Cross[Logger](ScalaVersions.binaries: _*)
@@ -325,6 +333,8 @@ object scala extends Module {
extends Cross[ScalaKernelHelper](ScalaVersions.all.filter(_.startsWith("3.")): _*)
object `almond-scalapy` extends Cross[AlmondScalaPy](ScalaVersions.binaries: _*)
object `almond-rx` extends Cross[AlmondRx](ScalaVersions.scala212, ScalaVersions.scala213)
+
+ object `toree-hooks` extends Cross[ToreeHooks](ScalaVersions.binaries: _*)
}
object echo extends Cross[Echo](ScalaVersions.binaries: _*)
diff --git a/docs/pages/api-jupyter.md b/docs/pages/api-jupyter.md
index 676a356f3..640c4a3ba 100644
--- a/docs/pages/api-jupyter.md
+++ b/docs/pages/api-jupyter.md
@@ -153,3 +153,50 @@ kernel.publish.updateHtml("Got all items", id)
```
![](/demo/updatable.gif)
+
+### Hooks
+
+Hooks allow to pre-process code right before it's executed. Use like
+```scala
+private def runSql(sql: String): String = {
+ println("Running query...")
+ val fakeResult =
+ """
+ |
+ |Id |
+ |Name |
+ |
+ |
+ |1 |
+ |Tree |
+ |
+ |
+ |2 |
+ |Apple |
+ |
+ |
+ |""".stripMargin
+ fakeResult
+}
+
+kernel.addExecuteHook { code =>
+ import almond.api.JupyterApi
+ import almond.interpreter.api.DisplayData
+
+ if (code.linesIterator.take(1).toList == List("%sql")) {
+ val sql = code.linesWithSeparators.drop(1).mkString // drop first line with "%sql"
+ val result = runSql(sql)
+ Left(JupyterApi.ExecuteHookResult.Success(DisplayData.html(result)))
+ }
+ else
+ Right(code) // just pass on code
+}
+```
+
+Such code can be run either in a cell or in a predef file.
+
+Later on, users can run things like
+```text
+%sql
+SELECT id, name FROM my_table
+```
diff --git a/docs/pages/install-options.md b/docs/pages/install-options.md
index 6a72a6cfc..d26058bac 100644
--- a/docs/pages/install-options.md
+++ b/docs/pages/install-options.md
@@ -157,3 +157,12 @@ Add entries to the Help menu (Jupyter classic). Use like
#### `--connection-file`
+## Other
+
+#### `--toree-magics`
+
+Enable experimental support for [Toree](https://toree.apache.org) magics.
+
+Simple line magics such as `%AddDeps` (always assumed to be transitive as of writing this, `--transitive` is just ignored),
+`%AddJar`, and cell magics such as `%%html` or `%%javascript` are supported. Note that `%%javascript` only works from Jupyter
+classic, as JupyterLab doesn't allow for random javascript code execution.
diff --git a/modules/scala/jupyter-api/src/main/scala/almond/api/JupyterApi.scala b/modules/scala/jupyter-api/src/main/scala/almond/api/JupyterApi.scala
index 5cd72741b..7409cd11d 100644
--- a/modules/scala/jupyter-api/src/main/scala/almond/api/JupyterApi.scala
+++ b/modules/scala/jupyter-api/src/main/scala/almond/api/JupyterApi.scala
@@ -2,7 +2,7 @@ package almond.api
import java.util.UUID
-import almond.interpreter.api.{CommHandler, OutputHandler}
+import almond.interpreter.api.{CommHandler, DisplayData, OutputHandler}
import jupyter.{Displayer, Displayers}
import scala.reflect.{ClassTag, classTag}
@@ -47,10 +47,62 @@ abstract class JupyterApi { api =>
def display(t: T) = f(t).asJava
}
)
+
+ def addExecuteHook(hook: JupyterApi.ExecuteHook): Boolean
+ def removeExecuteHook(hook: JupyterApi.ExecuteHook): Boolean
}
object JupyterApi {
+ /** A hook, that can pre-process code right before it's executed
+ */
+ @FunctionalInterface
+ abstract class ExecuteHook {
+
+ /** Pre-processes code right before it's executed.
+ *
+ * Like when actual code is executed, `println` / `{Console,System}.{out,err}` get sent to the
+ * cell output, stdin can be requested from users, CommHandler and OutputHandler can be used,
+ * etc.
+ *
+ * When several hooks were added, they are called in the order they were added. The output of
+ * the previous hook gets passed to the next one, as long as hooks return code to be executed
+ * rather than an `ExecuteHookResult`.
+ *
+ * @param code
+ * Code to be pre-processed
+ * @return
+ * Either code to be executed (right), or an `ExecuteHookResult` (left)
+ */
+ def hook(code: String): Either[ExecuteHookResult, String]
+ }
+
+ /** Can be returned by `ExecuteHook.hook` to stop code execution.
+ */
+ sealed abstract class ExecuteHookResult extends Product with Serializable
+ object ExecuteHookResult {
+
+ /** Returns data to be displayed */
+ final case class Success(data: DisplayData = DisplayData.empty) extends ExecuteHookResult
+
+ /** Exception-like error
+ *
+ * If you'd like to build one out of an actual `Throwable`, just throw it. It will then be
+ * caught while the hook is running, and sent to users.
+ */
+ final case class Error(
+ name: String,
+ message: String,
+ stackTrace: List[String]
+ ) extends ExecuteHookResult
+
+ /** Tells the front-end that execution was aborted */
+ case object Abort extends ExecuteHookResult
+
+ /** Should instruct the front-end to prompt the user for exit */
+ case object Exit extends ExecuteHookResult
+ }
+
abstract class UpdatableResults {
@deprecated("Use updatable instead", "0.4.1")
diff --git a/modules/scala/scala-interpreter/src/main/scala-2/almond/amm/AlmondParsers.scala b/modules/scala/scala-interpreter/src/main/scala-2/almond/amm/AlmondParsers.scala
index d3a076b23..415581050 100644
--- a/modules/scala/scala-interpreter/src/main/scala-2/almond/amm/AlmondParsers.scala
+++ b/modules/scala/scala-interpreter/src/main/scala-2/almond/amm/AlmondParsers.scala
@@ -6,11 +6,11 @@ import scalaparse.Scala._
object AlmondParsers {
- private def Prelude[_: P] = P((Annot ~ OneNLMax).rep ~ (Mod ~/ Pass).rep)
+ private def Prelude[X: P] = P((Annot ~ OneNLMax).rep ~ (Mod ~/ Pass).rep)
// same as the methods with the same name in ammonite.interp.Parsers, but keeping the type aside in LHS
- def PatVarSplitter[_: P] = {
+ def PatVarSplitter[X: P] = {
def Prefixes = P(Prelude ~ (`var` | `val`))
def Lhs = P(Prefixes ~/ VarId)
def TypeAnnotation = P((`:` ~/ Type.!).?)
diff --git a/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterInspections.scala b/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterInspections.scala
index 1a86514a2..43d48bb54 100644
--- a/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterInspections.scala
+++ b/modules/scala/scala-interpreter/src/main/scala-2/almond/internals/ScalaInterpreterInspections.scala
@@ -14,6 +14,7 @@ import ammonite.util.Util.newLine
import metabrowse.server.{MetabrowseServer, Sourcepath}
import scala.meta.dialects
+import scala.collection.compat._
import scala.tools.nsc.Global
import scala.tools.nsc.interactive.{Global => Interactive}
import scala.util.Random
@@ -231,13 +232,13 @@ object ScalaInterpreterInspections {
javaDirs.exists(path.startsWith)
}
- def classpath(cl: ClassLoader): Stream[java.net.URL] =
+ def classpath(cl: ClassLoader): immutable.LazyList[java.net.URL] =
if (cl == null)
- Stream()
+ immutable.LazyList()
else {
val cp = cl match {
- case u: java.net.URLClassLoader => u.getURLs.toStream
- case _ => Stream()
+ case u: java.net.URLClassLoader => u.getURLs.to(immutable.LazyList)
+ case _ => immutable.LazyList()
}
cp #::: classpath(cl.getParent)
diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/Execute.scala b/modules/scala/scala-interpreter/src/main/scala/almond/Execute.scala
index a9b7a8b77..072ec88e6 100644
--- a/modules/scala/scala-interpreter/src/main/scala/almond/Execute.scala
+++ b/modules/scala/scala-interpreter/src/main/scala/almond/Execute.scala
@@ -26,7 +26,7 @@ import fastparse.Parsed
import scala.collection.mutable
import scala.concurrent.{Await, ExecutionContext}
import scala.concurrent.duration.Duration
-import scala.util.{Failure, Success}
+import scala.util.{Failure, Success, Try}
/** Wraps contextual things around when executing code (capturing output, stdin via front-ends,
* interruption, etc.)
@@ -65,7 +65,7 @@ final class Execute(
Duration.Inf
)
}
- log.info("Received input")
+ log.info(s"Received input ${res.map { case "" => "[empty]"; case _ => "[non empty]" }}")
res match {
case Success(s) => Some(s)
@@ -136,7 +136,7 @@ final class Execute(
}
}
- private def withInputManager[T](m: Option[InputManager])(f: => T): T = {
+ private def withInputManager[T](m: Option[InputManager], done: Boolean = true)(f: => T): T = {
val previous = currentInputManagerOpt0
try {
currentInputManagerOpt0 = m
@@ -144,7 +144,8 @@ final class Execute(
}
finally {
currentInputManagerOpt0 = previous
- m.foreach(_.done())
+ if (done)
+ m.foreach(_.done())
}
}
@@ -337,7 +338,8 @@ final class Execute(
inputManager: Option[InputManager],
outputHandler: Option[OutputHandler],
colors0: Ref[Colors],
- storeHistory: Boolean
+ storeHistory: Boolean,
+ executeHooks: Seq[JupyterApi.ExecuteHook]
): ExecuteResult = {
if (storeHistory) {
@@ -345,36 +347,84 @@ final class Execute(
history0 = history0 :+ code
}
- ammResult(ammInterp, code, inputManager, outputHandler, storeHistory) match {
- case Res.Success((_, data)) =>
- ExecuteResult.Success(data)
- case Res.Failure(msg) =>
- interruptedStackTraceOpt0 match {
- case None =>
- val err = Execute.error(colors0(), None, msg)
- outputHandler.foreach(_.stderr(err.message)) // necessary?
- err
- case Some(st) =>
- val cutoff = Set("$main", "evaluatorRunPrinter")
-
- ExecuteResult.Error(
- (
- "Interrupted!" +: st
- .takeWhile(x => !cutoff(x.getMethodName))
- .map(Execute.highlightFrame(_, fansi.Attr.Reset, colors0().literal()))
- ).mkString(System.lineSeparator())
- )
+ val finalCodeOrResult =
+ withOutputHandler(outputHandler) {
+ interruptible {
+ withInputManager(inputManager, done = false) {
+ withClientStdin {
+ capturingOutput {
+ executeHooks.foldLeft[Try[Either[JupyterApi.ExecuteHookResult, String]]](
+ Success(Right(code))
+ ) {
+ (codeOrDisplayDataAttempt, hook) =>
+ codeOrDisplayDataAttempt.flatMap { codeOrDisplayData =>
+ try Success(codeOrDisplayData.flatMap { value =>
+ hook.hook(value)
+ })
+ catch {
+ case e: Throwable => // kind of meh, but Ammonite does the same it seemsā¦
+ Failure(e)
+ }
+ }
+ }
+ }
+ }
+ }
}
+ }
- case Res.Exception(ex, msg) =>
- log.error(s"exception in user code (${ex.getMessage})", ex)
- Execute.error(colors0(), Some(ex), msg)
+ finalCodeOrResult match {
+ case Failure(ex) =>
+ log.error(s"exception when running hooks (${ex.getMessage})", ex)
+ Execute.error(colors0(), Some(ex), "")
+
+ case Success(Left(res)) =>
+ res match {
+ case s: JupyterApi.ExecuteHookResult.Success =>
+ ExecuteResult.Success(s.data)
+ case e: JupyterApi.ExecuteHookResult.Error =>
+ ExecuteResult.Error(e.name, e.message, e.stackTrace)
+ case JupyterApi.ExecuteHookResult.Abort =>
+ ExecuteResult.Abort
+ case JupyterApi.ExecuteHookResult.Exit =>
+ ExecuteResult.Exit
+ }
- case Res.Skip =>
+ case Success(Right(emptyCode)) if emptyCode.trim.isEmpty =>
ExecuteResult.Success()
- case Res.Exit(_) =>
- ExecuteResult.Exit
+ case Success(Right(finalCode)) =>
+ ammResult(ammInterp, finalCode, inputManager, outputHandler, storeHistory) match {
+ case Res.Success((_, data)) =>
+ ExecuteResult.Success(data)
+ case Res.Failure(msg) =>
+ interruptedStackTraceOpt0 match {
+ case None =>
+ val err = Execute.error(colors0(), None, msg)
+ outputHandler.foreach(_.stderr(err.message)) // necessary?
+ err
+ case Some(st) =>
+ val cutoff = Set("$main", "evaluatorRunPrinter")
+
+ ExecuteResult.Error(
+ (
+ "Interrupted!" +: st
+ .takeWhile(x => !cutoff(x.getMethodName))
+ .map(Execute.highlightFrame(_, fansi.Attr.Reset, colors0().literal()))
+ ).mkString(System.lineSeparator())
+ )
+ }
+
+ case Res.Exception(ex, msg) =>
+ log.error(s"exception in user code (${ex.getMessage})", ex)
+ Execute.error(colors0(), Some(ex), msg)
+
+ case Res.Skip =>
+ ExecuteResult.Success()
+
+ case Res.Exit(_) =>
+ ExecuteResult.Exit
+ }
}
}
}
diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/JupyterApiImpl.scala b/modules/scala/scala-interpreter/src/main/scala/almond/JupyterApiImpl.scala
index ab739f5f5..137f00223 100644
--- a/modules/scala/scala-interpreter/src/main/scala/almond/JupyterApiImpl.scala
+++ b/modules/scala/scala-interpreter/src/main/scala/almond/JupyterApiImpl.scala
@@ -9,6 +9,7 @@ import almond.interpreter.api.CommHandler
import ammonite.util.Ref
import pprint.{TPrint, TPrintColors}
+import scala.collection.mutable
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
@@ -69,4 +70,20 @@ final class JupyterApiImpl(
protected def updatableResults0: JupyterApi.UpdatableResults =
execute.updatableResults
+
+ private val executeHooks0 = new mutable.ListBuffer[JupyterApi.ExecuteHook]
+ def executeHooks: Seq[JupyterApi.ExecuteHook] =
+ executeHooks0.toList
+ def addExecuteHook(hook: JupyterApi.ExecuteHook): Boolean =
+ !executeHooks0.contains(hook) && {
+ executeHooks0.append(hook)
+ true
+ }
+ def removeExecuteHook(hook: JupyterApi.ExecuteHook): Boolean = {
+ val idx = executeHooks0.indexOf(hook)
+ idx >= 0 && {
+ executeHooks0.remove(idx)
+ true
+ }
+ }
}
diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala
index a8c776d08..2d38dcf18 100644
--- a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala
+++ b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreter.scala
@@ -8,6 +8,7 @@ import almond.interpreter.input.InputManager
import almond.interpreter.util.AsyncInterpreterOps
import almond.logger.LoggerContext
import almond.protocol.KernelInfo
+import almond.toree.{CellMagicHook, LineMagicHook}
import ammonite.compiler.Parsers
import ammonite.repl.{ReplApiImpl => _, _}
import ammonite.runtime._
@@ -62,27 +63,32 @@ final class ScalaInterpreter(
params.useThreadInterrupt
)
- lazy val ammInterp: ammonite.interp.Interpreter = {
+ val sessApi = new SessionApiImpl(frames0)
- val sessApi = new SessionApiImpl(frames0)
+ val replApi =
+ new ReplApiImpl(
+ execute0,
+ storage,
+ colors0,
+ ammInterp,
+ sessApi
+ )
- val replApi =
- new ReplApiImpl(
- execute0,
- storage,
- colors0,
- ammInterp,
- sessApi
- )
+ val jupyterApi =
+ new JupyterApiImpl(
+ execute0,
+ commHandlerOpt,
+ replApi,
+ silent0,
+ params.allowVariableInspector
+ )
- val jupyterApi =
- new JupyterApiImpl(
- execute0,
- commHandlerOpt,
- replApi,
- silent0,
- params.allowVariableInspector
- )
+ if (params.toreeMagics) {
+ jupyterApi.addExecuteHook(LineMagicHook.hook(replApi.pprinter))
+ jupyterApi.addExecuteHook(CellMagicHook.hook(jupyterApi.publish))
+ }
+
+ lazy val ammInterp: ammonite.interp.Interpreter = {
for (ec <- params.updateBackgroundVariablesEcOpt)
UpdatableFuture.setup(replApi, jupyterApi, ec)
@@ -129,7 +135,15 @@ final class ScalaInterpreter(
inputManager: Option[InputManager],
outputHandler: Option[OutputHandler]
): ExecuteResult =
- execute0(ammInterp, code, inputManager, outputHandler, colors0, storeHistory)
+ execute0(
+ ammInterp,
+ code,
+ inputManager,
+ outputHandler,
+ colors0,
+ storeHistory,
+ jupyterApi.executeHooks
+ )
def currentLine(): Int =
execute0.currentLine
diff --git a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreterParams.scala b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreterParams.scala
index fdfe45c63..f8d0be7db 100644
--- a/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreterParams.scala
+++ b/modules/scala/scala-interpreter/src/main/scala/almond/ScalaInterpreterParams.scala
@@ -34,5 +34,6 @@ final case class ScalaInterpreterParams(
autoUpdateVars: Boolean = true,
allowVariableInspector: Option[Boolean] = None,
useThreadInterrupt: Boolean = false,
- outputDir: Either[os.Path, Boolean] = Right(true)
+ outputDir: Either[os.Path, Boolean] = Right(true),
+ toreeMagics: Boolean = false
)
diff --git a/modules/scala/scala-interpreter/src/test/scala/almond/ScalaKernelTests.scala b/modules/scala/scala-interpreter/src/test/scala/almond/ScalaKernelTests.scala
index b9f26fc44..82103854b 100644
--- a/modules/scala/scala-interpreter/src/test/scala/almond/ScalaKernelTests.scala
+++ b/modules/scala/scala-interpreter/src/test/scala/almond/ScalaKernelTests.scala
@@ -20,6 +20,7 @@ import fs2.Stream
import utest._
import scala.collection.compat._
+import scala.util.Properties
object ScalaKernelTests extends TestSuite {
@@ -30,6 +31,8 @@ object ScalaKernelTests extends TestSuite {
val threads = KernelThreads.create("test")
+ val maybePostImportNewLine = if (TestUtil.isScala2) "" else System.lineSeparator()
+
override def utestAfterAll() = {
threads.attemptShutdown()
if (!attemptShutdownExecutionContext(interpreterEc))
@@ -40,6 +43,16 @@ object ScalaKernelTests extends TestSuite {
test("stdin") {
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ initialColors = Colors.BlackWhite
+ ),
+ logCtx = logCtx
+ )
+
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
// How the pseudo-client behaves
val inputHandler = MessageHandler(Channel.Input, Input.requestType) { msg =>
@@ -64,30 +77,21 @@ object ScalaKernelTests extends TestSuite {
(_, m) =>
IO.pure(m.header.msg_type == "execute_reply" && m.content.toString().contains("exit"))
- val sessionId = UUID.randomUUID().toString
+ implicit val sessionId: SessionId = SessionId()
// Initial messages from client
val input = Stream(
- execute(sessionId, "val n = scala.io.StdIn.readInt()"),
- execute(sessionId, "val m = new java.util.Scanner(System.in).nextInt()"),
- execute(sessionId, """val s = "exit"""")
+ execute("val n = scala.io.StdIn.readInt()"),
+ execute("val m = new java.util.Scanner(System.in).nextInt()"),
+ execute("""val s = "exit"""")
)
val streams =
ClientStreams.create(input, stopWhen, inputHandler.orElse(ignoreExpectedReplies))
- val interpreter = new ScalaInterpreter(
- params = ScalaInterpreterParams(
- initialColors = Colors.BlackWhite
- ),
- logCtx = logCtx
- )
-
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val replies = streams.executeReplies
@@ -102,12 +106,29 @@ object ScalaKernelTests extends TestSuite {
test("stop on error") {
- // How the pseudo-client behaves
+ // There's something non-deterministic in this test.
+ // It requires the 3 cells to execute to be sent at once at the beginning.
+ // The exception running the first cell must make the kernel discard (not compile
+ // nor run) the other two, that are queued.
+ // That means, the other 2 cells must have been queued when the first cell's thrown
+ // exception is caught by the kernel.
+ // Because of that, we can't rely on individual calls to 'kernel.execute' like
+ // the other tests do, as these send messages one after the other, sending the
+ // next one when the previous one is done running (so no messages would be queued.)
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ initialColors = Colors.BlackWhite
+ ),
+ logCtx = logCtx
+ )
- // When the pseudo-client exits
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
+ implicit val sessionId: SessionId = SessionId()
+
+ val lastMsgId = UUID.randomUUID().toString
val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
(_, m) =>
@@ -115,27 +136,16 @@ object ScalaKernelTests extends TestSuite {
m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId)
)
- // Initial messages from client
-
val input = Stream(
- execute(sessionId, """sys.error("foo")"""),
- execute(sessionId, "val n = 2"),
- execute(sessionId, """val s = "other"""", lastMsgId)
+ execute("""sys.error("foo")"""),
+ execute("val n = 2"),
+ execute("""val s = "other"""", lastMsgId)
)
val streams = ClientStreams.create(input, stopWhen)
- val interpreter = new ScalaInterpreter(
- params = ScalaInterpreterParams(
- initialColors = Colors.BlackWhite
- ),
- logCtx = logCtx
- )
-
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val messageTypes = streams.generatedMessageTypes()
@@ -161,32 +171,6 @@ object ScalaKernelTests extends TestSuite {
test("jvm-repr") {
- // How the pseudo-client behaves
-
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
-
- // When the pseudo-client exits
-
- val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
- (_, m) =>
- IO.pure(
- m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId)
- )
-
- // Initial messages from client
-
- val input = Stream(
- execute(sessionId, """class Bar(val value: String)"""),
- execute(
- sessionId,
- """kernel.register[Bar](bar => Map("text/plain" -> s"Bar(${bar.value})"))"""
- ),
- execute(sessionId, """val b = new Bar("other")""", lastMsgId)
- )
-
- val streams = ClientStreams.create(input, stopWhen)
-
val interpreter = new ScalaInterpreter(
params = ScalaInterpreterParams(
initialColors = Colors.BlackWhite
@@ -194,63 +178,25 @@ object ScalaKernelTests extends TestSuite {
logCtx = logCtx
)
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
- val messageTypes = streams.generatedMessageTypes()
+ implicit val sessionId: SessionId = SessionId()
- val expectedMessageTypes = Seq(
- "execute_input",
- "execute_result",
- "execute_reply",
- "execute_input",
- "execute_reply",
- "execute_input",
- "display_data",
- "execute_reply"
+ kernel.execute("""class Bar(val value: String)""", "defined class Bar")
+ kernel.execute(
+ """kernel.register[Bar](bar => Map("text/plain" -> s"Bar(${bar.value})"))""",
+ ""
)
-
- assert(messageTypes == expectedMessageTypes)
-
- val displayData = streams.displayData
-
- val expectedDisplayData = Seq(
- ProtocolExecute.DisplayData(
- Map("text/plain" -> RawJson("\"Bar(other)\"".bytes)),
- Map()
- ) -> false
+ kernel.execute(
+ """val b = new Bar("other")""",
+ "",
+ displaysText = Seq("Bar(other)")
)
-
- assert(displayData == expectedDisplayData)
}
test("updatable display") {
- // How the pseudo-client behaves
-
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
-
- // When the pseudo-client exits
-
- val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
- (_, m) =>
- IO.pure(
- m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId)
- )
-
- // Initial messages from client
-
- val input = Stream(
- execute(sessionId, """val handle = Html("foo")"""),
- execute(sessionId, """handle.withContent("bzz").update()"""),
- execute(sessionId, """val s = "other"""", lastMsgId)
- )
-
- val streams = ClientStreams.create(input, stopWhen)
-
val interpreter = new ScalaInterpreter(
params = ScalaInterpreterParams(
initialColors = Colors.BlackWhite
@@ -258,83 +204,26 @@ object ScalaKernelTests extends TestSuite {
logCtx = logCtx
)
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
- t.unsafeRunTimedOrThrow()
+ implicit val sessionId: SessionId = SessionId()
- val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector
- val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector
-
- val expectedRequestsMessageTypes = Seq(
- "execute_reply",
- "execute_reply",
- "execute_reply"
+ kernel.execute(
+ """val handle = Html("foo")""",
+ "",
+ displaysHtml = Seq("foo")
)
- val expectedPublishMessageTypes = Seq(
- "execute_input",
- "display_data",
- "execute_input",
- "update_display_data",
- "execute_input",
- "execute_result"
+ kernel.execute(
+ """handle.withContent("bzz").update()""",
+ "",
+ displaysHtmlUpdates = Seq("bzz")
)
-
- assert(requestsMessageTypes == expectedRequestsMessageTypes)
- assert(publishMessageTypes == expectedPublishMessageTypes)
-
- val displayData = streams.displayData
- val id = {
- val ids = displayData.flatMap(_._1.transient.display_id).toSet
- assert(ids.size == 1)
- ids.head
- }
-
- val expectedDisplayData = Seq(
- ProtocolExecute.DisplayData(
- Map("text/html" -> RawJson("\"foo\"".bytes)),
- Map(),
- ProtocolExecute.DisplayData.Transient(Some(id))
- ) -> false,
- ProtocolExecute.DisplayData(
- Map("text/html" -> RawJson("\"bzz\"".bytes)),
- Map(),
- ProtocolExecute.DisplayData.Transient(Some(id))
- ) -> true
- )
-
- assert(displayData == expectedDisplayData)
}
test("auto-update Future results upon completion") {
- // How the pseudo-client behaves
-
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
-
- // When the pseudo-client exits
-
- val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
- (_, m) =>
- IO.pure(
- m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId)
- )
-
- // Initial messages from client
-
- val input = Stream(
- execute(
- sessionId,
- "import scala.concurrent.Future; import scala.concurrent.ExecutionContext.Implicits.global"
- ),
- execute(sessionId, "val f = Future { Thread.sleep(3000L); 2 }"),
- execute(sessionId, "Thread.sleep(6000L)", lastMsgId)
- )
-
- val streams = ClientStreams.create(input, stopWhen)
-
val interpreter = new ScalaInterpreter(
params = ScalaInterpreterParams(
updateBackgroundVariablesEcOpt = Some(bgVarEc),
@@ -343,37 +232,57 @@ object ScalaKernelTests extends TestSuite {
logCtx = logCtx
)
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
- t.unsafeRunTimedOrThrow()
+ implicit val sessionId: SessionId = SessionId()
- val messageTypes = streams.generatedMessageTypes()
+ val sp = " "
+ val ls = System.lineSeparator()
- val expectedMessageTypes = Seq(
- "execute_input",
- "execute_result",
- "execute_reply",
- "execute_input",
- "display_data",
- "execute_reply",
- "execute_input",
- // that one originates from the second line, but arrives while the third one is running
- "update_display_data",
- "execute_reply"
+ kernel.execute(
+ "import scala.concurrent.Future; import scala.concurrent.ExecutionContext.Implicits.global",
+ // Multi-line with stripMargin seems to be a problem on our Windows CI for this test,
+ // but not for the other ones using stripMarginā¦
+ s"import scala.concurrent.Future;$sp$ls" +
+ s"import scala.concurrent.ExecutionContext.Implicits.global$maybePostImportNewLine"
)
- assert(messageTypes == expectedMessageTypes)
+ kernel.execute(
+ "val f = Future { Thread.sleep(3000L); 2 }",
+ "",
+ displaysText = Seq("f: Future[Int] = [running]")
+ )
+ kernel.execute(
+ "Thread.sleep(6000L)",
+ "",
+ // the update originates from the previous cell, but arrives while the third one is running
+ displaysTextUpdates = Seq(
+ if (TestUtil.isScala212) "f: Future[Int] = Success(2)"
+ else "f: Future[Int] = Success(value = 2)"
+ )
+ )
}
test("auto-update Future results in background upon completion") {
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ updateBackgroundVariablesEcOpt = Some(bgVarEc),
+ initialColors = Colors.BlackWhite
+ ),
+ logCtx = logCtx
+ )
+
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
// same as above, except no cell is running when the future completes
// How the pseudo-client behaves
- val sessionId = UUID.randomUUID().toString
+ implicit val sessionId: SessionId = SessionId()
// When the pseudo-client exits
@@ -385,26 +294,15 @@ object ScalaKernelTests extends TestSuite {
val input = Stream(
execute(
- sessionId,
"import scala.concurrent.Future; import scala.concurrent.ExecutionContext.Implicits.global"
),
- execute(sessionId, "val f = Future { Thread.sleep(3000L); 2 }")
+ execute("val f = Future { Thread.sleep(3000L); 2 }")
)
val streams = ClientStreams.create(input, stopWhen)
- val interpreter = new ScalaInterpreter(
- params = ScalaInterpreterParams(
- updateBackgroundVariablesEcOpt = Some(bgVarEc),
- initialColors = Colors.BlackWhite
- ),
- logCtx = logCtx
- )
-
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val messageTypes = streams.generatedMessageTypes()
@@ -424,13 +322,23 @@ object ScalaKernelTests extends TestSuite {
test("auto-update Rx stuff upon change") {
if (isScala212) {
- // How the pseudo-client behaves
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ updateBackgroundVariablesEcOpt = Some(bgVarEc),
+ initialColors = Colors.BlackWhite
+ ),
+ logCtx = logCtx
+ )
+
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
+ implicit val sessionId: SessionId = SessionId()
// When the pseudo-client exits
+ val lastMsgId = UUID.randomUUID().toString
val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
(_, m) =>
IO.pure(
@@ -440,26 +348,16 @@ object ScalaKernelTests extends TestSuite {
// Initial messages from client
val input = Stream(
- execute(sessionId, "almondrx.setup()"),
- execute(sessionId, "val a = rx.Var(1)"),
- execute(sessionId, "a() = 2"),
- execute(sessionId, "a() = 3", lastMsgId)
+ execute("almondrx.setup()"),
+ execute("val a = rx.Var(1)"),
+ execute("a() = 2"),
+ execute("a() = 3", lastMsgId)
)
val streams = ClientStreams.create(input, stopWhen)
- val interpreter = new ScalaInterpreter(
- params = ScalaInterpreterParams(
- updateBackgroundVariablesEcOpt = Some(bgVarEc),
- initialColors = Colors.BlackWhite
- ),
- logCtx = logCtx
- )
-
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector
val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector
@@ -522,17 +420,24 @@ object ScalaKernelTests extends TestSuite {
test("handle interrupt messages") {
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ initialColors = Colors.BlackWhite
+ ),
+ logCtx = logCtx
+ )
- // How the pseudo-client behaves
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
+ implicit val sessionId: SessionId = SessionId()
val interruptOnInput = MessageHandler(Channel.Input, Input.requestType) { msg =>
Message(
Header(
UUID.randomUUID().toString,
"test",
- sessionId,
+ sessionId.sessionId,
Interrupt.requestType.messageType,
Some(Protocol.versionStr)
),
@@ -548,6 +453,7 @@ object ScalaKernelTests extends TestSuite {
// When the pseudo-client exits
+ val lastMsgId = UUID.randomUUID().toString
val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
(_, m) =>
IO.pure(
@@ -559,24 +465,15 @@ object ScalaKernelTests extends TestSuite {
// Initial messages from client
val input = Stream(
- execute(sessionId, "val n = scala.io.StdIn.readInt()"),
- execute(sessionId, """val s = "ok done"""", msgId = lastMsgId)
+ execute("val n = scala.io.StdIn.readInt()"),
+ execute("""val s = "ok done"""", msgId = lastMsgId)
)
val streams =
ClientStreams.create(input, stopWhen, interruptOnInput.orElse(ignoreExpectedReplies))
- val interpreter = new ScalaInterpreter(
- params = ScalaInterpreterParams(
- initialColors = Colors.BlackWhite
- ),
- logCtx = logCtx
- )
-
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val messageTypes = streams.generatedMessageTypes()
val controlMessageTypes = streams.generatedMessageTypes(Set(Channel.Control))
@@ -600,31 +497,6 @@ object ScalaKernelTests extends TestSuite {
test("start from custom class loader") {
- // How the pseudo-client behaves
-
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
-
- // When the pseudo-client exits
-
- val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
- (_, m) =>
- IO.pure(
- m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId)
- )
-
- // Initial messages from client
-
- val input = Stream(
- execute(
- sessionId,
- """val url = Thread.currentThread().getContextClassLoader.getResource("foo")"""
- ),
- execute(sessionId, """assert(url.toString == "https://google.fr")""", lastMsgId)
- )
-
- val streams = ClientStreams.create(input, stopWhen)
-
val loader = new URLClassLoader(Array(), Thread.currentThread().getContextClassLoader) {
override def getResource(name: String) =
if (name == "foo")
@@ -641,10 +513,31 @@ object ScalaKernelTests extends TestSuite {
logCtx = logCtx
)
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
+ implicit val sessionId: SessionId = SessionId()
+ val lastMsgId = UUID.randomUUID().toString
- t.unsafeRunTimedOrThrow()
+ // When the pseudo-client exits
+
+ val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
+ (_, m) =>
+ IO.pure(
+ m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId)
+ )
+
+ // Initial messages from client
+
+ val input = Stream(
+ execute("""val url = Thread.currentThread().getContextClassLoader.getResource("foo")"""),
+ execute("""assert(url.toString == "https://google.fr")""", lastMsgId)
+ )
+
+ val streams = ClientStreams.create(input, stopWhen)
+
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val messageTypes = streams.generatedMessageTypes()
@@ -661,17 +554,6 @@ object ScalaKernelTests extends TestSuite {
test("exit") {
- val sessionId = UUID.randomUUID().toString
-
- // Initial messages from client
-
- val input = Stream(
- execute(sessionId, "val n = 2"),
- execute(sessionId, "exit")
- )
-
- val streams = ClientStreams.create(input)
-
val interpreter = new ScalaInterpreter(
params = ScalaInterpreterParams(
initialColors = Colors.BlackWhite
@@ -679,10 +561,22 @@ object ScalaKernelTests extends TestSuite {
logCtx = logCtx
)
- val t = Kernel.create(interpreter, interpreterEc, threads)
- .flatMap(_.run(streams.source, streams.sink))
+ val kernel = Kernel.create(interpreter, interpreterEc, threads)
+ .unsafeRunTimedOrThrow()
- t.unsafeRunTimedOrThrow()
+ implicit val sessionId: SessionId = SessionId()
+
+ // Initial messages from client
+
+ val input = Stream(
+ execute("val n = 2"),
+ execute("exit")
+ )
+
+ val streams = ClientStreams.create(input)
+
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val messageTypes = streams.generatedMessageTypes()
@@ -709,19 +603,6 @@ object ScalaKernelTests extends TestSuite {
test("trap output") {
- val sessionId = UUID.randomUUID().toString
-
- // Initial messages from client
-
- val input = Stream(
- execute(sessionId, "val n = 2"),
- execute(sessionId, """println("Hello")"""),
- execute(sessionId, """System.err.println("Bbbb")"""),
- execute(sessionId, "exit")
- )
-
- val streams = ClientStreams.create(input)
-
val interpreter = new ScalaInterpreter(
params = ScalaInterpreterParams(
initialColors = Colors.BlackWhite,
@@ -730,10 +611,24 @@ object ScalaKernelTests extends TestSuite {
logCtx = logCtx
)
- val t = Kernel.create(interpreter, interpreterEc, threads)
- .flatMap(_.run(streams.source, streams.sink))
+ val kernel = Kernel.create(interpreter, interpreterEc, threads)
+ .unsafeRunTimedOrThrow()
+
+ implicit val sessionId: SessionId = SessionId()
+
+ // Initial messages from client
+
+ val input = Stream(
+ execute("val n = 2"),
+ execute("""println("Hello")"""),
+ execute("""System.err.println("Bbbb")"""),
+ execute("exit")
+ )
+
+ val streams = ClientStreams.create(input)
- t.unsafeRunTimedOrThrow()
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val messageTypes = streams.generatedMessageTypes()
@@ -755,29 +650,6 @@ object ScalaKernelTests extends TestSuite {
test("last exception") {
- // How the pseudo-client behaves
-
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
-
- // When the pseudo-client exits
-
- val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
- (_, m) =>
- IO.pure(
- m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId)
- )
-
- // Initial messages from client
-
- val input = Stream(
- execute(sessionId, """val nullBefore = repl.lastException == null"""),
- execute(sessionId, """sys.error("foo")""", stopOnError = false),
- execute(sessionId, """val nullAfter = repl.lastException == null""", lastMsgId)
- )
-
- val streams = ClientStreams.create(input, stopWhen)
-
val interpreter = new ScalaInterpreter(
params = ScalaInterpreterParams(
initialColors = Colors.BlackWhite
@@ -785,70 +657,21 @@ object ScalaKernelTests extends TestSuite {
logCtx = logCtx
)
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
- val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector
- val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector
+ implicit val sessionId: SessionId = SessionId()
- val expectedRequestsMessageTypes = Seq(
- "execute_reply",
- "execute_reply",
- "execute_reply"
+ kernel.execute(
+ """val nullBefore = repl.lastException == null""",
+ "nullBefore: Boolean = true"
)
-
- val expectedPublishMessageTypes = Seq(
- "execute_input",
- "execute_result",
- "execute_input",
- "error",
- "execute_input",
- "execute_result"
- )
-
- assert(requestsMessageTypes == expectedRequestsMessageTypes)
- assert(publishMessageTypes == expectedPublishMessageTypes)
-
- val replies = streams.executeReplies
- val expectedReplies = Map(
- 1 -> "nullBefore: Boolean = true",
- 3 -> "nullAfter: Boolean = false"
- )
- assert(replies == expectedReplies)
+ kernel.execute("""sys.error("foo")""", expectError = true)
+ kernel.execute("""val nullAfter = repl.lastException == null""", "nullAfter: Boolean = false")
}
test("history") {
- // How the pseudo-client behaves
-
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
-
- // When the pseudo-client exits
-
- val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
- (_, m) =>
- IO.pure(
- m.header.msg_type == "execute_reply" && m.parent_header.exists(_.msg_id == lastMsgId)
- )
-
- // Initial messages from client
-
- val input = Stream(
- execute(sessionId, """val before = repl.history.toVector"""),
- execute(sessionId, """val a = 2"""),
- execute(sessionId, """val b = a + 1"""),
- execute(
- sessionId,
- """val after = repl.history.toVector.mkString(",").toString""",
- lastMsgId
- )
- )
-
- val streams = ClientStreams.create(input, stopWhen)
-
val interpreter = new ScalaInterpreter(
params = ScalaInterpreterParams(
initialColors = Colors.BlackWhite
@@ -856,55 +679,40 @@ object ScalaKernelTests extends TestSuite {
logCtx = logCtx
)
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
- val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector
- val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector
+ implicit val sessionId: SessionId = SessionId()
- val expectedRequestsMessageTypes = Seq(
- "execute_reply",
- "execute_reply",
- "execute_reply",
- "execute_reply"
- )
-
- val expectedPublishMessageTypes = Seq(
- "execute_input",
- "execute_result",
- "execute_input",
- "execute_result",
- "execute_input",
- "execute_result",
- "execute_input",
- "execute_result"
+ kernel.execute(
+ """val before = repl.history.toVector""",
+ """before: Vector[String] = Vector("val before = repl.history.toVector")"""
)
-
- assert(requestsMessageTypes == expectedRequestsMessageTypes)
- assert(publishMessageTypes == expectedPublishMessageTypes)
-
- val replies = streams.executeReplies
- val expectedReplies = Map(
- 1 -> """before: Vector[String] = Vector("val before = repl.history.toVector")""",
- 2 -> """a: Int = 2""",
- 3 -> """b: Int = 3""",
- 4 -> """after: String = "val before = repl.history.toVector,val a = 2,val b = a + 1,val after = repl.history.toVector.mkString(\",\").toString""""
+ kernel.execute("val a = 2", "a: Int = 2")
+ kernel.execute("val b = a + 1", "b: Int = 3")
+ kernel.execute(
+ """val after = repl.history.toVector.mkString(",").toString""",
+ """after: String = "val before = repl.history.toVector,val a = 2,val b = a + 1,val after = repl.history.toVector.mkString(\",\").toString""""
)
- assert(replies == expectedReplies)
}
test("update vars") {
if (AlmondCompilerLifecycleManager.isAtLeast_2_12_7 && TestUtil.isScala2) {
- // How the pseudo-client behaves
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ updateBackgroundVariablesEcOpt = Some(bgVarEc),
+ initialColors = Colors.BlackWhite
+ ),
+ logCtx = logCtx
+ )
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
- // When the pseudo-client exits
+ implicit val sessionId: SessionId = SessionId()
+ val lastMsgId = UUID.randomUUID().toString
val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
(_, m) =>
IO.pure(
@@ -914,25 +722,15 @@ object ScalaKernelTests extends TestSuite {
// Initial messages from client
val input = Stream(
- execute(sessionId, """var n = 2"""),
- execute(sessionId, """n = n + 1"""),
- execute(sessionId, """n += 2""", lastMsgId)
+ execute("""var n = 2"""),
+ execute("""n = n + 1"""),
+ execute("""n += 2""", lastMsgId)
)
val streams = ClientStreams.create(input, stopWhen)
- val interpreter = new ScalaInterpreter(
- params = ScalaInterpreterParams(
- updateBackgroundVariablesEcOpt = Some(bgVarEc),
- initialColors = Colors.BlackWhite
- ),
- logCtx = logCtx
- )
-
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector
val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector
@@ -992,13 +790,20 @@ object ScalaKernelTests extends TestSuite {
def updateLazyValsTest(): Unit = {
- // How the pseudo-client behaves
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ updateBackgroundVariablesEcOpt = Some(bgVarEc),
+ initialColors = Colors.BlackWhite
+ ),
+ logCtx = logCtx
+ )
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
- // When the pseudo-client exits
+ implicit val sessionId: SessionId = SessionId()
+ val lastMsgId = UUID.randomUUID().toString
val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
(_, m) =>
IO.pure(
@@ -1008,25 +813,15 @@ object ScalaKernelTests extends TestSuite {
// Initial messages from client
val input = Stream(
- execute(sessionId, """lazy val n = 2"""),
- execute(sessionId, """val a = { n; () }"""),
- execute(sessionId, """val b = { n; () }""", lastMsgId)
+ execute("""lazy val n = 2"""),
+ execute("""val a = { n; () }"""),
+ execute("""val b = { n; () }""", lastMsgId)
)
val streams = ClientStreams.create(input, stopWhen)
- val interpreter = new ScalaInterpreter(
- params = ScalaInterpreterParams(
- updateBackgroundVariablesEcOpt = Some(bgVarEc),
- initialColors = Colors.BlackWhite
- ),
- logCtx = logCtx
- )
-
- val t = Kernel.create(interpreter, interpreterEc, threads, logCtx)
- .flatMap(_.run(streams.source, streams.sink))
-
- t.unsafeRunTimedOrThrow()
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector
val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector
@@ -1081,6 +876,238 @@ object ScalaKernelTests extends TestSuite {
if (TestUtil.isScala2) updateLazyValsTest()
else "disabled"
}
+
+ test("hooks") {
+
+ val predef =
+ """private val foos0 = new scala.collection.mutable.ListBuffer[String]
+ |
+ |def foos(): List[String] =
+ | foos0.result()
+ |
+ |kernel.addExecuteHook { code =>
+ | import almond.api.JupyterApi
+ |
+ | var errorOpt = Option.empty[JupyterApi.ExecuteHookResult]
+ | val remainingCode = code.linesWithSeparators.zip(code.linesIterator)
+ | .map {
+ | case (originalLine, line) =>
+ | if (line == "%AddFoo") ""
+ | else if (line.startsWith("%AddFoo ")) {
+ | foos0 ++= line.split("\\s+").drop(1).toSeq
+ | ""
+ | }
+ | else if (line == "%Error") {
+ | errorOpt = Some(JupyterApi.ExecuteHookResult.Error("thing", "erroring", List("aa", "bb")))
+ | ""
+ | }
+ | else if (line == "%Abort") {
+ | errorOpt = Some(JupyterApi.ExecuteHookResult.Abort)
+ | ""
+ | }
+ | else if (line == "%Exit") {
+ | errorOpt = Some(JupyterApi.ExecuteHookResult.Exit)
+ | ""
+ | }
+ | else
+ | originalLine
+ | }
+ | .mkString
+ |
+ | errorOpt.toLeft(remainingCode)
+ |}
+ |""".stripMargin
+
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ initialColors = Colors.BlackWhite,
+ predefCode = predef
+ ),
+ logCtx = logCtx
+ )
+
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
+ implicit val sessionId: SessionId = SessionId()
+
+ kernel.execute(
+ "val before = foos()",
+ "before: List[String] = List()"
+ )
+ kernel.execute("""%AddFoo a""", "")
+ kernel.execute("""%AddFoo b""", "")
+ kernel.execute(
+ "val after = foos()",
+ """after: List[String] = List("a", "b")"""
+ )
+ kernel.execute(
+ "%Error",
+ errors = Seq(
+ ("thing", "erroring", List("thing: erroring", " aa", " bb"))
+ )
+ )
+ kernel.execute("%Abort")
+ kernel.execute(
+ """val a = "a"
+ |""".stripMargin,
+ """a: String = "a""""
+ )
+ kernel.execute(
+ """%Exit
+ |val b = "b"
+ |""".stripMargin,
+ ""
+ )
+ }
+
+ test("toree AddDeps") {
+ toreeAddDepsTest()
+ }
+ // unsupported yet, needs tweaking in the Ammonite dependency parser
+ // test("toree AddDeps intransitive") {
+ // toreeAddDepsTest(transitive = false)
+ // }
+
+ def toreeAddDepsTest(transitive: Boolean = true): Unit = {
+
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ initialColors = Colors.BlackWhite,
+ toreeMagics = true
+ ),
+ logCtx = logCtx
+ )
+
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
+ implicit val sessionId: SessionId = SessionId()
+
+ val sbv = {
+ val sv = Properties.versionNumberString
+ if (sv.startsWith("2.")) sv.split('.').take(2).mkString(".")
+ else sv.takeWhile(_ != '.')
+ }
+
+ kernel.execute(
+ "import caseapp.CaseApp",
+ errors = Seq(
+ ("", "Compilation Failed", List("Compilation Failed"))
+ ),
+ ignoreStreams = true
+ )
+ kernel.execute(
+ "import caseapp.util._",
+ errors = Seq(
+ ("", "Compilation Failed", List("Compilation Failed"))
+ ),
+ ignoreStreams = true
+ )
+ val suffix = if (transitive) " --transitive" else ""
+ kernel.execute(
+ s"%AddDeps com.github.alexarchambault case-app_$sbv 2.1.0-M24" + suffix,
+ "import $ivy.$ " + (" " * sbv.length) + maybePostImportNewLine,
+ ignoreStreams = true // ignoring coursier messages (that it prints when downloading things)
+ )
+ kernel.execute(
+ "import caseapp.CaseApp",
+ "import caseapp.CaseApp" + maybePostImportNewLine
+ )
+ if (transitive)
+ kernel.execute(
+ "import caseapp.util._",
+ "import caseapp.util._" + maybePostImportNewLine
+ )
+ else
+ kernel.execute(
+ "import caseapp.util._",
+ errors = Seq(
+ ("", "Compilation Failed", List("Compilation Failed"))
+ ),
+ ignoreStreams = true
+ )
+ }
+
+ test("toree Html") {
+
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ initialColors = Colors.BlackWhite,
+ toreeMagics = true
+ ),
+ logCtx = logCtx
+ )
+
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
+ implicit val sessionId: SessionId = SessionId()
+
+ kernel.execute(
+ """%%html
+ |
+ |Hello
+ |
+ |""".stripMargin,
+ "",
+ displaysHtml = Seq(
+ """
+ |Hello
+ |
+ |""".stripMargin
+ )
+ )
+ }
+
+ test("toree Truncation") {
+
+ val interpreter = new ScalaInterpreter(
+ params = ScalaInterpreterParams(
+ initialColors = Colors.BlackWhite,
+ toreeMagics = true
+ ),
+ logCtx = logCtx
+ )
+
+ val kernel = Kernel.create(interpreter, interpreterEc, threads, logCtx)
+ .unsafeRunTimedOrThrow()
+
+ implicit val sessionId: SessionId = SessionId()
+
+ val nl = System.lineSeparator()
+
+ kernel.execute(
+ "%truncation",
+ "",
+ stdout =
+ "Truncation is currently on" + nl
+ )
+ kernel.execute(
+ "%truncation off",
+ "",
+ stdout =
+ "Output will NOT be truncated" + nl
+ )
+ kernel.execute(
+ "(1 to 200).toVector",
+ "res0: Vector[Int] = " + (1 to 200).toVector.toString
+ )
+ kernel.execute(
+ "%truncation on",
+ "",
+ stdout =
+ "Output WILL be truncated." + nl
+ )
+ kernel.execute(
+ "(1 to 200).toVector",
+ "res1: Vector[Int] = " +
+ (1 to 38)
+ .toVector
+ .map(" " + _ + "," + "\n")
+ .mkString("Vector(" + "\n", "", "...")
+ )
+ }
}
}
diff --git a/modules/scala/scala-interpreter/src/test/scala/almond/TestUtil.scala b/modules/scala/scala-interpreter/src/test/scala/almond/TestUtil.scala
index 9e620fc24..4a0b769b8 100644
--- a/modules/scala/scala-interpreter/src/test/scala/almond/TestUtil.scala
+++ b/modules/scala/scala-interpreter/src/test/scala/almond/TestUtil.scala
@@ -44,17 +44,144 @@ object TestUtil {
}
}
+ final case class SessionId(sessionId: String = UUID.randomUUID().toString)
+
+ implicit class KernelOps(private val kernel: Kernel) extends AnyVal {
+ def execute(
+ code: String,
+ reply: String = null,
+ expectError: Boolean = false,
+ errors: Seq[(String, String, List[String])] = null,
+ displaysText: Seq[String] = null,
+ displaysHtml: Seq[String] = null,
+ displaysTextUpdates: Seq[String] = null,
+ displaysHtmlUpdates: Seq[String] = null,
+ ignoreStreams: Boolean = false,
+ stdout: String = null
+ )(implicit sessionId: SessionId): Unit = {
+
+ val expectError0 = expectError || Option(errors).nonEmpty
+ val ignoreStreams0 = ignoreStreams || Option(stdout).nonEmpty
+
+ val input = Stream(
+ TestUtil.execute(code, stopOnError = !expectError0)
+ )
+
+ val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
+ (_, m) => IO.pure(m.header.msg_type == "execute_reply")
+
+ val streams = ClientStreams.create(input, stopWhen)
+
+ kernel.run(streams.source, streams.sink)
+ .unsafeRunTimedOrThrow()
+
+ val requestsMessageTypes = streams.generatedMessageTypes(Set(Channel.Requests)).toVector
+ val publishMessageTypes = streams.generatedMessageTypes(Set(Channel.Publish)).toVector
+ .filter(if (ignoreStreams0) _ != "stream" else _ => true)
+
+ val expectedRequestsMessageTypes =
+ if (reply == null && !expectError0)
+ Nil
+ else
+ Seq("execute_reply")
+ assert(requestsMessageTypes == Seq("execute_reply"))
+
+ val expectedPublishMessageTypes = {
+ val displayDataCount = Seq(
+ Option(displaysText).fold(0)(_.length),
+ Option(displaysHtml).fold(0)(_.length)
+ ).max
+ val updateDisplayDataCount = Seq(
+ Option(displaysTextUpdates).fold(0)(_.length),
+ Option(displaysHtmlUpdates).fold(0)(_.length)
+ ).max
+ val prefix = Seq("execute_input") ++
+ Seq.fill(displayDataCount)("display_data") ++
+ Seq.fill(updateDisplayDataCount)("update_display_data")
+ if (expectError0)
+ prefix :+ "error"
+ else if (reply == null || reply.isEmpty)
+ prefix
+ else
+ prefix :+ "execute_result"
+ }
+ assert(publishMessageTypes == expectedPublishMessageTypes)
+
+ if (stdout != null) {
+ val stdoutMessages = streams.output.mkString
+ assert(stdout == stdoutMessages)
+ }
+
+ val replies = streams.executeReplies.toVector.sortBy(_._1).map(_._2)
+ assert(replies == Option(reply).toVector)
+
+ for (expectedTextDisplay <- Option(displaysText)) {
+ import ClientStreams.RawJsonOps
+
+ val textDisplay = streams.displayData.collect {
+ case (data, false) =>
+ data.data.get("text/plain")
+ .map(_.stringOrEmpty)
+ .getOrElse("")
+ }
+
+ assert(textDisplay == expectedTextDisplay)
+ }
+
+ val receivedErrors = streams.executeErrors.toVector.sortBy(_._1).map(_._2)
+ assert(errors == null || receivedErrors == errors)
+
+ for (expectedHtmlDisplay <- Option(displaysHtml)) {
+ import ClientStreams.RawJsonOps
+
+ val htmlDisplay = streams.displayData.collect {
+ case (data, false) =>
+ data.data.get("text/html")
+ .map(_.stringOrEmpty)
+ .getOrElse("")
+ }
+
+ assert(htmlDisplay == expectedHtmlDisplay)
+ }
+
+ for (expectedTextDisplayUpdates <- Option(displaysTextUpdates)) {
+ import ClientStreams.RawJsonOps
+
+ val textDisplayUpdates = streams.displayData.collect {
+ case (data, true) =>
+ data.data.get("text/plain")
+ .map(_.stringOrEmpty)
+ .getOrElse("")
+ }
+
+ assert(textDisplayUpdates == expectedTextDisplayUpdates)
+ }
+
+ for (expectedHtmlDisplayUpdates <- Option(displaysHtmlUpdates)) {
+ import ClientStreams.RawJsonOps
+
+ val htmlDisplayUpdates = streams.displayData.collect {
+ case (data, true) =>
+ data.data.get("text/html")
+ .map(_.stringOrEmpty)
+ .getOrElse("")
+ }
+
+ assert(htmlDisplayUpdates == expectedHtmlDisplayUpdates)
+ }
+ }
+ }
+
def execute(
- sessionId: String,
code: String,
msgId: String = UUID.randomUUID().toString,
stopOnError: Boolean = true
- ) =
+ )(implicit sessionId: SessionId) =
Message(
Header(
msgId,
"test",
- sessionId,
+ sessionId.sessionId,
ProtocolExecute.requestType.messageType,
Some(Protocol.versionStr)
),
@@ -71,8 +198,8 @@ object TestUtil {
val (input, replies) = inputs.unzip
- val sessionId = UUID.randomUUID().toString
- val lastMsgId = UUID.randomUUID().toString
+ implicit val sessionId: SessionId = SessionId()
+ val lastMsgId = UUID.randomUUID().toString
val stopWhen: (Channel, Message[RawJson]) => IO[Boolean] =
(_, m) =>
@@ -83,8 +210,8 @@ object TestUtil {
assert(input.nonEmpty)
val input0 = Stream(
- input.init.map(s => execute(sessionId, s)) :+
- execute(sessionId, input.last, lastMsgId): _*
+ input.init.map(s => execute(s)) :+
+ execute(input.last, lastMsgId): _*
)
val streams = ClientStreams.create(input0, stopWhen)
diff --git a/modules/scala/scala-kernel/src/main/scala/almond/Options.scala b/modules/scala/scala-kernel/src/main/scala/almond/Options.scala
index 7dc437fcc..741950b2a 100644
--- a/modules/scala/scala-kernel/src/main/scala/almond/Options.scala
+++ b/modules/scala/scala-kernel/src/main/scala/almond/Options.scala
@@ -72,7 +72,10 @@ final case class Options(
@ExtraName("outputDir")
outputDirectory: Option[String] = None,
@ExtraName("tmpOutputDir")
- tmpOutputDirectory: Option[Boolean] = None
+ tmpOutputDirectory: Option[Boolean] = None,
+
+ @HelpMessage("Add experimental support for Toree magics")
+ toreeMagics: Boolean = false
) {
// format: on
diff --git a/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala b/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala
index b6562b733..f5ea9bea9 100644
--- a/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala
+++ b/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala
@@ -131,7 +131,8 @@ object ScalaKernel extends CaseApp[Options] {
.toLeft {
options.tmpOutputDirectory
.getOrElse(true) // Create tmp output dir by default
- }
+ },
+ toreeMagics = options.toreeMagics
),
logCtx = logCtx
)
diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandler.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandler.scala
new file mode 100644
index 000000000..e92497289
--- /dev/null
+++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandler.scala
@@ -0,0 +1,7 @@
+package almond.toree
+
+import almond.api.JupyterApi
+
+trait CellMagicHandler {
+ def handle(name: String, content: String): Either[JupyterApi.ExecuteHookResult, String]
+}
diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandlers.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandlers.scala
new file mode 100644
index 000000000..8142c11c2
--- /dev/null
+++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHandlers.scala
@@ -0,0 +1,24 @@
+package almond.toree
+
+import almond.api.JupyterApi
+import almond.interpreter.api.DisplayData.ContentType
+import almond.interpreter.api.{DisplayData, OutputHandler}
+
+object CellMagicHandlers {
+
+ class DisplayDataHandler(publish: OutputHandler, contentType: String) extends CellMagicHandler {
+ def handle(name: String, content: String): Either[JupyterApi.ExecuteHookResult, String] = {
+ publish.display(DisplayData(Map(contentType -> content)))
+ Right("")
+ }
+ }
+
+ def handlers(publish: OutputHandler) = Map(
+ "html" -> new DisplayDataHandler(publish, ContentType.html),
+ "javascript" -> new DisplayDataHandler(publish, ContentType.js)
+ )
+
+ def handlerKeys: Iterable[String] =
+ handlers(OutputHandler.NopOutputHandler).keys
+
+}
diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHook.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHook.scala
new file mode 100644
index 000000000..038074a86
--- /dev/null
+++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/CellMagicHook.scala
@@ -0,0 +1,32 @@
+package almond.toree
+
+import almond.api.JupyterApi
+import almond.interpreter.api.OutputHandler
+
+import java.util.Locale
+
+object CellMagicHook {
+
+ def hook(publish: OutputHandler): JupyterApi.ExecuteHook = {
+ val handlers = CellMagicHandlers.handlers(publish)
+ code =>
+ val nameOpt = code.linesIterator.take(1).toList.collectFirst {
+ case name if name.startsWith("%%") =>
+ name.stripPrefix("%%")
+ }
+ nameOpt match {
+ case Some(name) =>
+ handlers.get(name.toLowerCase(Locale.ROOT)) match {
+ case Some(handler) =>
+ val content = code.linesWithSeparators.drop(1).mkString
+ handler.handle(name, content)
+ case None =>
+ System.err.println(s"Warning: ignoring unrecognized Toree cell magic $name")
+ Right("")
+ }
+ case None =>
+ Right(code)
+ }
+ }
+
+}
diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandler.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandler.scala
new file mode 100644
index 000000000..e10e1832f
--- /dev/null
+++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandler.scala
@@ -0,0 +1,7 @@
+package almond.toree
+
+import almond.api.JupyterApi
+
+trait LineMagicHandler {
+ def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String]
+}
diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandlers.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandlers.scala
new file mode 100644
index 000000000..328b341c6
--- /dev/null
+++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHandlers.scala
@@ -0,0 +1,185 @@
+package almond.toree
+
+import ammonite.util.Ref
+import almond.api.JupyterApi
+
+import java.net.URI
+import java.nio.file.Paths
+
+object LineMagicHandlers {
+
+ class AddDepHandler extends LineMagicHandler {
+ import AddDepHandler._
+ def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String] =
+ values match {
+ case Seq(org, name, ver, other @ _*) =>
+ val params = parseParams(other.toList, Params())
+ val depSuffix = if (params.transitive) "" else ",intransitive"
+ val repoSuffix = params.repositories.map(repo => s"; import $$repo.`$repo`").mkString
+
+ if (params.trace)
+ System.err.println(s"Warning: ignoring unsupported %AddDeps argument --trace")
+ if (params.verbose)
+ System.err.println(s"Warning: ignoring unsupported %AddDeps argument --verbose")
+ if (params.abortOnResolutionErrors)
+ System.err.println(
+ s"Warning: ignoring unsupported %AddDeps argument --abort-on-resolution-errors"
+ )
+
+ Right(s"import $$ivy.`$org:$name:$ver$depSuffix`$repoSuffix")
+ case _ =>
+ System.err.println(
+ s"Warning: ignoring malformed %AddDeps Toree magic (expected '%AddDeps org name version [optional-arguments*]')"
+ )
+ Right("")
+ }
+ }
+
+ object AddDepHandler {
+ // not 100% sure about the default values, but https://github.com/apache/incubator-toree/blob/5b19aac2e56a56d35c888acc4ed5e549b1f4ed7c/etc/examples/notebooks/magic-tutorial.ipynb
+ // seems to imply these are correct
+ private case class Params(
+ transitive: Boolean = false,
+ trace: Boolean = false,
+ verbose: Boolean = false,
+ abortOnResolutionErrors: Boolean = false,
+ repositories: Seq[String] = Nil
+ )
+
+ private def parseParams(args: List[String], params: Params): Params =
+ args match {
+ case Nil => params
+ case "--trace" :: t => parseParams(t, params.copy(trace = true))
+ case "--verbose" :: t => parseParams(t, params.copy(verbose = true))
+ case "--transitive" :: t => parseParams(t, params.copy(transitive = true))
+ case "--abort-on-resolution-errors" :: t =>
+ parseParams(t, params.copy(abortOnResolutionErrors = true))
+ case "--repository" :: repo :: t =>
+ parseParams(t, params.copy(repositories = params.repositories :+ repo))
+ case other :: t =>
+ System.err.println(s"Warning: ignoring unrecognized %AddDeps argument '$other'")
+ parseParams(t, params)
+ }
+ }
+
+ class AddJarHandler extends LineMagicHandler {
+ import AddJarHandler._
+ def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String] =
+ values match {
+ case Seq(url, other @ _*) =>
+ val uri = new URI(url)
+ val params = parseParams(other.toList, Params())
+ if (params.force)
+ System.err.println(s"Warning: ignoring unsupported %AddJar argument -f")
+ if (params.magic)
+ System.err.println(s"Warning: ignoring unsupported %AddJar argument --magic")
+ if (uri.getScheme == "file") {
+ val path = Paths.get(uri)
+ Right(s"import $$cp.`$path`")
+ }
+ else {
+ System.err.println(
+ s"Warning: ignoring %AddJar URL $url (unsupported protocol ${uri.getScheme})"
+ )
+ Right("")
+ }
+ case _ =>
+ System.err.println(
+ s"Warning: ignoring malformed %AddJar Toree magic (expected '%AddJar url [optional-arguments*]')"
+ )
+ Right("")
+ }
+ }
+
+ object AddJarHandler {
+ private case class Params(
+ force: Boolean = false,
+ magic: Boolean = false
+ )
+
+ private def parseParams(args: List[String], params: Params): Params =
+ args match {
+ case Nil => params
+ case "-f" :: t => parseParams(t, params.copy(force = true))
+ case "--magic" :: t => parseParams(t, params.copy(magic = true))
+ case other :: t =>
+ System.err.println(s"Warning: ignoring unrecognized %AddJar argument '$other'")
+ parseParams(t, params)
+ }
+ }
+
+ class LsMagicHandler extends LineMagicHandler {
+ def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String] = {
+ if (values.nonEmpty)
+ System.err.println(
+ s"Warning: ignoring unrecognized values passed to %LsMagic: ${values.mkString(" ")}"
+ )
+
+ println("Available line magics:")
+ println(handlerKeys.toVector.sorted.map("%" + _).mkString(" "))
+ println()
+
+ println("Available cell magics:")
+ println(CellMagicHandlers.handlerKeys.toVector.sorted.map("%%" + _).mkString(" "))
+ println()
+
+ Right("")
+ }
+ }
+
+ class TruncationHandler(pprinter: Ref[pprint.PPrinter]) extends LineMagicHandler {
+ private def enabled() = {
+ val current = pprinter()
+ current.defaultWidth != Int.MaxValue && current.defaultHeight != Int.MaxValue
+ }
+ private var formerWidth = pprinter().defaultWidth
+ private var formerHeight = pprinter().defaultHeight
+ private def disable(): Unit =
+ if (enabled()) {
+ formerWidth = pprinter().defaultWidth
+ formerHeight = pprinter().defaultHeight
+
+ pprinter.update {
+ pprinter().copy(
+ defaultWidth = Int.MaxValue,
+ defaultHeight = Int.MaxValue
+ )
+ }
+ }
+ private def enable(): Unit =
+ if (!enabled())
+ pprinter.update {
+ pprinter().copy(
+ defaultWidth = formerWidth,
+ defaultHeight = formerHeight
+ )
+ }
+ def handle(name: String, values: Seq[String]): Either[JupyterApi.ExecuteHookResult, String] = {
+ values match {
+ case Seq() =>
+ val state = if (enabled()) "on" else "off"
+ println(s"Truncation is currently $state")
+ case Seq("on") =>
+ enable()
+ println("Output WILL be truncated.")
+ case Seq("off") =>
+ disable()
+ println("Output will NOT be truncated")
+ case _ =>
+ System.err.println(
+ s"Warning: ignoring %truncation magic with unrecognized parameters ${values.mkString(" ")}"
+ )
+ }
+ Right("")
+ }
+ }
+
+ def handlers(pprinter: Ref[pprint.PPrinter]) = Map(
+ "adddeps" -> new AddDepHandler,
+ "addjar" -> new AddJarHandler,
+ "lsmagic" -> new LsMagicHandler,
+ "truncation" -> new TruncationHandler(pprinter)
+ )
+ def handlerKeys: Iterable[String] =
+ handlers(Ref(pprint.PPrinter())).keys
+}
diff --git a/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHook.scala b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHook.scala
new file mode 100644
index 000000000..0a3c1d7bd
--- /dev/null
+++ b/modules/scala/toree-hooks/src/main/scala/almond/toree/LineMagicHook.scala
@@ -0,0 +1,69 @@
+package almond.toree
+
+import almond.api.JupyterApi
+import ammonite.util.Ref
+
+import java.util.regex.Pattern
+import java.util.Locale
+
+import scala.collection.mutable
+
+object LineMagicHook {
+
+ private val sep = Pattern.compile("\\s+")
+
+ def inspect(code: String): Iterator[Either[(Seq[String], String, String), String]] = {
+ var parsingMagics = true
+ code.linesWithSeparators.zip(code.linesIterator).map {
+ case (rawLine, line) =>
+ if (parsingMagics && line.startsWith("%") && !line.drop(1).startsWith("%"))
+ Left((sep.split(line).toSeq, rawLine, line))
+ else {
+ if (parsingMagics) {
+ val trimmed = line.trim()
+ parsingMagics = trimmed.isEmpty || trimmed.startsWith("//")
+ }
+ Right(rawLine)
+ }
+ }
+ }
+
+ def hook(pprinter: Ref[pprint.PPrinter]): JupyterApi.ExecuteHook = {
+
+ val handlers = LineMagicHandlers.handlers(pprinter)
+
+ code =>
+
+ val magicsIt = inspect(code)
+ var errorOpt = Option.empty[JupyterApi.ExecuteHookResult]
+ val remainingCode = new StringBuilder
+ while (magicsIt.hasNext && errorOpt.isEmpty)
+ magicsIt.next() match {
+ case Left((elems, rawLine, line)) =>
+ assert(elems.nonEmpty)
+
+ val name = elems.head
+ val values = elems.tail
+
+ assert(name.startsWith("%"))
+
+ handlers.get(name.toLowerCase(Locale.ROOT).stripPrefix("%")) match {
+ case None =>
+ System.err.println(s"Warning: ignoring unrecognized Toree line magic $name")
+ case Some(handler) =>
+ handler.handle(name, values) match {
+ case Left(res) => errorOpt = Some(res)
+ case Right(substituteCode) =>
+ remainingCode ++= substituteCode
+ remainingCode ++= rawLine.substring(line.length)
+ }
+ }
+
+ case Right(code) =>
+ remainingCode ++= code
+ }
+
+ errorOpt.toLeft(remainingCode.result())
+ }
+
+}
diff --git a/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/OutputHandler.scala b/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/OutputHandler.scala
index ce0453463..04df0a615 100644
--- a/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/OutputHandler.scala
+++ b/modules/shared/interpreter-api/src/main/scala/almond/interpreter/api/OutputHandler.scala
@@ -86,4 +86,11 @@ object OutputHandler {
underlying.updateDisplay(displayData)
}
+ object NopOutputHandler extends OutputHandler {
+ def stdout(s: String): Unit = ()
+ def stderr(s: String): Unit = ()
+ def display(displayData: DisplayData): Unit = ()
+ def updateDisplay(displayData: DisplayData): Unit = ()
+ }
+
}
diff --git a/modules/shared/interpreter/src/main/scala/almond/interpreter/InterpreterToIOInterpreter.scala b/modules/shared/interpreter/src/main/scala/almond/interpreter/InterpreterToIOInterpreter.scala
index 248788ef6..0f6ccc598 100644
--- a/modules/shared/interpreter/src/main/scala/almond/interpreter/InterpreterToIOInterpreter.scala
+++ b/modules/shared/interpreter/src/main/scala/almond/interpreter/InterpreterToIOInterpreter.scala
@@ -29,9 +29,8 @@ final class InterpreterToIOInterpreter(
private val cancelledSignal0 = {
implicit val shift =
- IO.contextShift(
- interpreterEc
- ) // maybe not the right ec, but that one shouldn't be used yet at that point
+ // maybe not the right ec, but that one shouldn't be used yet at that point
+ IO.contextShift(interpreterEc)
SignallingRef[IO, Boolean](false).unsafeRunSync()
}
def cancelledSignal: SignallingRef[IO, Boolean] =
diff --git a/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala b/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala
index 130e37957..cdbd5d702 100644
--- a/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala
+++ b/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala
@@ -120,7 +120,7 @@ final case class Kernel(
immediateHandlers.handleOrLogError(channel, rawMessage, log) match {
case None =>
- log.warn(s"Ignoring unhandled message:\n$rawMessage")
+ log.warn(s"Ignoring unhandled message on $channel:\n$rawMessage")
Stream.empty
case Some(output) =>
diff --git a/modules/shared/kernel/src/test/scala/almond/kernel/ClientStreams.scala b/modules/shared/kernel/src/test/scala/almond/kernel/ClientStreams.scala
index 4cfca75e0..cf0b213eb 100644
--- a/modules/shared/kernel/src/test/scala/almond/kernel/ClientStreams.scala
+++ b/modules/shared/kernel/src/test/scala/almond/kernel/ClientStreams.scala
@@ -115,6 +115,23 @@ final case class ClientStreams(
}
.toMap
+ def executeErrors: Map[Int, (String, String, List[String])] =
+ generatedMessages
+ .iterator
+ .collect {
+ case Left((Channel.Requests, m)) if m.header.msg_type == Execute.replyType.messageType =>
+ m.decodeAs[Execute.Reply] match {
+ case Left(_) => Nil
+ case Right(m) => Seq(m.content)
+ }
+ }
+ .flatten
+ .collect {
+ case e: Execute.Reply.Error =>
+ (e.execution_count, (e.ename, e.evalue, e.traceback))
+ }
+ .toMap
+
def executeReplyPayloads: Map[Int, Seq[RawJson]] =
generatedMessages
.iterator
@@ -175,6 +192,16 @@ final case class ClientStreams(
case _ => Nil
}
}
+ case Left((Channel.Publish, m))
+ if m.header.msg_type == "stream" =>
+ m.decodeAs[Execute.Stream] match {
+ case Left(_) => Nil
+ case Right(m) =>
+ if (m.content.name == "stdout")
+ Seq(m.content.text)
+ else
+ Nil
+ }
case Left((Channel.Publish, m))
if m.header.msg_type == "display_data" || m.header.msg_type == "update_display_data" =>
m.decodeAs[Execute.DisplayData] match {
@@ -191,7 +218,7 @@ object ClientStreams {
import com.github.plokhotnyuk.jsoniter_scala.core._
- private implicit class RawJsonOps(private val rawJson: RawJson) extends AnyVal {
+ implicit class RawJsonOps(private val rawJson: RawJson) extends AnyVal {
def stringOrEmpty: String =
Try(readFromArray[String](rawJson.value)).toOption.getOrElse("")
}
diff --git a/project/deps.sc b/project/deps.sc
index 9cc186289..41cca24f6 100644
--- a/project/deps.sc
+++ b/project/deps.sc
@@ -57,6 +57,7 @@ object Deps {
def jvmRepr = ivy"com.github.jupyter:jvm-repr:0.4.0"
def mdoc = ivy"org.scalameta::mdoc:2.3.7"
def metabrowseServer = ivy"org.scalameta:::metabrowse-server:0.2.9"
+ def pprint = ivy"com.lihaoyi::pprint:0.8.1"
def scalafmtDynamic = ivy"org.scalameta::scalafmt-dynamic:${Versions.scalafmt}"
def scalapy = ivy"me.shadaj::scalapy-core:0.5.2"
def scalaReflect(sv: String) = ivy"org.scala-lang:scala-reflect:$sv"
@@ -103,4 +104,9 @@ object ScalaVersions {
"2.12.10",
"2.12.9"
).distinct
+
+ def binary(sv: String) =
+ if (sv.startsWith("2.12.")) scala212
+ else if (sv.startsWith("2.13.")) scala213
+ else scala3Compat
}
diff --git a/project/jupyterserver.sc b/project/jupyterserver.sc
index c52cacc0a..5304f0e05 100644
--- a/project/jupyterserver.sc
+++ b/project/jupyterserver.sc
@@ -13,7 +13,8 @@ def writeKernelJson(launcher: Path, jupyterDir: Path): Unit = {
"$launcherPath",
"--log", "info",
"--connection-file", "{connection_file}",
- "--variable-inspector"
+ "--variable-inspector",
+ "--toree-magics"
]
}"""
Files.write(dir.resolve("kernel.json"), kernelJson.getBytes("UTF-8"))