Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Selection #219

Open
wants to merge 2 commits into
base: mlscript
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
val nTerms = termList.map(liftTerm(_)(using emptyCtx, nCache, globFuncs, nOuter)).unzip
clsList.foreach(x => liftTypeDef(x)(using nCache, globFuncs, nOuter))
retSeq = retSeq.appended(NuTypeDef(
kind, nName, nTps.map((None, _)), kind match
kind, nName, nTps.map((TypeParamInfo(None, false, N, N), _)), kind match
case Mod => None
case _ => S(Tup(nParams))
, None, None, nPars._1, None, None, TypingUnit(nFuncs._1 ++ nTerms._1))(None, None, Nil))
Expand Down
29 changes: 16 additions & 13 deletions shared/src/main/scala/mlscript/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@ class ConstraintSolver extends NormalForms { self: Typer =>
ErrorReport(
msg"${info.decl.kind.str.capitalize} `${info.decl.name}` does not contain member `${fld.name}`" -> fld.toLoc :: Nil, newDefs)

def lookupMember(clsNme: Str, rfnt: Var => Opt[FieldType], fld: Var)
(implicit ctx: Ctx, raise: Raise)
def lookupMember(clsNme: Str, rfnt: Var => Opt[FieldType], fld: Var)(implicit ctx: Ctx, raise: Raise)
: Either[Diagnostic, NuMember]
= {
val info = ctx.tyDefs2.getOrElse(clsNme, ???/*TODO*/)

= ctx.tyDefs2.get(clsNme).toRight(ErrorReport(msg"Cannot find class ${clsNme}" -> N :: Nil, newDefs)) flatMap { info =>
if (info.isComputing) {

??? // TODO support?
Expand Down Expand Up @@ -109,7 +106,7 @@ class ConstraintSolver extends NormalForms { self: Typer =>
Nil)
S(p.ty)
case S(m) =>
S(err(msg"Access to ${m.kind.str} member not yet supported", fld.toLoc).toUpper(noProv))
S(err(msg"Access to ${m.kind.str} member ${fld.name} not yet supported", fld.toLoc).toUpper(noProv))
case N => N
}

Expand All @@ -131,17 +128,20 @@ class ConstraintSolver extends NormalForms { self: Typer =>
implicit val shadows: Shadows = Shadows.empty

info.tparams.foreach { case (tn, _tv, vi) =>
val targ = rfnt(Var(info.decl.name + "#" + tn.name)) match {
val targ = rfnt(tparamField(TypeName(info.decl.name), tn, vi.visible)) match {
// * TODO to avoid infinite recursion due to ever-expanding type args,
// * we should set the shadows of the targ to be the same as that of the parameter it replaces...
case S(fty) if vi === S(VarianceInfo.co) => fty.ub
case S(fty) if vi === S(VarianceInfo.contra) => fty.lb.getOrElse(BotType)
case S(fty) if vi.varinfo === S(VarianceInfo.co) =>
println(s"Lookup: Found $fty")
fty.ub
case S(fty) if vi.varinfo === S(VarianceInfo.contra) =>
println(s"Lookup: Found $fty")
fty.lb.getOrElse(BotType)
case S(fty) =>
TypeBounds.mk(
fty.lb.getOrElse(BotType),
fty.ub,
)
println(s"Lookup: Found $fty")
TypeBounds.mk(fty.lb.getOrElse(BotType), fty.ub)
case N =>
println(s"Lookup: field not found, creating new bounds")
TypeBounds(
// _tv.lowerBounds.foldLeft(BotType: ST)(_ | _),
// _tv.upperBounds.foldLeft(TopType: ST)(_ & _),
Expand Down Expand Up @@ -192,6 +192,9 @@ class ConstraintSolver extends NormalForms { self: Typer =>
}()


private val DummyTV: TV = freshVar(noProv, N, S("<DUMMY>"), Nil, Nil, false)(-1)


// * Each type has a shadow which identifies all variables created from copying
// * variables that existed at the start of constraining.
// * The intent is to make the total number of shadows in a given constraint
Expand Down
5 changes: 3 additions & 2 deletions shared/src/main/scala/mlscript/NewLexer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ object NewLexer {
"val",
"var",
// "is",
// "as",
"as",
"of",
// "and",
// "or",
Expand Down Expand Up @@ -498,7 +498,8 @@ object NewLexer {
"undefined",
"abstract",
"constructor",
"virtual"
"virtual",
"restricts",
)

def printToken(tl: TokLoc): Str = tl match {
Expand Down
51 changes: 39 additions & 12 deletions shared/src/main/scala/mlscript/NewParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,9 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], newDefs: Bo
case (KEYWORD("super"), l0) :: _ =>
consume
exprCont(Super().withLoc(S(l0)), prec, allowNewlines = false)
case (IDENT("?", true), l0) :: _ =>
consume
exprCont(Var("?").withLoc(S(l0)), prec, allowNewlines = false)
case (IDENT("~", _), l0) :: _ =>
consume
val rest = expr(prec, allowSpace = true)
Expand Down Expand Up @@ -1080,7 +1083,7 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], newDefs: Bo
else App(App(v, PlainTup(acc)), PlainTup(rhs))
}, prec, allowNewlines)
}
case (KEYWORD(":"), l0) :: _ if prec <= NewParser.prec(':') =>
case (KEYWORD("as" | ":"), l0) :: _ if prec <= NewParser.prec(':') =>
consume
R(Asc(acc, typ(0)))
case (KEYWORD("where"), l0) :: _ if prec <= 1 =>
Expand Down Expand Up @@ -1278,35 +1281,59 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], newDefs: Bo
}

// TODO support line-broken param lists; share logic with args/argsOrIf
def typeParams(implicit fe: FoundErr, et: ExpectThen): Ls[(Opt[VarianceInfo], TypeName)] = {
def typeParams(implicit fe: FoundErr, et: ExpectThen): Ls[(TypeParamInfo, TypeName)] = {
val visinfo = yeetSpaces match {
case (KEYWORD("type"), l0) :: _ =>
consume
S(l0)
case _ => N
}
val vinfo = yeetSpaces match {
case (KEYWORD("in"), l0) :: (KEYWORD("out"), l1) :: _ =>
consume
S(VarianceInfo.in, l0 ++ l1)
S(VarianceInfo.in -> (l0++l1))
case (KEYWORD("in"), l0) :: _ =>
consume
S(VarianceInfo.contra, l0)
S(VarianceInfo.contra -> l0)
case (KEYWORD("out"), l0) :: _ =>
consume
S(VarianceInfo.co, l0)
case _ => N
S(VarianceInfo.co -> l0)
case _ =>
N
}
yeetSpaces match {
case (IDENT(nme, false), l0) :: _ =>
consume
val tyNme = TypeName(nme).withLoc(S(l0))

@inline def getTypeName(kw: String) = yeetSpaces match {
case (KEYWORD(k), l0) :: _ if k === kw => consume
yeetSpaces match {
case (IDENT(nme, false), l1) :: _ =>
consume; S(TypeName(nme).withLoc(S(l1)))
case _ => err(msg"dangling $kw keyword" -> S(l0) :: Nil); N
}
case _ => N
}
val lb = getTypeName("restricts")
val ub = getTypeName("extends")
// TODO update `TypeParamInfo` to use lb and ub
yeetSpaces match {
case (COMMA, l0) :: _ =>
consume
vinfo.map(_._1) -> tyNme :: typeParams
TypeParamInfo(vinfo.map(_._1), visinfo.isDefined, lb, ub) -> tyNme :: typeParams
case _ =>
vinfo.map(_._1) -> tyNme :: Nil
TypeParamInfo(vinfo.map(_._1), visinfo.isDefined, lb, ub) -> tyNme :: Nil
}
case _ =>
vinfo match {
case S((_, loc)) =>
(visinfo, vinfo) match {
case (S(l1), S(_ -> l2)) =>
err(msg"dangling type member and variance information" -> S(l1 ++ l2) :: Nil)
case (_, S(_ -> loc)) =>
err(msg"dangling variance information" -> S(loc) :: Nil)
case N =>
case (S(loc), _) =>
err(msg"dangling visible type member" -> S(loc) :: Nil)
case (N, N) =>
}
Nil
}
Expand Down Expand Up @@ -1373,7 +1400,7 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], newDefs: Bo
case (NEWLINE, _) :: _ => // TODO: | ...
assert(seqAcc.isEmpty)
acc.reverse
case (IDENT(nme, true), _) :: _ if nme =/= "-" => // TODO: | ...
case (IDENT(nme, true), _) :: _ if nme =/= "-" && nme =/= "?" => // TODO: | ...
assert(seqAcc.isEmpty)
acc.reverse
case _ =>
Expand Down
38 changes: 22 additions & 16 deletions shared/src/main/scala/mlscript/NuTypeDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>


type Params = Ls[Var -> FieldType]
type TyParams = Ls[(TN, TV, Opt[VarianceInfo])]
type TyParams = Ls[(TN, TV, TypeParamInfo)]


sealed abstract class NuDeclInfo
Expand Down Expand Up @@ -174,8 +174,9 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>

// TODO dedup with the one in TypedNuCls
lazy val virtualMembers: Map[Str, NuMember] = members ++ tparams.map {
case (nme @ TypeName(name), tv, _) =>
td.nme.name+"#"+name -> NuParam(nme, FieldType(S(tv), tv)(provTODO), isPublic = true)(level)
case (nme @ TypeName(name), tv, TypeParamInfo(_, v, _, _)) =>
tparamField(td.nme, nme, v).name ->
NuParam(nme, FieldType(S(tv), tv)(provTODO), isPublic = true)(level)
} ++ parentTP

def freshenAbove(lim: Int, rigidify: Bool)
Expand Down Expand Up @@ -239,8 +240,8 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>

/** Includes class-name-coded type parameter fields. */
lazy val virtualMembers: Map[Str, NuMember] = members ++ tparams.map {
case (nme @ TypeName(name), tv, _) =>
td.nme.name+"#"+name -> NuParam(nme, FieldType(S(tv), tv)(provTODO), isPublic = true)(level)
case (nme @ TypeName(name), tv, TypeParamInfo(_, v, _, _)) =>
tparamField(td.nme, nme, v).name -> NuParam(nme, FieldType(S(tv), tv)(provTODO), isPublic = true)(level)
} ++ parentTP

// TODO
Expand Down Expand Up @@ -283,7 +284,7 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
}

// TODO check consistency with explicitVariances
val res = store ++ tparams.iterator.collect { case (_, tv, S(vi)) => tv -> vi }
val res = store ++ tparams.iterator.collect { case (_, tv, TypeParamInfo(S(vi), _, _, _)) => tv -> vi }

_variances = S(res)

Expand Down Expand Up @@ -356,9 +357,9 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
def isPublic = true // TODO

lazy val virtualMembers: Map[Str, NuMember] = members ++ tparams.map {
case (nme @ TypeName(name), tv, _) =>
td.nme.name+"#"+name -> NuParam(nme, FieldType(S(tv), tv)(provTODO), isPublic = false)(level)
}
case (nme @ TypeName(name), tv, TypeParamInfo(_, v, _, _)) =>
tparamField(td.nme, nme, v).name -> NuParam(nme, FieldType(S(tv), tv)(provTODO), isPublic = false)(level)
}

def freshenAbove(lim: Int, rigidify: Bool)
(implicit ctx: Ctx, freshened: MutMap[TV,ST])
Expand Down Expand Up @@ -929,19 +930,23 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
lazy val tparams: TyParams = ctx.nest.nextLevel { implicit ctx =>
decl match {
case td: NuTypeDef =>
td.tparams.map(tp =>
(tp._2, freshVar(
td.tparams.map(tp => {
val fv = freshVar(
TypeProvenance(tp._2.toLoc, "type parameter",
S(tp._2.name),
isType = true),
N, S(tp._2.name)), tp._1))
N, S(tp._2.name))
// TODO assign the correct bounds (`typeType` causes cycle)
// fv.lowerBounds = tp._1.lb.toList.map(TypeRef(_, Nil)(provTODO))
// fv.upperBounds = tp._1.ub.toList.map(TypeRef(_, Nil)(provTODO))
(tp._2, fv, tp._1) })
case fd: NuFunDef =>
fd.tparams.map { tn =>
(tn, freshVar(
TypeProvenance(tn.toLoc, "method type parameter",
originName = S(tn.name),
isType = true),
N, S(tn.name)), N)
N, S(tn.name)), TypeParamInfo(N, false, N, N))
}
}
}
Expand All @@ -950,7 +955,7 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
}

lazy val explicitVariances: VarianceStore =
MutMap.from(tparams.iterator.map(tp => tp._2 -> tp._3.getOrElse(VarianceInfo.in)))
MutMap.from(tparams.iterator.map(tp => tp._2 -> tp._3.varinfo.getOrElse(VarianceInfo.in)))

def varianceOf(tv: TV)(implicit ctx: Ctx): VarianceInfo =
// TODO make use of inferred vce if result is completed
Expand Down Expand Up @@ -1541,7 +1546,7 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
val finalType = thisTV

val tparamMems = tparams.map { case (tp, tv, vi) => // TODO use vi
val fldNme = td.nme.name + "#" + tp.name
val fldNme = tparamField(td.nme.name, tp.name, vi.visible)
val skol = SkolemTag(tv)(tv.prov)
NuParam(TypeName(fldNme).withLocOf(tp), FieldType(S(skol), skol)(tv.prov), isPublic = true)(lvl)
}
Expand Down Expand Up @@ -1921,7 +1926,8 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>
tv
})
freshened += _tv -> tv
rawName+"#"+tn.name -> NuParam(tn, FieldType(S(tv), tv)(provTODO), isPublic = true)(ctx.lvl)
tparamField(rawName, tn.name, vi.visible) ->
NuParam(tn, FieldType(S(tv), tv)(provTODO), isPublic = true)(ctx.lvl)
}

freshened -> parTP.toMap
Expand Down
11 changes: 7 additions & 4 deletions shared/src/main/scala/mlscript/TypeDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ class TypeDefs extends NuTypeDefs { Typer: Typer =>
}
}


def tparamField(clsNme: TypeName, tparamNme: TypeName): Var =
Var(clsNme.name + "#" + tparamNme.name)
def tparamField(clsNme: TypeName, tparamNme: TypeName, visible: Bool): Var =
Var(tparamField(clsNme.name, tparamNme.name, visible))

def tparamField(clsNme: String, tparamNme: String, visible: Bool): String =
if (!visible) clsNme + "#" + tparamNme else tparamNme

def clsNameToNomTag(td: NuTypeDef)(prov: TypeProvenance, ctx: Ctx): ClassTag = {
require((td.kind is Cls) || (td.kind is Mod), td.kind)
Expand Down Expand Up @@ -343,7 +345,8 @@ class TypeDefs extends NuTypeDefs { Typer: Typer =>
case _ =>
val fields = fieldsOf(td.bodyTy, paramTags = true)
val tparamTags = td.tparamsargs.map { case (tp, tv) =>
tparamField(td.nme, tp) -> FieldType(Some(tv), tv)(tv.prov) }
// `false` means using `C#A` (old type member names)
tparamField(td.nme, tp, false) -> FieldType(Some(tv), tv)(tv.prov) }
val ctor = k match {
case Cls =>
val nomTag = clsNameToNomTag(td)(originProv(td.nme.toLoc, "class", td.nme.name), ctx)
Expand Down
6 changes: 3 additions & 3 deletions shared/src/main/scala/mlscript/TypeSimplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ trait TypeSimplifier { self: Typer =>
else v -> default :: Nil
}
case S(trt: TypedNuTrt) => // TODO factor w/ above & generalize
trt.tparams.iterator.find(_._1.name === postfix).flatMap(_._3).getOrElse(VarianceInfo.in) match {
trt.tparams.iterator.find(_._1.name === postfix).flatMap(_._3.varinfo).getOrElse(VarianceInfo.in) match {
case VarianceInfo(true, true) => Nil
case VarianceInfo(co, contra) =>
if (co) v -> FieldType(S(BotType), process(fty.ub, N))(fty.prov) :: Nil
Expand Down Expand Up @@ -248,7 +248,7 @@ trait TypeSimplifier { self: Typer =>

// * Reconstruct a TypeRef from its current structural components
val typeRef = TypeRef(td.nme, td.tparamsargs.zipWithIndex.map { case ((tp, tv), tpidx) =>
val fieldTagNme = tparamField(clsTyNme, tp)
val fieldTagNme = tparamField(clsTyNme, tp, false)
val fromTyRef = trs2.get(clsTyNme).map(_.targs(tpidx) |> { ta => FieldType(S(ta), ta)(noProv) })
fromTyRef.++(rcd2.fields.iterator.filter(_._1 === fieldTagNme).map(_._2))
.foldLeft((BotType: ST, TopType: ST)) {
Expand Down Expand Up @@ -358,7 +358,7 @@ trait TypeSimplifier { self: Typer =>

// * Reconstruct a TypeRef from its current structural components
val typeRef = TypeRef(cls.td.nme, cls.tparams.zipWithIndex.map { case ((tp, tv, vi), tpidx) =>
val fieldTagNme = tparamField(clsTyNme, tp)
val fieldTagNme = tparamField(clsTyNme, tp, vi.visible)
val fromTyRef = trs2.get(clsTyNme).map(_.targs(tpidx) |> { ta => FieldType(S(ta), ta)(noProv) })
fromTyRef.++(rcd2.fields.iterator.filter(_._1 === fieldTagNme).map(_._2))
.foldLeft((BotType: ST, TopType: ST)) {
Expand Down
Loading
Loading