Skip to content

Commit

Permalink
Fix code generation for local functions declared by let rec (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengluyu authored Jan 2, 2024
1 parent 29e38d1 commit 1a7463e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
47 changes: 26 additions & 21 deletions shared/src/main/scala/mlscript/JSBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ import scala.util.chaining._

abstract class JSBackend(allowUnresolvedSymbols: Bool) {
def oldDefs: Bool

protected implicit class TermOps(term: Term) {
def isLam: Bool = term match {
case _: Lam => true
case Bra(false, inner) => inner.isLam
case Asc(inner, _) => inner.isLam
case _ => false
}
}

/**
* The root scope of the program.
Expand Down Expand Up @@ -237,7 +246,8 @@ abstract class JSBackend(allowUnresolvedSymbols: Bool) {
case (t: Term, index) => JSExprStmt(translateTerm(t)(blkScope))
case (NuFunDef(isLetRec, Var(nme), symNme, _, L(rhs)), _) =>
val symb = symNme.map(_.name)
val pat = blkScope.declareValue(nme, isLetRec, isLetRec.isEmpty, symb)
val isLocalFunction = isLetRec.isEmpty || rhs.isLam
val pat = blkScope.declareValue(nme, isLetRec, isLocalFunction, symb)
JSLetDecl(Ls(pat.runtimeName -> S(translateTerm(rhs)(blkScope))))
case (nt: NuTypeDef, _) => translateLocalNewType(nt)(blkScope)
// TODO: find out if we need to support this.
Expand Down Expand Up @@ -1166,18 +1176,17 @@ class JSWebBackend extends JSBackend(allowUnresolvedSymbols = true) {
// ```
.concat(otherStmts.flatMap {
case Def(recursive, Var(name), L(body), isByname) =>
val isLam = body.isInstanceOf[Lam]
val (originalExpr, sym) = if (recursive) {
val isByvalueRecIn = if (isByname) None else Some(true)
val sym = topLevelScope.declareValue(name, isByvalueRecIn, isLam, N)
val sym = topLevelScope.declareValue(name, isByvalueRecIn, body.isLam, N)
val translated = translateTerm(body)(topLevelScope)
topLevelScope.unregisterSymbol(sym)
val isByvalueRecOut = if (isByname) None else Some(false)
(translated, topLevelScope.declareValue(name, isByvalueRecOut, isLam, N))
(translated, topLevelScope.declareValue(name, isByvalueRecOut, body.isLam, N))
} else {
val translatedBody = translateTerm(body)(topLevelScope)
val isByvalueRec = if (isByname) None else Some(false)
(translatedBody, topLevelScope.declareValue(name, isByvalueRec, isLam, N))
(translatedBody, topLevelScope.declareValue(name, isByvalueRec, body.isLam, N))
}
val translatedBody = if (sym.isByvalueRec.isEmpty && !sym.isLam) JSArrowFn(Nil, L(originalExpr)) else originalExpr
topLevelScope.tempVars `with` JSConstDecl(sym.runtimeName, translatedBody) ::
Expand Down Expand Up @@ -1227,22 +1236,21 @@ class JSWebBackend extends JSBackend(allowUnresolvedSymbols = true) {
case NuFunDef(isLetRec, nme @ Var(name), symNme, tys, rhs @ L(body)) =>
val recursive = isLetRec.getOrElse(true)
val isByname = isLetRec.isEmpty
val bodyIsLam = body match { case _: Lam => true case _ => false }
val symb = symNme.map(_.name)
val (originalExpr, sym) = (if (recursive) {
val isByvalueRecIn = if (isByname) None else Some(true)

// TODO Improve: (Lionel) what?!
val sym = topLevelScope.declareValue(name, isByvalueRecIn, bodyIsLam, N)
val sym = topLevelScope.declareValue(name, isByvalueRecIn, body.isLam, N)
val translated = translateTerm(body)(topLevelScope)
topLevelScope.unregisterSymbol(sym)

val isByvalueRecOut = if (isByname) None else Some(false)
(translated, topLevelScope.declareValue(name, isByvalueRecOut, bodyIsLam, symb))
(translated, topLevelScope.declareValue(name, isByvalueRecOut, body.isLam, symb))
} else {
val translated = translateTerm(body)(topLevelScope)
val isByvalueRec = if (isByname) None else Some(false)
(translated, topLevelScope.declareValue(name, isByvalueRec, bodyIsLam, symb))
(translated, topLevelScope.declareValue(name, isByvalueRec, body.isLam, symb))
})
val translatedBody = if (sym.isByvalueRec.isEmpty && !sym.isLam) JSArrowFn(Nil, L(originalExpr)) else originalExpr
resultNames += sym.runtimeName
Expand Down Expand Up @@ -1320,23 +1328,22 @@ abstract class JSTestBackend extends JSBackend(allowUnresolvedSymbols = false) {
// Generate statements.
val queries = otherStmts.map {
case Def(recursive, Var(name), L(body), isByname) =>
val bodyIsLam = body match { case _: Lam => true case _ => false }
(if (recursive) {
val isByvalueRecIn = if (isByname) None else Some(true)
val sym = scope.declareValue(name, isByvalueRecIn, bodyIsLam, N)
val sym = scope.declareValue(name, isByvalueRecIn, body.isLam, N)
try {
val translated = translateTerm(body)
scope.unregisterSymbol(sym)
val isByvalueRecOut = if (isByname) None else Some(false)
R((translated, scope.declareValue(name, isByvalueRecOut, bodyIsLam, N)))
R((translated, scope.declareValue(name, isByvalueRecOut, body.isLam, N)))
} catch {
case e: UnimplementedError =>
scope.stubize(sym, e.symbol)
L(e.getMessage())
case NonFatal(e) =>
scope.unregisterSymbol(sym)
val isByvalueRecOut = if (isByname) None else Some(false)
scope.declareValue(name, isByvalueRecOut, bodyIsLam, N)
scope.declareValue(name, isByvalueRecOut, body.isLam, N)
throw e
}
} else {
Expand All @@ -1346,7 +1353,7 @@ abstract class JSTestBackend extends JSBackend(allowUnresolvedSymbols = false) {
L(e.getMessage())
}) map {
val isByvalueRec = if (isByname) None else Some(false)
expr => (expr, scope.declareValue(name, isByvalueRec, bodyIsLam, N))
expr => (expr, scope.declareValue(name, isByvalueRec, body.isLam, N))
}
}) match {
case R((originalExpr, sym)) =>
Expand Down Expand Up @@ -1404,9 +1411,8 @@ abstract class JSTestBackend extends JSBackend(allowUnresolvedSymbols = false) {
case fd @ NuFunDef(isLetRec, Var(nme), symNme, _, L(body)) =>
val isByname = isLetRec.isEmpty
val isByvalueRecIn = if (isByname) None else Some(true)
val bodyIsLam = body match { case _: Lam => true case _ => false }
val symb = symNme.map(_.name)
scope.declareValue(nme, isByvalueRecIn, bodyIsLam, symb, true)
scope.declareValue(nme, isByvalueRecIn, body.isLam, symb, true)
case _ => ()
}

Expand Down Expand Up @@ -1434,26 +1440,25 @@ abstract class JSTestBackend extends JSBackend(allowUnresolvedSymbols = false) {
case NuFunDef(isLetRec, nme @ Var(name), symNme, tys, rhs @ L(body)) =>
val recursive = isLetRec.getOrElse(true)
val isByname = isLetRec.isEmpty
val bodyIsLam = body match { case _: Lam => true case _ => false }
val symb = symNme.map(_.name)
(if (recursive) {
val isByvalueRecIn = if (isByname) None else Some(true)
val sym = scope.resolveValue(name) match {
case Some(s: ValueSymbol) => s
case _ => scope.declareValue(name, isByvalueRecIn, bodyIsLam, symb)
case _ => scope.declareValue(name, isByvalueRecIn, body.isLam, symb)
}
val isByvalueRecOut = if (isByname) None else Some(false)
try {
val translated = translateTerm(body) // TODO Improve: (Lionel) Why are the bodies translated in the SAME scope?!
scope.unregisterSymbol(sym) // TODO Improve: (Lionel) ???
R((translated, scope.declareValue(name, isByvalueRecOut, bodyIsLam, symb)))
R((translated, scope.declareValue(name, isByvalueRecOut, body.isLam, symb)))
} catch {
case e: UnimplementedError =>
scope.stubize(sym, e.symbol)
L(e.getMessage())
case NonFatal(e) =>
scope.unregisterSymbol(sym) // TODO Improve: (Lionel) You should only try/catch around the part that may actually fail, and if `unregisterSymbol` should always be called, that should be done in `finally`... but the very logic of calling `unregisterSymbol` is very fishy, to say the least
scope.declareValue(name, isByvalueRecOut, bodyIsLam, symb)
scope.declareValue(name, isByvalueRecOut, body.isLam, symb)
throw e
}
} else {
Expand All @@ -1463,7 +1468,7 @@ abstract class JSTestBackend extends JSBackend(allowUnresolvedSymbols = false) {
L(e.getMessage())
}) map {
val isByvalueRec = if (isByname) None else Some(false)
expr => (expr, scope.declareValue(name, isByvalueRec, bodyIsLam, symb))
expr => (expr, scope.declareValue(name, isByvalueRec, body.isLam, symb))
}
}) match {
case R((originalExpr, sym)) =>
Expand Down
18 changes: 9 additions & 9 deletions shared/src/test/diff/nu/LetRec.mls
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,29 @@ fun test =
//│ Code generation encountered an error:
//│ unguarded recursive use of by-value binding f

:ge // TODO this one should actually be accepted by codegen!
fun test =
let rec f() = f()
//│ fun test: ()
//│ Code generation encountered an error:
//│ unguarded recursive use of by-value binding f

:ge // TODO this one should actually be accepted by codegen!
fun test =
let rec lol = () => lol
//│ fun test: ()
//│ Code generation encountered an error:
//│ unguarded recursive use of by-value binding lol

:ge // TODO this one should actually be accepted by codegen!
fun test =
let rec lol() = lol
lol
//│ fun test: forall 'lol. 'lol
//│ where
//│ 'lol :> () -> 'lol
//│ Code generation encountered an error:
//│ unguarded recursive use of by-value binding lol

fun testWithAsc =
let rec aux: Int -> Int = x => if x <= 0 then 1 else x * aux(x - 1)
aux(10)
testWithAsc
//│ fun testWithAsc: Int
//│ Int
//│ res
//│ = 3628800

let rec lol = () => lol
//│ let rec lol: forall 'lol. 'lol
Expand Down

0 comments on commit 1a7463e

Please sign in to comment.