Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple suspension points support #348

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,25 @@ internal open class ParallelThreadsRunner(
// Coroutine will be resumed. Call method so that strategy can learn it.
afterCoroutineResumed(threadId)
// Check whether the result of the suspension point with the continuation has been stored
// by the resuming thread and invoke the follow-up part in this case.
if (completion.resWithCont.get() !== null) {
// Suspended thread got the result of the suspension point and continuation to resume.
val resumedValue = completion.resWithCont.get().first
// It is important to run the coroutine resumption part outside the ignored section
// to track the events inside resumption.
// by the resuming thread, and invoke the follow-up part in this case
val suspendResultToContinuation = completion.resWithCont.get()
if (suspendResultToContinuation !== null) {
// Suspended thread got result of the suspension point and continuation to resume
val (resumedValue, continuation) = suspendResultToContinuation
// Erase the current result
completion.resWithCont.set(null)
// We must exit the ignored section to keep tracking execution after the resumption.
runOutsideIgnoredSection(thread) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should this code be called outside the ignored section?

completion.resWithCont.get().second.resumeWith(resumedValue)
// Resume the execution of the coroutine.
continuation.resumeWith(resumedValue)
}
}
// If we've suspended again - then clean the completion status and rerun all the logic of this method to
// wait for resumption.
if (suspensionPointResults[threadId][actorId] == NoResult) {
completionStatuses[threadId].set(actorId, null)
return waitAndInvokeFollowUp(thread, actorId)
}
return suspensionPointResults[threadId][actorId]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -970,19 +970,25 @@ abstract class ManagedStrategy(
params: Array<Any?>
) {
val guarantee = runInIgnoredSection {
val threadId = currentThread
val atomicMethodDescriptor = getAtomicMethodDescriptor(owner, methodName)
val guarantee = when {
(atomicMethodDescriptor != null) -> ManagedGuaranteeType.TREAT_AS_ATOMIC
else -> methodGuaranteeType(owner, className, methodName)
else -> methodGuaranteeType(owner, className.canonicalClassName, methodName)
}
if (owner == null && atomicMethodDescriptor == null && guarantee == null) { // static method
LincheckJavaAgent.ensureClassHierarchyIsTransformed(className.canonicalClassName)
}
if (collectTrace) {
traceCollector!!.checkActiveLockDetected()
addBeforeMethodCallTracePoint(owner, codeLocation, methodId, className, methodName, params, atomicMethodDescriptor)
addBeforeMethodCallTracePoint(threadId, owner, codeLocation, methodId, className, methodName, params,
atomicMethodDescriptor
)
}
if (guarantee == ManagedGuaranteeType.TREAT_AS_ATOMIC) {
if (guarantee == ManagedGuaranteeType.TREAT_AS_ATOMIC &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a comment explaining what is going on here?

// do not create a trace point on resumption
!isResumptionMethodCall(threadId, className.canonicalClassName, methodName, params, atomicMethodDescriptor)
) {
newSwitchPointOnAtomicMethodCall(codeLocation, params)
}
if (guarantee == null) {
Expand Down Expand Up @@ -1052,39 +1058,18 @@ abstract class ManagedStrategy(
loopDetector.passParameters(params)
}

private fun isSuspendFunction(className: String, methodName: String, params: Array<Any?>): Boolean =
try {
// While this code is inefficient, it is called only when an error is detected.
getMethod(className.canonicalClassName, methodName, params)?.isSuspendable() ?: false
} catch (t: Throwable) {
// Something went wrong. Ignore it, as the error might lead only
// to an extra "<cont>" in the method call line in the trace.
false
}

private fun getMethod(className: String, methodName: String, params: Array<Any?>): Method? {
val clazz = Class.forName(className)

// Filter methods by name
val possibleMethods = clazz.declaredMethods.filter { it.name == methodName }

for (method in possibleMethods) {
val parameterTypes = method.parameterTypes
if (parameterTypes.size != params.size) continue

var match = true
for (i in parameterTypes.indices) {
val paramType = params[i]?.javaClass
if (paramType != null && !parameterTypes[i].isAssignableFrom(paramType)) {
match = false
break
}
}

if (match) return method
}

return null // or throw an exception if a match is mandatory
private fun isResumptionMethodCall(
threadId: Int,
className: String,
methodName: String,
methodParams: Array<Any?>,
atomicMethodDescriptor: AtomicMethodDescriptor?,
): Boolean {
// optimization - first quickly check if the method is atomics API method,
// in which case it cannot be suspended/resumed method
if (atomicMethodDescriptor != null) return false
val suspendedMethodStack = suspendedFunctionsStack[threadId]
return suspendedMethodStack.isNotEmpty() && isSuspendFunction(className, methodName, methodParams)
}

/**
Expand Down Expand Up @@ -1124,6 +1109,7 @@ abstract class ManagedStrategy(
}

private fun addBeforeMethodCallTracePoint(
iThread: Int,
owner: Any?,
codeLocation: Int,
methodId: Int,
Expand All @@ -1132,12 +1118,9 @@ abstract class ManagedStrategy(
methodParams: Array<Any?>,
atomicMethodDescriptor: AtomicMethodDescriptor?,
) {
val iThread = currentThread
val callStackTrace = callStackTrace[iThread]
val suspendedMethodStack = suspendedFunctionsStack[iThread]
val isSuspending = isSuspendFunction(className, methodName, methodParams)
val isResumption = isSuspending && suspendedMethodStack.isNotEmpty()
if (isResumption) {
if (isResumptionMethodCall(iThread, className.canonicalClassName, methodName, methodParams, atomicMethodDescriptor)) {
// In case of resumption, we need to find a call stack frame corresponding to the resumed function
var elementIndex = suspendedMethodStack.indexOfFirst {
it.tracePoint.className == className && it.tracePoint.methodName == methodName
Expand Down Expand Up @@ -1172,7 +1155,7 @@ abstract class ManagedStrategy(
return
}
val callId = callStackTraceElementId++
val params = if (isSuspending) {
val params = if (isSuspendFunction(className.canonicalClassName, methodName, methodParams)) {
methodParams.dropLast(1).toTypedArray()
} else {
methodParams
Expand Down Expand Up @@ -1513,11 +1496,15 @@ abstract class ManagedStrategy(
beforeMethodCallSwitch = beforeMethodCallSwitch
)
}
val callStackTrace = when (reason) {
SwitchReason.SUSPENDED -> suspendedFunctionsStack[iThread].reversed()
else -> callStackTrace[iThread]
}
_trace += SwitchEventTracePoint(
iThread = iThread,
actorId = currentActorId[iThread],
reason = reason,
callStackTrace = callStackTrace[iThread]
callStackTrace = callStackTrace,
)
spinCycleStartAdded = false
}
Expand Down
70 changes: 70 additions & 0 deletions src/jvm/main/org/jetbrains/kotlinx/lincheck/util/Reflection.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Lincheck
*
* Copyright (C) 2019 - 2024 JetBrains s.r.o.
*
* This Source Code Form is subject to the terms of the
* Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed
* with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/

package org.jetbrains.kotlinx.lincheck.util

import org.jetbrains.kotlinx.lincheck.isSuspendable
import kotlin.coroutines.Continuation
import java.lang.reflect.Method


/**
* Determines whether a given method is a suspending function.
*
* @param className The name of the class containing the method.
* @param methodName The name of the method to check.
* @param params An array of parameters passed to the method used to infer the method signature.
* @return `true` if the method is a suspending function; `false` otherwise.
*/
internal fun isSuspendFunction(className: String, methodName: String, params: Array<Any?>): Boolean {
// fast-path: if the last parameter is not continuation - then this is not suspending function
if (params.lastOrNull() !is Continuation<*>) return false
val result = runCatching {
// While this code is inefficient, it is called only on the slow path.
val method = getMethod(className, methodName, params)
method?.isSuspendable() == true
}
return result.getOrElse {
// Something went wrong. Ignore it, as the error might lead only
// to an extra "<cont>" in the method call line in the trace.
false
}
}

/**
* Retrieves a `Method` object representing a method of the specified name and parameter types.
*
* @param className The name of the class containing the method.
* @param methodName The name of the method to retrieve.
* @param params An array of parameters to match against the method's parameter types.
* The method is selected if its parameter types are compatible
* with the runtime classes of the elements in this array.
* @return The matching [Method] object if found, or `null` if no method matches.
*/
internal fun getMethod(className: String, methodName: String, params: Array<Any?>): Method? {
val clazz = Class.forName(className)
// filter methods by name
val possibleMethods = clazz.declaredMethods.filter { it.name == methodName }
// search through all possible methods, matching the arguments' types
for (method in possibleMethods) {
val parameterTypes = method.parameterTypes
if (parameterTypes.size != params.size) continue
var match = true
for (i in parameterTypes.indices) {
val paramType = params[i]?.javaClass
if (paramType != null && !parameterTypes[i].isAssignableFrom(paramType)) {
match = false
break
}
}
if (match) return method
}
return null // or throw an exception if a match is mandatory
}
Loading