Skip to content

Commit

Permalink
Apply formatting to created files
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladislav0Art committed Feb 16, 2024
1 parent a3b85c2 commit 79c723e
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package org.jetbrains.research.testspark.core.generation.prompt

import com.intellij.openapi.project.Project
import com.intellij.openapi.util.TextRange
import com.intellij.psi.*
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiDocumentManager
import com.intellij.psi.PsiMethod
import com.intellij.psi.search.GlobalSearchScope
import com.intellij.psi.search.searches.ClassInheritorsSearch
import com.intellij.psi.util.PsiTypesUtil
Expand All @@ -12,15 +14,13 @@ import org.jetbrains.research.testspark.core.generation.prompt.configuration.Gen
import org.jetbrains.research.testspark.core.generation.prompt.configuration.PromptTemplates
import org.jetbrains.research.testspark.core.helpers.generateMethodDescriptor


class PromptGenerator(
private val project: Project,
private val cut: PsiClass,
private val classesToTest: MutableList<PsiClass>,
private val settings: GenerationSettings,
private val promptTemplates: PromptTemplates
private val promptTemplates: PromptTemplates,
) {

/**
* Generates a prompt for generating unit tests in Java for a given class.
*
Expand Down Expand Up @@ -58,15 +58,15 @@ class PromptGenerator(
methodPrompt = insertMockingFramework(methodPrompt)
methodPrompt = insertCodeUnderTest(methodPrompt, psiMethod.text)
methodPrompt = insertMethodsSignatures(methodPrompt, getInterestingPsiClasses(psiMethod))
methodPrompt = insertPolymorphismRelations(
methodPrompt,
getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut),
)
methodPrompt =
insertPolymorphismRelations(
methodPrompt,
getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut),
)

return methodPrompt
}


/**
* Generates a prompt for a specific line number in the code.
*
Expand All @@ -90,23 +90,26 @@ class PromptGenerator(
linePrompt = insertMockingFramework(linePrompt)
linePrompt = insertCodeUnderTest(linePrompt, psiMethod.text)
linePrompt = insertMethodsSignatures(linePrompt, getInterestingPsiClasses(psiMethod))
linePrompt = insertPolymorphismRelations(
linePrompt,
getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut),
)
linePrompt =
insertPolymorphismRelations(
linePrompt,
getPolymorphismRelations(project, getInterestingPsiClasses(classesToTest), cut),
)

return linePrompt
}


/**
* Returns the method descriptor of the method containing the given line number in the specified PsiClass.
*
* @param psiClass the PsiClass containing the method
* @param lineNumber the line number within the file where the method is located
* @return the method descriptor as a String, or an empty string if no method is found
*/
private fun getMethodDescriptor(psiClass: PsiClass, lineNumber: Int): String {
private fun getMethodDescriptor(
psiClass: PsiClass,
lineNumber: Int,
): String {
for (currentPsiMethod in psiClass.allMethods) {
if (isLineInPsiMethod(currentPsiMethod, lineNumber)) return generateMethodDescriptor(currentPsiMethod)
}
Expand All @@ -120,7 +123,10 @@ class PromptGenerator(
* @param lineNumber The line number to check.
* @return `true` if the line number is within the range of the method, `false` otherwise.
*/
private fun isLineInPsiMethod(method: PsiMethod, lineNumber: Int): Boolean {
private fun isLineInPsiMethod(
method: PsiMethod,
lineNumber: Int,
): Boolean {
val psiFile = method.containingFile ?: return false
val document = PsiDocumentManager.getInstance(psiFile.project).getDocument(psiFile) ?: return false
val textRange = method.textRange
Expand All @@ -136,7 +142,10 @@ class PromptGenerator(
* @param methodDescriptor The method descriptor to match against.
* @return The matching PsiMethod if found, otherwise an empty string.
*/
private fun getPsiMethod(psiClass: PsiClass, methodDescriptor: String): PsiMethod? {
private fun getPsiMethod(
psiClass: PsiClass,
methodDescriptor: String,
): PsiMethod? {
for (currentPsiMethod in psiClass.allMethods) {
if (generateMethodDescriptor(currentPsiMethod) == methodDescriptor) return currentPsiMethod
}
Expand All @@ -154,6 +163,7 @@ class PromptGenerator(
for (currentPsiMethod in cut.allMethods) {
if (currentPsiMethod.isConstructor) interestingMethods.add(currentPsiMethod)
}

val interestingPsiClasses = mutableSetOf(cut)
interestingMethods.forEach { methodIt ->
methodIt.parameterList.parameters.forEach { paramIt ->
Expand Down Expand Up @@ -201,15 +211,16 @@ class PromptGenerator(
return interestingPsiClasses
}


private fun isPromptValid(keyword: PromptKeyword, prompt: String): Boolean {
private fun isPromptValid(
keyword: PromptKeyword,
prompt: String,
): Boolean {
val keywordText = keyword.text
val isMandatory = keyword.mandatory

return (prompt.contains(keywordText) || !isMandatory)
}


private fun insertLanguage(classPrompt: String): String {
if (isPromptValid(PromptKeyword.LANGUAGE, classPrompt)) {
val keyword = "\$${PromptKeyword.LANGUAGE.text}"
Expand All @@ -219,7 +230,10 @@ class PromptGenerator(
}
}

private fun insertName(classPrompt: String, classDisplayName: String): String {
private fun insertName(
classPrompt: String,
classDisplayName: String,
): String {
if (isPromptValid(PromptKeyword.NAME, classPrompt)) {
val keyword = "\$${PromptKeyword.NAME.text}"
return classPrompt.replace(keyword, classDisplayName, ignoreCase = false)
Expand All @@ -246,7 +260,10 @@ class PromptGenerator(
}
}

private fun insertCodeUnderTest(classPrompt: String, classFullText: String): String {
private fun insertCodeUnderTest(
classPrompt: String,
classFullText: String,
): String {
if (isPromptValid(PromptKeyword.CODE, classPrompt)) {
val keyword = "\$${PromptKeyword.CODE.text}"
var fullText = "```\n${classFullText}\n```\n"
Expand All @@ -256,16 +273,19 @@ class PromptGenerator(
val superClass = classesToTest[i - 1]

fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " +
"The source code of ${superClass.qualifiedName} is:\n```\n${getClassFullText(superClass)}\n" +
"```\n"
"The source code of ${superClass.qualifiedName} is:\n```\n${getClassFullText(superClass)}\n" +
"```\n"
}
return classPrompt.replace(keyword, fullText, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.CODE.text}")
}
}

private fun insertMethodsSignatures(classPrompt: String, interestingPsiClasses: MutableSet<PsiClass>): String {
private fun insertMethodsSignatures(
classPrompt: String,
interestingPsiClasses: MutableSet<PsiClass>,
): String {
val keyword = "\$${PromptKeyword.METHODS.text}"

if (isPromptValid(PromptKeyword.METHODS, classPrompt)) {
Expand Down Expand Up @@ -381,4 +401,4 @@ class PromptGenerator(

return fullText
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.jetbrains.research.testspark.core.generation.prompt


enum class PromptKeyword(val text: String, val description: String, val mandatory: Boolean) {
NAME("NAME", "The name of the code under test (Class name, method name, line number)", true),
CODE("CODE", "The code under test (Class, method, or line)", true),
Expand Down Expand Up @@ -29,4 +28,4 @@ enum class PromptKeyword(val text: String, val description: String, val mandator
val endOffset = startOffset + textToHighlight.length
return Pair(startOffset, endOffset)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package org.jetbrains.research.testspark.core.generation.prompt.configuration

data class GenerationSettings(
val maxInputParamsDepth: Int
val maxInputParamsDepth: Int,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ data class PromptTemplates(
val classPrompt: String,
val methodPrompt: String,
val linePrompt: String,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,13 @@ import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.Computable
import com.intellij.openapi.util.TextRange
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiDocumentManager
import com.intellij.psi.PsiMethod
import com.intellij.psi.search.GlobalSearchScope
import com.intellij.psi.search.searches.ClassInheritorsSearch
import com.intellij.psi.util.PsiTypesUtil
import org.jetbrains.research.testspark.bundles.TestSparkBundle
import org.jetbrains.research.testspark.core.generation.importPattern
import org.jetbrains.research.testspark.core.generation.packagePattern
import org.jetbrains.research.testspark.core.generation.prompt.PromptGenerator
import org.jetbrains.research.testspark.core.generation.prompt.configuration.GenerationSettings
import org.jetbrains.research.testspark.core.generation.prompt.configuration.PromptTemplates
import org.jetbrains.research.testspark.data.CodeType
import org.jetbrains.research.testspark.data.FragmentToTestData
import org.jetbrains.research.testspark.core.helpers.generateMethodDescriptor
import org.jetbrains.research.testspark.services.PromptKeyword
import org.jetbrains.research.testspark.services.SettingsApplicationService
import org.jetbrains.research.testspark.services.TestGenerationDataService
import org.jetbrains.research.testspark.settings.SettingsApplicationState
Expand All @@ -46,29 +36,31 @@ class PromptManager(
private val llmErrorManager: LLMErrorManager = LLMErrorManager()

fun generatePrompt(codeType: FragmentToTestData): String {
val promptGenerator = PromptGenerator(
project,
cut,
classesToTest,
GenerationSettings(
maxInputParamsDepth = SettingsArguments.maxInputParamsDepth(project)
),
PromptTemplates(
classPrompt = settingsState.classPrompt,
methodPrompt = settingsState.methodPrompt,
linePrompt = settingsState.linePrompt,
val promptGenerator =
PromptGenerator(
project,
cut,
classesToTest,
GenerationSettings(
maxInputParamsDepth = SettingsArguments.maxInputParamsDepth(project),
),
PromptTemplates(
classPrompt = settingsState.classPrompt,
methodPrompt = settingsState.methodPrompt,
linePrompt = settingsState.linePrompt,
),
)
)

val prompt = ApplicationManager.getApplication().runReadAction(
Computable {
when (codeType.type!!) {
CodeType.CLASS -> promptGenerator.generatePromptForClass()
CodeType.METHOD -> promptGenerator.generatePromptForMethod(codeType.objectDescription)
CodeType.LINE -> promptGenerator.generatePromptForLine(codeType.objectIndex)
}
}
)
val prompt =
ApplicationManager.getApplication().runReadAction(
Computable {
when (codeType.type!!) {
CodeType.CLASS -> promptGenerator.generatePromptForClass()
CodeType.METHOD -> promptGenerator.generatePromptForMethod(codeType.objectDescription)
CodeType.LINE -> promptGenerator.generatePromptForLine(codeType.objectIndex)
}
},
)

log.info("Prompt is:\n$prompt")
return prompt
Expand Down

0 comments on commit 79c723e

Please sign in to comment.