diff --git a/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala b/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala index 6aa51b329253..3cd2ef79ccce 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala @@ -1,14 +1,26 @@ package dotty.tools.pc.completions +import scala.util.Try + import dotty.tools.dotc.ast.Trees.ValDef import dotty.tools.dotc.ast.tpd.* import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.ContextOps.localContext import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Definitions import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Flags.Method import dotty.tools.dotc.core.NameKinds.DefaultGetterName import dotty.tools.dotc.core.Names.Name +import dotty.tools.dotc.core.Symbols import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.core.Types.AndType +import dotty.tools.dotc.core.Types.AppliedType +import dotty.tools.dotc.core.Types.OrType +import dotty.tools.dotc.core.Types.TermRef import dotty.tools.dotc.core.Types.Type +import dotty.tools.dotc.core.Types.TypeBounds +import dotty.tools.dotc.core.Types.WildcardType import dotty.tools.dotc.util.SourcePosition import dotty.tools.pc.IndexedContext import dotty.tools.pc.utils.MtagsEnrichments.* @@ -60,6 +72,7 @@ object NamedArgCompletions: apply.fun match case Select(New(_), _) => false case Select(_, name) if name.decoded == "apply" => false + case Select(This(_), _) => false // is a select statement without a dot `qual.name` case sel @ Select(qual, _) if !sel.symbol.is(Flags.Synthetic) => !(qual.span.end until sel.nameSpan.start) @@ -72,7 +85,7 @@ object NamedArgCompletions: apply: Apply, indexedContext: IndexedContext, clientSupportsSnippets: Boolean - )(using Context): List[CompletionValue] = + )(using context: Context): List[CompletionValue] = def isUselessLiteral(arg: Tree): Boolean = arg match case Literal(Constant(())) => true // unitLiteral @@ -98,52 +111,121 @@ object NamedArgCompletions: end collectArgss val method = apply.fun - val methodSym = method.symbol - // paramSymss contains both type params and value params - val vparamss = - methodSym.paramSymss.filter(params => params.forall(p => p.isTerm)) val argss = collectArgss(apply) - // get params and args we are interested in - // e.g. - // in the following case, the interesting args and params are - // - params: [apple, banana] - // - args: [apple, b] - // ``` - // def curry(x; Int)(apple: String, banana: String) = ??? - // curry(1)(apple = "test", b@@) - // ``` - val (baseParams, baseArgs) = - vparamss.zip(argss).lastOption.getOrElse((Nil, Nil)) - - val args = ident - .map(i => baseArgs.filterNot(_ == i)) - .getOrElse(baseArgs) - .filterNot(isUselessLiteral) - - val isNamed: Set[Name] = args.iterator - .zip(baseParams.iterator) - // filter out synthesized args and default arg getters - .filterNot { - case (arg, _) if arg.symbol.denot.is(Flags.Synthetic) => true - case (Ident(name), _) => name.is(DefaultGetterName) // default args - case (Select(Ident(_), name), _) => - name.is(DefaultGetterName) // default args for apply method - case _ => false - } - .map { - case (NamedArg(name, _), _) => name - case (_, param) => param.name - } - .toSet - val allParams: List[Symbol] = + // fallback for when multiple overloaded methods match the supplied args + def fallbackFindMatchingMethods() = + def maybeNameAndIndexedContext( + method: Tree + ): Option[(Name, IndexedContext)] = + method match + case Ident(name) => Some((name, indexedContext)) + case Select(This(_), name) => Some((name, indexedContext)) + case Select(from, name) => + val symbol = from.symbol + val ownerSymbol = + if symbol.is(Method) && symbol.owner.isClass then + Some(symbol.owner) + else Try(symbol.info.classSymbol).toOption + ownerSymbol.map(sym => + (name, IndexedContext(context.localContext(from, sym))) + ) + case Apply(fun, _) => maybeNameAndIndexedContext(fun) + case _ => None + val matchingMethods = + for + (name, indxContext) <- maybeNameAndIndexedContext(method) + potentialMatches <- indxContext.findSymbol(name) + yield potentialMatches.collect { + case m + if m.is(Flags.Method) && + m.vparamss.length >= argss.length && + Try(m.isAccessibleFrom(apply.symbol.info)).toOption + .getOrElse(false) && + m.vparamss + .zip(argss) + .reverse + .zipWithIndex + .forall { case (pair, index) => + FuzzyArgMatcher(m.tparams) + .doMatch(allArgsProvided = index != 0) + .tupled(pair) + } => + m + } + matchingMethods.getOrElse(Nil) + end fallbackFindMatchingMethods + + val matchingMethods: List[Symbols.Symbol] = + if method.symbol.paramSymss.nonEmpty + then + val allArgsAreSupplied = + val vparamss = method.symbol.vparamss + vparamss.length == argss.length && vparamss + .zip(argss) + .lastOption + .exists { case (baseParams, baseArgs) => + baseArgs.length == baseParams.length + } + // ``` + // m(arg : Int) + // m(arg : Int, anotherArg : Int) + // m(a@@) + // ``` + // complier will choose the first `m`, so we need to manually look for the other one + if allArgsAreSupplied then + val foundPotential = fallbackFindMatchingMethods() + if foundPotential.contains(method.symbol) then foundPotential + else method.symbol :: foundPotential + else List(method.symbol) + else fallbackFindMatchingMethods() + end if + end matchingMethods + + val allParams = matchingMethods.flatMap { methodSym => + val vparamss = methodSym.vparamss + + // get params and args we are interested in + // e.g. + // in the following case, the interesting args and params are + // - params: [apple, banana] + // - args: [apple, b] + // ``` + // def curry(x: Int)(apple: String, banana: String) = ??? + // curry(1)(apple = "test", b@@) + // ``` + val (baseParams, baseArgs) = + vparamss.zip(argss).lastOption.getOrElse((Nil, Nil)) + + val args = ident + .map(i => baseArgs.filterNot(_ == i)) + .getOrElse(baseArgs) + .filterNot(isUselessLiteral) + + val isNamed: Set[Name] = args.iterator + .zip(baseParams.iterator) + // filter out synthesized args and default arg getters + .filterNot { + case (arg, _) if arg.symbol.denot.is(Flags.Synthetic) => true + case (Ident(name), _) => name.is(DefaultGetterName) // default args + case (Select(Ident(_), name), _) => + name.is(DefaultGetterName) // default args for apply method + case _ => false + } + .map { + case (NamedArg(name, _), _) => name + case (_, param) => param.name + } + .toSet + baseParams.filterNot(param => isNamed(param.name) || param.denot.is( Flags.Synthetic ) // filter out synthesized param, like evidence ) + } val prefix = ident @@ -151,7 +233,9 @@ object NamedArgCompletions: .getOrElse("") .replace(Cursor.value, "") val params: List[Symbol] = - allParams.filter(param => param.name.startsWith(prefix)) + allParams + .filter(param => param.name.startsWith(prefix)) + .distinctBy(sym => (sym.name, sym.info)) val completionSymbols = indexedContext.scopeSymbols def matchingTypesInScope(paramType: Type): List[String] = @@ -173,11 +257,11 @@ object NamedArgCompletions: def fillAllFields(): List[CompletionValue] = val suffix = "autofill" - val shouldShow = + def shouldShow = allParams.exists(param => param.name.startsWith(prefix)) - val isExplicitlyCalled = suffix.startsWith(prefix) - val hasParamsToFill = allParams.count(!_.is(Flags.HasDefault)) > 1 - if (shouldShow || isExplicitlyCalled) && hasParamsToFill && clientSupportsSnippets + def isExplicitlyCalled = suffix.startsWith(prefix) + def hasParamsToFill = allParams.count(!_.is(Flags.HasDefault)) > 1 + if clientSupportsSnippets && matchingMethods.length == 1 && (shouldShow || isExplicitlyCalled) && hasParamsToFill then val editText = allParams.zipWithIndex .collect { @@ -215,4 +299,59 @@ object NamedArgCompletions: ) ::: findPossibleDefaults() ::: fillAllFields() end contribute + extension (method: Symbols.Symbol) + def vparamss(using Context) = method.filteredParamss(_.isTerm) + def tparams(using Context) = method.filteredParamss(_.isType).flatten + def filteredParamss(f: Symbols.Symbol => Boolean)(using Context) = + method.paramSymss.filter(params => params.forall(f)) end NamedArgCompletions + +class FuzzyArgMatcher(tparams: List[Symbols.Symbol])(using Context): + + /** + * A heuristic for checking if the passed arguments match the method's arguments' types. + * For non-polymorphic methods we use the subtype relation (`<:<`) + * and for polymorphic methods we use a heuristic. + * We check the args types not the result type. + */ + def doMatch( + allArgsProvided: Boolean + )(expectedArgs: List[Symbols.Symbol], actualArgs: List[Tree]) = + (expectedArgs.length == actualArgs.length || + (!allArgsProvided && expectedArgs.length >= actualArgs.length)) && + actualArgs.zipWithIndex.forall { + case (Ident(name), _) if name.endsWith(Cursor.value) => true + case (NamedArg(name, arg), _) => + expectedArgs.exists { expected => + expected.name == name && (!arg.hasType || arg.typeOpt.unfold + .fuzzyArg_<:<(expected.info)) + } + case (arg, i) => + !arg.hasType || arg.typeOpt.unfold.fuzzyArg_<:<(expectedArgs(i).info) + } + + extension (arg: Type) + def fuzzyArg_<:<(expected: Type) = + if tparams.isEmpty then arg <:< expected + else arg <:< substituteTypeParams(expected) + def unfold = + arg match + case arg: TermRef => arg.underlying + case e => e + + private def substituteTypeParams(t: Type): Type = + t match + case e if tparams.exists(_ == e.typeSymbol) => + val matchingParam = tparams.find(_ == e.typeSymbol).get + matchingParam.info match + case b @ TypeBounds(_, _) => WildcardType(b) + case _ => WildcardType + case o @ OrType(e1, e2) => + OrType(substituteTypeParams(e1), substituteTypeParams(e2), o.isSoft) + case AndType(e1, e2) => + AndType(substituteTypeParams(e1), substituteTypeParams(e2)) + case AppliedType(et, eparams) => + AppliedType(et, eparams.map(substituteTypeParams)) + case _ => t + +end FuzzyArgMatcher diff --git a/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionArgSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionArgSuite.scala index 119a320fde2b..b376284aa6a6 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionArgSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/completion/CompletionArgSuite.scala @@ -20,7 +20,7 @@ class CompletionArgSuite extends BaseCompletionSuite: |} |""".stripMargin, """|assertion = : Boolean - |Main test + |message = : => Any |""".stripMargin, topLines = Option(2) ) @@ -587,6 +587,7 @@ class CompletionArgSuite extends BaseCompletionSuite: |} |""".stripMargin, """|foo = : Int + |fooBar = : Int |""".stripMargin ) @@ -620,6 +621,8 @@ class CompletionArgSuite extends BaseCompletionSuite: |""".stripMargin, """|foo = : Int |foo = a : Int + |fooBar = : Int + |fooBar = a : Int |""".stripMargin ) @@ -676,3 +679,205 @@ class CompletionArgSuite extends BaseCompletionSuite: topLines = Some(4), ) + @Test def `overloaded-with-param` = + check( + """|def m(idd : String, abb: Int): Int = ??? + |def m(inn : Int, uuu: Option[Int]): Int = ??? + |def m(inn : Int, aaa: Int): Int = ??? + |def k: Int = m(1, a@@) + |""".stripMargin, + """|aaa = : Int + |assert(assertion: Boolean): Unit + |""".stripMargin, + topLines = Some(2), + ) + + @Test def `overloaded-with-named-param` = + check( + """|def m(idd : String, abb: Int): Int = ??? + |def m(inn : Int, uuu: Option[Int]): Int = ??? + |def m(inn : Int, aaa: Int): Int = ??? + |def k: Int = m(inn = 1, a@@) + |""".stripMargin, + """|aaa = : Int + |assert(assertion: Boolean): Unit + |""".stripMargin, + topLines = Some(2), + ) + + @Test def `overloaded-generic` = + check( + """|object M: + | val g = 3 + | val l : List[Int] = List(1,2,3) + | def m[T](inn : List[T], yy: Int, aaa: Int, abb: Option[Int]): Int = ??? + | def m[T](inn : List[T], yy: Int, aaa: Int, abb: Int): Int = ??? + | def k: Int = m(yy = 3, inn = l, a@@) + |""".stripMargin, + """|aaa = : Int + |aaa = g : Int + |abb = : Option[Int] + |abb = : Int + |abb = g : Int + |""".stripMargin, + topLines = Some(5), + ) + + @Test def `overloaded-methods` = + check( + """|class A(): + | def m(anInt : Int): Int = ??? + | def m(aString : String): String = ??? + | + |object O: + | def m(aaa: Int): Int = ??? + | val k = new A().m(a@@) + |""".stripMargin, + """|aString = : String + |anInt = : Int + |""".stripMargin, + topLines = Some(2), + ) + + @Test def `overloaded-methods2` = + check( + """|class A(): + | def m(anInt : Int): Int = ??? + | def m(aString : String): String = ??? + | private def m(aBoolean: Boolean): Boolean = ??? + | + |object O: + | def m(aaa: Int): Int = ??? + | val myInstance = new A() + | val k = myInstance.m(a@@) + |""".stripMargin, + """|aString = : String + |anInt = : Int + |""".stripMargin, + topLines = Some(2), + ) + + @Test def `overloaded-select` = + check( + """|package a.b { + | object A { + | def m(anInt : Int): Int = ??? + | def m(aString : String): String = ??? + | } + |} + |object O { + | def m(aaa: Int): Int = ??? + | val k = a.b.A.m(a@@) + |} + |""".stripMargin, + """|aString = : String + |anInt = : Int + |""".stripMargin, + topLines = Some(2), + ) + + @Test def `overloaded-in-a-class` = + check( + """|trait Planet + |case class Venus() extends Planet + |class Main[T <: Planet](t : T) { + | def m(inn: Planet, abb: Option[Int]): Int = ??? + | def m(inn: Planet, aaa: Int): Int = ??? + | def k = m(t, a@@) + |} + |""".stripMargin, + """|aaa = : Int + |abb = : Option[Int] + |""".stripMargin, + topLines = Some(2), + ) + + @Test def `overloaded-function-param` = + check( + """|def m[T](i: Int)(inn: T => Int, abb: Option[Int]): Int = ??? + |def m[T](i: Int)(inn: T => Int, aaa: Int): Int = ??? + |def m[T](i: Int)(inn: T => String, acc: List[Int]): Int = ??? + |def k = m(1)(inn = identity[Int], a@@) + |""".stripMargin, + """|aaa = : Int + |abb = : Option[Int] + |assert(assertion: Boolean): Unit + |""".stripMargin, + topLines = Some(3), + ) + + @Test def `overloaded-function-param2` = + check( + """|def m[T](i: Int)(inn: T => Int, abb: Option[Int]): Int = ??? + |def m[T](i: Int)(inn: T => Int, aaa: Int): Int = ??? + |def m[T](i: String)(inn: T => Int, acc: List[Int]): Int = ??? + |def k = m(1)(inn = identity[Int], a@@) + |""".stripMargin, + """|aaa = : Int + |abb = : Option[Int] + |assert(assertion: Boolean): Unit + |""".stripMargin, + topLines = Some(3), + ) + + @Test def `overloaded-applied-type` = + check( + """|trait MyCollection[+T] + |case class IntCollection() extends MyCollection[Int] + |object Main { + | def m[T](inn: MyCollection[T], abb: Option[Int]): Int = ??? + | def m[T](inn: MyCollection[T], aaa: Int): Int = ??? + | def m[T](inn: List[T], acc: Int): Int = ??? + | def k = m(IntCollection(), a@@) + |} + |""".stripMargin, + """|aaa = : Int + |abb = : Option[Int] + |assert(assertion: Boolean): Unit + |""".stripMargin, + topLines = Some(3), + ) + + @Test def `overloaded-bounds` = + check( + """|trait Planet + |case class Moon() + |object Main { + | def m[M](inn: M, abb: Option[Int]): M = ??? + | def m[M](inn: M, acc: List[Int]): M = ??? + | def m[M <: Planet](inn: M, aaa: Int): M = ??? + | def k = m(Moon(), a@@) + |} + |""".stripMargin, + """|abb = : Option[Int] + |acc = : List[Int] + |assert(assertion: Boolean): Unit + |""".stripMargin, + topLines = Some(3), + ) + + @Test def `overloaded-or-type` = + check( + """|object Main: + | val h : Int = 3 + | def m[T](inn: String | T, abb: Option[Int]): Int = ??? + | def m(inn: Int, aaa: Int): Int = ??? + | def k: Int = m(3, a@@) + |""".stripMargin, + """|aaa = : Int + |aaa = h : Int + |abb = : Option[Int] + |""".stripMargin, + topLines = Some(3), + ) + + @Test def `overloaded-function-param3` = + check( + """|def m[T](inn: Int => T, abb: Option[Int]): Int = ??? + |def m[T](inn: String => T, aaa: Int): Int = ??? + |def k = m(identity[Int], a@@) + |""".stripMargin, + """|abb = : Option[Int] + |""".stripMargin, + topLines = Some(1), + )