Skip to content

Commit

Permalink
Remove loadRoboAi
Browse files Browse the repository at this point in the history
  • Loading branch information
takahirom committed Oct 24, 2024
1 parent 5c32eb1 commit 9cd7a03
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,11 @@ fun processOutputImageAndReport(
)
val aiOptions = compareOptions.aiCompareOptions
val aiResult = if (aiOptions != null && aiOptions.aiConditions.isNotEmpty()) {
val comparisonResultFactory = aiComparisonResultFactory ?:
if(aiOptions.aiModel is AiComparisonResultFactory) {
aiOptions.aiModel
} else {
null
} ?: throw NotImplementedError("aiCompareCanvasFactory is not implemented. Did you add roborazzi-ai dependency and (call loadRoboAi() or use RoborazziRule)?")
val comparisonResultFactory = if (aiOptions.aiModel is AiComparisonResultFactory) {
aiOptions.aiModel
} else {
throw NotImplementedError("aiCompareCanvasFactory is not implemented. Did you add roborazzi-ai dependency and (call loadRoboAi() or use RoborazziRule)?")
}
val aiResult = comparisonResultFactory.invoke(comparisonFile.absolutePath, aiOptions)
aiResult
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ INPUT_PROMPT
},
) {
interface AiModel {
data class Gemini(
val apiKey: String,
val modelName: String = "gemini-1.5-pro"
) : AiModel

/**
* You can use this model if you want to use other models.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,3 @@ fun interface AiComparisonResultFactory {
aiCompareOptions: AiCompareOptions
): AiComparisonResult
}

var aiComparisonResultFactory: AiComparisonResultFactory? = null
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@file:JvmName("RoborazziAi")
package com.github.takahirom.roborazzi

import com.github.takahirom.roborazzi.AiCompareOptions.AiModel
import dev.shreyaspatil.ai.client.generativeai.GenerativeModel
import dev.shreyaspatil.ai.client.generativeai.type.FunctionType
import dev.shreyaspatil.ai.client.generativeai.type.PlatformImage
Expand All @@ -12,14 +13,89 @@ import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlin.jvm.JvmName

@InternalRoborazziApi
val loaded = run {
aiComparisonResultFactory = AiComparisonResultFactory { comparisonImageFilePath, aiOptions ->
createAiResult(aiOptions, comparisonImageFilePath)
data class GeminiAiModel(
val apiKey: String,
val modelName: String = "gemini-1.5-pro"
) : AiModel, AiComparisonResultFactory {
override fun invoke(
comparisonImageFilePath: String,
aiCompareOptions: AiCompareOptions
): AiComparisonResult {
val systemPrompt = aiCompareOptions.systemPrompt
val generativeModel = GenerativeModel(
modelName = modelName,
apiKey = apiKey,
systemInstruction = content {
text(systemPrompt)
},
generationConfig = generationConfig {
maxOutputTokens = 8192
responseMimeType = "application/json"
responseSchema = Schema(
name = "content",
description = "content",
type = FunctionType.ARRAY,
items = Schema(
name = "assert_results",
description = "An array of assertion results",
type = FunctionType.OBJECT,
properties = mapOf(
"fulfillment_percent" to Schema.int(
name = "fulfillment_percent",
description = "A fulfillment percentage from 0 to 100",
),
"explanation" to Schema(
name = "explanation",
description = "A brief explanation of how this percentage was determined. If fulfillment_percent is 100, this field should be empty.",
type = FunctionType.STRING,
nullable = true,
)
),
required = listOf("fulfillment_percent")
),
)
},
)

val template = aiCompareOptions.promptTemplate

val inputPrompt = aiCompareOptions.inputPrompt(aiCompareOptions)
val inputContent = content {
image(readByteArrayFromFile(comparisonImageFilePath))
val prompt = template.replace("INPUT_PROMPT", inputPrompt)
text(prompt)

debugLog {
"RoborazziAi: prompt:$prompt"
}
}

val response = runBlocking { generativeModel.generateContent(inputContent) }
debugLog {
"RoborazziAi: response: ${response.text}"
}
val geminiResult = CaptureResults.json.decodeFromString<Array<GeminiAiConditionResult>>(
requireNotNull(
response.text
)
)
return AiComparisonResult(
aiConditionResults = aiCompareOptions.aiConditions.mapIndexed { index, it ->
val assertResult = geminiResult.getOrNull(index) ?: GeminiAiConditionResult(
fulfillmentPercent = 0,
explanation = "AI model did not return a result for this assertion"
)
AiConditionResult(
assertPrompt = it.assertPrompt,
requiredFulfillmentPercent = it.requiredFulfillmentPercent,
fulfillmentPercent = assertResult.fulfillmentPercent,
explanation = assertResult.explanation,
)
}
)
}
}

fun loadRoboAi() = loaded

@Serializable
data class GeminiAiConditionResult(
Expand All @@ -29,94 +105,3 @@ data class GeminiAiConditionResult(
)

expect fun readByteArrayFromFile(filePath: String): PlatformImage

@InternalRoborazziApi
fun createAiResult(
aiCompareOptions: AiCompareOptions,
comparisonImageFilePath: String,
): AiComparisonResult {
when (val aiModel = aiCompareOptions.aiModel) {
is AiCompareOptions.AiModel.Gemini -> {
val systemPrompt = aiCompareOptions.systemPrompt
val generativeModel = GenerativeModel(
modelName = aiModel.modelName,
apiKey = aiModel.apiKey,
systemInstruction = content {
text(systemPrompt)
},
generationConfig = generationConfig {
maxOutputTokens = 8192
responseMimeType = "application/json"
responseSchema = Schema(
name = "content",
description = "content",
type = FunctionType.ARRAY,
items = Schema(
name = "assert_results",
description = "An array of assertion results",
type = FunctionType.OBJECT,
properties = mapOf(
"fulfillment_percent" to Schema.int(
name = "fulfillment_percent",
description = "A fulfillment percentage from 0 to 100",
),
"explanation" to Schema(
name = "explanation",
description = "A brief explanation of how this percentage was determined. If fulfillment_percent is 100, this field should be empty.",
type = FunctionType.STRING,
nullable = true,
)
),
required = listOf("fulfillment_percent")
),
)
},
)

val template = aiCompareOptions.promptTemplate

val inputPrompt = aiCompareOptions.inputPrompt(aiCompareOptions)
val inputContent = content {
image(readByteArrayFromFile(comparisonImageFilePath))
val prompt = template.replace("INPUT_PROMPT", inputPrompt)
text(prompt)

debugLog {
"RoborazziAi: prompt:$prompt"
}
}

val response = runBlocking { generativeModel.generateContent(inputContent) }
debugLog {
"RoborazziAi: response: ${response.text}"
}
val geminiResult = CaptureResults.json.decodeFromString<Array<GeminiAiConditionResult>>(
requireNotNull(
response.text
)
)
return AiComparisonResult(
aiConditionResults = aiCompareOptions.aiConditions.mapIndexed { index, it ->
val assertResult = geminiResult.getOrNull(index) ?: GeminiAiConditionResult(
fulfillmentPercent = 0,
explanation = "AI model did not return a result for this assertion"
)
AiConditionResult(
assertPrompt = it.assertPrompt,
requiredFulfillmentPercent = it.requiredFulfillmentPercent,
fulfillmentPercent = assertResult.fulfillmentPercent,
explanation = assertResult.explanation,
)
}
)
}

is AiCompareOptions.AiModel.Manual -> {
return aiModel(comparisonImageFilePath, aiCompareOptions)
}

else -> {
throw NotImplementedError("aiCompareCanvasFactory for $aiModel is not implemented in this version")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import androidx.test.espresso.Espresso.onView
import androidx.test.espresso.matcher.ViewMatchers
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.github.takahirom.roborazzi.AiCompareOptions
import com.github.takahirom.roborazzi.GeminiAiModel
import com.github.takahirom.roborazzi.ROBORAZZI_DEBUG
import com.github.takahirom.roborazzi.RobolectricDeviceQualifiers
import com.github.takahirom.roborazzi.RoborazziOptions
Expand Down Expand Up @@ -33,7 +34,7 @@ class AiTest {
roborazziOptions = RoborazziOptions(
compareOptions = RoborazziOptions.CompareOptions(
aiCompareOptions = AiCompareOptions(
aiModel = AiCompareOptions.AiModel.Gemini(
aiModel = GeminiAiModel(
apiKey = System.getenv("gemini_api_key") ?: ""
),
)
Expand All @@ -43,7 +44,7 @@ class AiTest {
)

@Test
fun captureWithAi() {
fun captureWithAi2() {
ROBORAZZI_DEBUG = true
if (System.getenv("gemini_api_key") == null) {
println("Skip the test because gemini_api_key is not set.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ import com.dropbox.differ.ImageComparator
import com.dropbox.differ.SimpleImageComparator
import com.github.takahirom.roborazzi.AiCompareOptions
import com.github.takahirom.roborazzi.Dump
import com.github.takahirom.roborazzi.GeminiAiModel
import com.github.takahirom.roborazzi.RoboComponent
import com.github.takahirom.roborazzi.RobolectricDeviceQualifiers
import com.github.takahirom.roborazzi.RoborazziOptions
import com.github.takahirom.roborazzi.captureRoboAllImage
import com.github.takahirom.roborazzi.captureRoboGif
import com.github.takahirom.roborazzi.captureRoboImage
import com.github.takahirom.roborazzi.captureRoboLastImage
import com.github.takahirom.roborazzi.loadRoboAi
import com.github.takahirom.roborazzi.roboOutputName
import com.github.takahirom.roborazzi.roborazziSystemPropertyOutputDirectory
import com.github.takahirom.roborazzi.withComposeTestTag
Expand Down Expand Up @@ -61,8 +61,6 @@ class ManualTest {
@Test
@Config
fun captureWithAi() {
loadRoboAi()

onView(ViewMatchers.isRoot())
.captureRoboImage(
roborazziOptions = RoborazziOptions(
Expand All @@ -79,7 +77,7 @@ class ManualTest {
requiredFulfillmentPercent = 90,
),
),
aiModel = AiCompareOptions.AiModel.Gemini(
aiModel = GeminiAiModel(
apiKey = System.getenv("gemini_api_key")!!,
),
)
Expand Down

0 comments on commit 9cd7a03

Please sign in to comment.