Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiments into infering tracked #21628

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ object Feature:
* feature is defined.
*/
def enabled(feature: TermName)(using Context): Boolean =
enabledBySetting(feature) || enabledByImport(feature)
enabledBySetting(feature) || enabledByImport(feature) || feature == modularity

/** Is auto-tupling enabled? */
def autoTuplingEnabled(using Context): Boolean = !enabled(nme.noAutoTupling)
Expand Down
84 changes: 75 additions & 9 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -877,16 +877,16 @@ class Namer { typer: Typer =>
protected def addAnnotations(sym: Symbol): Unit = original match {
case original: untpd.MemberDef =>
lazy val annotCtx = annotContext(original, sym)
original.setMods:
original.setMods:
original.mods.withAnnotations :
original.mods.annotations.mapConserve: annotTree =>
original.mods.annotations.mapConserve: annotTree =>
val cls = typedAheadAnnotationClass(annotTree)(using annotCtx)
if (cls eq sym)
report.error(em"An annotation class cannot be annotated with iself", annotTree.srcPos)
annotTree
else
val ann =
if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass)
val ann =
if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass)
else annotTree
val ann1 = Annotation.deferred(cls)(typedAheadExpr(ann)(using annotCtx))
sym.addAnnotation(ann1)
Expand Down Expand Up @@ -1545,6 +1545,8 @@ class Namer { typer: Typer =>
case completer: Completer => completer.indexConstructor(constr, constrSym)
case _ =>

// constrSym.info = typeSig(constrSym)

tempInfo = denot.asClass.classInfo.integrateOpaqueMembers.asInstanceOf[TempClassInfo]
denot.info = savedInfo
}
Expand Down Expand Up @@ -1653,7 +1655,8 @@ class Namer { typer: Typer =>
case tp: MethodOrPoly => Method | Synthetic | Deferred | Tracked
case _ if name.isTermName => Synthetic | Deferred | Tracked
case _ => Synthetic | Deferred
refinedSyms += newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered
val s = newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered
refinedSyms += s
if refinedSyms.nonEmpty then
typr.println(i"parent refinement symbols: ${refinedSyms.toList}")
original.pushAttachment(ParentRefinements, refinedSyms.toList)
Expand Down Expand Up @@ -1928,7 +1931,7 @@ class Namer { typer: Typer =>
val mt = wrapMethType(effectiveResultType(sym, paramSymss))
if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner)
mt
else if sym.isAllOf(Given | Method) && Feature.enabled(modularity) then
else if Feature.enabled(modularity) then
// set every context bound evidence parameter of a given companion method
// to be tracked, provided it has a type that has an abstract type member.
// Add refinements for all tracked parameters to the result type.
Expand Down Expand Up @@ -1986,14 +1989,77 @@ class Namer { typer: Typer =>
cls.srcPos)
case _ =>

/** Under x.modularity, we add `tracked` to context bound witnesses
* that have abstract type members
/** Try to infer if the parameter needs a `tracked` modifier
*/
def needsTracked(sym: Symbol, param: ValDef)(using Context) =
!sym.is(Tracked)
&& param.hasAttachment(ContextBoundParam)
&& sym.isTerm
&& sym.maybeOwner.isPrimaryConstructor
// && !sym.flags.is(Synthetic)
// && !sym.maybeOwner.flags.is(Synthetic)
&& !sym.maybeOwner.maybeOwner.flags.is(Synthetic)
&& (
isContextBoundWitnessWithAbstractMembers(sym, param)
|| isReferencedInPublicSignatures(sym)
|| isPassedToTrackedParentParameter(sym, param)
)

/** Under x.modularity, we add `tracked` to context bound witnesses
* that have abstract type members
*/
def isContextBoundWitnessWithAbstractMembers(sym: Symbol, param: ValDef)(using Context): Boolean =
param.hasAttachment(ContextBoundParam)
&& sym.info.memberNames(abstractTypeNameFilter).nonEmpty

/** Under x.modularity, we add `tracked` to term parameters whose types are referenced
* in public signatures of the defining class
*/
def isReferencedInPublicSignatures(sym: Symbol)(using Context): Boolean =
val owner = sym.maybeOwner.maybeOwner
val accessorSyms = maybeParamAccessors(owner, sym)
def checkOwnerMemberSignatures(owner: Symbol): Boolean =
owner.infoOrCompleter match
case info: ClassInfo =>
info.decls.filter(_.isTerm).filter(_.isPublic)
.filter(_ != sym.maybeOwner)
.exists(d => tpeContainsSymbolRef(d.info, accessorSyms))
case _ => false
checkOwnerMemberSignatures(owner)

def isPassedToTrackedParentParameter(sym: Symbol, param: ValDef)(using Context): Boolean =
// TODO(kπ) Add tracked if the param is passed as a tracked arg in parent. Can we touch the inheritance terms?
val owner = sym.maybeOwner.maybeOwner
val accessorSyms = maybeParamAccessors(owner, sym)
owner.infoOrCompleter match
// case info: ClassInfo =>
// info.parents.foreach(println)
// info.parents.exists(tpeContainsSymbolRef(_, accessorSyms))
case _ => false

private def namedTypeWithPrefixContainsSymbolRef(tpe: Type, syms: List[Symbol])(using Context): Boolean = tpe match
case tpe: NamedType => tpe.prefix.exists && tpeContainsSymbolRef(tpe.prefix, syms)
case _ => false

private def tpeContainsSymbolRef(tpe0: Type, syms: List[Symbol])(using Context): Boolean =
val tpe = tpe0.dropAlias.safeDealias
tpe match
case ExprType(resType) => tpeContainsSymbolRef(resType, syms)
case m : MethodOrPoly =>
m.paramInfos.exists(tpeContainsSymbolRef(_, syms))
|| tpeContainsSymbolRef(m.resultType, syms)
case r @ RefinedType(parent, _, refinedInfo) => tpeContainsSymbolRef(parent, syms) || tpeContainsSymbolRef(refinedInfo, syms)
case TypeBounds(lo, hi) => tpeContainsSymbolRef(lo, syms) || tpeContainsSymbolRef(hi, syms)
case t: Type =>
tpe.termSymbol.exists && syms.contains(tpe.termSymbol)
|| tpe.argInfos.exists(tpeContainsSymbolRef(_, syms))
|| namedTypeWithPrefixContainsSymbolRef(tpe, syms)

private def maybeParamAccessors(owner: Symbol, sym: Symbol)(using Context): List[Symbol] =
owner.infoOrCompleter match
case info: ClassInfo =>
info.decls.lookupAll(sym.name).filter(d => d.is(ParamAccessor)).toList
case _ => List.empty

/** Under x.modularity, set every context bound evidence parameter of a class to be tracked,
* provided it has a type that has an abstract type member. Reset private and local flags
* so that the parameter becomes a `val`.
Expand Down
34 changes: 34 additions & 0 deletions tests/pos/infer-tracked-1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import scala.language.experimental.modularity
import scala.language.future

trait Ordering {
type T
def compare(t1:T, t2: T): Int
}

class SetFunctor(val ord: Ordering) {
type Set = List[ord.T]
def empty: Set = Nil

implicit class helper(s: Set) {
def add(x: ord.T): Set = x :: remove(x)
def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0)
def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0)
}
}

object Test {
val orderInt = new Ordering {
type T = Int
def compare(t1: T, t2: T): Int = t1 - t2
}

val IntSet = new SetFunctor(orderInt)
import IntSet.*

def main(args: Array[String]) = {
val set = IntSet.empty.add(6).add(8).add(23)
assert(!set.member(7))
assert(set.member(8))
}
}
65 changes: 65 additions & 0 deletions tests/pos/infer-tracked-parsercombinators-expanded.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import scala.language.experimental.modularity
import scala.language.future

import collection.mutable

/// A parser combinator.
trait Combinator[T]:

/// The context from which elements are being parsed, typically a stream of tokens.
type Context
/// The element being parsed.
type Element

extension (self: T)
/// Parses and returns an element from `context`.
def parse(context: Context): Option[Element]
end Combinator

final case class Apply[C, E](action: C => Option[E])
final case class Combine[A, B](first: A, second: B)

object test:

class apply[C, E] extends Combinator[Apply[C, E]]:
type Context = C
type Element = E
extension(self: Apply[C, E])
def parse(context: C): Option[E] = self.action(context)

def apply[C, E]: apply[C, E] = new apply[C, E]

class combine[A, B](
val f: Combinator[A],
val s: Combinator[B] { type Context = f.Context}
) extends Combinator[Combine[A, B]]:
type Context = f.Context
type Element = (f.Element, s.Element)
extension(self: Combine[A, B])
def parse(context: Context): Option[Element] = ???

def combine[A, B](
_f: Combinator[A],
_s: Combinator[B] { type Context = _f.Context}
) = new combine[A, B](_f, _s)
// cast is needed since the type of new combine[A, B](_f, _s)
// drops the required refinement.

extension [A] (buf: mutable.ListBuffer[A]) def popFirst() =
if buf.isEmpty then None
else try Some(buf.head) finally buf.remove(0)

@main def hello: Unit = {
val source = (0 to 10).toList
val stream = source.to(mutable.ListBuffer)

val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
val m = Combine(n, n)

val c = combine(
apply[mutable.ListBuffer[Int], Int],
apply[mutable.ListBuffer[Int], Int]
)
val r = c.parse(m)(stream) // was type mismatch, now OK
val rc: Option[(Int, Int)] = r
}
56 changes: 56 additions & 0 deletions tests/pos/infer-tracked-parsercombinators-givens.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import scala.language.experimental.modularity
import scala.language.future

import collection.mutable

/// A parser combinator.
trait Combinator[T]:

/// The context from which elements are being parsed, typically a stream of tokens.
type Context
/// The element being parsed.
type Element

extension (self: T)
/// Parses and returns an element from `context`.
def parse(context: Context): Option[Element]
end Combinator

final case class Apply[C, E](action: C => Option[E])
final case class Combine[A, B](first: A, second: B)

given apply[C, E]: Combinator[Apply[C, E]] with {
type Context = C
type Element = E
extension(self: Apply[C, E]) {
def parse(context: C): Option[E] = self.action(context)
}
}

// TODO(kπ) infer tracked correctly here
given combine[A, B](using
tracked val f: Combinator[A],
tracked val s: Combinator[B] { type Context = f.Context }
): Combinator[Combine[A, B]] with {
type Context = f.Context
type Element = (f.Element, s.Element)
extension(self: Combine[A, B]) {
def parse(context: Context): Option[Element] = ???
}
}

extension [A] (buf: mutable.ListBuffer[A]) def popFirst() =
if buf.isEmpty then None
else try Some(buf.head) finally buf.remove(0)

@main def hello: Unit = {
val source = (0 to 10).toList
val stream = source.to(mutable.ListBuffer)

val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
val m = Combine(n, n)

val r = m.parse(stream) // error: type mismatch, found `mutable.ListBuffer[Int]`, required `?1.Context`
val rc: Option[(Int, Int)] = r
// it would be great if this worked
}
82 changes: 82 additions & 0 deletions tests/pos/infer-tracked-vector.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import scala.language.experimental.modularity
import scala.language.future

object typeparams:
sealed trait Nat
object Z extends Nat
final case class S[N <: Nat]() extends Nat

type Zero = Z.type
type Succ[N <: Nat] = S[N]

sealed trait Fin[N <: Nat]
case class FZero[N <: Nat]() extends Fin[Succ[N]]
case class FSucc[N <: Nat](pred: Fin[N]) extends Fin[Succ[N]]

object Fin:
def zero[N <: Nat]: Fin[Succ[N]] = FZero()
def succ[N <: Nat](i: Fin[N]): Fin[Succ[N]] = FSucc(i)

sealed trait Vec[A, N <: Nat]
case class VNil[A]() extends Vec[A, Zero]
case class VCons[A, N <: Nat](head: A, tail: Vec[A, N]) extends Vec[A, Succ[N]]

object Vec:
def empty[A]: Vec[A, Zero] = VNil()
def cons[A, N <: Nat](head: A, tail: Vec[A, N]): Vec[A, Succ[N]] = VCons(head, tail)

def get[A, N <: Nat](v: Vec[A, N], index: Fin[N]): A = (v, index) match
case (VCons(h, _), FZero()) => h
case (VCons(_, t), FSucc(pred)) => get(t, pred)

def runVec(): Unit =
val v: Vec[Int, Succ[Succ[Succ[Zero]]]] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty)))

println(s"Element at index 0: ${Vec.get(v, Fin.zero)}")
println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}")
println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}")
// println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}") // error

// TODO(kπ) check if I can get it to work
// object typemembers:
// sealed trait Nat
// object Z extends Nat
// case class S() extends Nat:
// type N <: Nat

// type Zero = Z.type
// type Succ[N1 <: Nat] = S { type N = N1 }

// sealed trait Fin:
// type N <: Nat
// case class FZero[N1 <: Nat]() extends Fin:
// type N = Succ[N1]
// case class FSucc(tracked val pred: Fin) extends Fin:
// type N = Succ[pred.N]

// object Fin:
// def zero[N1 <: Nat]: Fin { type N = Succ[N1] } = FZero[N1]()
// def succ[N1 <: Nat](i: Fin { type N = N1 }): Fin { type N = Succ[N1] } = FSucc(i)

// sealed trait Vec[A]:
// type N <: Nat
// case class VNil[A]() extends Vec[A]:
// type N = Zero
// case class VCons[A](head: A, tracked val tail: Vec[A]) extends Vec[A]:
// type N = Succ[tail.N]

// object Vec:
// def empty[A]: Vec[A] = VNil()
// def cons[A](head: A, tail: Vec[A]): Vec[A] = VCons(head, tail)

// def get[A](v: Vec[A], index: Fin { type N = v.N }): A = (v, index) match
// case (VCons(h, _), FZero()) => h
// case (VCons(_, t), FSucc(pred)) => get(t, pred)

// // def runVec(): Unit =
// val v: Vec[Int] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty)))

// println(s"Element at index 0: ${Vec.get(v, Fin.zero)}")
// println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}")
// println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}")
// // println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}")
Loading
Loading