Skip to content

Commit

Permalink
Fix lack of stability in inferred types (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
LPTK authored Mar 13, 2024
1 parent c6a1c1e commit 1be5a14
Show file tree
Hide file tree
Showing 114 changed files with 1,326 additions and 1,075 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Specializer(monoer: Monomorph)(using debug: Debug){
case TypeVal(name) =>
BoundedTerm(monoer.createObjValue(name, extractFuncArgs(rhs).map(getRes).toList))
case l@LiteralVal(i) => BoundedTerm(l)
case _ => utils.die
}.fold(BoundedTerm())(_ ++ _))
case New(Some((constructor, args)), body) =>
evaluate(args)
Expand Down Expand Up @@ -102,6 +103,7 @@ class Specializer(monoer: Monomorph)(using debug: Debug){
then BoundedTerm(LiteralVal(UnitLit(false)))
else stmts.reverse.head match
case t: Term => getRes(t)
case _ => utils.die
})
case If(body, alternate) =>
val res = body match
Expand Down Expand Up @@ -138,6 +140,7 @@ class Specializer(monoer: Monomorph)(using debug: Debug){
// case Assign(lhs, rhs) => ???
// case New(None, body) => ???
// case Rcd(fields) => ???
case _ => utils.die
debug.outdent()
debug.writeLine(s"╙Result ${getRes(term).getValue.map(_.toString).toList}:")

Expand Down Expand Up @@ -178,10 +181,12 @@ class Specializer(monoer: Monomorph)(using debug: Debug){
then
IfThen(App(Var(name), toTuple(params.map(k => Var(k)).toList)), field)
else throw MonomorphError(s"Selection of field ${field} from object ${o} results in no values")
case _ => utils.die
}
valSetToBranches(next, Left(branchCase) :: acc)
// case t@TupVal(fields) =>
// val selValue = fields.getOrElse(field, throw MonomorphError(s"Invalid field selection ${field} from Tuple"))
case _ => utils.die


val ret = term match
Expand Down
2 changes: 1 addition & 1 deletion shared/src/main/scala/mlscript/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ConstraintSolver extends NormalForms { self: Typer =>
// 200
250

type ExtrCtx = MutMap[TV, Buffer[(Bool, ST)]] // tv, is-lower, bound
type ExtrCtx = MutSortMap[TV, Buffer[(Bool, ST)]] // tv, is-lower, bound

protected var currentConstrainingRun = 0

Expand Down
79 changes: 77 additions & 2 deletions shared/src/main/scala/mlscript/NormalForms.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,33 @@ class NormalForms extends TyperDatatypes { self: Typer =>


sealed abstract class LhsNf {
final def compareEquiv(that: LhsNf): Int = (this, that) match {
case (LhsRefined(b1, ts1, r1, trs1), LhsRefined(b2, ts2, r2, trs2)) =>
var cmp = (b1, b2) match {
case (S(c1), S(c2)) => c1.compareEquiv(c2)
case (S(c1), N) => -1
case (N, S(c2)) => 1
case (N, N) => 0
}
if (cmp =/= 0) return cmp
// * Just compare the heads for simplicity...
cmp = (trs1.headOption, trs2.headOption) match {
case (S((n1, _)), S((n2, _))) =>
n1.compare(n2) // * in principle we could go on to compare the tails if this is 0
case (S(_), N) => 1
case (N, S(_)) => -1
case (N, N) => 0
}
if (cmp =/= 0) return cmp
cmp = -trs1.sizeCompare(trs2)
if (cmp =/= 0) return cmp
cmp = ts1.sizeCompare(ts2.size)
if (cmp =/= 0) return cmp
cmp = r1.fields.sizeCompare(r2.fields)
cmp
case (LhsTop, _) => 1
case (_, LhsTop) => -1
}
def toTypes: Ls[SimpleType] = toType() :: Nil
def toType(sort: Bool = false): SimpleType =
if (sort) mkType(true) else underlying
Expand Down Expand Up @@ -264,6 +291,32 @@ class NormalForms extends TyperDatatypes { self: Typer =>


sealed abstract class RhsNf {
final def compareEquiv(that: RhsNf): Int = (this, that) match {
case (RhsField(n1, t1), RhsField(n2, t2)) => n1.compare(n2)
case (RhsBases(ps1, bf1, trs1), RhsBases(ps2, bf2, trs2)) =>
var cmp = ps1.minOption match {
case S(m1) => ps2.minOption match {
case S(m2) => m1.compare(m2)
case N => ps1.size.compare(ps2.size)
}
case N => ps1.size.compare(ps2.size)
}
if (cmp =/= 0) return cmp
cmp = (trs1.headOption, trs2.headOption) match {
case (S((n1, _)), S((n2, _))) => n1.compare(n2)
case (S(_), N) => 1
case (N, S(_)) => -1
case (N, N) => 0
}
if (cmp =/= 0) return cmp
cmp = -trs1.sizeCompare(trs2)
cmp
case (_: RhsBases, _) => -1
case (_, _: RhsBases) => 1
case (_: RhsField, _) => -1
case (_, _: RhsField) => 1
case (RhsBot, RhsBot) => 0
}
def toTypes: Ls[SimpleType] = toType() :: Nil
def toType(sort: Bool = false): SimpleType =
if (sort) mkType(true) else underlying
Expand Down Expand Up @@ -444,8 +497,20 @@ class NormalForms extends TyperDatatypes { self: Typer =>
}


case class Conjunct(lnf: LhsNf, vars: SortedSet[TypeVariable], rnf: RhsNf, nvars: SortedSet[TypeVariable]) extends Ordered[Conjunct] {
def compare(that: Conjunct): Int = this.mkString compare that.mkString // TODO less inefficient!!
case class Conjunct(lnf: LhsNf, vars: SortedSet[TypeVariable], rnf: RhsNf, nvars: SortedSet[TypeVariable]) {
final def compareEquiv(that: Conjunct): Int =
// trace(s"compareEquiv($this, $that)")(compareEquivImpl(that))(r => s"= $r")
compareEquivImpl(that)
final def compareEquivImpl(that: Conjunct): Int = {
var cmp = lnf.compareEquiv(that.lnf)
if (cmp =/= 0) return cmp
cmp = rnf.compareEquiv(that.rnf)
if (cmp =/= 0) return cmp
cmp = -vars.sizeCompare(that.vars)
if (cmp =/= 0) return cmp
cmp = -nvars.sizeCompare(that.nvars)
cmp
}
def toType(sort: Bool = false): SimpleType =
toTypeWith(_.toType(sort), _.toType(sort), sort)
def toTypeWith(f: LhsNf => SimpleType, g: RhsNf => SimpleType, sort: Bool = false): SimpleType =
Expand Down Expand Up @@ -576,6 +641,15 @@ class NormalForms extends TyperDatatypes { self: Typer =>
case RhsBot => RhsBot
}, nvars)
}
/** Scala's standard library is weird. I would have normally made Conjunct extend Ordered[Conjunct],
* but the contract of Ordered says that `equals` and `hashCode` should be "consistent" with `compare`,
* which I understand as two things comparing to 0 HAVING to be equal and to have the same hash code...
* But achieving this is very expensive for general type forms.
* All we want to do here is to define an ordering between implicit equivalence classes
* whose members are not necessarily equal. Which is fine since we only use this to do stable sorts. */
implicit object Ordering extends Ordering[Conjunct] {
def compare(x: Conjunct, y: Conjunct): Int = x.compareEquiv(y)
}
}


Expand Down Expand Up @@ -679,6 +753,7 @@ class NormalForms extends TyperDatatypes { self: Typer =>
val (newLvl, thisCs, thatCs, thisCons, thatCons) = levelWith(that)
// println(s"-- $polymLevel ${that.polymLevel} $newLvl")
thatCs.foldLeft(DNF(newLvl, thisCons ::: thatCons, thisCs))(_ | _)
// ^ Note: conjuncting the constrained-type constraints here is probably the wrong thing to do...
}
def & (that: Conjunct, pol: Bool)(implicit ctx: Ctx, etf: ExpandTupleFields): DNF =
DNF(polymLevel, cons, cs.flatMap(_ & (that, pol))) // TODO may need further simplif afterward
Expand Down
32 changes: 21 additions & 11 deletions shared/src/main/scala/mlscript/TypeSimplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ trait TypeSimplifier { self: Typer =>
/** Remove bounds that are not reachable by traversing the type following variances.
* Note that doing this on annotated type signatures would need to use polarity None
* because a type signature can both be used (positively) and checked against (negatively). */
def removeIrrelevantBounds(ty: TypeLike, pol: Opt[Bool], inPlace: Bool = false)
def removeIrrelevantBounds(ty: TypeLike, pol: Opt[Bool], reverseBoundsOrder: Bool, inPlace: Bool = false)
(implicit ctx: Ctx): TypeLike =
{
val _ctx = ctx
Expand Down Expand Up @@ -67,11 +67,15 @@ trait TypeSimplifier { self: Typer =>
}
case N =>
nv.lowerBounds = if (allVarPols(tv).forall(_ === true))
tv.lowerBounds.reverseIterator.map(process(_, S(true -> tv)))
(if (reverseBoundsOrder) tv.lowerBounds.reverseIterator
else tv.lowerBounds.iterator
).map(process(_, S(true -> tv)))
.reduceOption(_ | _).filterNot(_.isBot).toList
else Nil
nv.upperBounds = if (allVarPols(tv).forall(_ === false))
tv.upperBounds.reverseIterator.map(process(_, S(false -> tv)))
(if (reverseBoundsOrder) tv.upperBounds.reverseIterator
else tv.upperBounds.iterator
).map(process(_, S(false -> tv)))
.reduceOption(_ &- _).filterNot(_.isTop).toList
else Nil
}
Expand Down Expand Up @@ -198,7 +202,7 @@ trait TypeSimplifier { self: Typer =>
// * where `T` is `A & B & C`.
// * It is fine to call `go` because we made sure A, B, C, etc. do not themsleves have any negative components.
val csNegs2 = if (csNegs.isEmpty) BotType
else go(csNegs.foldLeft(TopType: ST)(_ & _.toType().neg()), pol.map(!_)).neg()
else go(csNegs.foldLeft(TopType: ST)(_ & _.toType(sort = true).neg()), pol.map(!_)).neg() // TODO sort?! csNegs and toType

val otherCs2 = otherCs.sorted.map { c =>
c.vars.foreach(processVar)
Expand Down Expand Up @@ -484,8 +488,10 @@ trait TypeSimplifier { self: Typer =>
case S(ty) =>
tv.assignedTo = S(go(ty, N))
case N =>
tv.lowerBounds = tv.lowerBounds.map(go(_, S(true)))
tv.upperBounds = tv.upperBounds.map(go(_, S(false)))
// tv.lowerBounds = tv.lowerBounds.map(go(_, S(true)))
// tv.upperBounds = tv.upperBounds.map(go(_, S(false)))
tv.lowerBounds = tv.lowerBounds.reduceOption(_ | _).fold(nil[ST])(go(_, S(true)) :: Nil)
tv.upperBounds = tv.upperBounds.reduceOption(_ & _).fold(nil[ST])(go(_, S(false)) :: Nil)
}
}
}
Expand Down Expand Up @@ -566,7 +572,7 @@ trait TypeSimplifier { self: Typer =>

// * Note: for negatively-quantified vars, the notion of co-occurrence is reversed (wrt unions/inters)...

val coOccurrences: MutMap[(Bool, TypeVariable), MutSet[SimpleType]] = LinkedHashMap.empty
val coOccurrences: MutMap[(Bool, TypeVariable), LinkedHashSet[SimpleType]] = MutMap.empty

// * Remember which TVs we analyzed at which polarity
val analyzed2 = MutSet.empty[Bool -> ST]
Expand Down Expand Up @@ -668,7 +674,7 @@ trait TypeSimplifier { self: Typer =>
}

def processImpl(st: SimpleType, pol: PolMap, occPol: Bool) = {
val newOccs = MutSet.empty[SimpleType]
val newOccs = LinkedHashSet.empty[SimpleType]

println(s">> Processing $st at [${printPol(S(occPol))}]")

Expand Down Expand Up @@ -705,7 +711,7 @@ trait TypeSimplifier { self: Typer =>
case Some(os) =>
// Q: filter out vars of different level?
os.filterInPlace(occs) // computes the intersection
case None => coOccurrences(pol -> tv) = occs.clone() // `clone` not needed?
case None => coOccurrences(pol -> tv) = LinkedHashSet.from(occs) // copy not needed?
}
}
pol(tv) match {
Expand Down Expand Up @@ -1264,7 +1270,9 @@ trait TypeSimplifier { self: Typer =>
debugOutput(s"⬤ Initial: ${cur}")
debugOutput(s" where: ${cur.showBounds}")

cur = removeIrrelevantBounds(cur, pol, inPlace = false)
cur = removeIrrelevantBounds(cur, pol,
reverseBoundsOrder = true, // bounds are accumulated by type inference in reverse order of appearance; so nicer to reverse them here
inPlace = false)
debugOutput(s"⬤ Cleaned up: ${cur}")
debugOutput(s" where: ${cur.showBounds}")

Expand All @@ -1285,7 +1293,9 @@ trait TypeSimplifier { self: Typer =>
debugOutput(s"⬤ Normalized: ${cur}")
debugOutput(s" where: ${cur.showBounds}")

cur = removeIrrelevantBounds(cur, pol, inPlace = true)
cur = removeIrrelevantBounds(cur, pol,
reverseBoundsOrder = false,
inPlace = true)
debugOutput(s"⬤ Cleaned up: ${cur}")
debugOutput(s" where: ${cur.showBounds}")

Expand Down
30 changes: 17 additions & 13 deletions shared/src/main/scala/mlscript/Typer.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package mlscript

import scala.collection.mutable
import scala.collection.mutable.{Map => MutMap, Set => MutSet}
import scala.collection.mutable.{Map => MutMap, Set => MutSet, SortedMap => MutSortMap, LinkedHashMap, LinkedHashSet, Buffer}
import scala.collection.immutable.{SortedSet, SortedMap}
import Set.{empty => semp}
import scala.util.chaining._
Expand Down Expand Up @@ -63,7 +63,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
mthEnv: MutMap[(Str, Str) \/ (Opt[Str], Str), MethodType],
lvl: Int,
quoteSkolemEnv: MutMap[Str, SkolemTag], // * SkolemTag for variables in quasiquotes
freeVarsInCurrentQuote: MutSet[ST], // * Free variables appearing in the current quote scope
freeVarsInCurrentQuote: LinkedHashSet[ST], // * Free variables appearing in the current quote scope
inQuote: Bool, // * Is in quasiquote
inPattern: Bool,
tyDefs: Map[Str, TypeDef],
Expand Down Expand Up @@ -135,10 +135,10 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
* So for `code"x => ..."`, freeVarsInCurrentQuote = {'a}, quoteSkolemEnv = {'gx}, where 'gx <= 'a.
* After calling `enterQuotedScope`, **solve the constraints** using `solveQuoteContext` to make sure free variables are handled correctly.
*/
def enterQuotedScope: Ctx = copy(Some(this), MutMap.empty, MutMap.empty, lvl = lvl + 1, inQuote = true, quoteSkolemEnv = MutMap.empty, freeVarsInCurrentQuote = MutSet.empty)
def enterQuotedScope: Ctx = copy(Some(this), MutMap.empty, MutMap.empty, lvl = lvl + 1, inQuote = true, quoteSkolemEnv = MutMap.empty, freeVarsInCurrentQuote = LinkedHashSet.empty)
def enterUnquote: Ctx = copy(Some(this), MutMap.empty, MutMap.empty, inQuote = false)
def nextLevel[R](k: Ctx => R)(implicit raise: Raise, prov: TP): R = {
val newCtx = copy(lvl = lvl + 1, extrCtx = MutMap.empty)
val newCtx = copy(lvl = lvl + 1, extrCtx = MutSortMap.empty)
val res = k(newCtx)
val ec = newCtx.extrCtx
assert(constrainedTypes || newCtx.extrCtx.isEmpty)
Expand Down Expand Up @@ -204,13 +204,13 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
mthEnv = MutMap.empty,
lvl = MinLevel,
quoteSkolemEnv = MutMap.empty,
freeVarsInCurrentQuote = MutSet.empty,
freeVarsInCurrentQuote = LinkedHashSet.empty,
inQuote = false,
inPattern = false,
tyDefs = Map.from(builtinTypes.map(t => t.nme.name -> t)),
tyDefs2 = MutMap.empty,
inRecursiveDef = N,
MutMap.empty,
MutSortMap.empty,
)
def init: Ctx = if (!newDefs) initBase else {
val res = initBase.copy(
Expand Down Expand Up @@ -1943,8 +1943,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
case tv: TypeVariable if stopAtTyVars => tv.asTypeVar
case tv: TypeVariable => ectx.tps.getOrElse(tv, {
val nv = tv.asTypeVar
if (!seenVars(tv)) {
seenVars += tv
if (seenVars.add(tv)) {
tv.assignedTo match {
case S(ty) =>
val b = go(ty)
Expand Down Expand Up @@ -2017,13 +2016,18 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
newBounds.iterator.flatMap(_._2.freeTypeVariables)
val fvars = qvars.filter(tv => ftvs.contains(tv.asTypeVar))
if (fvars.isEmpty) b else
PolyType(fvars.map(_.asTypeVar pipe (R(_))).toList, b)
PolyType(fvars
.toArray.sorted
.map(_.asTypeVar pipe (R(_))).toList, b)
case ConstrainedType(cs, bod) =>
val (ubs, others1) = cs.groupMap(_._1)(_._2).toList.partition(_._2.sizeIs > 1)
val lbs = others1.mapValues(_.head).groupMap(_._2)(_._1).toList
val groups1, groups2 = LinkedHashMap.empty[ST, Buffer[ST]]
cs.foreach { case (lo, hi) => groups1.getOrElseUpdate(lo, Buffer.empty) += hi }
val (ubs, others1) = groups1.toList.partition(_._2.sizeIs > 1)
others1.foreach { case (k, vs) => groups2.getOrElseUpdate(vs.head, Buffer.empty) += k }
val lbs = groups2.toList
val bounds = (ubs.mapValues(_.reduce(_ &- _)) ++ lbs.mapValues(_.reduce(_ | _)).map(_.swap))
val procesased = bounds.map { case (lo, hi) => Bounds(go(lo), go(hi)) }
Constrained(go(bod), Nil, procesased)
val processed = bounds.map { case (lo, hi) => Bounds(go(lo), go(hi)) }
Constrained(go(bod), Nil, processed)

// case DeclType(lvl, info) =>

Expand Down
21 changes: 21 additions & 0 deletions shared/src/main/scala/mlscript/TyperDatatypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,27 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>

sealed abstract class BaseTypeOrTag extends SimpleType
sealed abstract class BaseType extends BaseTypeOrTag {
def compareEquiv(that: BaseType): Int = (this, that) match {
case (a: TypeTag, b: TypeTag) => a.compare(b)
case (a: TypeTag, _) => -1
case (_, b: TypeTag) => 1
case (_: FunctionType, _: FunctionType) => 0
case (_: FunctionType, _) => -1
case (_, _: FunctionType) => 1
case (_: ArrayType, _: ArrayType) => 0
case (_: ArrayType, _) => -1
case (_, _: ArrayType) => 1
case (_: TupleType, _: TupleType) => 0
case (_: TupleType, _) => -1
case (_, _: TupleType) => 1
case (_: Without, _: Without) => 0
case (_: Without, _) => -1
case (_, _: Without) => 1
case (_: Overload, _: Overload) => 0
case (_: Overload, _) => -1
case (_, _: Overload) => 1
case (_: SpliceType, _: SpliceType) => 0
}
def toRecord: RecordType = RecordType.empty
protected def freshenAboveImpl(lim: Int, rigidify: Bool)(implicit ctx: Ctx, freshened: MutMap[TV, ST]): BaseType
override def freshenAbove(lim: Int, rigidify: Bool)(implicit ctx: Ctx, freshened: MutMap[TV, ST]): BaseType =
Expand Down
2 changes: 1 addition & 1 deletion shared/src/main/scala/mlscript/TyperHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ abstract class TyperHelpers { Typer: Typer =>
def rebuild(cs: Ls[Ls[ST]]): ST =
cs.iterator.map(_.foldLeft(TopType: ST)(_ & _)).foldLeft(BotType: ST)(_ | _)
if (cs.sizeCompare(1) <= 0) return rebuild(cs)
val factors = MutMap.empty[Factorizable, Int]
val factors = LinkedHashMap.empty[Factorizable, Int]
cs.foreach { c =>
c.foreach {
case tv: TV =>
Expand Down
Loading

0 comments on commit 1be5a14

Please sign in to comment.