Skip to content

Commit

Permalink
Introduce Roborazzi AI prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
takahirom committed Oct 21, 2024
1 parent 0f528a1 commit 6ef71df
Show file tree
Hide file tree
Showing 20 changed files with 536 additions and 23 deletions.
5 changes: 5 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mavenPublish = "0.25.3"
composeCompiler = "1.5.10"
composeMultiplatform = "1.6.2"
robolectric = "4.12.2"
generativeaiGoogle = "0.9.0-1.0.1"
robolectric-android-all = "Q-robolectric-5415296"

roborazzi-for-replacing-by-include-build = "1.0.0"
Expand Down Expand Up @@ -37,6 +38,7 @@ google-android-material = "1.5.0"
junit = "4.13.2"
ktor-serialization-kotlinx-xml = "2.3.11"
kotlinx-serialization = "1.6.3"
kotlinx-coroutines = "1.6.0"
squareup-okhttp = "5.0.0-alpha.11"
kotlinx-io = "0.3.3"
webjar-material-design-icons = "4.0.0"
Expand All @@ -55,6 +57,8 @@ kotlin-stdlib-jdk8 = { module = "org.jetbrains.kotlin:kotlin-stdlib-jdk8" }
kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test" }
kotlin-test-junit = { module = "org.jetbrains.kotlin:kotlin-test-junit" }

generativeai-google = { module = "dev.shreyaspatil.generativeai:generativeai-google", version.ref = "generativeaiGoogle" }

# for sample
composable-preview-scanner = { module = "io.github.sergio-sastre.ComposablePreviewScanner:android", version.ref = "composable-preview-scanner" }

Expand Down Expand Up @@ -90,6 +94,7 @@ dropbox-differ = { module = "com.dropbox.differ:differ", version.ref = "dropbox-
google-android-material = { module = "com.google.android.material:material", version.ref = "google-android-material" }
junit = { module = "junit:junit", version.ref = "junit" }
kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-serialization" }
kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" }
ktor-serialization-kotlinx-xml = { module = "io.ktor:ktor-serialization-kotlinx-xml", version.ref = "ktor-serialization-kotlinx-xml" }
squareup-okhttp = { module = "com.squareup.okhttp3:okhttp", version.ref = "squareup-okhttp" }
squareup-okhttp-coroutines = { module = "com.squareup.okhttp3:okhttp-coroutines", version.ref = "squareup-okhttp" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ data class RoborazziOptions(
val outputDirectoryPath: String = roborazziSystemPropertyOutputDirectory(),
val imageComparator: ImageComparator = DefaultImageComparator,
val comparisonStyle: ComparisonStyle = ComparisonStyle.Grid(),
val aiOptions: AiOptions? = null,
val resultValidator: (result: ImageComparator.ComparisonResult) -> Boolean = DefaultResultValidator,
) {

@ExperimentalRoborazziApi
sealed interface ComparisonStyle {
@ExperimentalRoborazziApi
Expand Down Expand Up @@ -151,6 +153,36 @@ data class RoborazziOptions(
} else {
JsonOutputCaptureResultReporter().report(captureResult, roborazziTaskType)
}
AiCaptureResultReporter().report(captureResult, roborazziTaskType)
}
}

class AiCaptureResultReporter : CaptureResultReporter {
override fun report(captureResult: CaptureResult, roborazziTaskType: RoborazziTaskType) {
val aiResult = when (captureResult) {
is CaptureResult.Changed -> {
captureResult.aiResult
}

is CaptureResult.Added -> {
captureResult.aiResult
}

else -> {
null
}
}
aiResult?.aiAssertions?.forEach { aiAssertion ->
if (aiAssertion.fulfillmentPercent < aiAssertion.requiredFulfillmentPercent) {
throw AssertionError(
"The generated image did not meet the required prompt fulfillment percentage.\n" +
"prompt:${aiAssertion.assertPrompt}\n" +
"aiAssertion.fulfillmentPercent:${aiAssertion.fulfillmentPercent}\n" +
"requiredFulfillmentPercent:${aiAssertion.requiredFulfillmentPercent}\n" +
"explanation:${aiAssertion.explanation}"
)
}
}
}
}

Expand Down Expand Up @@ -210,6 +242,32 @@ data class RoborazziOptions(
}

internal val shouldTakeBitmap: Boolean = captureType.shouldTakeScreenshot()

fun addedCompareAiAssertion(
assert: String,
requiredFulfillmentPercent: Int
): RoborazziOptions {
return copy(
compareOptions = compareOptions.copy(
aiOptions = compareOptions.aiOptions!!.copy(
aiAssertions = compareOptions.aiOptions.aiAssertions + AiOptions.AiAssertion(
assertPrompt = assert,
requiredFulfillmentPercent = requiredFulfillmentPercent
)
)
)
)
}

fun addedCompareAiAssertions(vararg assertions: AiOptions.AiAssertion): RoborazziOptions {
return copy(
compareOptions = compareOptions.copy(
aiOptions = compareOptions.aiOptions!!.copy(
aiAssertions = compareOptions.aiOptions.aiAssertions + assertions
)
)
)
}
}

expect fun canScreenshot(): Boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,16 @@ fun processOutputImageAndReport(
// Only used by CaptureResult.Changed
var diffPercentage: Float? = null

val compareOptions = roborazziOptions.compareOptions
val changed = if (height == goldenRoboCanvas.height && width == goldenRoboCanvas.width) {
val comparisonResult: ImageComparator.ComparisonResult =
newRoboCanvas.differ(
other = goldenRoboCanvas,
resizeScale = resizeScale,
imageComparator = roborazziOptions.compareOptions.imageComparator
imageComparator = compareOptions.imageComparator
)
diffPercentage = comparisonResult.pixelDifferences.toFloat() / comparisonResult.pixelCount
val changed = !roborazziOptions.compareOptions.resultValidator(comparisonResult)
val changed = !compareOptions.resultValidator(comparisonResult)
reportLog("${goldenFile.name} The differ result :$comparisonResult changed:$changed")
changed
} else {
Expand All @@ -106,7 +107,7 @@ fun processOutputImageAndReport(

val result: CaptureResult = if (changed) {
val comparisonFile = File(
roborazziOptions.compareOptions.outputDirectoryPath,
compareOptions.outputDirectoryPath,
goldenFile.nameWithoutExtension + "_compare." + goldenFile.extension
)
val comparisonCanvas = comparisonCanvasFactory(
Expand All @@ -121,6 +122,14 @@ fun processOutputImageAndReport(
resizeScale = resizeScale,
contextData = contextData
)
val aiOptions = compareOptions.aiOptions
val aiResult = if (aiOptions != null && aiOptions.aiAssertions.isNotEmpty()) {
val aiResult = aiCompareResultFactory?.invoke(comparisonFile.absolutePath, aiOptions)
?: throw NotImplementedError("aiCompareCanvasFactory is not implemented. Did you add roborazzi-ai dependency and (call loadRoboAi() or use RoborazziRule)?")
aiResult
} else {
null
}
debugLog {
"processOutputImageAndReport(): compareCanvas is saved " +
"compareFile:${comparisonFile.absolutePath}"
Expand All @@ -132,7 +141,7 @@ fun processOutputImageAndReport(
goldenFile
} else {
File(
roborazziOptions.compareOptions.outputDirectoryPath,
compareOptions.outputDirectoryPath,
goldenFile.nameWithoutExtension + "_actual." + goldenFile.extension
)
}
Expand All @@ -153,6 +162,7 @@ fun processOutputImageAndReport(
goldenFile = goldenFile.absolutePath,
timestampNs = System.nanoTime(),
diffPercentage = diffPercentage,
aiResult = aiResult,
contextData = contextData,
)
} else {
Expand All @@ -161,6 +171,7 @@ fun processOutputImageAndReport(
actualFile = actualFile.absolutePath,
goldenFile = goldenFile.absolutePath,
timestampNs = System.nanoTime(),
aiResult = aiResult,
contextData = contextData,
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.github.takahirom.roborazzi

/**
* If you want to use AI to compare images, you can specify the model and prompt.
*/
data class AiOptions(
val aiModel: AiModel,
val aiAssertions: List<AiAssertion> = emptyList(),
val inputPrompt: (AiOptions) -> String = { aiOptions ->
buildString {
aiOptions.aiAssertions.forEachIndexed { index, aiAssertion ->
appendLine("Assertion ${index + 1}: ${aiAssertion.assertPrompt}\n")
}
}
},
val template: String = """
Evaluate the following assertion for fulfillment in the new image.
The evaluation should be based on the comparison between the original image on the left and the new image on the right, with differences highlighted in red in the center. Focus on whether the new image fulfills the requirement specified in the user input.
Output:
For each assertion:
A fulfillment percentage from 0 to 100.
A brief explanation of how this percentage was determined.
Assertions:
INPUT_PROMPT
"""
) {
interface AiModel {
data class Gemini(
val apiKey: String,
val modelName: String = "gemini-1.5-pro"
) : AiModel

/**
* You can use this model if you want to use other models.
*/
interface Manual : AiModel, AiCompareResultFactory
}

data class AiAssertion(
val assertPrompt: String,
/**
* If null, the AI result is not validated. But they are still included in the report.
*/
val requiredFulfillmentPercent: Int
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.github.takahirom.roborazzi

fun interface AiCompareResultFactory {
operator fun invoke(
comparisonImageFilePath: String,
aiOptions: AiOptions
): AiResult
}

var aiCompareResultFactory: AiCompareResultFactory? =
AiCompareResultFactory { comparisonImageFilePath, aiOptions ->
throw NotImplementedError("aiCompareCanvasFactory is not implemented")
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ sealed interface CaptureResult {
val compareFile: String?
val actualFile: String?
val goldenFile: String?
val contextData: Map<String,@Contextual Any>
val contextData: Map<String, @Contextual Any>

val reportFile: String
get() = when (val result = this) {
Expand All @@ -35,11 +35,11 @@ sealed interface CaptureResult {
@Serializable
data class Recorded(
@SerialName("golden_file_path")
override val goldenFile:@Contextual String,
override val goldenFile: @Contextual String,
@SerialName("timestamp")
override val timestampNs: Long,
@SerialName("context_data")
override val contextData: Map<String,@Contextual Any>
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
override val type = "recorded"
override val actualFile: String?
Expand All @@ -51,45 +51,49 @@ sealed interface CaptureResult {
@Serializable
data class Added(
@SerialName("compare_file_path")
override val compareFile:@Contextual String,
override val compareFile: @Contextual String,
@SerialName("actual_file_path")
override val actualFile:@Contextual String,
override val actualFile: @Contextual String,
@SerialName("golden_file_path")
override val goldenFile:@Contextual String,
override val goldenFile: @Contextual String,
@SerialName("timestamp")
override val timestampNs: Long,
@SerialName("ai_result")
val aiResult: AiResult?,
@SerialName("context_data")
override val contextData: Map<String,@Contextual Any>
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
override val type = "added"
}

@Serializable
data class Changed(
@SerialName("compare_file_path")
override val compareFile:@Contextual String,
override val compareFile: @Contextual String,
@SerialName("golden_file_path")
override val goldenFile:@Contextual String,
override val goldenFile: @Contextual String,
@SerialName("actual_file_path")
override val actualFile:@Contextual String,
override val actualFile: @Contextual String,
@SerialName("timestamp")
override val timestampNs: Long,
@SerialName("diff_percentage")
val diffPercentage: Float?,
@SerialName("ai_result")
val aiResult: AiResult?,
@SerialName("context_data")
override val contextData: Map<String,@Contextual Any>
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
override val type = "changed"
}

@Serializable
data class Unchanged(
@SerialName("golden_file_path")
override val goldenFile:@Contextual String,
override val goldenFile: @Contextual String,
@SerialName("timestamp")
override val timestampNs: Long,
@SerialName("context_data")
override val contextData: Map<String,@Contextual Any>
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
override val type = "unchanged"
override val actualFile: String?
Expand Down Expand Up @@ -122,12 +126,26 @@ sealed interface CaptureResult {
require(decoder is JsonDecoder)
val type = decoder.decodeJsonElement().jsonObject["type"]!!.jsonPrimitive.content
return when (type) {
"recorded" -> decoder.decodeSerializableValue(Recorded.serializer())
"changed" -> decoder.decodeSerializableValue(Changed.serializer())
"unchanged" -> decoder.decodeSerializableValue(Unchanged.serializer())
"added" -> decoder.decodeSerializableValue(Added.serializer())
else -> throw IllegalArgumentException("Unknown type $type")
"recorded" -> decoder.decodeSerializableValue(Recorded.serializer())
"changed" -> decoder.decodeSerializableValue(Changed.serializer())
"unchanged" -> decoder.decodeSerializableValue(Unchanged.serializer())
"added" -> decoder.decodeSerializableValue(Added.serializer())
else -> throw IllegalArgumentException("Unknown type $type")
}
}
}
}

@Serializable
data class AiResult(
val aiAssertions: List<AiAssertion> = emptyList()
)

@Serializable
data class AiAssertion(
val assertPrompt: String,
@SerialName("required_fulfillment_percent")
val requiredFulfillmentPercent: Int,
val fulfillmentPercent: Int,
val explanation: String?,
)
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,13 @@ data class CaptureResults(
}

companion object {
@OptIn(ExperimentalSerializationApi::class)
val json = Json {
isLenient = true
encodeDefaults = true
ignoreUnknownKeys = true
classDiscriminator = "#class"
explicitNulls = false
serializersModule = SerializersModule {
contextual(Any::class, AnySerializer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class CaptureResultTest {
actualFile = "/actual_file",
goldenFile = "/golden_file",
timestampNs = 123456789,
aiResult = null,
contextData = mapOf(
"key" to 2,
"keyDouble" to 2.5,
Expand All @@ -47,6 +48,7 @@ class CaptureResultTest {
actualFile = "/actual_file",
timestampNs = 123456789,
diffPercentage = 0.123f,
aiResult = null,
contextData = mapOf("key" to Long.MAX_VALUE - 100),
),
CaptureResult.Unchanged(
Expand Down
1 change: 1 addition & 0 deletions roborazzi-ai/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build/
Loading

0 comments on commit 6ef71df

Please sign in to comment.