Skip to content

Commit

Permalink
Scala 3 fixes and improvements (#509)
Browse files Browse the repository at this point in the history
* fix default param with type param

* path dependent types support added to scala 3

* fix context bounded classes
  • Loading branch information
goshacodes authored Feb 26, 2024
1 parent a43f4de commit 02c41bd
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 46 deletions.
20 changes: 13 additions & 7 deletions shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
package org.scalamock.clazz

import org.scalamock.context.MockContext

import scala.quoted.*
import scala.reflect.Selectable

Expand All @@ -42,16 +41,23 @@ private[clazz] object MockMaker:
def asParent(tree: TypeTree): TypeTree | Term =
val constructorFieldsFilledWithNulls: List[List[Term]] =
tree.tpe.dealias.typeSymbol.primaryConstructor.paramSymss
.filter(_.exists(!_.isType))
.map(_.map(_.typeRef.asType match { case '[t] => '{ null.asInstanceOf[t] }.asTerm }))
.filterNot(_.exists(_.isType))
.map(_.map(_.info.widen match {
case t@AppliedType(inner, applied) =>
Select.unique('{null}.asTerm, "asInstanceOf").appliedToTypes(List(inner.appliedTo(tpe.typeArgs)))
case other =>
Select.unique('{null}.asTerm, "asInstanceOf").appliedToTypes(List(other))
}))

if constructorFieldsFilledWithNulls.forall(_.isEmpty) then
tree
else
Select(
New(TypeIdent(tree.tpe.typeSymbol)),
tree.tpe.typeSymbol.primaryConstructor
).appliedToArgss(constructorFieldsFilledWithNulls)
).appliedToTypes(tree.tpe.typeArgs)
.appliedToArgss(constructorFieldsFilledWithNulls)



val parents =
Expand Down Expand Up @@ -91,15 +97,15 @@ private[clazz] object MockMaker:
Symbol.newVal(
parent = classSymbol,
name = definition.symbol.name,
tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol),
tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol),
flags = Flags.Override,
privateWithin = Symbol.noSymbol
)
else
Symbol.newMethod(
parent = classSymbol,
name = definition.symbol.name,
tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol),
tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol),
flags = Flags.Override,
privateWithin = Symbol.noSymbol
)
Expand Down Expand Up @@ -177,7 +183,7 @@ private[clazz] object MockMaker:
"asInstanceOf"
),
definition.tpe
.resolveParamRefs(definition.resTypeWithPathDependentOverrideFor(classSymbol), args)
.resolveParamRefs(definition.resTypeWithInnerTypesOverrideFor(classSymbol), args)
.asType match { case '[t] => List(TypeTree.of[t]) }
)
)
Expand Down
98 changes: 62 additions & 36 deletions shared/src/main/scala-3/org/scalamock/clazz/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@ package org.scalamock.clazz
import scala.quoted.*
import org.scalamock.context.MockContext

import scala.annotation.tailrec
import scala.annotation.{experimental, tailrec}
private[clazz] class Utils(using val quotes: Quotes):
import quotes.reflect.*

extension (tpe: TypeRepr)
def collectPathDependent(ownerSymbol: Symbol): List[TypeRepr] =
def collectInnerTypes(ownerSymbol: Symbol): List[TypeRepr] =
def loop(currentTpe: TypeRepr, names: List[String]): List[TypeRepr] =
currentTpe match
case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectPathDependent(ownerSymbol))
case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectInnerTypes(ownerSymbol))
case TypeRef(inner, name) if name == ownerSymbol.name && names.nonEmpty => List(tpe)
case TypeRef(inner, name) => loop(inner, name :: names)
case _ => Nil

loop(tpe, Nil)

def pathDependentOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr =
def innerTypeOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr =
@tailrec
def loop(currentTpe: TypeRepr, names: List[(String, List[TypeRepr])], appliedTypes: List[TypeRepr]): TypeRepr =
currentTpe match
Expand Down Expand Up @@ -53,55 +53,80 @@ private[clazz] class Utils(using val quotes: Quotes):
case _ =>
tpe

@experimental
def resolveParamRefs(resType: TypeRepr, methodArgs: List[List[Tree]]) =
def loop(baseBindings: TypeRepr, typeRepr: TypeRepr): TypeRepr =
typeRepr match
case pr@ParamRef(bindings, idx) if bindings == baseBindings =>
methodArgs.head(idx).asInstanceOf[TypeTree].tpe
tpe match
case baseBindings: PolyType =>
def loop(typeRepr: TypeRepr): TypeRepr =
typeRepr match
case pr@ParamRef(bindings, idx) if bindings == baseBindings =>
methodArgs.head(idx).asInstanceOf[TypeTree].tpe

case AppliedType(tycon, args) =>
AppliedType(tycon, args.map(arg => loop(baseBindings, arg)))
case AppliedType(tycon, args) =>
AppliedType(loop(tycon), args.map(arg => loop(arg)))

case other => other
case ff @ TypeRef(ref @ ParamRef(bindings, idx), name) =>
def getIndex(bindings: TypeRepr): Int =
@tailrec
def loop(bindings: TypeRepr, idx: Int): Int =
bindings match
case MethodType(_, _, method: MethodType) => loop(method, idx + 1)
case _ => idx

tpe match
case pt: PolyType => loop(pt, resType)
case _ => resType
loop(bindings, 1)

val maxIndex = methodArgs.length
val parameterListIdx = maxIndex - getIndex(bindings)

TypeSelect(methodArgs(parameterListIdx)(idx).asInstanceOf[Term], name).tpe

case other => other

loop(resType)
case _ =>
resType


def collectTypes: List[TypeRepr] =
def loop(currentTpe: TypeRepr, params: List[TypeRepr]): List[TypeRepr] =
def collectTypes: (List[TypeRepr], TypeRepr) =
@tailrec
def loop(currentTpe: TypeRepr, argTypesAcc: List[List[TypeRepr]], resType: TypeRepr): (List[TypeRepr], TypeRepr) =
currentTpe match
case PolyType(_, _, res) => loop(res, Nil)
case MethodType(_, argTypes, res) => argTypes ++ loop(res, params)
case other => List(other)
loop(tpe, Nil)
case PolyType(_, _, res) => loop(res, List.empty[TypeRepr] :: argTypesAcc, resType)
case MethodType(_, argTypes, res) => loop(res, argTypes :: argTypesAcc, resType)
case other => (argTypesAcc.reverse.flatten, other)
loop(tpe, Nil, TypeRepr.of[Nothing])

case class MockableDefinition(idx: Int, symbol: Symbol, ownerTpe: TypeRepr):
val mockValName = s"mock$$${symbol.name}$$$idx"
val tpe = ownerTpe.memberType(symbol)
private val rawTypes = tpe.widen.collectTypes
private val (rawTypes, rawResType) = tpe.widen.collectTypes
val parameterTypes = prepareTypesFor(ownerTpe.typeSymbol).map(_.tpe).init

def resTypeWithPathDependentOverrideFor(classSymbol: Symbol): TypeRepr =
val pd = rawTypes.last.collectPathDependent(ownerTpe.typeSymbol)
val pdUpdated = pd.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
rawTypes.last.substituteTypes(pd.map(_.typeSymbol), pdUpdated)
def resTypeWithInnerTypesOverrideFor(classSymbol: Symbol): TypeRepr =
updatePathDependent(rawResType, List(rawResType), classSymbol)

def tpeWithSubstitutedInnerTypesFor(classSymbol: Symbol): TypeRepr =
updatePathDependent(tpe, rawResType :: rawTypes, classSymbol)

def tpeWithSubstitutedPathDependentFor(classSymbol: Symbol): TypeRepr =
val pathDependentTypes = rawTypes.flatMap(_.collectPathDependent(ownerTpe.typeSymbol))
val pdUpdated = pathDependentTypes.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
tpe.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)
private def updatePathDependent(where: TypeRepr, types: List[TypeRepr], classSymbol: Symbol): TypeRepr =
val pathDependentTypes = types.flatMap(_.collectInnerTypes(ownerTpe.typeSymbol))
val pdUpdated = pathDependentTypes.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false))
where.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated)

def prepareTypesFor(classSymbol: Symbol) = rawTypes
.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
def prepareTypesFor(classSymbol: Symbol) = (rawTypes :+ rawResType)
.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
.map { typeRepr =>
val adjusted =
typeRepr.widen.mapParamRefWithWildcard match
case TypeBounds(lower, upper) => upper
case AppliedType(TypeRef(_, "<repeated>"), elemTyps) =>
TypeRepr.typeConstructorOf(classOf[Seq[_]]).appliedTo(elemTyps)
case other => other
case TypeRef(_: ParamRef, _) =>
TypeRepr.of[Any]
case AppliedType(TypeRef(_: ParamRef, _), _) =>
TypeRepr.of[Any]
case other =>
other
adjusted.asType match
case '[t] => TypeTree.of[t]
}
Expand All @@ -128,10 +153,11 @@ private[clazz] class Utils(using val quotes: Quotes):

def apply(tpe: TypeRepr): List[MockableDefinition] =
val methods = (tpe.typeSymbol.methodMembers.toSet -- TypeRepr.of[Object].typeSymbol.methodMembers).toList
.filter(sym => !sym.flags.is(Flags.Private) && !sym.flags.is(Flags.Final) && !sym.flags.is(Flags.Mutable))
.filterNot(sym => tpe.memberType(sym) match
case defaultParam @ ByNameType(AnnotatedType(_, Apply(Select(New(Inferred()), "<init>"), Nil))) => true
case _ => false
.filter(sym =>
!sym.flags.is(Flags.Private) &&
!sym.flags.is(Flags.Final) &&
!sym.flags.is(Flags.Mutable) &&
!sym.name.contains("$default$")
)
.zipWithIndex
.map((sym, idx) => MockableDefinition(idx, sym, tpe))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.paulbutcher.test

import org.scalamock.scalatest.MockFactory
import org.scalatest.funspec.AnyFunSpec

import scala.reflect.ClassTag

class ClassWithContextBoundSpec extends AnyFunSpec with MockFactory {

it("compile without args") {
class ContextBounded[T: ClassTag] {
def method(x: Int): Unit = ()
}

val m = mock[ContextBounded[String]]

}

it("compile with args") {
class ContextBounded[T: ClassTag](x: Int) {
def method(x: Int): Unit = ()
}

val m = mock[ContextBounded[String]]

}

it("compile with provided explicitly type class") {
class ContextBounded[T](x: ClassTag[T]) {
def method(x: Int): Unit = ()
}

val m = mock[ContextBounded[String]]

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package com.paulbutcher.test

import org.scalamock.matchers.Matchers
import org.scalamock.scalatest.MockFactory
import org.scalatest.funspec.AnyFunSpec

class PathDependentParamSpec extends AnyFunSpec with Matchers with MockFactory {

trait Command {
type Answer
type AnswerConstructor[A]
}

case class IntCommand() extends Command {
override type Answer = Int
override type AnswerConstructor[A] = Option[A]
}

val cmd = IntCommand()

trait PathDependent {

def call0[T <: Command](cmd: T): cmd.Answer

def call1[T <: Command](x: Int)(cmd: T): cmd.Answer

def call2[T <: Command](y: String)(cmd: T)(x: Int): cmd.Answer

def call3[T <: Command](cmd: T)(y: String)(x: Int): cmd.Answer

def call4[T <: Command](cmd: T): Option[cmd.Answer]

def call5[T <: Command](cmd: T)(x: cmd.Answer): Unit

def call6[T <: Command](cmd: T): cmd.AnswerConstructor[Int]

def call7[T <: Command](cmd: T)(x: cmd.AnswerConstructor[String])(y: cmd.Answer): Unit
}


it("path dependent in return type") {
val pathDependent = mock[PathDependent]

(pathDependent.call0[IntCommand] _).expects(cmd).returns(5)

assert(pathDependent.call0(cmd) == 5)
}

it("path dependent in return type and parameter in last parameter list") {
val pathDependent = mock[PathDependent]

(pathDependent.call1(_: Int)(_: IntCommand)).expects(5, cmd).returns(5)

assert(pathDependent.call1(5)(cmd) == 5)
}

it("path dependent in return type and parameter in middle parameter list ") {
val pathDependent = mock[PathDependent]

(pathDependent.call2(_: String)(_: IntCommand)(_: Int)).expects("5", cmd, 5).returns(5)

assert(pathDependent.call2("5")(cmd)(5) == 5)
}

it("path dependent in return type and parameter in first parameter list ") {
val pathDependent = mock[PathDependent]

(pathDependent.call3(_: IntCommand)(_: String)(_: Int)).expects(cmd, "5", 5).returns(5)

assert(pathDependent.call3(cmd)("5")(5) == 5)
}

it("path dependent in tycon return type") {
val pathDependent = mock[PathDependent]

(pathDependent.call4[IntCommand] _).expects(cmd).returns(Some(5))

assert(pathDependent.call4(cmd) == Some(5))
}

it("path dependent in parameter list") {
val pathDependent = mock[PathDependent]

(pathDependent.call5(_: IntCommand)(_: Int)).expects(cmd, 5).returns(())

assert(pathDependent.call5(cmd)(5) == ())
}

it("path dependent tycon in return type") {
val pathDependent = mock[PathDependent]

(pathDependent.call6[IntCommand] _).expects(cmd).returns(Some(5))

assert(pathDependent.call6(cmd) == Some(5))
}

it("path dependent tycon in parameter list") {
val pathDependent = mock[PathDependent]

(pathDependent.call7[IntCommand](_: IntCommand)(_: Option[String])(_: Int))
.expects(cmd, Some("5"), 6)
.returns(())

assert(pathDependent.call7(cmd)(Some("5"))(6) == ())
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class MethodsWithDefaultParamsTest extends IsolatedSpec {

trait TraitHavingMethodsWithDefaultParams {
def withAllDefaultParams(a: String = "default", b: CaseClass = CaseClass(42)): String

def withDefaultParamAndTypeParam[T](a: String = "default", b: Int = 5): T
}

behavior of "Mocks"
Expand Down Expand Up @@ -84,5 +86,13 @@ class MethodsWithDefaultParamsTest extends IsolatedSpec {
m.withAllDefaultParams("other", CaseClass(99))
}

they should "mock trait methods with type param and default parameters" in {
val m = mock[TraitHavingMethodsWithDefaultParams]

(m.withDefaultParamAndTypeParam[Int] _).expects("default", 5).returns(5)

m.withDefaultParamAndTypeParam[Int]("default", 5) shouldBe 5
}

override def newInstance = new MethodsWithDefaultParamsTest
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@

package org.scalamock.test.scalatest

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest._
import org.scalatest.flatspec.{AnyFlatSpec, AsyncFlatSpec}
import org.scalamock.scalatest.{MockFactory, AsyncMockFactory}

/**
* Tests for issue #371
*/
@Ignore
class AsyncSyncMixinTest extends AnyFlatSpec {

"MockFactory" should "be mixed only with Any*Spec and not Async*Spec traits" in {
Expand Down

0 comments on commit 02c41bd

Please sign in to comment.