Skip to content

Commit

Permalink
Fix naming
Browse files Browse the repository at this point in the history
  • Loading branch information
takahirom committed Oct 22, 2024
1 parent 6ef71df commit 1ce5650
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 86 deletions.
5 changes: 3 additions & 2 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ org.jetbrains.compose.experimental.uikit.enabled=true
kotlin.incremental.native=true

# To debug
roborazzi.test.record=true
#roborazzi.test.verify=true
#roborazzi.test.record=true
roborazzi.test.verify=true
#roborazzi.test.compare=true
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ data class RoborazziOptions(
val outputDirectoryPath: String = roborazziSystemPropertyOutputDirectory(),
val imageComparator: ImageComparator = DefaultImageComparator,
val comparisonStyle: ComparisonStyle = ComparisonStyle.Grid(),
val aiOptions: AiOptions? = null,
val aiCompareOptions: AiCompareOptions? = null,
val resultValidator: (result: ImageComparator.ComparisonResult) -> Boolean = DefaultResultValidator,
) {

Expand Down Expand Up @@ -161,25 +161,27 @@ data class RoborazziOptions(
override fun report(captureResult: CaptureResult, roborazziTaskType: RoborazziTaskType) {
val aiResult = when (captureResult) {
is CaptureResult.Changed -> {
captureResult.aiResult
captureResult.aiComparisonResult
}

is CaptureResult.Added -> {
captureResult.aiResult
captureResult.aiComparisonResult
}

else -> {
null
}
}
aiResult?.aiAssertions?.forEach { aiAssertion ->
if (aiAssertion.fulfillmentPercent < aiAssertion.requiredFulfillmentPercent) {
aiResult?.aiConditionResults
?.filter { conditionResult -> conditionResult.requiredFulfillmentPercent != null }
?.forEach { conditionResult ->
if (conditionResult.fulfillmentPercent < conditionResult.requiredFulfillmentPercent!!) {
throw AssertionError(
"The generated image did not meet the required prompt fulfillment percentage.\n" +
"prompt:${aiAssertion.assertPrompt}\n" +
"aiAssertion.fulfillmentPercent:${aiAssertion.fulfillmentPercent}\n" +
"requiredFulfillmentPercent:${aiAssertion.requiredFulfillmentPercent}\n" +
"explanation:${aiAssertion.explanation}"
"prompt:${conditionResult.assertPrompt}\n" +
"aiAssertion.fulfillmentPercent:${conditionResult.fulfillmentPercent}\n" +
"requiredFulfillmentPercent:${conditionResult.requiredFulfillmentPercent}\n" +
"explanation:${conditionResult.explanation}"
)
}
}
Expand Down Expand Up @@ -249,8 +251,8 @@ data class RoborazziOptions(
): RoborazziOptions {
return copy(
compareOptions = compareOptions.copy(
aiOptions = compareOptions.aiOptions!!.copy(
aiAssertions = compareOptions.aiOptions.aiAssertions + AiOptions.AiAssertion(
aiCompareOptions = compareOptions.aiCompareOptions!!.copy(
aiConditions = compareOptions.aiCompareOptions.aiConditions + AiCompareOptions.AiCondition(
assertPrompt = assert,
requiredFulfillmentPercent = requiredFulfillmentPercent
)
Expand All @@ -259,11 +261,11 @@ data class RoborazziOptions(
)
}

fun addedCompareAiAssertions(vararg assertions: AiOptions.AiAssertion): RoborazziOptions {
fun addedCompareAiAssertions(vararg assertions: AiCompareOptions.AiCondition): RoborazziOptions {
return copy(
compareOptions = compareOptions.copy(
aiOptions = compareOptions.aiOptions!!.copy(
aiAssertions = compareOptions.aiOptions.aiAssertions + assertions
aiCompareOptions = compareOptions.aiCompareOptions!!.copy(
aiConditions = compareOptions.aiCompareOptions.aiConditions + assertions
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ fun processOutputImageAndReport(
resizeScale = resizeScale,
contextData = contextData
)
val aiOptions = compareOptions.aiOptions
val aiResult = if (aiOptions != null && aiOptions.aiAssertions.isNotEmpty()) {
val aiResult = aiCompareResultFactory?.invoke(comparisonFile.absolutePath, aiOptions)
val aiOptions = compareOptions.aiCompareOptions
val aiResult = if (aiOptions != null && aiOptions.aiConditions.isNotEmpty()) {
val aiResult = aiComparisonResultFactory?.invoke(comparisonFile.absolutePath, aiOptions)
?: throw NotImplementedError("aiCompareCanvasFactory is not implemented. Did you add roborazzi-ai dependency and (call loadRoboAi() or use RoborazziRule)?")
aiResult
} else {
Expand Down Expand Up @@ -162,7 +162,7 @@ fun processOutputImageAndReport(
goldenFile = goldenFile.absolutePath,
timestampNs = System.nanoTime(),
diffPercentage = diffPercentage,
aiResult = aiResult,
aiComparisonResult = aiResult,
contextData = contextData,
)
} else {
Expand All @@ -171,7 +171,7 @@ fun processOutputImageAndReport(
actualFile = actualFile.absolutePath,
goldenFile = goldenFile.absolutePath,
timestampNs = System.nanoTime(),
aiResult = aiResult,
aiComparisonResult = aiResult,
contextData = contextData,
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,26 @@ package com.github.takahirom.roborazzi
/**
* If you want to use AI to compare images, you can specify the model and prompt.
*/
data class AiOptions(
data class AiCompareOptions(
val aiModel: AiModel,
val aiAssertions: List<AiAssertion> = emptyList(),
val inputPrompt: (AiOptions) -> String = { aiOptions ->
buildString {
aiOptions.aiAssertions.forEachIndexed { index, aiAssertion ->
appendLine("Assertion ${index + 1}: ${aiAssertion.assertPrompt}\n")
}
}
},
val template: String = """
Evaluate the following assertion for fulfillment in the new image.
val aiConditions: List<AiCondition> = emptyList(),
val systemPrompt: String = """Evaluate the following assertion for fulfillment in the new image.
The evaluation should be based on the comparison between the original image on the left and the new image on the right, with differences highlighted in red in the center. Focus on whether the new image fulfills the requirement specified in the user input.
Output:
For each assertion:
A fulfillment percentage from 0 to 100.
A brief explanation of how this percentage was determined.
Assertions:
A brief explanation of how this percentage was determined.""",
val promptTemplate: String = """Assertions:
INPUT_PROMPT
"""
""",
val inputPrompt: (AiCompareOptions) -> String = { aiOptions ->
buildString {
aiOptions.aiConditions.forEachIndexed { index, aiAssertion ->
appendLine("Assertion ${index + 1}: ${aiAssertion.assertPrompt}\n")
}
}
},
) {
interface AiModel {
data class Gemini(
Expand All @@ -35,14 +33,14 @@ INPUT_PROMPT
/**
* You can use this model if you want to use other models.
*/
interface Manual : AiModel, AiCompareResultFactory
interface Manual : AiModel, AiComparisonResultFactory
}

data class AiAssertion(
data class AiCondition(
val assertPrompt: String,
/**
* If null, the AI result is not validated. But they are still included in the report.
*/
val requiredFulfillmentPercent: Int
val requiredFulfillmentPercent: Int?
)
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package com.github.takahirom.roborazzi

fun interface AiCompareResultFactory {
fun interface AiComparisonResultFactory {
operator fun invoke(
comparisonImageFilePath: String,
aiOptions: AiOptions
): AiResult
aiCompareOptions: AiCompareOptions
): AiComparisonResult
}

var aiCompareResultFactory: AiCompareResultFactory? =
AiCompareResultFactory { comparisonImageFilePath, aiOptions ->
var aiComparisonResultFactory: AiComparisonResultFactory? =
AiComparisonResultFactory { comparisonImageFilePath, aiOptions ->
throw NotImplementedError("aiCompareCanvasFactory is not implemented")
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ sealed interface CaptureResult {
@SerialName("timestamp")
override val timestampNs: Long,
@SerialName("ai_result")
val aiResult: AiResult?,
val aiComparisonResult: AiComparisonResult?,
@SerialName("context_data")
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
Expand All @@ -79,7 +79,7 @@ sealed interface CaptureResult {
@SerialName("diff_percentage")
val diffPercentage: Float?,
@SerialName("ai_result")
val aiResult: AiResult?,
val aiComparisonResult: AiComparisonResult?,
@SerialName("context_data")
override val contextData: Map<String, @Contextual Any>
) : CaptureResult {
Expand Down Expand Up @@ -137,15 +137,15 @@ sealed interface CaptureResult {
}

@Serializable
data class AiResult(
val aiAssertions: List<AiAssertion> = emptyList()
data class AiComparisonResult(
val aiConditionResults: List<AiConditionResult> = emptyList()
)

@Serializable
data class AiAssertion(
data class AiConditionResult(
val assertPrompt: String,
@SerialName("required_fulfillment_percent")
val requiredFulfillmentPercent: Int,
val requiredFulfillmentPercent: Int?,
val fulfillmentPercent: Int,
val explanation: String?,
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CaptureResultTest {
actualFile = "/actual_file",
goldenFile = "/golden_file",
timestampNs = 123456789,
aiResult = null,
aiComparisonResult = null,
contextData = mapOf(
"key" to 2,
"keyDouble" to 2.5,
Expand All @@ -48,7 +48,7 @@ class CaptureResultTest {
actualFile = "/actual_file",
timestampNs = 123456789,
diffPercentage = 0.123f,
aiResult = null,
aiComparisonResult = null,
contextData = mapOf("key" to Long.MAX_VALUE - 100),
),
CaptureResult.Unchanged(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ import kotlin.jvm.JvmName

@InternalRoborazziApi
val loaded = run {
aiCompareResultFactory = AiCompareResultFactory { comparisonImageFilePath, aiOptions ->
aiComparisonResultFactory = AiComparisonResultFactory { comparisonImageFilePath, aiOptions ->
createAiResult(aiOptions, comparisonImageFilePath)
}
}

fun loadRoboAi() = loaded

@Serializable
data class AssertionResult(
data class GeminiAiConditionResult(
@SerialName("fulfillment_percent")
val fulfillmentPercent: Int,
val explanation: String?,
Expand All @@ -32,14 +32,18 @@ expect fun readByteArrayFromFile(filePath: String): PlatformImage

@InternalRoborazziApi
fun createAiResult(
aiOptions: AiOptions,
aiCompareOptions: AiCompareOptions,
comparisonImageFilePath: String,
): AiResult {
when (val aiModel = aiOptions.aiModel) {
is AiOptions.AiModel.Gemini -> {
): AiComparisonResult {
when (val aiModel = aiCompareOptions.aiModel) {
is AiCompareOptions.AiModel.Gemini -> {
val systemPrompt = aiCompareOptions.systemPrompt
val generativeModel = GenerativeModel(
modelName = aiModel.modelName,
apiKey = aiModel.apiKey,
systemInstruction = content {
text(systemPrompt)
},
generationConfig = generationConfig {
maxOutputTokens = 8192
responseMimeType = "application/json"
Expand Down Expand Up @@ -69,9 +73,9 @@ fun createAiResult(
},
)

val template = aiOptions.template
val template = aiCompareOptions.promptTemplate

val inputPrompt = aiOptions.inputPrompt(aiOptions)
val inputPrompt = aiCompareOptions.inputPrompt(aiCompareOptions)
val inputContent = content {
image(readByteArrayFromFile(comparisonImageFilePath))
val prompt = template.replace("INPUT_PROMPT", inputPrompt)
Expand All @@ -86,18 +90,18 @@ fun createAiResult(
debugLog {
"RoborazziAi: response: ${response.text}"
}
val geminiResult = CaptureResults.json.decodeFromString<Array<AssertionResult>>(
val geminiResult = CaptureResults.json.decodeFromString<Array<GeminiAiConditionResult>>(
requireNotNull(
response.text
)
)
return AiResult(
aiAssertions = aiOptions.aiAssertions.mapIndexed { index, it ->
val assertResult = geminiResult.getOrNull(index) ?: AssertionResult(
return AiComparisonResult(
aiConditionResults = aiCompareOptions.aiConditions.mapIndexed { index, it ->
val assertResult = geminiResult.getOrNull(index) ?: GeminiAiConditionResult(
fulfillmentPercent = 0,
explanation = "AI model did not return a result for this assertion"
)
AiAssertion(
AiConditionResult(
assertPrompt = it.assertPrompt,
requiredFulfillmentPercent = it.requiredFulfillmentPercent,
fulfillmentPercent = assertResult.fulfillmentPercent,
Expand All @@ -107,8 +111,8 @@ fun createAiResult(
)
}

is AiOptions.AiModel.Manual -> {
return aiModel(comparisonImageFilePath, aiOptions)
is AiCompareOptions.AiModel.Manual -> {
return aiModel(comparisonImageFilePath, aiCompareOptions)
}

else -> {
Expand Down
Loading

0 comments on commit 1ce5650

Please sign in to comment.