Skip to content

Commit

Permalink
Use kotlin's explicit API in vertexAI
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo committed Sep 26, 2024
1 parent 552132b commit 39620ae
Show file tree
Hide file tree
Showing 28 changed files with 242 additions and 194 deletions.
18 changes: 18 additions & 0 deletions firebase-vertexai/firebase-vertexai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

@file:Suppress("UnstableApiUsage")

import org.jetbrains.kotlin.gradle.tasks.KotlinCompile


plugins {
id("firebase-library")
id("kotlin-android")
Expand Down Expand Up @@ -56,6 +59,21 @@ android {
}
}

// Enable Kotlin "Explicit API Mode". This causes the Kotlin compiler to fail if any
// classes, methods, or properties have implicit `public` visibility. This check helps
// avoid accidentally leaking elements into the public API, requiring that any public
// element be explicitly declared as `public`.
// https://github.com/Kotlin/KEEP/blob/master/proposals/explicit-api-mode.md
// https://chao2zhang.medium.com/explicit-api-mode-for-kotlin-on-android-b8264fdd76d1
tasks.withType<KotlinCompile>().all {
if (!name.contains("test", ignoreCase = true)) {
if (!kotlinOptions.freeCompilerArgs.contains("-Xexplicit-api=strict")) {
kotlinOptions.freeCompilerArgs += "-Xexplicit-api=strict"
}
}
}


dependencies {
val ktorVersion = "2.3.2"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ import kotlinx.coroutines.flow.onEach
* @param model The model to use for the interaction
* @property history The previous interactions with the model
*/
class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) {
public class Chat(
private val model: GenerativeModel,
public val history: MutableList<Content> = ArrayList()
) {
private var lock = Semaphore(1)

/**
Expand All @@ -53,7 +56,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
public suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()
attemptLock()
try {
Expand All @@ -72,7 +75,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @param prompt The text to be converted into a single piece of [Content] to send to the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: String): GenerateContentResponse {
public suspend fun sendMessage(prompt: String): GenerateContentResponse {
val content = content { text(prompt) }
return sendMessage(content)
}
Expand All @@ -83,7 +86,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @param prompt The image to be converted into a single piece of [Content] to send to the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Bitmap): GenerateContentResponse {
public suspend fun sendMessage(prompt: Bitmap): GenerateContentResponse {
val content = content { image(prompt) }
return sendMessage(content)
}
Expand All @@ -96,7 +99,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: Content): Flow<GenerateContentResponse> {
public fun sendMessageStream(prompt: Content): Flow<GenerateContentResponse> {
prompt.assertComesFromUser()
attemptLock()

Expand Down Expand Up @@ -149,7 +152,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: String): Flow<GenerateContentResponse> {
public fun sendMessageStream(prompt: String): Flow<GenerateContentResponse> {
val content = content { text(prompt) }
return sendMessageStream(content)
}
Expand All @@ -161,7 +164,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: Bitmap): Flow<GenerateContentResponse> {
public fun sendMessageStream(prompt: Bitmap): Flow<GenerateContentResponse> {
val content = content { image(prompt) }
return sendMessageStream(content)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import com.google.firebase.vertexai.type.Tool
import com.google.firebase.vertexai.type.ToolConfig

/** Entry point for all _Vertex AI for Firebase_ functionality. */
class FirebaseVertexAI
public class FirebaseVertexAI
internal constructor(
private val firebaseApp: FirebaseApp,
private val location: String,
Expand All @@ -51,7 +51,7 @@ internal constructor(
* @param systemInstruction contains a [Content] that directs the model to behave a certain way
*/
@JvmOverloads
fun generativeModel(
public fun generativeModel(
modelName: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
Expand All @@ -77,13 +77,13 @@ internal constructor(
)
}

companion object {
public companion object {
/** The [FirebaseVertexAI] instance for the default [FirebaseApp] */
@JvmStatic
val instance: FirebaseVertexAI
public val instance: FirebaseVertexAI
get() = getInstance(location = "us-central1")

@JvmStatic fun getInstance(app: FirebaseApp): FirebaseVertexAI = getInstance(app)
@JvmStatic public fun getInstance(app: FirebaseApp): FirebaseVertexAI = getInstance(app)

/**
* Returns the [FirebaseVertexAI] instance for the provided [FirebaseApp] and [location]
Expand All @@ -93,19 +93,19 @@ internal constructor(
*/
@JvmStatic
@JvmOverloads
fun getInstance(app: FirebaseApp = Firebase.app, location: String): FirebaseVertexAI {
public fun getInstance(app: FirebaseApp = Firebase.app, location: String): FirebaseVertexAI {
val multiResourceComponent = app[FirebaseVertexAIMultiResourceComponent::class.java]
return multiResourceComponent.get(location)
}
}
}

/** Returns the [FirebaseVertexAI] instance of the default [FirebaseApp]. */
val Firebase.vertexAI: FirebaseVertexAI
public val Firebase.vertexAI: FirebaseVertexAI
get() = FirebaseVertexAI.instance

/** Returns the [FirebaseVertexAI] instance of a given [FirebaseApp]. */
fun Firebase.vertexAI(
public fun Firebase.vertexAI(
app: FirebaseApp = Firebase.app,
location: String = "us-central1"
): FirebaseVertexAI = FirebaseVertexAI.getInstance(app, location)
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import kotlinx.coroutines.tasks.await
/**
* A controller for communicating with the API of a given multimodal model (for example, Gemini).
*/
class GenerativeModel
public class GenerativeModel
internal constructor(
private val modelName: String,
private val generationConfig: GenerationConfig? = null,
Expand Down Expand Up @@ -128,7 +128,7 @@ internal constructor(
* @return A [GenerateContentResponse]. Function should be called within a suspend context to
* properly manage concurrency.
*/
suspend fun generateContent(vararg prompt: Content): GenerateContentResponse =
public suspend fun generateContent(vararg prompt: Content): GenerateContentResponse =
try {
controller.generateContent(constructRequest(*prompt)).toPublic().validate()
} catch (e: Throwable) {
Expand All @@ -141,7 +141,7 @@ internal constructor(
* @param prompt [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(vararg prompt: Content): Flow<GenerateContentResponse> =
public fun generateContentStream(vararg prompt: Content): Flow<GenerateContentResponse> =
controller
.generateContentStream(constructRequest(*prompt))
.catch { throw FirebaseVertexAIException.from(it) }
Expand All @@ -154,7 +154,7 @@ internal constructor(
* @return A [GenerateContentResponse] after some delay. Function should be called within a
* suspend context to properly manage concurrency.
*/
suspend fun generateContent(prompt: String): GenerateContentResponse =
public suspend fun generateContent(prompt: String): GenerateContentResponse =
generateContent(content { text(prompt) })

/**
Expand All @@ -163,7 +163,7 @@ internal constructor(
* @param prompt The text to be converted into a single piece of [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(prompt: String): Flow<GenerateContentResponse> =
public fun generateContentStream(prompt: String): Flow<GenerateContentResponse> =
generateContentStream(content { text(prompt) })

/**
Expand All @@ -173,7 +173,7 @@ internal constructor(
* @return A [GenerateContentResponse] after some delay. Function should be called within a
* suspend context to properly manage concurrency.
*/
suspend fun generateContent(prompt: Bitmap): GenerateContentResponse =
public suspend fun generateContent(prompt: Bitmap): GenerateContentResponse =
generateContent(content { image(prompt) })

/**
Expand All @@ -182,19 +182,20 @@ internal constructor(
* @param prompt The image to be converted into a single piece of [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(prompt: Bitmap): Flow<GenerateContentResponse> =
public fun generateContentStream(prompt: Bitmap): Flow<GenerateContentResponse> =
generateContentStream(content { image(prompt) })

/** Creates a [Chat] instance which internally tracks the ongoing conversation with the model */
fun startChat(history: List<Content> = emptyList()): Chat = Chat(this, history.toMutableList())
public fun startChat(history: List<Content> = emptyList()): Chat =
Chat(this, history.toMutableList())

/**
* Counts the amount of tokens in a prompt.
*
* @param prompt A group of [Content] to count tokens of.
* @return A [CountTokensResponse] containing the amount of tokens in the prompt.
*/
suspend fun countTokens(vararg prompt: Content): CountTokensResponse {
public suspend fun countTokens(vararg prompt: Content): CountTokensResponse {
try {
return controller.countTokens(constructCountTokensRequest(*prompt)).toPublic()
} catch (e: Throwable) {
Expand All @@ -208,7 +209,7 @@ internal constructor(
* @param prompt The text to be converted to a single piece of [Content] to count the tokens of.
* @return A [CountTokensResponse] containing the amount of tokens in the prompt.
*/
suspend fun countTokens(prompt: String): CountTokensResponse {
public suspend fun countTokens(prompt: String): CountTokensResponse {
return countTokens(content { text(prompt) })
}

Expand All @@ -218,7 +219,7 @@ internal constructor(
* @param prompt The image to be converted to a single piece of [Content] to count the tokens of.
* @return A [CountTokensResponse] containing the amount of tokens in the prompt.
*/
suspend fun countTokens(prompt: Bitmap): CountTokensResponse {
public suspend fun countTokens(prompt: Bitmap): CountTokensResponse {
return countTokens(content { image(prompt) })
}

Expand Down Expand Up @@ -247,7 +248,7 @@ internal constructor(
?.let { throw ResponseStoppedException(this) }
}

companion object {
private companion object {
private val TAG = GenerativeModel::class.java.simpleName
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ internal enum class HarmCategory {
@SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT
}

typealias Base64 = String
internal typealias Base64 = String

@ExperimentalSerializationApi
@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,25 @@ import org.reactivestreams.Publisher
*
* @see from
*/
abstract class ChatFutures internal constructor() {
public abstract class ChatFutures internal constructor() {

/**
* Generates a response from the backend with the provided [Content], and any previous ones
* sent/returned from this chat.
*
* @param prompt A [Content] to send to the model.
*/
abstract fun sendMessage(prompt: Content): ListenableFuture<GenerateContentResponse>
public abstract fun sendMessage(prompt: Content): ListenableFuture<GenerateContentResponse>

/**
* Generates a streaming response from the backend with the provided [Content].
*
* @param prompt A [Content] to send to the model.
*/
abstract fun sendMessageStream(prompt: Content): Publisher<GenerateContentResponse>
public abstract fun sendMessageStream(prompt: Content): Publisher<GenerateContentResponse>

/** Returns the [Chat] instance that was used to create this instance */
abstract fun getChat(): Chat
public abstract fun getChat(): Chat

private class FuturesImpl(private val chat: Chat) : ChatFutures() {
override fun sendMessage(prompt: Content): ListenableFuture<GenerateContentResponse> =
Expand All @@ -59,9 +59,9 @@ abstract class ChatFutures internal constructor() {
override fun getChat(): Chat = chat
}

companion object {
public companion object {

/** @return a [ChatFutures] created around the provided [Chat] */
@JvmStatic fun from(chat: Chat): ChatFutures = FuturesImpl(chat)
@JvmStatic public fun from(chat: Chat): ChatFutures = FuturesImpl(chat)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,41 +31,45 @@ import org.reactivestreams.Publisher
*
* @see from
*/
abstract class GenerativeModelFutures internal constructor() {
public abstract class GenerativeModelFutures internal constructor() {

/**
* Generates a response from the backend with the provided [Content].
*
* @param prompt A group of [Content] to send to the model.
*/
abstract fun generateContent(vararg prompt: Content): ListenableFuture<GenerateContentResponse>
public abstract fun generateContent(
vararg prompt: Content
): ListenableFuture<GenerateContentResponse>

/**
* Generates a streaming response from the backend with the provided [Content].
*
* @param prompt A group of [Content] to send to the model.
*/
abstract fun generateContentStream(vararg prompt: Content): Publisher<GenerateContentResponse>
public abstract fun generateContentStream(
vararg prompt: Content
): Publisher<GenerateContentResponse>

/**
* Counts the number of tokens used in a prompt.
*
* @param prompt A group of [Content] to count tokens of.
*/
abstract fun countTokens(vararg prompt: Content): ListenableFuture<CountTokensResponse>
public abstract fun countTokens(vararg prompt: Content): ListenableFuture<CountTokensResponse>

/** Creates a chat instance which internally tracks the ongoing conversation with the model */
abstract fun startChat(): ChatFutures
public abstract fun startChat(): ChatFutures

/**
* Creates a chat instance which internally tracks the ongoing conversation with the model
*
* @param history an existing history of context to use as a starting point
*/
abstract fun startChat(history: List<Content>): ChatFutures
public abstract fun startChat(history: List<Content>): ChatFutures

/** Returns the [GenerativeModel] instance that was used to create this object */
abstract fun getGenerativeModel(): GenerativeModel
public abstract fun getGenerativeModel(): GenerativeModel

private class FuturesImpl(private val model: GenerativeModel) : GenerativeModelFutures() {
override fun generateContent(
Expand All @@ -86,9 +90,9 @@ abstract class GenerativeModelFutures internal constructor() {
override fun getGenerativeModel(): GenerativeModel = model
}

companion object {
public companion object {

/** @return a [GenerativeModelFutures] created around the provided [GenerativeModel] */
@JvmStatic fun from(model: GenerativeModel): GenerativeModelFutures = FuturesImpl(model)
@JvmStatic public fun from(model: GenerativeModel): GenerativeModelFutures = FuturesImpl(model)
}
}
Loading

0 comments on commit 39620ae

Please sign in to comment.