From a3b85c2bab2591a9d048e633cec5749cac190e0e Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Fri, 16 Feb 2024 16:34:12 +0100 Subject: [PATCH] Move prompt generation and psi utils into separate module named core --- PromptGenerator/build.gradle.kts | 20 - build.gradle.kts | 3 + core/build.gradle.kts | 46 +++ .../testspark/core}/generation/Patterns.kt | 3 +- .../core/generation/prompt/PromptGenerator.kt | 384 +++++++++++++++++ .../core/generation/prompt/PromptKeyword.kt | 32 ++ .../configuration/GenerationSettings.kt | 5 + .../prompt/configuration/PromptTemplates.kt | 7 + .../testspark/core}/helpers/PsiHelper.kt | 2 +- settings.gradle.kts | 2 +- .../testspark/actions/TestSparkAction.kt | 2 +- .../research/testspark/tools/Pipeline.kt | 2 +- .../testspark/tools/evosuite/EvoSuite.kt | 6 +- .../research/testspark/tools/llm/Llm.kt | 8 +- .../tools/llm/generation/PromptManager.kt | 385 ++---------------- .../tools/llm/generation/TestsAssembler.kt | 2 + 16 files changed, 516 insertions(+), 393 deletions(-) delete mode 100644 PromptGenerator/build.gradle.kts create mode 100644 core/build.gradle.kts rename {src/main/kotlin/org/jetbrains/research/testspark/tools/llm => core/src/main/kotlin/org/jetbrains/research/testspark/core}/generation/Patterns.kt (87%) create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptGenerator.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptKeyword.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/GenerationSettings.kt create mode 100644 core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/PromptTemplates.kt rename {src/main/kotlin/org/jetbrains/research/testspark => core/src/main/kotlin/org/jetbrains/research/testspark/core}/helpers/PsiHelper.kt (99%) diff --git a/PromptGenerator/build.gradle.kts b/PromptGenerator/build.gradle.kts deleted file mode 100644 index 55d6eb487..000000000 --- a/PromptGenerator/build.gradle.kts +++ /dev/null @@ -1,20 +0,0 @@ -plugins { - kotlin("jvm") -} - -group = "org.jetbrains.research" - -repositories { - mavenCentral() -} - -dependencies { - testImplementation("org.jetbrains.kotlin:kotlin-test") -} - -tasks.test { - useJUnitPlatform() -} -kotlin { - jvmToolchain(19) -} \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index f6162e5c1..46eace43c 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -104,6 +104,8 @@ dependencies { implementation(files("lib/byte-buddy-agent-1.14.6.jar")) implementation(files("lib/JUnitRunner.jar")) + implementation(project(":core")) + // validation dependencies // https://mvnrepository.com/artifact/junit/junit implementation("junit:junit:4.13") @@ -182,6 +184,7 @@ tasks { compileKotlin { dependsOn("updateEvosuite") dependsOn("copyJUnitRunnerLib") + dependsOn(":core:buildPlugin") } // Set the JVM compatibility versions properties("javaVersion").let { diff --git a/core/build.gradle.kts b/core/build.gradle.kts new file mode 100644 index 000000000..72454ebc1 --- /dev/null +++ b/core/build.gradle.kts @@ -0,0 +1,46 @@ + +fun properties(key: String) = project.findProperty(key).toString() + +plugins { + kotlin("jvm") + id("org.jetbrains.intellij") +} + +group = "org.jetbrains.research" + +repositories { + mavenCentral() + maven("https://cache-redirector.jetbrains.com/intellij-dependencies") +} + +dependencies { + testImplementation("org.jetbrains.kotlin:kotlin-test") + compileOnly(kotlin("stdlib")) +} + +// TODO: already configured in parent project, how to inherit it? +intellij { + version.set(properties("platformVersion")) + type.set(properties("platformType")) + + plugins.set(properties("platformPlugins").split(',').map(String::trim).filter(String::isNotEmpty)) +} + +tasks { + runIde { enabled = false } + runIdeForUiTests { enabled = false } + buildSearchableOptions { enabled = false } + + patchPluginXml { + sinceBuild.set(properties("pluginSinceBuild")) + untilBuild.set(properties("pluginUntilBuild")) + } +} + + +tasks.test { + useJUnitPlatform() +} +kotlin { + jvmToolchain(17) +} \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/Patterns.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/Patterns.kt similarity index 87% rename from src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/Patterns.kt rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/Patterns.kt index 5cd3f387a..3079e4663 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/Patterns.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/Patterns.kt @@ -1,4 +1,5 @@ -package org.jetbrains.research.testspark.tools.llm.generation +package org.jetbrains.research.testspark.core.generation + val importPattern = Regex( pattern = "^import\\s+(static\\s)?((?:[a-zA-Z_]\\w*\\.)*[a-zA-Z_](?:\\w*\\.?)*)(?:\\.\\*)?;", diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptGenerator.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptGenerator.kt new file mode 100644 index 000000000..690bf27e1 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptGenerator.kt @@ -0,0 +1,384 @@ +package org.jetbrains.research.testspark.core.generation.prompt + +import com.intellij.openapi.project.Project +import com.intellij.openapi.util.TextRange +import com.intellij.psi.* +import com.intellij.psi.search.GlobalSearchScope +import com.intellij.psi.search.searches.ClassInheritorsSearch +import com.intellij.psi.util.PsiTypesUtil +import org.jetbrains.research.testspark.core.generation.importPattern +import org.jetbrains.research.testspark.core.generation.packagePattern +import org.jetbrains.research.testspark.core.generation.prompt.configuration.GenerationSettings +import org.jetbrains.research.testspark.core.generation.prompt.configuration.PromptTemplates +import org.jetbrains.research.testspark.core.helpers.generateMethodDescriptor + + +class PromptGenerator( + private val project: Project, + private val cut: PsiClass, + private val classesToTest: MutableList, + private val settings: GenerationSettings, + private val promptTemplates: PromptTemplates +) { + + /** + * Generates a prompt for generating unit tests in Java for a given class. + * + * @return The generated prompt. + */ + fun generatePromptForClass(): String { + var classPrompt = promptTemplates.classPrompt + val interestingPsiClasses = getInterestingPsiClasses(classesToTest) + + classPrompt = insertLanguage(classPrompt) + classPrompt = insertName(classPrompt, cut.qualifiedName!!) + classPrompt = insertTestingPlatform(classPrompt) + classPrompt = insertMockingFramework(classPrompt) + classPrompt = insertCodeUnderTest(classPrompt, getClassFullText(cut)) + classPrompt = insertMethodsSignatures(classPrompt, interestingPsiClasses) + classPrompt = + insertPolymorphismRelations(classPrompt, getPolymorphismRelations(project, interestingPsiClasses, cut)) + + return classPrompt + } + + /** + * Generates a prompt for a method. + * + * @param methodDescriptor The descriptor of the method. + * @return The generated prompt. + */ + fun generatePromptForMethod(methodDescriptor: String): String { + var methodPrompt = promptTemplates.methodPrompt + val psiMethod = getPsiMethod(cut, methodDescriptor)!! + + methodPrompt = insertLanguage(methodPrompt) + methodPrompt = insertName(methodPrompt, "${cut.qualifiedName!!}.${psiMethod.name}") + methodPrompt = insertTestingPlatform(methodPrompt) + methodPrompt = insertMockingFramework(methodPrompt) + methodPrompt = insertCodeUnderTest(methodPrompt, psiMethod.text) + methodPrompt = insertMethodsSignatures(methodPrompt, getInterestingPsiClasses(psiMethod)) + methodPrompt = insertPolymorphismRelations( + methodPrompt, + getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut), + ) + + return methodPrompt + } + + + /** + * Generates a prompt for a specific line number in the code. + * + * @param lineNumber the line number for which to generate the prompt + * @return the generated prompt string + */ + fun generatePromptForLine(lineNumber: Int): String { + var linePrompt = promptTemplates.linePrompt + val methodDescriptor = getMethodDescriptor(cut, lineNumber) + val psiMethod = getPsiMethod(cut, methodDescriptor)!! + + // get code of line under test + val document = PsiDocumentManager.getInstance(project).getDocument(cut.containingFile) + val lineStartOffset = document!!.getLineStartOffset(lineNumber - 1) + val lineEndOffset = document.getLineEndOffset(lineNumber - 1) + val lineUnderTest = document.getText(TextRange.create(lineStartOffset, lineEndOffset)) + + linePrompt = insertLanguage(linePrompt) + linePrompt = insertName(linePrompt, lineUnderTest.trim()) + linePrompt = insertTestingPlatform(linePrompt) + linePrompt = insertMockingFramework(linePrompt) + linePrompt = insertCodeUnderTest(linePrompt, psiMethod.text) + linePrompt = insertMethodsSignatures(linePrompt, getInterestingPsiClasses(psiMethod)) + linePrompt = insertPolymorphismRelations( + linePrompt, + getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut), + ) + + return linePrompt + } + + + /** + * Returns the method descriptor of the method containing the given line number in the specified PsiClass. + * + * @param psiClass the PsiClass containing the method + * @param lineNumber the line number within the file where the method is located + * @return the method descriptor as a String, or an empty string if no method is found + */ + private fun getMethodDescriptor(psiClass: PsiClass, lineNumber: Int): String { + for (currentPsiMethod in psiClass.allMethods) { + if (isLineInPsiMethod(currentPsiMethod, lineNumber)) return generateMethodDescriptor(currentPsiMethod) + } + return "" + } + + /** + * Checks if the given line number is within the range of the specified PsiMethod. + * + * @param method The PsiMethod to check. + * @param lineNumber The line number to check. + * @return `true` if the line number is within the range of the method, `false` otherwise. + */ + private fun isLineInPsiMethod(method: PsiMethod, lineNumber: Int): Boolean { + val psiFile = method.containingFile ?: return false + val document = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return false + val textRange = method.textRange + val startLine = document.getLineNumber(textRange.startOffset) + 1 + val endLine = document.getLineNumber(textRange.endOffset) + 1 + return lineNumber in startLine..endLine + } + + /** + * Retrieves a PsiMethod matching the given method descriptor within the provided PsiClass. + * + * @param psiClass The PsiClass in which to search for the method. + * @param methodDescriptor The method descriptor to match against. + * @return The matching PsiMethod if found, otherwise an empty string. + */ + private fun getPsiMethod(psiClass: PsiClass, methodDescriptor: String): PsiMethod? { + for (currentPsiMethod in psiClass.allMethods) { + if (generateMethodDescriptor(currentPsiMethod) == methodDescriptor) return currentPsiMethod + } + return null + } + + /** + * Returns a set of interesting PsiClasses based on the given PsiMethod. + * + * @param psiMethod the PsiMethod for which to find interesting PsiClasses + * @return a mutable set of interesting PsiClasses + */ + private fun getInterestingPsiClasses(psiMethod: PsiMethod): MutableSet { + val interestingMethods = mutableSetOf(psiMethod) + for (currentPsiMethod in cut.allMethods) { + if (currentPsiMethod.isConstructor) interestingMethods.add(currentPsiMethod) + } + val interestingPsiClasses = mutableSetOf(cut) + interestingMethods.forEach { methodIt -> + methodIt.parameterList.parameters.forEach { paramIt -> + PsiTypesUtil.getPsiClass(paramIt.type)?.let { + if (it.qualifiedName != null && !it.qualifiedName!!.startsWith("java.")) { + interestingPsiClasses.add(it) + } + } + } + } + return interestingPsiClasses + } + + /** + * Retrieves a set of interesting PsiClasses based on a given cutPsiClass and a list of classesToTest. + * + * @param classesToTest The list of classes to test for interesting PsiClasses. + * @return The set of interesting PsiClasses found during the search. + */ + private fun getInterestingPsiClasses(classesToTest: MutableList): MutableSet { + val interestingPsiClasses: MutableSet = mutableSetOf() + + var currentLevelClasses = mutableListOf().apply { addAll(classesToTest) } + + for (i in 0 until settings.maxInputParamsDepth) { + val tempListOfClasses = mutableSetOf() + + currentLevelClasses.forEach { classIt -> + classIt.methods.forEach { methodIt -> + methodIt.parameterList.parameters.forEach { paramIt -> + PsiTypesUtil.getPsiClass(paramIt.type)?.let { + if (!interestingPsiClasses.contains(it) && it.qualifiedName != null && + !it.qualifiedName!!.startsWith("java.") + ) { + tempListOfClasses.add(it) + } + } + } + } + } + currentLevelClasses = mutableListOf().apply { addAll(tempListOfClasses) } + interestingPsiClasses.addAll(tempListOfClasses) + } + + return interestingPsiClasses + } + + + private fun isPromptValid(keyword: PromptKeyword, prompt: String): Boolean { + val keywordText = keyword.text + val isMandatory = keyword.mandatory + + return (prompt.contains(keywordText) || !isMandatory) + } + + + private fun insertLanguage(classPrompt: String): String { + if (isPromptValid(PromptKeyword.LANGUAGE, classPrompt)) { + val keyword = "\$${PromptKeyword.LANGUAGE.text}" + return classPrompt.replace(keyword, "Java", ignoreCase = false) + } else { + throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}") + } + } + + private fun insertName(classPrompt: String, classDisplayName: String): String { + if (isPromptValid(PromptKeyword.NAME, classPrompt)) { + val keyword = "\$${PromptKeyword.NAME.text}" + return classPrompt.replace(keyword, classDisplayName, ignoreCase = false) + } else { + throw IllegalStateException("The prompt must contain ${PromptKeyword.NAME.text}") + } + } + + private fun insertTestingPlatform(classPrompt: String): String { + if (isPromptValid(PromptKeyword.TESTING_PLATFORM, classPrompt)) { + val keyword = "\$${PromptKeyword.TESTING_PLATFORM.text}" + return classPrompt.replace(keyword, "JUnit 4", ignoreCase = false) + } else { + throw IllegalStateException("The prompt must contain ${PromptKeyword.TESTING_PLATFORM.text}") + } + } + + private fun insertMockingFramework(classPrompt: String): String { + if (isPromptValid(PromptKeyword.MOCKING_FRAMEWORK, classPrompt)) { + val keyword = "\$${PromptKeyword.MOCKING_FRAMEWORK.text}" + return classPrompt.replace(keyword, "Mockito 5", ignoreCase = false) + } else { + throw IllegalStateException("The prompt must contain ${PromptKeyword.MOCKING_FRAMEWORK.text}") + } + } + + private fun insertCodeUnderTest(classPrompt: String, classFullText: String): String { + if (isPromptValid(PromptKeyword.CODE, classPrompt)) { + val keyword = "\$${PromptKeyword.CODE.text}" + var fullText = "```\n${classFullText}\n```\n" + + for (i in 2..classesToTest.size) { + val subClass = classesToTest[i - 2] + val superClass = classesToTest[i - 1] + + fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " + + "The source code of ${superClass.qualifiedName} is:\n```\n${getClassFullText(superClass)}\n" + + "```\n" + } + return classPrompt.replace(keyword, fullText, ignoreCase = false) + } else { + throw IllegalStateException("The prompt must contain ${PromptKeyword.CODE.text}") + } + } + + private fun insertMethodsSignatures(classPrompt: String, interestingPsiClasses: MutableSet): String { + val keyword = "\$${PromptKeyword.METHODS.text}" + + if (isPromptValid(PromptKeyword.METHODS, classPrompt)) { + var fullText = "" + for (interestingPsiClass: PsiClass in interestingPsiClasses) { + if (interestingPsiClass.qualifiedName!!.startsWith("java")) { + continue + } + + fullText += "=== methods in ${interestingPsiClass.qualifiedName!!}:\n" + for (currentPsiMethod in interestingPsiClass.allMethods) { + // Skip java methods + if (currentPsiMethod.containingClass!!.qualifiedName!!.startsWith("java")) { + continue + } + fullText += " - ${currentPsiMethod.getSignatureString()}\n" + } + } + return classPrompt.replace(keyword, fullText, ignoreCase = false) + } else { + throw IllegalStateException("The prompt must contain ${PromptKeyword.METHODS.text}") + } + } + + private fun PsiMethod.getSignatureString(): String { + val bodyStart = body?.startOffsetInParent ?: this.textLength + return text.substring(0, bodyStart).replace('\n', ' ').trim() + } + + private fun insertPolymorphismRelations( + classPrompt: String, + polymorphismRelations: MutableMap>, + ): String { + val keyword = "\$${PromptKeyword.POLYMORPHISM.text}" + if (isPromptValid(PromptKeyword.POLYMORPHISM, classPrompt)) { + var fullText = "" + + polymorphismRelations.forEach { entry -> + for (currentSubClass in entry.value) { + currentSubClass.qualifiedName ?: continue + fullText += "${currentSubClass.qualifiedName} is a sub-class of ${entry.key.qualifiedName}.\n" + } + } + return classPrompt.replace(keyword, fullText, ignoreCase = false) + } else { + throw IllegalStateException("The prompt must contain ${PromptKeyword.POLYMORPHISM.text}") + } + } + + /** + * Retrieves the polymorphism relations between a given set of interesting PsiClasses and a cut PsiClass. + * + * @param project The project context in which the PsiClasses exist. + * @param interestingPsiClasses The set of PsiClasses that are considered interesting. + * @param cutPsiClass The cut PsiClass to determine polymorphism relations against. + * @return A mutable map where the key represents an interesting PsiClass and the value is a list of its detected subclasses. + */ + private fun getPolymorphismRelations( + project: Project, + interestingPsiClasses: MutableSet, + cutPsiClass: PsiClass, + ): MutableMap> { + val polymorphismRelations: MutableMap> = mutableMapOf() + + val psiClassesToVisit: ArrayDeque = ArrayDeque(listOf(cutPsiClass)) + interestingPsiClasses.add(cutPsiClass) + + interestingPsiClasses.forEach { currentInterestingClass -> + val scope = GlobalSearchScope.projectScope(project) + val query = ClassInheritorsSearch.search(currentInterestingClass, scope, false) + val detectedSubClasses: Collection = query.findAll() + + detectedSubClasses.forEach { detectedSubClass -> + if (!polymorphismRelations.contains(currentInterestingClass)) { + polymorphismRelations[currentInterestingClass] = ArrayList() + } + polymorphismRelations[currentInterestingClass]?.add(detectedSubClass) + if (!psiClassesToVisit.contains(detectedSubClass)) { + psiClassesToVisit.addLast(detectedSubClass) + } + } + } + + return polymorphismRelations + } + + /** + * Returns the full text of a given class including the package, imports, and class code. + * + * @param cl The PsiClass object representing the class. + * @return The full text of the class. + */ + private fun getClassFullText(cl: PsiClass): String { + var fullText = "" + val fileText = cl.containingFile.text + + // get package + packagePattern.findAll(fileText, 0).map { + it.groupValues[0] + }.forEach { + fullText += "$it\n\n" + } + + // get imports + importPattern.findAll(fileText, 0).map { + it.groupValues[0] + }.forEach { + fullText += "$it\n" + } + + // Add class code + fullText += cl.text + + return fullText + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptKeyword.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptKeyword.kt new file mode 100644 index 000000000..81f025bc6 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptKeyword.kt @@ -0,0 +1,32 @@ +package org.jetbrains.research.testspark.core.generation.prompt + + +enum class PromptKeyword(val text: String, val description: String, val mandatory: Boolean) { + NAME("NAME", "The name of the code under test (Class name, method name, line number)", true), + CODE("CODE", "The code under test (Class, method, or line)", true), + LANGUAGE("LANGUAGE", "Programming language of the project under test (only Java supported at this point)", true), + TESTING_PLATFORM( + "TESTING_PLATFORM", + "testing platform used in the project (Only JUnit 4 is supported at this point)", + true, + ), + MOCKING_FRAMEWORK( + "MOCKING_FRAMEWORK", + "mock framework that can be used in generated test (Only Mockito is supported at this point)", + false, + ), + METHODS("METHODS", "signature of methods used in the code under tests", false), + POLYMORPHISM("POLYMORPHISM", "polymorphism relations between classes involved in the code under test.", false), + ; + + fun getOffsets(prompt: String): Pair? { + val textToHighlight = "\$$text" + if (!prompt.contains(textToHighlight)) { + return null + } + + val startOffset = prompt.indexOf(textToHighlight) + val endOffset = startOffset + textToHighlight.length + return Pair(startOffset, endOffset) + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/GenerationSettings.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/GenerationSettings.kt new file mode 100644 index 000000000..157c0c4ac --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/GenerationSettings.kt @@ -0,0 +1,5 @@ +package org.jetbrains.research.testspark.core.generation.prompt.configuration + +data class GenerationSettings( + val maxInputParamsDepth: Int +) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/PromptTemplates.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/PromptTemplates.kt new file mode 100644 index 000000000..47cf21081 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/PromptTemplates.kt @@ -0,0 +1,7 @@ +package org.jetbrains.research.testspark.core.generation.prompt.configuration + +data class PromptTemplates( + val classPrompt: String, + val methodPrompt: String, + val linePrompt: String, +) \ No newline at end of file diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/PsiHelper.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/helpers/PsiHelper.kt similarity index 99% rename from src/main/kotlin/org/jetbrains/research/testspark/helpers/PsiHelper.kt rename to core/src/main/kotlin/org/jetbrains/research/testspark/core/helpers/PsiHelper.kt index 97df6fcdc..2bf4d264f 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/PsiHelper.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/helpers/PsiHelper.kt @@ -1,4 +1,4 @@ -package org.jetbrains.research.testspark.helpers +package org.jetbrains.research.testspark.core.helpers import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys diff --git a/settings.gradle.kts b/settings.gradle.kts index 5436e759d..a5261e06c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -3,4 +3,4 @@ plugins { } rootProject.name = "TestSpark" include("JUnitRunner") -include("PromptGenerator") +include("core") diff --git a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt index e1a5535da..6c6130fd4 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/actions/TestSparkAction.kt @@ -5,8 +5,8 @@ import com.intellij.openapi.actionSystem.AnAction import com.intellij.openapi.actionSystem.AnActionEvent import org.jetbrains.research.testspark.actions.evosuite.EvoSuitePanelFactory import org.jetbrains.research.testspark.actions.llm.LLMPanelFactory +import org.jetbrains.research.testspark.core.helpers.getCurrentListOfCodeTypes import org.jetbrains.research.testspark.display.TestSparkIcons -import org.jetbrains.research.testspark.helpers.getCurrentListOfCodeTypes import org.jetbrains.research.testspark.tools.Manager import org.jetbrains.research.testspark.tools.evosuite.EvoSuite import org.jetbrains.research.testspark.tools.llm.Llm diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt index 29f96deeb..0b5f6a05b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt @@ -9,9 +9,9 @@ import com.intellij.openapi.progress.Task import com.intellij.openapi.roots.ProjectFileIndex import com.intellij.openapi.roots.ProjectRootManager import org.jetbrains.research.testspark.bundles.TestSparkBundle +import org.jetbrains.research.testspark.core.helpers.getSurroundingClass import org.jetbrains.research.testspark.data.DataFilesUtil import org.jetbrains.research.testspark.data.FragmentToTestData -import org.jetbrains.research.testspark.helpers.getSurroundingClass import org.jetbrains.research.testspark.services.ClearService import org.jetbrains.research.testspark.services.ProjectContextService import org.jetbrains.research.testspark.services.TestStorageProcessingService diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt index 7d78c2fcb..cbc55119c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/EvoSuite.kt @@ -9,11 +9,11 @@ import com.intellij.openapi.project.Project import com.intellij.openapi.roots.ProjectRootManager import com.intellij.psi.PsiFile import com.intellij.psi.PsiMethod +import org.jetbrains.research.testspark.core.helpers.generateMethodDescriptor +import org.jetbrains.research.testspark.core.helpers.getSurroundingLine +import org.jetbrains.research.testspark.core.helpers.getSurroundingMethod import org.jetbrains.research.testspark.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData -import org.jetbrains.research.testspark.helpers.generateMethodDescriptor -import org.jetbrains.research.testspark.helpers.getSurroundingLine -import org.jetbrains.research.testspark.helpers.getSurroundingMethod import org.jetbrains.research.testspark.services.SettingsProjectService import org.jetbrains.research.testspark.tools.Pipeline import org.jetbrains.research.testspark.tools.evosuite.generation.EvoSuiteProcessManager diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt index 423628fa6..4a3562177 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/Llm.kt @@ -8,12 +8,12 @@ import com.intellij.openapi.project.Project import com.intellij.psi.PsiClass import com.intellij.psi.PsiFile import com.intellij.psi.PsiMethod +import org.jetbrains.research.testspark.core.helpers.generateMethodDescriptor +import org.jetbrains.research.testspark.core.helpers.getSurroundingClass +import org.jetbrains.research.testspark.core.helpers.getSurroundingLine +import org.jetbrains.research.testspark.core.helpers.getSurroundingMethod import org.jetbrains.research.testspark.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData -import org.jetbrains.research.testspark.helpers.generateMethodDescriptor -import org.jetbrains.research.testspark.helpers.getSurroundingClass -import org.jetbrains.research.testspark.helpers.getSurroundingLine -import org.jetbrains.research.testspark.helpers.getSurroundingMethod import org.jetbrains.research.testspark.services.LLMChatService import org.jetbrains.research.testspark.tools.Pipeline import org.jetbrains.research.testspark.tools.llm.generation.LLMProcessManager diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt index 4ecf4e4da..0c624d736 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt @@ -13,9 +13,14 @@ import com.intellij.psi.search.GlobalSearchScope import com.intellij.psi.search.searches.ClassInheritorsSearch import com.intellij.psi.util.PsiTypesUtil import org.jetbrains.research.testspark.bundles.TestSparkBundle +import org.jetbrains.research.testspark.core.generation.importPattern +import org.jetbrains.research.testspark.core.generation.packagePattern +import org.jetbrains.research.testspark.core.generation.prompt.PromptGenerator +import org.jetbrains.research.testspark.core.generation.prompt.configuration.GenerationSettings +import org.jetbrains.research.testspark.core.generation.prompt.configuration.PromptTemplates import org.jetbrains.research.testspark.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData -import org.jetbrains.research.testspark.helpers.generateMethodDescriptor +import org.jetbrains.research.testspark.core.helpers.generateMethodDescriptor import org.jetbrains.research.testspark.services.PromptKeyword import org.jetbrains.research.testspark.services.SettingsApplicationService import org.jetbrains.research.testspark.services.TestGenerationDataService @@ -41,15 +46,30 @@ class PromptManager( private val llmErrorManager: LLMErrorManager = LLMErrorManager() fun generatePrompt(codeType: FragmentToTestData): String { + val promptGenerator = PromptGenerator( + project, + cut, + classesToTest, + GenerationSettings( + maxInputParamsDepth = SettingsArguments.maxInputParamsDepth(project) + ), + PromptTemplates( + classPrompt = settingsState.classPrompt, + methodPrompt = settingsState.methodPrompt, + linePrompt = settingsState.linePrompt, + ) + ) + val prompt = ApplicationManager.getApplication().runReadAction( Computable { when (codeType.type!!) { - CodeType.CLASS -> generatePromptForClass() - CodeType.METHOD -> generatePromptForMethod(codeType.objectDescription) - CodeType.LINE -> generatePromptForLine(codeType.objectIndex) + CodeType.CLASS -> promptGenerator.generatePromptForClass() + CodeType.METHOD -> promptGenerator.generatePromptForMethod(codeType.objectDescription) + CodeType.LINE -> promptGenerator.generatePromptForLine(codeType.objectIndex) } } ) + log.info("Prompt is:\n$prompt") return prompt } @@ -82,361 +102,4 @@ class PromptManager( project, ) } - - /** - * Generates a prompt for generating unit tests in Java for a given class. - * - * @return The generated prompt. - */ - private fun generatePromptForClass(): String { - var classPrompt = settingsState.classPrompt - val interestingPsiClasses = getInterestingPsiClasses(classesToTest) - - classPrompt = insertLanguage(classPrompt) - classPrompt = insertName(classPrompt, cut.qualifiedName!!) - classPrompt = insertTestingPlatform(classPrompt) - classPrompt = insertMockingFramework(classPrompt) - classPrompt = insertCodeUnderTest(classPrompt, getClassFullText(cut)) - classPrompt = insertMethodsSignatures(classPrompt, interestingPsiClasses) - classPrompt = - insertPolymorphismRelations(classPrompt, getPolymorphismRelations(project, interestingPsiClasses, cut)) - - return classPrompt - } - - /** - * Generates a prompt for a method. - * - * @param methodDescriptor The descriptor of the method. - * @return The generated prompt. - */ - private fun generatePromptForMethod(methodDescriptor: String): String { - var methodPrompt = settingsState.methodPrompt - val psiMethod = getPsiMethod(cut, methodDescriptor)!! - - methodPrompt = insertLanguage(methodPrompt) - methodPrompt = insertName(methodPrompt, "${cut.qualifiedName!!}.${psiMethod.name}") - methodPrompt = insertTestingPlatform(methodPrompt) - methodPrompt = insertMockingFramework(methodPrompt) - methodPrompt = insertCodeUnderTest(methodPrompt, psiMethod.text) - methodPrompt = insertMethodsSignatures(methodPrompt, getInterestingPsiClasses(psiMethod)) - methodPrompt = insertPolymorphismRelations( - methodPrompt, - getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut), - ) - - return methodPrompt - } - - /** - * Generates a prompt for a specific line number in the code. - * - * @param lineNumber the line number for which to generate the prompt - * @return the generated prompt string - */ - private fun generatePromptForLine(lineNumber: Int): String { - var linePrompt = settingsState.linePrompt - val methodDescriptor = getMethodDescriptor(cut, lineNumber) - val psiMethod = getPsiMethod(cut, methodDescriptor)!! - - // get code of line under test - val document = PsiDocumentManager.getInstance(project).getDocument(cut.containingFile) - val lineStartOffset = document!!.getLineStartOffset(lineNumber - 1) - val lineEndOffset = document.getLineEndOffset(lineNumber - 1) - val lineUnderTest = document.getText(TextRange.create(lineStartOffset, lineEndOffset)) - - linePrompt = insertLanguage(linePrompt) - linePrompt = insertName(linePrompt, lineUnderTest.trim()) - linePrompt = insertTestingPlatform(linePrompt) - linePrompt = insertMockingFramework(linePrompt) - linePrompt = insertCodeUnderTest(linePrompt, psiMethod.text) - linePrompt = insertMethodsSignatures(linePrompt, getInterestingPsiClasses(psiMethod)) - linePrompt = insertPolymorphismRelations( - linePrompt, - getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut), - ) - - return linePrompt - } - - private fun isPromptValid(keyword: PromptKeyword, prompt: String): Boolean { - val keywordText = keyword.text - val isMandatory = keyword.mandatory - - return (prompt.contains(keywordText) || !isMandatory) - } - - private fun insertLanguage(classPrompt: String): String { - if (isPromptValid(PromptKeyword.LANGUAGE, classPrompt)) { - val keyword = "\$${PromptKeyword.LANGUAGE.text}" - return classPrompt.replace(keyword, "Java", ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}") - } - } - - private fun insertName(classPrompt: String, classDisplayName: String): String { - if (isPromptValid(PromptKeyword.NAME, classPrompt)) { - val keyword = "\$${PromptKeyword.NAME.text}" - return classPrompt.replace(keyword, classDisplayName, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.NAME.text}") - } - } - - private fun insertTestingPlatform(classPrompt: String): String { - if (isPromptValid(PromptKeyword.TESTING_PLATFORM, classPrompt)) { - val keyword = "\$${PromptKeyword.TESTING_PLATFORM.text}" - return classPrompt.replace(keyword, "JUnit 4", ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.TESTING_PLATFORM.text}") - } - } - - private fun insertMockingFramework(classPrompt: String): String { - if (isPromptValid(PromptKeyword.MOCKING_FRAMEWORK, classPrompt)) { - val keyword = "\$${PromptKeyword.MOCKING_FRAMEWORK.text}" - return classPrompt.replace(keyword, "Mockito 5", ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.MOCKING_FRAMEWORK.text}") - } - } - - private fun insertCodeUnderTest(classPrompt: String, classFullText: String): String { - if (isPromptValid(PromptKeyword.CODE, classPrompt)) { - val keyword = "\$${PromptKeyword.CODE.text}" - var fullText = "```\n${classFullText}\n```\n" - - for (i in 2..classesToTest.size) { - val subClass = classesToTest[i - 2] - val superClass = classesToTest[i - 1] - - fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " + - "The source code of ${superClass.qualifiedName} is:\n```\n${getClassFullText(superClass)}\n" + - "```\n" - } - return classPrompt.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.CODE.text}") - } - } - - private fun insertMethodsSignatures(classPrompt: String, interestingPsiClasses: MutableSet): String { - val keyword = "\$${PromptKeyword.METHODS.text}" - - if (isPromptValid(PromptKeyword.METHODS, classPrompt)) { - var fullText = "" - for (interestingPsiClass: PsiClass in interestingPsiClasses) { - if (interestingPsiClass.qualifiedName!!.startsWith("java")) { - continue - } - - fullText += "=== methods in ${interestingPsiClass.qualifiedName!!}:\n" - for (currentPsiMethod in interestingPsiClass.allMethods) { - // Skip java methods - if (currentPsiMethod.containingClass!!.qualifiedName!!.startsWith("java")) { - continue - } - fullText += " - ${currentPsiMethod.getSignatureString()}\n" - } - } - return classPrompt.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.METHODS.text}") - } - } - - private fun PsiMethod.getSignatureString(): String { - val bodyStart = body?.startOffsetInParent ?: this.textLength - return text.substring(0, bodyStart).replace('\n', ' ').trim() - } - - private fun insertPolymorphismRelations( - classPrompt: String, - polymorphismRelations: MutableMap>, - ): String { - val keyword = "\$${PromptKeyword.POLYMORPHISM.text}" - if (isPromptValid(PromptKeyword.POLYMORPHISM, classPrompt)) { - var fullText = "" - - polymorphismRelations.forEach { entry -> - for (currentSubClass in entry.value) { - currentSubClass.qualifiedName ?: continue - fullText += "${currentSubClass.qualifiedName} is a sub-class of ${entry.key.qualifiedName}.\n" - } - } - return classPrompt.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.POLYMORPHISM.text}") - } - } - - /** - * Returns a set of interesting PsiClasses based on the given PsiMethod. - * - * @param psiMethod the PsiMethod for which to find interesting PsiClasses - * @return a mutable set of interesting PsiClasses - */ - private fun getInterestingPsiClasses(psiMethod: PsiMethod): MutableSet { - val interestingMethods = mutableSetOf(psiMethod) - for (currentPsiMethod in cut.allMethods) { - if (currentPsiMethod.isConstructor) interestingMethods.add(currentPsiMethod) - } - val interestingPsiClasses = mutableSetOf(cut) - interestingMethods.forEach { methodIt -> - methodIt.parameterList.parameters.forEach { paramIt -> - PsiTypesUtil.getPsiClass(paramIt.type)?.let { - if (it.qualifiedName != null && !it.qualifiedName!!.startsWith("java.")) { - interestingPsiClasses.add(it) - } - } - } - } - return interestingPsiClasses - } - - /** - * Retrieves a set of interesting PsiClasses based on a given cutPsiClass and a list of classesToTest. - * - * @param classesToTest The list of classes to test for interesting PsiClasses. - * @return The set of interesting PsiClasses found during the search. - */ - private fun getInterestingPsiClasses(classesToTest: MutableList): MutableSet { - val interestingPsiClasses: MutableSet = mutableSetOf() - - var currentLevelClasses = mutableListOf().apply { addAll(classesToTest) } - - repeat(SettingsArguments.maxInputParamsDepth(project)) { - val tempListOfClasses = mutableSetOf() - - currentLevelClasses.forEach { classIt -> - classIt.methods.forEach { methodIt -> - methodIt.parameterList.parameters.forEach { paramIt -> - PsiTypesUtil.getPsiClass(paramIt.type)?.let { - if (!interestingPsiClasses.contains(it) && it.qualifiedName != null && - !it.qualifiedName!!.startsWith("java.") - ) { - tempListOfClasses.add(it) - } - } - } - } - } - currentLevelClasses = mutableListOf().apply { addAll(tempListOfClasses) } - interestingPsiClasses.addAll(tempListOfClasses) - } - - return interestingPsiClasses - } - - /** - * Retrieves the polymorphism relations between a given set of interesting PsiClasses and a cut PsiClass. - * - * @param project The project context in which the PsiClasses exist. - * @param interestingPsiClasses The set of PsiClasses that are considered interesting. - * @param cutPsiClass The cut PsiClass to determine polymorphism relations against. - * @return A mutable map where the key represents an interesting PsiClass and the value is a list of its detected subclasses. - */ - private fun getPolymorphismRelations( - project: Project, - interestingPsiClasses: MutableSet, - cutPsiClass: PsiClass, - ): MutableMap> { - val polymorphismRelations: MutableMap> = mutableMapOf() - - val psiClassesToVisit: ArrayDeque = ArrayDeque(listOf(cutPsiClass)) - interestingPsiClasses.add(cutPsiClass) - - interestingPsiClasses.forEach { currentInterestingClass -> - val scope = GlobalSearchScope.projectScope(project) - val query = ClassInheritorsSearch.search(currentInterestingClass, scope, false) - val detectedSubClasses: Collection = query.findAll() - - detectedSubClasses.forEach { detectedSubClass -> - if (!polymorphismRelations.contains(currentInterestingClass)) { - polymorphismRelations[currentInterestingClass] = ArrayList() - } - polymorphismRelations[currentInterestingClass]?.add(detectedSubClass) - if (!psiClassesToVisit.contains(detectedSubClass)) { - psiClassesToVisit.addLast(detectedSubClass) - } - } - } - - return polymorphismRelations - } - - /** - * Retrieves a PsiMethod matching the given method descriptor within the provided PsiClass. - * - * @param psiClass The PsiClass in which to search for the method. - * @param methodDescriptor The method descriptor to match against. - * @return The matching PsiMethod if found, otherwise an empty string. - */ - private fun getPsiMethod(psiClass: PsiClass, methodDescriptor: String): PsiMethod? { - for (currentPsiMethod in psiClass.allMethods) { - if (generateMethodDescriptor(currentPsiMethod) == methodDescriptor) return currentPsiMethod - } - return null - } - - /** - * Returns the method descriptor of the method containing the given line number in the specified PsiClass. - * - * @param psiClass the PsiClass containing the method - * @param lineNumber the line number within the file where the method is located - * @return the method descriptor as a String, or an empty string if no method is found - */ - private fun getMethodDescriptor(psiClass: PsiClass, lineNumber: Int): String { - for (currentPsiMethod in psiClass.allMethods) { - if (isLineInPsiMethod(currentPsiMethod, lineNumber)) return generateMethodDescriptor(currentPsiMethod) - } - return "" - } - - /** - * Checks if the given line number is within the range of the specified PsiMethod. - * - * @param method The PsiMethod to check. - * @param lineNumber The line number to check. - * @return `true` if the line number is within the range of the method, `false` otherwise. - */ - private fun isLineInPsiMethod(method: PsiMethod, lineNumber: Int): Boolean { - val psiFile = method.containingFile ?: return false - val document = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return false - val textRange = method.textRange - val startLine = document.getLineNumber(textRange.startOffset) + 1 - val endLine = document.getLineNumber(textRange.endOffset) + 1 - return lineNumber in startLine..endLine - } - - /** - * Returns the full text of a given class including the package, imports, and class code. - * - * @param cl The PsiClass object representing the class. - * @return The full text of the class. - */ - private fun getClassFullText(cl: PsiClass): String { - var fullText = "" - val fileText = cl.containingFile.text - - // get package - packagePattern.findAll(fileText, 0).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n\n" - } - - // get imports - importPattern.findAll(fileText, 0).map { - it.groupValues[0] - }.forEach { - fullText += "$it\n" - } - - // Add class code - fullText += cl.text - - return fullText - } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/TestsAssembler.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/TestsAssembler.kt index 536584a27..05088fe4a 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/TestsAssembler.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/TestsAssembler.kt @@ -8,6 +8,8 @@ import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.project.Project import com.intellij.util.io.HttpRequests import org.jetbrains.research.testspark.bundles.TestSparkBundle +import org.jetbrains.research.testspark.core.generation.importPattern +import org.jetbrains.research.testspark.core.generation.runWithPattern import org.jetbrains.research.testspark.services.TestGenerationDataService import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIChoice import org.jetbrains.research.testspark.tools.llm.test.TestCaseGeneratedByLLM