From 79c723e787aa6c7ebe19dab0c78af10163c3b3bc Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Fri, 16 Feb 2024 16:51:14 +0100 Subject: [PATCH] Apply formatting to created files --- .../core/generation/prompt/PromptGenerator.kt | 72 ++++++++++++------- .../core/generation/prompt/PromptKeyword.kt | 3 +- .../configuration/GenerationSettings.kt | 2 +- .../prompt/configuration/PromptTemplates.kt | 2 +- .../tools/llm/generation/PromptManager.kt | 54 ++++++-------- 5 files changed, 72 insertions(+), 61 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptGenerator.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptGenerator.kt index 690bf27e1..f628ff306 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptGenerator.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptGenerator.kt @@ -2,7 +2,9 @@ package org.jetbrains.research.testspark.core.generation.prompt import com.intellij.openapi.project.Project import com.intellij.openapi.util.TextRange -import com.intellij.psi.* +import com.intellij.psi.PsiClass +import com.intellij.psi.PsiDocumentManager +import com.intellij.psi.PsiMethod import com.intellij.psi.search.GlobalSearchScope import com.intellij.psi.search.searches.ClassInheritorsSearch import com.intellij.psi.util.PsiTypesUtil @@ -12,15 +14,13 @@ import org.jetbrains.research.testspark.core.generation.prompt.configuration.Gen import org.jetbrains.research.testspark.core.generation.prompt.configuration.PromptTemplates import org.jetbrains.research.testspark.core.helpers.generateMethodDescriptor - class PromptGenerator( private val project: Project, private val cut: PsiClass, private val classesToTest: MutableList, private val settings: GenerationSettings, - private val promptTemplates: PromptTemplates + private val promptTemplates: PromptTemplates, ) { - /** * Generates a prompt for generating unit tests in Java for a given class. * @@ -58,15 +58,15 @@ class PromptGenerator( methodPrompt = insertMockingFramework(methodPrompt) methodPrompt = insertCodeUnderTest(methodPrompt, psiMethod.text) methodPrompt = insertMethodsSignatures(methodPrompt, getInterestingPsiClasses(psiMethod)) - methodPrompt = insertPolymorphismRelations( - methodPrompt, - getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut), - ) + methodPrompt = + insertPolymorphismRelations( + methodPrompt, + getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut), + ) return methodPrompt } - /** * Generates a prompt for a specific line number in the code. * @@ -90,15 +90,15 @@ class PromptGenerator( linePrompt = insertMockingFramework(linePrompt) linePrompt = insertCodeUnderTest(linePrompt, psiMethod.text) linePrompt = insertMethodsSignatures(linePrompt, getInterestingPsiClasses(psiMethod)) - linePrompt = insertPolymorphismRelations( - linePrompt, - getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut), - ) + linePrompt = + insertPolymorphismRelations( + linePrompt, + getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut), + ) return linePrompt } - /** * Returns the method descriptor of the method containing the given line number in the specified PsiClass. * @@ -106,7 +106,10 @@ class PromptGenerator( * @param lineNumber the line number within the file where the method is located * @return the method descriptor as a String, or an empty string if no method is found */ - private fun getMethodDescriptor(psiClass: PsiClass, lineNumber: Int): String { + private fun getMethodDescriptor( + psiClass: PsiClass, + lineNumber: Int, + ): String { for (currentPsiMethod in psiClass.allMethods) { if (isLineInPsiMethod(currentPsiMethod, lineNumber)) return generateMethodDescriptor(currentPsiMethod) } @@ -120,7 +123,10 @@ class PromptGenerator( * @param lineNumber The line number to check. * @return `true` if the line number is within the range of the method, `false` otherwise. */ - private fun isLineInPsiMethod(method: PsiMethod, lineNumber: Int): Boolean { + private fun isLineInPsiMethod( + method: PsiMethod, + lineNumber: Int, + ): Boolean { val psiFile = method.containingFile ?: return false val document = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return false val textRange = method.textRange @@ -136,7 +142,10 @@ class PromptGenerator( * @param methodDescriptor The method descriptor to match against. * @return The matching PsiMethod if found, otherwise an empty string. */ - private fun getPsiMethod(psiClass: PsiClass, methodDescriptor: String): PsiMethod? { + private fun getPsiMethod( + psiClass: PsiClass, + methodDescriptor: String, + ): PsiMethod? { for (currentPsiMethod in psiClass.allMethods) { if (generateMethodDescriptor(currentPsiMethod) == methodDescriptor) return currentPsiMethod } @@ -154,6 +163,7 @@ class PromptGenerator( for (currentPsiMethod in cut.allMethods) { if (currentPsiMethod.isConstructor) interestingMethods.add(currentPsiMethod) } + val interestingPsiClasses = mutableSetOf(cut) interestingMethods.forEach { methodIt -> methodIt.parameterList.parameters.forEach { paramIt -> @@ -201,15 +211,16 @@ class PromptGenerator( return interestingPsiClasses } - - private fun isPromptValid(keyword: PromptKeyword, prompt: String): Boolean { + private fun isPromptValid( + keyword: PromptKeyword, + prompt: String, + ): Boolean { val keywordText = keyword.text val isMandatory = keyword.mandatory return (prompt.contains(keywordText) || !isMandatory) } - private fun insertLanguage(classPrompt: String): String { if (isPromptValid(PromptKeyword.LANGUAGE, classPrompt)) { val keyword = "\$${PromptKeyword.LANGUAGE.text}" @@ -219,7 +230,10 @@ class PromptGenerator( } } - private fun insertName(classPrompt: String, classDisplayName: String): String { + private fun insertName( + classPrompt: String, + classDisplayName: String, + ): String { if (isPromptValid(PromptKeyword.NAME, classPrompt)) { val keyword = "\$${PromptKeyword.NAME.text}" return classPrompt.replace(keyword, classDisplayName, ignoreCase = false) @@ -246,7 +260,10 @@ class PromptGenerator( } } - private fun insertCodeUnderTest(classPrompt: String, classFullText: String): String { + private fun insertCodeUnderTest( + classPrompt: String, + classFullText: String, + ): String { if (isPromptValid(PromptKeyword.CODE, classPrompt)) { val keyword = "\$${PromptKeyword.CODE.text}" var fullText = "```\n${classFullText}\n```\n" @@ -256,8 +273,8 @@ class PromptGenerator( val superClass = classesToTest[i - 1] fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " + - "The source code of ${superClass.qualifiedName} is:\n```\n${getClassFullText(superClass)}\n" + - "```\n" + "The source code of ${superClass.qualifiedName} is:\n```\n${getClassFullText(superClass)}\n" + + "```\n" } return classPrompt.replace(keyword, fullText, ignoreCase = false) } else { @@ -265,7 +282,10 @@ class PromptGenerator( } } - private fun insertMethodsSignatures(classPrompt: String, interestingPsiClasses: MutableSet): String { + private fun insertMethodsSignatures( + classPrompt: String, + interestingPsiClasses: MutableSet, + ): String { val keyword = "\$${PromptKeyword.METHODS.text}" if (isPromptValid(PromptKeyword.METHODS, classPrompt)) { @@ -381,4 +401,4 @@ class PromptGenerator( return fullText } -} \ No newline at end of file +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptKeyword.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptKeyword.kt index 81f025bc6..5055060ba 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptKeyword.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/PromptKeyword.kt @@ -1,6 +1,5 @@ package org.jetbrains.research.testspark.core.generation.prompt - enum class PromptKeyword(val text: String, val description: String, val mandatory: Boolean) { NAME("NAME", "The name of the code under test (Class name, method name, line number)", true), CODE("CODE", "The code under test (Class, method, or line)", true), @@ -29,4 +28,4 @@ enum class PromptKeyword(val text: String, val description: String, val mandator val endOffset = startOffset + textToHighlight.length return Pair(startOffset, endOffset) } -} \ No newline at end of file +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/GenerationSettings.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/GenerationSettings.kt index 157c0c4ac..f04a8a37d 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/GenerationSettings.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/GenerationSettings.kt @@ -1,5 +1,5 @@ package org.jetbrains.research.testspark.core.generation.prompt.configuration data class GenerationSettings( - val maxInputParamsDepth: Int + val maxInputParamsDepth: Int, ) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/PromptTemplates.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/PromptTemplates.kt index 47cf21081..4bc38f431 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/PromptTemplates.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/prompt/configuration/PromptTemplates.kt @@ -4,4 +4,4 @@ data class PromptTemplates( val classPrompt: String, val methodPrompt: String, val linePrompt: String, -) \ No newline at end of file +) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt index 0c624d736..d8cf013e0 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/PromptManager.kt @@ -5,23 +5,13 @@ import com.intellij.openapi.components.service import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import com.intellij.openapi.util.Computable -import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiClass -import com.intellij.psi.PsiDocumentManager -import com.intellij.psi.PsiMethod -import com.intellij.psi.search.GlobalSearchScope -import com.intellij.psi.search.searches.ClassInheritorsSearch -import com.intellij.psi.util.PsiTypesUtil import org.jetbrains.research.testspark.bundles.TestSparkBundle -import org.jetbrains.research.testspark.core.generation.importPattern -import org.jetbrains.research.testspark.core.generation.packagePattern import org.jetbrains.research.testspark.core.generation.prompt.PromptGenerator import org.jetbrains.research.testspark.core.generation.prompt.configuration.GenerationSettings import org.jetbrains.research.testspark.core.generation.prompt.configuration.PromptTemplates import org.jetbrains.research.testspark.data.CodeType import org.jetbrains.research.testspark.data.FragmentToTestData -import org.jetbrains.research.testspark.core.helpers.generateMethodDescriptor -import org.jetbrains.research.testspark.services.PromptKeyword import org.jetbrains.research.testspark.services.SettingsApplicationService import org.jetbrains.research.testspark.services.TestGenerationDataService import org.jetbrains.research.testspark.settings.SettingsApplicationState @@ -46,29 +36,31 @@ class PromptManager( private val llmErrorManager: LLMErrorManager = LLMErrorManager() fun generatePrompt(codeType: FragmentToTestData): String { - val promptGenerator = PromptGenerator( - project, - cut, - classesToTest, - GenerationSettings( - maxInputParamsDepth = SettingsArguments.maxInputParamsDepth(project) - ), - PromptTemplates( - classPrompt = settingsState.classPrompt, - methodPrompt = settingsState.methodPrompt, - linePrompt = settingsState.linePrompt, + val promptGenerator = + PromptGenerator( + project, + cut, + classesToTest, + GenerationSettings( + maxInputParamsDepth = SettingsArguments.maxInputParamsDepth(project), + ), + PromptTemplates( + classPrompt = settingsState.classPrompt, + methodPrompt = settingsState.methodPrompt, + linePrompt = settingsState.linePrompt, + ), ) - ) - val prompt = ApplicationManager.getApplication().runReadAction( - Computable { - when (codeType.type!!) { - CodeType.CLASS -> promptGenerator.generatePromptForClass() - CodeType.METHOD -> promptGenerator.generatePromptForMethod(codeType.objectDescription) - CodeType.LINE -> promptGenerator.generatePromptForLine(codeType.objectIndex) - } - } - ) + val prompt = + ApplicationManager.getApplication().runReadAction( + Computable { + when (codeType.type!!) { + CodeType.CLASS -> promptGenerator.generatePromptForClass() + CodeType.METHOD -> promptGenerator.generatePromptForMethod(codeType.objectDescription) + CodeType.LINE -> promptGenerator.generatePromptForLine(codeType.objectIndex) + } + }, + ) log.info("Prompt is:\n$prompt") return prompt