Skip to content

Commit

Permalink
Merge pull request #1077 from alexarchambault/execution-hooks
Browse files Browse the repository at this point in the history
Add support for execution hooks and Toree magics
  • Loading branch information
alexarchambault authored Apr 17, 2023
2 parents 79be90f + 331ddd9 commit 55e8bda
Show file tree
Hide file tree
Showing 26 changed files with 1,251 additions and 528 deletions.
14 changes: 12 additions & 2 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,14 @@ class ScalaInterpreter(val crossScalaVersion: String) extends AlmondModule with
if (crossScalaVersion.startsWith("3."))
Seq(
shared.interpreter(ScalaVersions.scala3Compat),
scala.`scala-kernel-api`()
scala.`scala-kernel-api`(),
scala.`toree-hooks`(ScalaVersions.binary(crossScalaVersion))
)
else
Seq(
shared.interpreter(),
scala.`scala-kernel-api`()
scala.`scala-kernel-api`(),
scala.`toree-hooks`(ScalaVersions.binary(crossScalaVersion))
)
def ivyDeps = T {
val metabrowse =
Expand Down Expand Up @@ -302,6 +304,12 @@ class Echo(val crossScalaVersion: String) extends AlmondModule {
}
}

class ToreeHooks(val crossScalaVersion: String) extends AlmondModule {
def compileModuleDeps = super.compileModuleDeps ++ Seq(
scala.`scala-kernel-api`(ScalaVersions.binary(crossScalaVersion))
)
}

object shared extends Module {
object `logger-scala2-macros` extends Cross[LoggerScala2Macros](ScalaVersions.binaries: _*)
object logger extends Cross[Logger](ScalaVersions.binaries: _*)
Expand All @@ -325,6 +333,8 @@ object scala extends Module {
extends Cross[ScalaKernelHelper](ScalaVersions.all.filter(_.startsWith("3.")): _*)
object `almond-scalapy` extends Cross[AlmondScalaPy](ScalaVersions.binaries: _*)
object `almond-rx` extends Cross[AlmondRx](ScalaVersions.scala212, ScalaVersions.scala213)

object `toree-hooks` extends Cross[ToreeHooks](ScalaVersions.binaries: _*)
}

object echo extends Cross[Echo](ScalaVersions.binaries: _*)
Expand Down
47 changes: 47 additions & 0 deletions docs/pages/api-jupyter.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,50 @@ kernel.publish.updateHtml("Got all items", id)
```

![](/demo/updatable.gif)

### Hooks

Hooks allow to pre-process code right before it's executed. Use like
```scala
private def runSql(sql: String): String = {
println("Running query...")
val fakeResult =
"""<table>
|<tr>
|<th>Id</th>
|<th>Name</th>
|</tr>
|<tr>
|<td>1</td>
|<td>Tree</td>
|</tr>
|<tr>
|<td>2</td>
|<td>Apple</td>
|</tr>
|</table>
|""".stripMargin
fakeResult
}

kernel.addExecuteHook { code =>
import almond.api.JupyterApi
import almond.interpreter.api.DisplayData

if (code.linesIterator.take(1).toList == List("%sql")) {
val sql = code.linesWithSeparators.drop(1).mkString // drop first line with "%sql"
val result = runSql(sql)
Left(JupyterApi.ExecuteHookResult.Success(DisplayData.html(result)))
}
else
Right(code) // just pass on code
}
```

Such code can be run either in a cell or in a predef file.

Later on, users can run things like
```text
%sql
SELECT id, name FROM my_table
```
9 changes: 9 additions & 0 deletions docs/pages/install-options.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,12 @@ Add entries to the Help menu (Jupyter classic). Use like

#### `--connection-file`

## Other

#### `--toree-magics`

Enable experimental support for [Toree](https://toree.apache.org) magics.

Simple line magics such as `%AddDeps` (always assumed to be transitive as of writing this, `--transitive` is just ignored),
`%AddJar`, and cell magics such as `%%html` or `%%javascript` are supported. Note that `%%javascript` only works from Jupyter
classic, as JupyterLab doesn't allow for random javascript code execution.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package almond.api

import java.util.UUID

import almond.interpreter.api.{CommHandler, OutputHandler}
import almond.interpreter.api.{CommHandler, DisplayData, OutputHandler}
import jupyter.{Displayer, Displayers}

import scala.reflect.{ClassTag, classTag}
Expand Down Expand Up @@ -47,10 +47,62 @@ abstract class JupyterApi { api =>
def display(t: T) = f(t).asJava
}
)

def addExecuteHook(hook: JupyterApi.ExecuteHook): Boolean
def removeExecuteHook(hook: JupyterApi.ExecuteHook): Boolean
}

object JupyterApi {

/** A hook, that can pre-process code right before it's executed
*/
@FunctionalInterface
abstract class ExecuteHook {

/** Pre-processes code right before it's executed.
*
* Like when actual code is executed, `println` / `{Console,System}.{out,err}` get sent to the
* cell output, stdin can be requested from users, CommHandler and OutputHandler can be used,
* etc.
*
* When several hooks were added, they are called in the order they were added. The output of
* the previous hook gets passed to the next one, as long as hooks return code to be executed
* rather than an `ExecuteHookResult`.
*
* @param code
* Code to be pre-processed
* @return
* Either code to be executed (right), or an `ExecuteHookResult` (left)
*/
def hook(code: String): Either[ExecuteHookResult, String]
}

/** Can be returned by `ExecuteHook.hook` to stop code execution.
*/
sealed abstract class ExecuteHookResult extends Product with Serializable
object ExecuteHookResult {

/** Returns data to be displayed */
final case class Success(data: DisplayData = DisplayData.empty) extends ExecuteHookResult

/** Exception-like error
*
* If you'd like to build one out of an actual `Throwable`, just throw it. It will then be
* caught while the hook is running, and sent to users.
*/
final case class Error(
name: String,
message: String,
stackTrace: List[String]
) extends ExecuteHookResult

/** Tells the front-end that execution was aborted */
case object Abort extends ExecuteHookResult

/** Should instruct the front-end to prompt the user for exit */
case object Exit extends ExecuteHookResult
}

abstract class UpdatableResults {

@deprecated("Use updatable instead", "0.4.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import scalaparse.Scala._

object AlmondParsers {

private def Prelude[_: P] = P((Annot ~ OneNLMax).rep ~ (Mod ~/ Pass).rep)
private def Prelude[X: P] = P((Annot ~ OneNLMax).rep ~ (Mod ~/ Pass).rep)

// same as the methods with the same name in ammonite.interp.Parsers, but keeping the type aside in LHS

def PatVarSplitter[_: P] = {
def PatVarSplitter[X: P] = {
def Prefixes = P(Prelude ~ (`var` | `val`))
def Lhs = P(Prefixes ~/ VarId)
def TypeAnnotation = P((`:` ~/ Type.!).?)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import ammonite.util.Util.newLine
import metabrowse.server.{MetabrowseServer, Sourcepath}
import scala.meta.dialects

import scala.collection.compat._
import scala.tools.nsc.Global
import scala.tools.nsc.interactive.{Global => Interactive}
import scala.util.Random
Expand Down Expand Up @@ -231,13 +232,13 @@ object ScalaInterpreterInspections {
javaDirs.exists(path.startsWith)
}

def classpath(cl: ClassLoader): Stream[java.net.URL] =
def classpath(cl: ClassLoader): immutable.LazyList[java.net.URL] =
if (cl == null)
Stream()
immutable.LazyList()
else {
val cp = cl match {
case u: java.net.URLClassLoader => u.getURLs.toStream
case _ => Stream()
case u: java.net.URLClassLoader => u.getURLs.to(immutable.LazyList)
case _ => immutable.LazyList()
}

cp #::: classpath(cl.getParent)
Expand Down
110 changes: 80 additions & 30 deletions modules/scala/scala-interpreter/src/main/scala/almond/Execute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import fastparse.Parsed
import scala.collection.mutable
import scala.concurrent.{Await, ExecutionContext}
import scala.concurrent.duration.Duration
import scala.util.{Failure, Success}
import scala.util.{Failure, Success, Try}

/** Wraps contextual things around when executing code (capturing output, stdin via front-ends,
* interruption, etc.)
Expand Down Expand Up @@ -65,7 +65,7 @@ final class Execute(
Duration.Inf
)
}
log.info("Received input")
log.info(s"Received input ${res.map { case "" => "[empty]"; case _ => "[non empty]" }}")

res match {
case Success(s) => Some(s)
Expand Down Expand Up @@ -136,15 +136,16 @@ final class Execute(
}
}

private def withInputManager[T](m: Option[InputManager])(f: => T): T = {
private def withInputManager[T](m: Option[InputManager], done: Boolean = true)(f: => T): T = {
val previous = currentInputManagerOpt0
try {
currentInputManagerOpt0 = m
f
}
finally {
currentInputManagerOpt0 = previous
m.foreach(_.done())
if (done)
m.foreach(_.done())
}
}

Expand Down Expand Up @@ -337,44 +338,93 @@ final class Execute(
inputManager: Option[InputManager],
outputHandler: Option[OutputHandler],
colors0: Ref[Colors],
storeHistory: Boolean
storeHistory: Boolean,
executeHooks: Seq[JupyterApi.ExecuteHook]
): ExecuteResult = {

if (storeHistory) {
storage.fullHistory() = storage.fullHistory() :+ code
history0 = history0 :+ code
}

ammResult(ammInterp, code, inputManager, outputHandler, storeHistory) match {
case Res.Success((_, data)) =>
ExecuteResult.Success(data)
case Res.Failure(msg) =>
interruptedStackTraceOpt0 match {
case None =>
val err = Execute.error(colors0(), None, msg)
outputHandler.foreach(_.stderr(err.message)) // necessary?
err
case Some(st) =>
val cutoff = Set("$main", "evaluatorRunPrinter")

ExecuteResult.Error(
(
"Interrupted!" +: st
.takeWhile(x => !cutoff(x.getMethodName))
.map(Execute.highlightFrame(_, fansi.Attr.Reset, colors0().literal()))
).mkString(System.lineSeparator())
)
val finalCodeOrResult =
withOutputHandler(outputHandler) {
interruptible {
withInputManager(inputManager, done = false) {
withClientStdin {
capturingOutput {
executeHooks.foldLeft[Try[Either[JupyterApi.ExecuteHookResult, String]]](
Success(Right(code))
) {
(codeOrDisplayDataAttempt, hook) =>
codeOrDisplayDataAttempt.flatMap { codeOrDisplayData =>
try Success(codeOrDisplayData.flatMap { value =>
hook.hook(value)
})
catch {
case e: Throwable => // kind of meh, but Ammonite does the same it seems…
Failure(e)
}
}
}
}
}
}
}
}

case Res.Exception(ex, msg) =>
log.error(s"exception in user code (${ex.getMessage})", ex)
Execute.error(colors0(), Some(ex), msg)
finalCodeOrResult match {
case Failure(ex) =>
log.error(s"exception when running hooks (${ex.getMessage})", ex)
Execute.error(colors0(), Some(ex), "")

case Success(Left(res)) =>
res match {
case s: JupyterApi.ExecuteHookResult.Success =>
ExecuteResult.Success(s.data)
case e: JupyterApi.ExecuteHookResult.Error =>
ExecuteResult.Error(e.name, e.message, e.stackTrace)
case JupyterApi.ExecuteHookResult.Abort =>
ExecuteResult.Abort
case JupyterApi.ExecuteHookResult.Exit =>
ExecuteResult.Exit
}

case Res.Skip =>
case Success(Right(emptyCode)) if emptyCode.trim.isEmpty =>
ExecuteResult.Success()

case Res.Exit(_) =>
ExecuteResult.Exit
case Success(Right(finalCode)) =>
ammResult(ammInterp, finalCode, inputManager, outputHandler, storeHistory) match {
case Res.Success((_, data)) =>
ExecuteResult.Success(data)
case Res.Failure(msg) =>
interruptedStackTraceOpt0 match {
case None =>
val err = Execute.error(colors0(), None, msg)
outputHandler.foreach(_.stderr(err.message)) // necessary?
err
case Some(st) =>
val cutoff = Set("$main", "evaluatorRunPrinter")

ExecuteResult.Error(
(
"Interrupted!" +: st
.takeWhile(x => !cutoff(x.getMethodName))
.map(Execute.highlightFrame(_, fansi.Attr.Reset, colors0().literal()))
).mkString(System.lineSeparator())
)
}

case Res.Exception(ex, msg) =>
log.error(s"exception in user code (${ex.getMessage})", ex)
Execute.error(colors0(), Some(ex), msg)

case Res.Skip =>
ExecuteResult.Success()

case Res.Exit(_) =>
ExecuteResult.Exit
}
}
}
}
Expand Down
Loading

0 comments on commit 55e8bda

Please sign in to comment.