diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 8b9a64924ace..776828de8f67 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -113,7 +113,7 @@ object Feature: * feature is defined. */ def enabled(feature: TermName)(using Context): Boolean = - enabledBySetting(feature) || enabledByImport(feature) + enabledBySetting(feature) || enabledByImport(feature) || feature == modularity /** Is auto-tupling enabled? */ def autoTuplingEnabled(using Context): Boolean = !enabled(nme.noAutoTupling) diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 6167db62fbe0..f927bb113a18 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -877,16 +877,16 @@ class Namer { typer: Typer => protected def addAnnotations(sym: Symbol): Unit = original match { case original: untpd.MemberDef => lazy val annotCtx = annotContext(original, sym) - original.setMods: + original.setMods: original.mods.withAnnotations : - original.mods.annotations.mapConserve: annotTree => + original.mods.annotations.mapConserve: annotTree => val cls = typedAheadAnnotationClass(annotTree)(using annotCtx) if (cls eq sym) report.error(em"An annotation class cannot be annotated with iself", annotTree.srcPos) annotTree else - val ann = - if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass) + val ann = + if cls.is(JavaDefined) then Checking.checkNamedArgumentForJavaAnnotation(annotTree, cls.asClass) else annotTree val ann1 = Annotation.deferred(cls)(typedAheadExpr(ann)(using annotCtx)) sym.addAnnotation(ann1) @@ -1545,6 +1545,8 @@ class Namer { typer: Typer => case completer: Completer => completer.indexConstructor(constr, constrSym) case _ => + // constrSym.info = typeSig(constrSym) + tempInfo = denot.asClass.classInfo.integrateOpaqueMembers.asInstanceOf[TempClassInfo] denot.info = savedInfo } @@ -1653,7 +1655,8 @@ class Namer { typer: Typer => case tp: MethodOrPoly => Method | Synthetic | Deferred | Tracked case _ if name.isTermName => Synthetic | Deferred | Tracked case _ => Synthetic | Deferred - refinedSyms += newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered + val s = newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered + refinedSyms += s if refinedSyms.nonEmpty then typr.println(i"parent refinement symbols: ${refinedSyms.toList}") original.pushAttachment(ParentRefinements, refinedSyms.toList) @@ -1928,7 +1931,7 @@ class Namer { typer: Typer => val mt = wrapMethType(effectiveResultType(sym, paramSymss)) if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner) mt - else if sym.isAllOf(Given | Method) && Feature.enabled(modularity) then + else if Feature.enabled(modularity) then // set every context bound evidence parameter of a given companion method // to be tracked, provided it has a type that has an abstract type member. // Add refinements for all tracked parameters to the result type. @@ -1986,14 +1989,77 @@ class Namer { typer: Typer => cls.srcPos) case _ => - /** Under x.modularity, we add `tracked` to context bound witnesses - * that have abstract type members + /** Try to infer if the parameter needs a `tracked` modifier */ def needsTracked(sym: Symbol, param: ValDef)(using Context) = !sym.is(Tracked) - && param.hasAttachment(ContextBoundParam) + && sym.isTerm + && sym.maybeOwner.isPrimaryConstructor + // && !sym.flags.is(Synthetic) + // && !sym.maybeOwner.flags.is(Synthetic) + && !sym.maybeOwner.maybeOwner.flags.is(Synthetic) + && ( + isContextBoundWitnessWithAbstractMembers(sym, param) + || isReferencedInPublicSignatures(sym) + || isPassedToTrackedParentParameter(sym, param) + ) + + /** Under x.modularity, we add `tracked` to context bound witnesses + * that have abstract type members + */ + def isContextBoundWitnessWithAbstractMembers(sym: Symbol, param: ValDef)(using Context): Boolean = + param.hasAttachment(ContextBoundParam) && sym.info.memberNames(abstractTypeNameFilter).nonEmpty + /** Under x.modularity, we add `tracked` to term parameters whose types are referenced + * in public signatures of the defining class + */ + def isReferencedInPublicSignatures(sym: Symbol)(using Context): Boolean = + val owner = sym.maybeOwner.maybeOwner + val accessorSyms = maybeParamAccessors(owner, sym) + def checkOwnerMemberSignatures(owner: Symbol): Boolean = + owner.infoOrCompleter match + case info: ClassInfo => + info.decls.filter(_.isTerm).filter(_.isPublic) + .filter(_ != sym.maybeOwner) + .exists(d => tpeContainsSymbolRef(d.info, accessorSyms)) + case _ => false + checkOwnerMemberSignatures(owner) + + def isPassedToTrackedParentParameter(sym: Symbol, param: ValDef)(using Context): Boolean = + // TODO(kπ) Add tracked if the param is passed as a tracked arg in parent. Can we touch the inheritance terms? + val owner = sym.maybeOwner.maybeOwner + val accessorSyms = maybeParamAccessors(owner, sym) + owner.infoOrCompleter match + // case info: ClassInfo => + // info.parents.foreach(println) + // info.parents.exists(tpeContainsSymbolRef(_, accessorSyms)) + case _ => false + + private def namedTypeWithPrefixContainsSymbolRef(tpe: Type, syms: List[Symbol])(using Context): Boolean = tpe match + case tpe: NamedType => tpe.prefix.exists && tpeContainsSymbolRef(tpe.prefix, syms) + case _ => false + + private def tpeContainsSymbolRef(tpe0: Type, syms: List[Symbol])(using Context): Boolean = + val tpe = tpe0.dropAlias.safeDealias + tpe match + case ExprType(resType) => tpeContainsSymbolRef(resType, syms) + case m : MethodOrPoly => + m.paramInfos.exists(tpeContainsSymbolRef(_, syms)) + || tpeContainsSymbolRef(m.resultType, syms) + case r @ RefinedType(parent, _, refinedInfo) => tpeContainsSymbolRef(parent, syms) || tpeContainsSymbolRef(refinedInfo, syms) + case TypeBounds(lo, hi) => tpeContainsSymbolRef(lo, syms) || tpeContainsSymbolRef(hi, syms) + case t: Type => + tpe.termSymbol.exists && syms.contains(tpe.termSymbol) + || tpe.argInfos.exists(tpeContainsSymbolRef(_, syms)) + || namedTypeWithPrefixContainsSymbolRef(tpe, syms) + + private def maybeParamAccessors(owner: Symbol, sym: Symbol)(using Context): List[Symbol] = + owner.infoOrCompleter match + case info: ClassInfo => + info.decls.lookupAll(sym.name).filter(d => d.is(ParamAccessor)).toList + case _ => List.empty + /** Under x.modularity, set every context bound evidence parameter of a class to be tracked, * provided it has a type that has an abstract type member. Reset private and local flags * so that the parameter becomes a `val`. diff --git a/tests/pos/infer-tracked-1.scala b/tests/pos/infer-tracked-1.scala new file mode 100644 index 000000000000..b4976a963074 --- /dev/null +++ b/tests/pos/infer-tracked-1.scala @@ -0,0 +1,34 @@ +import scala.language.experimental.modularity +import scala.language.future + +trait Ordering { + type T + def compare(t1:T, t2: T): Int +} + +class SetFunctor(val ord: Ordering) { + type Set = List[ord.T] + def empty: Set = Nil + + implicit class helper(s: Set) { + def add(x: ord.T): Set = x :: remove(x) + def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0) + def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0) + } +} + +object Test { + val orderInt = new Ordering { + type T = Int + def compare(t1: T, t2: T): Int = t1 - t2 + } + + val IntSet = new SetFunctor(orderInt) + import IntSet.* + + def main(args: Array[String]) = { + val set = IntSet.empty.add(6).add(8).add(23) + assert(!set.member(7)) + assert(set.member(8)) + } +} diff --git a/tests/pos/infer-tracked-parsercombinators-expanded.scala b/tests/pos/infer-tracked-parsercombinators-expanded.scala new file mode 100644 index 000000000000..63c6aec9e84a --- /dev/null +++ b/tests/pos/infer-tracked-parsercombinators-expanded.scala @@ -0,0 +1,65 @@ +import scala.language.experimental.modularity +import scala.language.future + +import collection.mutable + +/// A parser combinator. +trait Combinator[T]: + + /// The context from which elements are being parsed, typically a stream of tokens. + type Context + /// The element being parsed. + type Element + + extension (self: T) + /// Parses and returns an element from `context`. + def parse(context: Context): Option[Element] +end Combinator + +final case class Apply[C, E](action: C => Option[E]) +final case class Combine[A, B](first: A, second: B) + +object test: + + class apply[C, E] extends Combinator[Apply[C, E]]: + type Context = C + type Element = E + extension(self: Apply[C, E]) + def parse(context: C): Option[E] = self.action(context) + + def apply[C, E]: apply[C, E] = new apply[C, E] + + class combine[A, B]( + val f: Combinator[A], + val s: Combinator[B] { type Context = f.Context} + ) extends Combinator[Combine[A, B]]: + type Context = f.Context + type Element = (f.Element, s.Element) + extension(self: Combine[A, B]) + def parse(context: Context): Option[Element] = ??? + + def combine[A, B]( + _f: Combinator[A], + _s: Combinator[B] { type Context = _f.Context} + ) = new combine[A, B](_f, _s) + // cast is needed since the type of new combine[A, B](_f, _s) + // drops the required refinement. + + extension [A] (buf: mutable.ListBuffer[A]) def popFirst() = + if buf.isEmpty then None + else try Some(buf.head) finally buf.remove(0) + + @main def hello: Unit = { + val source = (0 to 10).toList + val stream = source.to(mutable.ListBuffer) + + val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst()) + val m = Combine(n, n) + + val c = combine( + apply[mutable.ListBuffer[Int], Int], + apply[mutable.ListBuffer[Int], Int] + ) + val r = c.parse(m)(stream) // was type mismatch, now OK + val rc: Option[(Int, Int)] = r + } diff --git a/tests/pos/infer-tracked-parsercombinators-givens.scala b/tests/pos/infer-tracked-parsercombinators-givens.scala new file mode 100644 index 000000000000..8bb514c8a75a --- /dev/null +++ b/tests/pos/infer-tracked-parsercombinators-givens.scala @@ -0,0 +1,56 @@ +import scala.language.experimental.modularity +import scala.language.future + +import collection.mutable + +/// A parser combinator. +trait Combinator[T]: + + /// The context from which elements are being parsed, typically a stream of tokens. + type Context + /// The element being parsed. + type Element + + extension (self: T) + /// Parses and returns an element from `context`. + def parse(context: Context): Option[Element] +end Combinator + +final case class Apply[C, E](action: C => Option[E]) +final case class Combine[A, B](first: A, second: B) + +given apply[C, E]: Combinator[Apply[C, E]] with { + type Context = C + type Element = E + extension(self: Apply[C, E]) { + def parse(context: C): Option[E] = self.action(context) + } +} + +// TODO(kπ) infer tracked correctly here +given combine[A, B](using + tracked val f: Combinator[A], + tracked val s: Combinator[B] { type Context = f.Context } +): Combinator[Combine[A, B]] with { + type Context = f.Context + type Element = (f.Element, s.Element) + extension(self: Combine[A, B]) { + def parse(context: Context): Option[Element] = ??? + } +} + +extension [A] (buf: mutable.ListBuffer[A]) def popFirst() = + if buf.isEmpty then None + else try Some(buf.head) finally buf.remove(0) + +@main def hello: Unit = { + val source = (0 to 10).toList + val stream = source.to(mutable.ListBuffer) + + val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst()) + val m = Combine(n, n) + + val r = m.parse(stream) // error: type mismatch, found `mutable.ListBuffer[Int]`, required `?1.Context` + val rc: Option[(Int, Int)] = r + // it would be great if this worked +} diff --git a/tests/pos/infer-tracked-vector.scala b/tests/pos/infer-tracked-vector.scala new file mode 100644 index 000000000000..e748dc9cbe8e --- /dev/null +++ b/tests/pos/infer-tracked-vector.scala @@ -0,0 +1,82 @@ +import scala.language.experimental.modularity +import scala.language.future + +object typeparams: + sealed trait Nat + object Z extends Nat + final case class S[N <: Nat]() extends Nat + + type Zero = Z.type + type Succ[N <: Nat] = S[N] + + sealed trait Fin[N <: Nat] + case class FZero[N <: Nat]() extends Fin[Succ[N]] + case class FSucc[N <: Nat](pred: Fin[N]) extends Fin[Succ[N]] + + object Fin: + def zero[N <: Nat]: Fin[Succ[N]] = FZero() + def succ[N <: Nat](i: Fin[N]): Fin[Succ[N]] = FSucc(i) + + sealed trait Vec[A, N <: Nat] + case class VNil[A]() extends Vec[A, Zero] + case class VCons[A, N <: Nat](head: A, tail: Vec[A, N]) extends Vec[A, Succ[N]] + + object Vec: + def empty[A]: Vec[A, Zero] = VNil() + def cons[A, N <: Nat](head: A, tail: Vec[A, N]): Vec[A, Succ[N]] = VCons(head, tail) + + def get[A, N <: Nat](v: Vec[A, N], index: Fin[N]): A = (v, index) match + case (VCons(h, _), FZero()) => h + case (VCons(_, t), FSucc(pred)) => get(t, pred) + + def runVec(): Unit = + val v: Vec[Int, Succ[Succ[Succ[Zero]]]] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty))) + + println(s"Element at index 0: ${Vec.get(v, Fin.zero)}") + println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}") + println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}") + // println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}") // error + +// TODO(kπ) check if I can get it to work +// object typemembers: +// sealed trait Nat +// object Z extends Nat +// case class S() extends Nat: +// type N <: Nat + +// type Zero = Z.type +// type Succ[N1 <: Nat] = S { type N = N1 } + +// sealed trait Fin: +// type N <: Nat +// case class FZero[N1 <: Nat]() extends Fin: +// type N = Succ[N1] +// case class FSucc(tracked val pred: Fin) extends Fin: +// type N = Succ[pred.N] + +// object Fin: +// def zero[N1 <: Nat]: Fin { type N = Succ[N1] } = FZero[N1]() +// def succ[N1 <: Nat](i: Fin { type N = N1 }): Fin { type N = Succ[N1] } = FSucc(i) + +// sealed trait Vec[A]: +// type N <: Nat +// case class VNil[A]() extends Vec[A]: +// type N = Zero +// case class VCons[A](head: A, tracked val tail: Vec[A]) extends Vec[A]: +// type N = Succ[tail.N] + +// object Vec: +// def empty[A]: Vec[A] = VNil() +// def cons[A](head: A, tail: Vec[A]): Vec[A] = VCons(head, tail) + +// def get[A](v: Vec[A], index: Fin { type N = v.N }): A = (v, index) match +// case (VCons(h, _), FZero()) => h +// case (VCons(_, t), FSucc(pred)) => get(t, pred) + +// // def runVec(): Unit = +// val v: Vec[Int] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty))) + +// println(s"Element at index 0: ${Vec.get(v, Fin.zero)}") +// println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}") +// println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}") +// // println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}") diff --git a/tests/pos/infer-tracked.scala b/tests/pos/infer-tracked.scala new file mode 100644 index 000000000000..496508ffdc6c --- /dev/null +++ b/tests/pos/infer-tracked.scala @@ -0,0 +1,46 @@ +import scala.language.experimental.modularity +import scala.language.future + +abstract class C: + type T + def foo: T + +class F(val x: C): + val result: x.T = x.foo + +class G(override val x: C) extends F(x) + +class H(val x: C): + type T1 = x.T + val result: T1 = x.foo + +class I(val c: C, val t: c.T) + +case class J(c: C): + val result: c.T = c.foo + +case class K(c: C): + def result[B >: c.T]: B = c.foo + +def Test = + val c = new C: + type T = Int + def foo = 42 + + val f = new F(c) + val _: Int = f.result + + // val g = new G(c) + // val _: Int = g.result + + val h = new H(c) + val _: Int = h.result + + val i = new I(c, c.foo) + val _: Int = i.t + + val j = J(c) + val _: Int = j.result + + val k = K(c) + val _: Int = k.result