Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Roborazzi AI-Powered Image Verification #491

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ kotlin.incremental.native=true
# To debug
roborazzi.test.record=true
#roborazzi.test.verify=true
#roborazzi.test.compare=true
5 changes: 5 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mavenPublish = "0.25.3"
composeCompiler = "1.5.10"
composeMultiplatform = "1.6.2"
robolectric = "4.12.2"
generativeaiGoogle = "0.9.0-1.0.1"
robolectric-android-all = "Q-robolectric-5415296"

roborazzi-for-replacing-by-include-build = "1.0.0"
Expand Down Expand Up @@ -37,6 +38,7 @@ google-android-material = "1.5.0"
junit = "4.13.2"
ktor-serialization-kotlinx-xml = "2.3.11"
kotlinx-serialization = "1.6.3"
kotlinx-coroutines = "1.6.0"
squareup-okhttp = "5.0.0-alpha.11"
kotlinx-io = "0.3.3"
webjar-material-design-icons = "4.0.0"
Expand All @@ -55,6 +57,8 @@ kotlin-stdlib-jdk8 = { module = "org.jetbrains.kotlin:kotlin-stdlib-jdk8" }
kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test" }
kotlin-test-junit = { module = "org.jetbrains.kotlin:kotlin-test-junit" }

generativeai-google = { module = "dev.shreyaspatil.generativeai:generativeai-google", version.ref = "generativeaiGoogle" }

# for sample
composable-preview-scanner = { module = "io.github.sergio-sastre.ComposablePreviewScanner:android", version.ref = "composable-preview-scanner" }

Expand Down Expand Up @@ -90,6 +94,7 @@ dropbox-differ = { module = "com.dropbox.differ:differ", version.ref = "dropbox-
google-android-material = { module = "com.google.android.material:material", version.ref = "google-android-material" }
junit = { module = "junit:junit", version.ref = "junit" }
kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-serialization" }
kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" }
ktor-serialization-kotlinx-xml = { module = "io.ktor:ktor-serialization-kotlinx-xml", version.ref = "ktor-serialization-kotlinx-xml" }
squareup-okhttp = { module = "com.squareup.okhttp3:okhttp", version.ref = "squareup-okhttp" }
squareup-okhttp-coroutines = { module = "com.squareup.okhttp3:okhttp-coroutines", version.ref = "squareup-okhttp" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ data class RoborazziOptions(
val outputDirectoryPath: String = roborazziSystemPropertyOutputDirectory(),
val imageComparator: ImageComparator = DefaultImageComparator,
val comparisonStyle: ComparisonStyle = ComparisonStyle.Grid(),
val aiCompareOptions: AiCompareOptions? = null,
val resultValidator: (result: ImageComparator.ComparisonResult) -> Boolean = DefaultResultValidator,
) {

@ExperimentalRoborazziApi
sealed interface ComparisonStyle {
@ExperimentalRoborazziApi
Expand Down Expand Up @@ -151,6 +153,38 @@ data class RoborazziOptions(
} else {
JsonOutputCaptureResultReporter().report(captureResult, roborazziTaskType)
}
AiCaptureResultReporter().report(captureResult, roborazziTaskType)
}
}

class AiCaptureResultReporter : CaptureResultReporter {
override fun report(captureResult: CaptureResult, roborazziTaskType: RoborazziTaskType) {
val aiResult = when (captureResult) {
is CaptureResult.Changed -> {
captureResult.aiComparisonResult
}

is CaptureResult.Added -> {
captureResult.aiComparisonResult
}

else -> {
null
}
}
aiResult?.aiConditionResults
?.filter { conditionResult -> conditionResult.requiredFulfillmentPercent != null }
?.forEach { conditionResult ->
if (conditionResult.fulfillmentPercent < conditionResult.requiredFulfillmentPercent!!) {
throw AssertionError(
"The generated image did not meet the required prompt fulfillment percentage.\n" +
"prompt:${conditionResult.assertPrompt}\n" +
"aiAssertion.fulfillmentPercent:${conditionResult.fulfillmentPercent}\n" +
"requiredFulfillmentPercent:${conditionResult.requiredFulfillmentPercent}\n" +
"explanation:${conditionResult.explanation}"
)
}
}
}
}

Expand Down Expand Up @@ -210,6 +244,32 @@ data class RoborazziOptions(
}

internal val shouldTakeBitmap: Boolean = captureType.shouldTakeScreenshot()

fun addedCompareAiAssertion(
assert: String,
requiredFulfillmentPercent: Int
): RoborazziOptions {
return copy(
compareOptions = compareOptions.copy(
aiCompareOptions = compareOptions.aiCompareOptions!!.copy(
aiConditions = compareOptions.aiCompareOptions.aiConditions + AiCompareOptions.AiCondition(
assertPrompt = assert,
requiredFulfillmentPercent = requiredFulfillmentPercent
)
)
)
)
}

fun addedCompareAiAssertions(vararg assertions: AiCompareOptions.AiCondition): RoborazziOptions {
return copy(
compareOptions = compareOptions.copy(
aiCompareOptions = compareOptions.aiCompareOptions!!.copy(
aiConditions = compareOptions.aiCompareOptions.aiConditions + assertions
)
)
)
}
}

expect fun canScreenshot(): Boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,16 @@ fun processOutputImageAndReport(
// Only used by CaptureResult.Changed
var diffPercentage: Float? = null

val compareOptions = roborazziOptions.compareOptions
val changed = if (height == goldenRoboCanvas.height && width == goldenRoboCanvas.width) {
val comparisonResult: ImageComparator.ComparisonResult =
newRoboCanvas.differ(
other = goldenRoboCanvas,
resizeScale = resizeScale,
imageComparator = roborazziOptions.compareOptions.imageComparator
imageComparator = compareOptions.imageComparator
)
diffPercentage = comparisonResult.pixelDifferences.toFloat() / comparisonResult.pixelCount
val changed = !roborazziOptions.compareOptions.resultValidator(comparisonResult)
val changed = !compareOptions.resultValidator(comparisonResult)
reportLog("${goldenFile.name} The differ result :$comparisonResult changed:$changed")
changed
} else {
Expand All @@ -106,7 +107,7 @@ fun processOutputImageAndReport(

val result: CaptureResult = if (changed) {
val comparisonFile = File(
roborazziOptions.compareOptions.outputDirectoryPath,
compareOptions.outputDirectoryPath,
goldenFile.nameWithoutExtension + "_compare." + goldenFile.extension
)
val comparisonCanvas = comparisonCanvasFactory(
Expand All @@ -121,6 +122,18 @@ fun processOutputImageAndReport(
resizeScale = resizeScale,
contextData = contextData
)
val aiOptions = compareOptions.aiCompareOptions
val aiResult = if (aiOptions != null && aiOptions.aiConditions.isNotEmpty()) {
val comparisonResultFactory = if (aiOptions.aiModel is AiComparisonResultFactory) {
aiOptions.aiModel
} else {
throw NotImplementedError("aiCompareCanvasFactory is not implemented. Did you add roborazzi-ai dependency and (call loadRoboAi() or use RoborazziRule)?")
}
val aiResult = comparisonResultFactory.invoke(comparisonFile.absolutePath, aiOptions)
aiResult
} else {
null
}
debugLog {
"processOutputImageAndReport(): compareCanvas is saved " +
"compareFile:${comparisonFile.absolutePath}"
Expand All @@ -132,7 +145,7 @@ fun processOutputImageAndReport(
goldenFile
} else {
File(
roborazziOptions.compareOptions.outputDirectoryPath,
compareOptions.outputDirectoryPath,
goldenFile.nameWithoutExtension + "_actual." + goldenFile.extension
)
}
Expand All @@ -153,6 +166,7 @@ fun processOutputImageAndReport(
goldenFile = goldenFile.absolutePath,
timestampNs = System.nanoTime(),
diffPercentage = diffPercentage,
aiComparisonResult = aiResult,
contextData = contextData,
)
} else {
Expand All @@ -161,6 +175,7 @@ fun processOutputImageAndReport(
actualFile = actualFile.absolutePath,
goldenFile = goldenFile.absolutePath,
timestampNs = System.nanoTime(),
aiComparisonResult = aiResult,
contextData = contextData,
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.github.takahirom.roborazzi

/**
* If you want to use AI to compare images, you can specify the model and prompt.
*/
data class AiCompareOptions(
val aiModel: AiModel,
val aiConditions: List<AiCondition> = emptyList(),
val systemPrompt: String = """Evaluate the following assertion for fulfillment in the new image.
The evaluation should be based on the comparison between the original image on the left and the new image on the right, with differences highlighted in red in the center. Focus on whether the new image fulfills the requirement specified in the user input.

Output:
For each assertion:
A fulfillment percentage from 0 to 100.
A brief explanation of how this percentage was determined.""",
val promptTemplate: String = """Assertions:
INPUT_PROMPT
""",
val inputPrompt: (AiCompareOptions) -> String = { aiOptions ->
buildString {
aiOptions.aiConditions.forEachIndexed { index, aiAssertion ->
appendLine("Assertion ${index + 1}: ${aiAssertion.assertPrompt}\n")
}
}
},
) {
interface AiModel {
/**
* You can use this model if you want to use other models.
*/
interface Manual : AiModel, AiComparisonResultFactory
}

data class AiCondition(
val assertPrompt: String,
/**
* If null, the AI result is not validated. But they are still included in the report.
*/
val requiredFulfillmentPercent: Int?
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.github.takahirom.roborazzi

fun interface AiComparisonResultFactory {
operator fun invoke(
comparisonImageFilePath: String,
aiCompareOptions: AiCompareOptions
): AiComparisonResult
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ sealed interface CaptureResult {
val compareFile: String?
val actualFile: String?
val goldenFile: String?
val contextData: Map<String,@Contextual Any>
val contextData: Map<String, @Contextual Any>

val reportFile: String
get() = when (val result = this) {
Expand All @@ -35,11 +35,11 @@ sealed interface CaptureResult {
@Serializable
data class Recorded(
@SerialName("golden_file_path")
override val goldenFile:@Contextual String,
override val goldenFile: @Contextual String,
@SerialName("timestamp")
override val timestampNs: Long,
@SerialName("context_data")
override val contextData: Map<String,@Contextual Any>
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
override val type = "recorded"
override val actualFile: String?
Expand All @@ -51,45 +51,49 @@ sealed interface CaptureResult {
@Serializable
data class Added(
@SerialName("compare_file_path")
override val compareFile:@Contextual String,
override val compareFile: @Contextual String,
@SerialName("actual_file_path")
override val actualFile:@Contextual String,
override val actualFile: @Contextual String,
@SerialName("golden_file_path")
override val goldenFile:@Contextual String,
override val goldenFile: @Contextual String,
@SerialName("timestamp")
override val timestampNs: Long,
@SerialName("ai_result")
val aiComparisonResult: AiComparisonResult?,
@SerialName("context_data")
override val contextData: Map<String,@Contextual Any>
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
override val type = "added"
}

@Serializable
data class Changed(
@SerialName("compare_file_path")
override val compareFile:@Contextual String,
override val compareFile: @Contextual String,
@SerialName("golden_file_path")
override val goldenFile:@Contextual String,
override val goldenFile: @Contextual String,
@SerialName("actual_file_path")
override val actualFile:@Contextual String,
override val actualFile: @Contextual String,
@SerialName("timestamp")
override val timestampNs: Long,
@SerialName("diff_percentage")
val diffPercentage: Float?,
@SerialName("ai_result")
val aiComparisonResult: AiComparisonResult?,
@SerialName("context_data")
override val contextData: Map<String,@Contextual Any>
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
override val type = "changed"
}

@Serializable
data class Unchanged(
@SerialName("golden_file_path")
override val goldenFile:@Contextual String,
override val goldenFile: @Contextual String,
@SerialName("timestamp")
override val timestampNs: Long,
@SerialName("context_data")
override val contextData: Map<String,@Contextual Any>
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
override val type = "unchanged"
override val actualFile: String?
Expand Down Expand Up @@ -122,12 +126,26 @@ sealed interface CaptureResult {
require(decoder is JsonDecoder)
val type = decoder.decodeJsonElement().jsonObject["type"]!!.jsonPrimitive.content
return when (type) {
"recorded" -> decoder.decodeSerializableValue(Recorded.serializer())
"changed" -> decoder.decodeSerializableValue(Changed.serializer())
"unchanged" -> decoder.decodeSerializableValue(Unchanged.serializer())
"added" -> decoder.decodeSerializableValue(Added.serializer())
else -> throw IllegalArgumentException("Unknown type $type")
"recorded" -> decoder.decodeSerializableValue(Recorded.serializer())
"changed" -> decoder.decodeSerializableValue(Changed.serializer())
"unchanged" -> decoder.decodeSerializableValue(Unchanged.serializer())
"added" -> decoder.decodeSerializableValue(Added.serializer())
else -> throw IllegalArgumentException("Unknown type $type")
}
}
}
}

@Serializable
data class AiComparisonResult(
val aiConditionResults: List<AiConditionResult> = emptyList()
)

@Serializable
data class AiConditionResult(
val assertPrompt: String,
@SerialName("required_fulfillment_percent")
val requiredFulfillmentPercent: Int?,
val fulfillmentPercent: Int,
val explanation: String?,
)
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,13 @@ data class CaptureResults(
}

companion object {
@OptIn(ExperimentalSerializationApi::class)
val json = Json {
isLenient = true
encodeDefaults = true
ignoreUnknownKeys = true
classDiscriminator = "#class"
explicitNulls = false
serializersModule = SerializersModule {
contextual(Any::class, AnySerializer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class CaptureResultTest {
actualFile = "/actual_file",
goldenFile = "/golden_file",
timestampNs = 123456789,
aiComparisonResult = null,
contextData = mapOf(
"key" to 2,
"keyDouble" to 2.5,
Expand All @@ -47,6 +48,7 @@ class CaptureResultTest {
actualFile = "/actual_file",
timestampNs = 123456789,
diffPercentage = 0.123f,
aiComparisonResult = null,
contextData = mapOf("key" to Long.MAX_VALUE - 100),
),
CaptureResult.Unchanged(
Expand Down
1 change: 1 addition & 0 deletions roborazzi-ai-gemini/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build/
Loading
Loading