Skip to content

Commit

Permalink
Merge pull request #118 from JetBrains-Research/nikolaisv/ICTL-851
Browse files Browse the repository at this point in the history
ICTL-851 Add support for JUnit 5 for LLM-based test generation
  • Loading branch information
arksap2002 authored Feb 21, 2024
2 parents 50580a9 + b826b74 commit e332956
Show file tree
Hide file tree
Showing 19 changed files with 213 additions and 33 deletions.
4 changes: 4 additions & 0 deletions JUnitRunner/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ repositories {

dependencies {
implementation("junit:junit:4.13")

implementation("org.junit.jupiter:junit-jupiter-api:5.10.0")
implementation("org.junit.platform:junit-platform-launcher:1.10.0")
implementation("org.junit.jupiter:junit-jupiter-engine:5.10.0")
}

tasks.test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import org.junit.runner.Result;
import org.junit.runner.notification.Failure;

public class SingleJUnitTestRunner {
public static void main(String... args) throws ClassNotFoundException {
String[] classAndMethod = args[0].split("#");
public class SingleJUnitTestRunner4 {
public static void main(String... args) throws ClassNotFoundException {
String[] classAndMethod = args[0].split("#");
Request request = Request.method(Class.forName(classAndMethod[0]),
classAndMethod[1]);

Expand All @@ -18,5 +18,6 @@ public static void main(String... args) throws ClassNotFoundException {
}

System.exit(result.wasSuccessful() ? 0 : 1);
}
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package org.jetbrains.research;

import org.junit.platform.engine.discovery.MethodSelector;
import org.junit.platform.launcher.Launcher;
import org.junit.platform.launcher.LauncherDiscoveryRequest;
import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder;
import org.junit.platform.launcher.core.LauncherFactory;
import org.junit.platform.launcher.listeners.SummaryGeneratingListener;
import org.junit.platform.launcher.listeners.TestExecutionSummary;

import java.util.Arrays;
import java.util.List;

import static org.junit.platform.engine.discovery.DiscoverySelectors.selectMethod;

public class SingleJUnitTestRunner5 {
public static void main(String... args) {
String classAndMethod = args[0];
MethodSelector methodSelector = selectMethod(classAndMethod);
LauncherDiscoveryRequest request =
LauncherDiscoveryRequestBuilder.request()
.selectors(methodSelector)
.build();

Launcher launcher = LauncherFactory.create();
SummaryGeneratingListener listener = new SummaryGeneratingListener();

launcher.registerTestExecutionListeners(listener);
launcher.execute(request);

TestExecutionSummary result = listener.getSummary();
List<TestExecutionSummary.Failure> failures = result.getFailures();
for (TestExecutionSummary.Failure failure : failures) {
failure.getException().printStackTrace(System.err);
System.err.println("\n ===");
}
System.exit(result.getTestsFailedCount() == 0 ? 0 : 1);

}
}
6 changes: 6 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ dependencies {
// validation dependencies
// https://mvnrepository.com/artifact/junit/junit
implementation("junit:junit:4.13")

// https://mvnrepository.com/artifact/org.junit.jupiter/junit-jupiter-api
implementation("org.junit.jupiter:junit-jupiter-api:5.10.0")
implementation("org.junit.platform:junit-platform-launcher:1.10.0")
implementation("org.junit.jupiter:junit-jupiter-engine:5.10.0")

// https://mvnrepository.com/artifact/org.jacoco/org.jacoco.core
implementation("org.jacoco:org.jacoco.core:0.8.8")
// https://mvnrepository.com/artifact/com.github.javaparser/javaparser-core
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,3 @@ val packagePattern =
pattern = "^package\\s+((?:[a-zA-Z_]\\w*\\.)*[a-zA-Z_](?:\\w*\\.?)*)(?:\\.\\*)?;",
options = setOf(RegexOption.MULTILINE),
)

val runWithPattern =
Regex(
pattern = "@RunWith\\([^)]*\\)",
options = setOf(RegexOption.MULTILINE),
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ package org.jetbrains.research.testspark.actions
import com.intellij.openapi.actionSystem.ActionUpdateThread
import com.intellij.openapi.actionSystem.AnAction
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.actionSystem.CommonDataKeys
import com.intellij.openapi.roots.LibraryOrderEntry
import com.intellij.openapi.roots.ModuleRootManager
import com.intellij.openapi.roots.ProjectRootManager
import org.jetbrains.research.testspark.actions.evosuite.EvoSuitePanelFactory
import org.jetbrains.research.testspark.actions.llm.LLMPanelFactory
import org.jetbrains.research.testspark.data.JUnitVersion
import org.jetbrains.research.testspark.display.TestSparkIcons
import org.jetbrains.research.testspark.helpers.getCurrentListOfCodeTypes
import org.jetbrains.research.testspark.tools.Manager
Expand Down Expand Up @@ -59,6 +64,7 @@ class TestSparkAction : AnAction() {
* @property e The AnActionEvent object.
*/
class TestSparkActionWindow(val e: AnActionEvent) : JFrame("TestSpark") {

private val llmButton = JRadioButton("<html><b>${Llm().name}</b></html>")
private val evoSuiteButton = JRadioButton("<html><b>${EvoSuite().name}</b></html>")
private val testGeneratorButtonGroup = ButtonGroup()
Expand All @@ -74,11 +80,12 @@ class TestSparkAction : AnAction() {
private val evoSuitePanelFactory = EvoSuitePanelFactory()

init {
val junit = findJUnitDependency(e)
val panel = JPanel(cardLayout)

panel.add(getMainPanel(), "1")
panel.add(llmPanelFactory.getPanel(), "2")
panel.add(evoSuitePanelFactory.getPanel(), "3")
panel.add(llmPanelFactory.getPanel(junit), "2")
panel.add(evoSuitePanelFactory.getPanel(junit), "3")

addListeners(panel)

Expand All @@ -94,6 +101,26 @@ class TestSparkAction : AnAction() {
isVisible = true
}

private fun findJUnitDependency(e: AnActionEvent): JUnitVersion? {
val project = e.project!!
val virtualFile = e.getData(CommonDataKeys.VIRTUAL_FILE_ARRAY)?.firstOrNull() ?: return null

val index = ProjectRootManager.getInstance(project).fileIndex
val module = index.getModuleForFile(virtualFile) ?: return null

for (orderEntry in ModuleRootManager.getInstance(module).orderEntries) {
if (orderEntry is LibraryOrderEntry) {
val libraryName = orderEntry.library?.name ?: continue
for (junit in JUnitVersion.values()) {
if (libraryName.contains(junit.groupId)) {
return junit
}
}
}
}
return null
}

/**
* Returns the main panel for the test generator UI.
* This panel contains options for selecting the test generator and the code type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.jetbrains.research.testspark.actions.template.ToolPanelFactory
import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
import org.jetbrains.research.testspark.bundles.TestSparkToolTipsBundle
import org.jetbrains.research.testspark.data.ContentDigestAlgorithm
import org.jetbrains.research.testspark.data.JUnitVersion
import org.jetbrains.research.testspark.services.SettingsApplicationService
import java.awt.Font
import javax.swing.JButton
Expand Down Expand Up @@ -41,7 +42,7 @@ class EvoSuitePanelFactory : ToolPanelFactory {
*
* @return the JPanel containing the EvoSuite setup GUI components
*/
override fun getPanel(): JPanel {
override fun getPanel(junit: JUnitVersion?): JPanel {
val textTitle = JLabel("EvoSuite Setup")
textTitle.font = Font("Monochrome", Font.BOLD, 20)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import com.intellij.util.ui.FormBuilder
import org.jetbrains.research.testspark.actions.template.ToolPanelFactory
import org.jetbrains.research.testspark.bundles.TestSparkDefaultsBundle
import org.jetbrains.research.testspark.bundles.TestSparkLabelsBundle
import org.jetbrains.research.testspark.data.JUnitVersion
import org.jetbrains.research.testspark.display.JUnitCombobox
import org.jetbrains.research.testspark.helpers.addLLMPanelListeners
import org.jetbrains.research.testspark.helpers.getLLLMPlatforms
import org.jetbrains.research.testspark.helpers.stylizeMainComponents
Expand All @@ -23,6 +25,7 @@ class LLMPanelFactory : ToolPanelFactory {
private var platformSelector = ComboBox(arrayOf(TestSparkDefaultsBundle.defaultValue("openAI")))
private val backLlmButton = JButton("Back")
private val okLlmButton = JButton("OK")
private val junitSelector = JUnitCombobox()

private val settingsState = SettingsApplicationService.getInstance().state!!

Expand Down Expand Up @@ -56,7 +59,9 @@ class LLMPanelFactory : ToolPanelFactory {
*
* @return The JPanel object representing the LLM setup panel.
*/
override fun getPanel(): JPanel {
override fun getPanel(junit: JUnitVersion?): JPanel {
junitSelector.detected = junit

val textTitle = JLabel("LLM Setup")
textTitle.font = Font("Monochrome", Font.BOLD, 20)

Expand Down Expand Up @@ -97,6 +102,12 @@ class LLMPanelFactory : ToolPanelFactory {
10,
false,
)
.addLabeledComponent(
JBLabel(TestSparkLabelsBundle.defaultValue("junitVersion")),
junitSelector,
10,
false,
)
.addComponentFillVertically(bottomButtons, 10)
.panel
}
Expand All @@ -115,5 +126,6 @@ class LLMPanelFactory : ToolPanelFactory {
settingsState.llmPlatforms[index].token = llmPlatforms[index].token
settingsState.llmPlatforms[index].model = llmPlatforms[index].model
}
settingsState.junitVersion = (junitSelector.selectedItem!! as JUnitVersion)
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package org.jetbrains.research.testspark.actions.template

import org.jetbrains.research.testspark.data.JUnitVersion
import javax.swing.JButton
import javax.swing.JPanel

interface ToolPanelFactory {
fun getPanel(): JPanel
fun getPanel(junit: JUnitVersion?): JPanel

fun getBackButton(): JButton

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.jetbrains.research.testspark.data

enum class JUnitVersion(
val groupId: String,
val version: Int,
val libJar: Set<String>,
val runWithAnnotationMeta: RunWithAnnotationMeta,
val showName: String = "JUnit $version",
) {
JUnit5(
"org.junit.jupiter",
5,
setOf(
"junit-jupiter-api-5.10.0.jar",
"junit-jupiter-engine-5.10.0.jar",
"junit-platform-commons-1.10.0.jar",
"junit-platform-engine-1.10.0.jar",
"junit-platform-launcher-1.10.0.jar",
),
RunWithAnnotationMeta("ExtendWith", "import org.junit.jupiter.api.extension.ExtendWith;"),
),
JUnit4(
"junit",
4,
setOf("junit-4.13.jar"),
RunWithAnnotationMeta("RunWith", "import org.junit.runner.RunWith;"),
),
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package org.jetbrains.research.testspark.data

data class RunWithAnnotationMeta(val annotationName: String, val import: String) {

val regex = annotationRegex(annotationName)

fun extract(line: String): String? {
val detectedRunWith = regex.find(line, startIndex = 0)?.groupValues?.get(0) ?: return null
return detectedRunWith
.split("@$annotationName(")[1]
.split(")")[0]
}

companion object {
private fun annotationRegex(annotationName: String) = Regex(
pattern = "@$annotationName\\([^)]*\\)",
options = setOf(RegexOption.MULTILINE),
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.jetbrains.research.testspark.display

import com.intellij.openapi.ui.ComboBox
import org.jetbrains.research.testspark.data.JUnitVersion
import java.awt.Component
import javax.swing.DefaultListCellRenderer
import javax.swing.JList

class JUnitCombobox : ComboBox<JUnitVersion>(JUnitVersion.values()) {

var detected: JUnitVersion? = null
set(value) {
field = value
value?.let {
this.selectedItem = value
}
}

init {
renderer = object : DefaultListCellRenderer() {
override fun getListCellRendererComponent(
list: JList<*>?,
value: Any?,
index: Int,
isSelected: Boolean,
cellHasFocus: Boolean,
): Component {
var v = value
if (value is JUnitVersion) {
v = value.showName
if (value == detected) {
v += " (Detected)"
}
}
return super.getListCellRendererComponent(list, v, index, isSelected, cellHasFocus)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import com.intellij.openapi.util.io.FileUtilRt
import org.jetbrains.research.testspark.data.DataFilesUtil
import org.jetbrains.research.testspark.data.TestCase
import org.jetbrains.research.testspark.tools.getBuildPath
import org.jetbrains.research.testspark.tools.llm.SettingsArguments
import org.jetbrains.research.testspark.tools.llm.test.TestCaseGeneratedByLLM
import java.io.File
import java.util.UUID
Expand Down Expand Up @@ -42,13 +43,14 @@ class TestStorageProcessingService(private val project: Project) {
*/
private fun getPath(buildPath: String): String {
// create the path for the command
val junitPath = getLibrary("junit-4.13.jar")
val junitVersion = SettingsArguments.settingsState!!.junitVersion
val separator = DataFilesUtil.classpathSeparator
val junitPath = junitVersion.libJar.joinToString(separator.toString()) { getLibrary(it) }
val mockitoPath = getLibrary("mockito-core-5.0.0.jar")
val hamcrestPath = getLibrary("hamcrest-core-1.3.jar")
val byteBuddy = getLibrary("byte-buddy-1.14.6.jar")
val byteBuddyAgent = getLibrary("byte-buddy-agent-1.14.6.jar")
val sep = DataFilesUtil.classpathSeparator
return "$junitPath${sep}$hamcrestPath${sep}$mockitoPath${sep}$byteBuddy${sep}$byteBuddyAgent${sep}$buildPath"
return "$junitPath${separator}$hamcrestPath${separator}$mockitoPath${separator}$byteBuddy${separator}$byteBuddyAgent${separator}$buildPath"
}

/**
Expand Down Expand Up @@ -176,14 +178,16 @@ class TestStorageProcessingService(private val project: Project) {
var name = if (generatedTestPackage.isEmpty()) "" else "$generatedTestPackage."
name += "$className#$testCaseName"

val junitVersion = SettingsArguments.settingsState!!.junitVersion.version

// run the test method with jacoco agent
val testExecutionError = project.service<RunCommandLineService>().runCommandLine(
arrayListOf(
javaRunner.absolutePath,
"-javaagent:$jacocoAgentDir=destfile=$dataFileName.exec,append=false,includes=${project.service<ProjectContextService>().classFQN}",
"-cp",
"${getPath(projectBuildPath)}${getLibrary("JUnitRunner.jar")}${DataFilesUtil.classpathSeparator}$resultPath",
"org.jetbrains.research.SingleJUnitTestRunner",
"org.jetbrains.research.SingleJUnitTestRunner$junitVersion",
name,
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.jetbrains.research.testspark.settings

import org.jetbrains.research.testspark.bundles.TestSparkDefaultsBundle
import org.jetbrains.research.testspark.data.ContentDigestAlgorithm
import org.jetbrains.research.testspark.data.JUnitVersion
import org.jetbrains.research.testspark.tools.llm.generation.LLMPlatform
import org.jetbrains.research.testspark.tools.llm.generation.grazie.GraziePlatform
import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIPlatform
Expand Down Expand Up @@ -35,6 +36,7 @@ data class SettingsApplicationState(
var classPrompt: String = DefaultSettingsApplicationState.classPrompt,
var methodPrompt: String = DefaultSettingsApplicationState.methodPrompt,
var linePrompt: String = DefaultSettingsApplicationState.linePrompt,
var junitVersion: JUnitVersion = DefaultSettingsApplicationState.junitVersion,
) {

/**
Expand Down Expand Up @@ -66,6 +68,7 @@ data class SettingsApplicationState(
val classPrompt: String = TestSparkDefaultsBundle.defaultValue("classPrompt")
val methodPrompt: String = TestSparkDefaultsBundle.defaultValue("methodPrompt")
val linePrompt: String = TestSparkDefaultsBundle.defaultValue("linePrompt")
val junitVersion: JUnitVersion = JUnitVersion.JUnit4
}

// TODO remove from here
Expand Down
Loading

0 comments on commit e332956

Please sign in to comment.