Skip to content

Commit

Permalink
Merge branch 'main' into rl.function.call.wrap
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo authored Sep 27, 2024
2 parents 28afaa1 + 1f6afc7 commit 744b00b
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@ sealed interface DataConnectExecutable {
"aea3583ebe1a36938eec5164de79405951ddf05b70a857ddb4f346f1424666f1d96" +
"989a5f81326c7e2aef4a195d31ff356fdf2331ed98fa1048c4bd469cbfd97"
)
"1.3.9" ->
VerificationInfo(
fileSizeInBytes = 24_977_560L,
sha512DigestHex =
"4558928c2a84b54113e0d6918907eb75bdeb9bd059dcc4b6f22cb4a7c9c7421a357" +
"7f3b0d2eeb246b1df739b38f1eb91e5a6166b0e559707746d79e6ccdf9ed4"
)
"1.4.0" ->
VerificationInfo(
fileSizeInBytes = 25_018_520L,
sha512DigestHex =
"c06ccade89cb46459452f71c6d49a01b4b30c9f96cc4cb770ed168e7420ef0cb368" +
"cd602ff596137e6586270046cf0ffd9f8d294e44b036e5c5b373a074b7e5a"
)
else ->
throw DataConnectGradleException(
"3svd27ch8y",
Expand All @@ -86,10 +100,16 @@ sealed interface DataConnectExecutable {
data class Version(val version: String, val verificationInfo: VerificationInfo?) :
DataConnectExecutable {
companion object {

private const val DEFAULT_VERSION = "1.4.0"

fun forVersionWithDefaultVerificationInfo(version: String): Version {
val verificationInfo = DataConnectExecutable.VerificationInfo.forVersion(version)
return Version(version, verificationInfo)
}

fun forDefaultVersionWithDefaultVerificationInfo(): Version =
forVersionWithDefaultVerificationInfo(DEFAULT_VERSION)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class DataConnectProviders(
.orElse(versionValueFromGradleProperty)
.orElse(valueFromVariant)
.orElse(valueFromProject)
.orElse(DataConnectExecutable.Version.forVersionWithDefaultVerificationInfo("1.3.8"))
.orElse(DataConnectExecutable.Version.forDefaultVersionWithDefaultVerificationInfo())
}

val postgresConnectionUrl: Provider<String> = run {
Expand Down
23 changes: 23 additions & 0 deletions firebase-vertexai/consumer-rules.pro
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html

# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}

# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable

# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

-keep class com.google.firebase.vertexai.common.** { *; }
10 changes: 10 additions & 0 deletions firebase-vertexai/firebase-vertexai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,19 @@ android {
defaultConfig {
minSdk = 21
targetSdk = 34
consumerProguardFiles("consumer-rules.pro")
multiDexEnabled = true
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
Expand Down
21 changes: 21 additions & 0 deletions firebase-vertexai/proguard-rules.pro
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html

# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}

# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable

# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
package com.google.firebase.vertexai.common

import android.util.Log
import androidx.annotation.VisibleForTesting
import com.google.firebase.Firebase
import com.google.firebase.options
import com.google.firebase.vertexai.common.server.FinishReason
import com.google.firebase.vertexai.common.server.GRpcError
import com.google.firebase.vertexai.common.server.GRpcErrorDetails
import com.google.firebase.vertexai.common.util.decodeToFlow
import com.google.firebase.vertexai.common.util.fullModelName
import com.google.firebase.vertexai.type.RequestOptions
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.client.engine.okhttp.OkHttp
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
Expand All @@ -39,12 +40,9 @@ import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.bodyAsChannel
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.headersOf
import io.ktor.serialization.kotlinx.json.json
import io.ktor.utils.io.ByteChannel
import kotlin.math.max
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
Expand Down Expand Up @@ -94,24 +92,6 @@ internal constructor(
headerProvider: HeaderProvider? = null,
) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider)

@VisibleForTesting(otherwise = VisibleForTesting.NONE)
constructor(
key: String,
model: String,
requestOptions: RequestOptions,
apiClient: String,
headerProvider: HeaderProvider?,
channel: ByteChannel,
status: HttpStatusCode,
) : this(
key,
model,
requestOptions,
MockEngine { respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) },
apiClient,
headerProvider,
)

private val model = fullModelName(model)

private val client =
Expand Down Expand Up @@ -263,12 +243,32 @@ private suspend fun validateResponse(response: HttpResponse) {
if (message.contains("quota")) {
throw QuotaExceededException(message)
}
if (error.details?.any { "SERVICE_DISABLED" == it.reason } == true) {
throw ServiceDisabledException(message)
getServiceDisabledErrorDetailsOrNull(error)?.let {
val errorMessage =
if (it.metadata?.get("service") == "firebasevertexai.googleapis.com") {
"""
The Vertex AI for Firebase SDK requires the Firebase Vertex AI API
`firebasevertexai.googleapis.com` to be enabled for your project. Enable it by visiting
the Firebase Console at https://console.firebase.google.com/project/${Firebase.options.projectId}/genai/vertex then
retry. If you enabled this API recently, wait a few minutes for the action to propagate
to our systems and retry.
"""
.trimIndent()
} else {
error.message
}

throw ServiceDisabledException(errorMessage)
}
throw ServerException(message)
}

private fun getServiceDisabledErrorDetailsOrNull(error: GRpcError): GRpcErrorDetails? {
return error.details?.firstOrNull {
it.reason == "SERVICE_DISABLED" && it.domain == "googleapis.com"
}
}

private fun GenerateContentResponse.validate() = apply {
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
throw SerializationException("Error deserializing response, found no valid fields")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,9 @@ internal data class GRpcError(
val details: List<GRpcErrorDetails>? = null
)

@Serializable internal data class GRpcErrorDetails(val reason: String? = null)
@Serializable
internal data class GRpcErrorDetails(
val reason: String? = null,
val domain: String? = null,
val metadata: Map<String, String>? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import kotlin.time.toDuration
class RequestOptions
internal constructor(
internal val timeout: Duration,
internal val endpoint: String = "https://firebaseml.googleapis.com",
internal val apiVersion: String = "v2beta",
internal val endpoint: String = "https://firebasevertexai.googleapis.com",
internal val apiVersion: String = "v1beta",
) {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ internal class RequestFormatTests {
}
}

mockEngine.requestHistory.first().url.host shouldBe "firebaseml.googleapis.com"
mockEngine.requestHistory.first().url.host shouldBe "firebasevertexai.googleapis.com"
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ import com.google.firebase.vertexai.common.shared.TextPart
import com.google.firebase.vertexai.type.RequestOptions
import io.kotest.matchers.collections.shouldNotBeEmpty
import io.kotest.matchers.nulls.shouldNotBeNull
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
Expand Down Expand Up @@ -106,10 +110,11 @@ internal fun commonTest(
"super_cool_test_key",
"gemini-pro",
requestOptions,
MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
},
TEST_CLIENT_ID,
null,
channel,
status,
)
CommonTestScope(channel, apiController).block()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ import com.google.firebase.vertexai.common.APIController
import com.google.firebase.vertexai.type.RequestOptions
import io.kotest.matchers.collections.shouldNotBeEmpty
import io.kotest.matchers.nulls.shouldNotBeNull
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
Expand Down Expand Up @@ -93,10 +97,11 @@ internal fun commonTest(
"super_cool_test_key",
"gemini-pro",
requestOptions,
MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
},
TEST_CLIENT_ID,
null,
channel,
status,
)
val model = GenerativeModel("cool-model-name", controller = apiController)
CommonTestScope(channel, model).block()
Expand Down

0 comments on commit 744b00b

Please sign in to comment.