Skip to content

Commit

Permalink
Merge pull request #463 from google-ai-edge/mrschmidt/spaces
Browse files Browse the repository at this point in the history
Trim spaces in UI display
  • Loading branch information
khanhlvg authored Sep 25, 2024
2 parents bdafe5d + 8b6ec73 commit 7e69d9e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import java.util.UUID
*/
data class ChatMessage(
val id: String = UUID.randomUUID().toString(),
val message: String = "",
val rawMessage: String = "",
val author: String,
val isLoading: Boolean = false
) {
val isFromUser: Boolean
get() = author == USER_PREFIX
val message: String
get() = rawMessage.trim()
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ChatUiState(

// Prompt the model with the current chat history
override val fullPrompt: String
get() = _messages.joinToString(separator = "\n") { it.message }
get() = _messages.joinToString(separator = "\n") { it.rawMessage }

override fun createLoadingMessage(): String {
val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true)
Expand All @@ -54,14 +54,14 @@ class ChatUiState(
override fun appendMessage(id: String, text: String, done: Boolean) {
val index = _messages.indexOfFirst { it.id == id }
if (index != -1) {
val newText = _messages[index].message + text
_messages[index] = _messages[index].copy(message = newText, isLoading = false)
val newText = _messages[index].rawMessage + text
_messages[index] = _messages[index].copy(rawMessage = newText, isLoading = false)
}
}

override fun addMessage(text: String, author: String): String {
val chatMessage = ChatMessage(
message = text,
rawMessage = text,
author = author
)
_messages.add(chatMessage)
Expand All @@ -85,7 +85,7 @@ class GemmaUiState(
_messages. apply{
for (i in indices) {
this[i] = this[i].copy(
message = this[i].message.replace(START_TURN + this[i].author + "\n", "")
rawMessage = this[i].rawMessage.replace(START_TURN + this[i].author + "\n", "")
.replace(END_TURN, "")
)
}
Expand All @@ -95,7 +95,7 @@ class GemmaUiState(

// Only using the last 4 messages to keep input + output short
override val fullPrompt: String
get() = _messages.takeLast(4).joinToString(separator = "\n") { it.message }
get() = _messages.takeLast(4).joinToString(separator = "\n") { it.rawMessage }

override fun createLoadingMessage(): String {
val chatMessage = ChatMessage(author = MODEL_PREFIX, isLoading = true)
Expand All @@ -112,18 +112,18 @@ class GemmaUiState(
if (index != -1) {
val newText = if (done) {
// Append the Suffix when model is done generating the response
_messages[index].message + text + END_TURN
_messages[index].rawMessage + text + END_TURN
} else {
// Append the text
_messages[index].message + text
_messages[index].rawMessage + text
}
_messages[index] = _messages[index].copy(message = newText, isLoading = false)
_messages[index] = _messages[index].copy(rawMessage = newText, isLoading = false)
}
}

override fun addMessage(text: String, author: String): String {
val chatMessage = ChatMessage(
message = "$START_TURN$author\n$text$END_TURN",
rawMessage = "$START_TURN$author\n$text$END_TURN",
author = author
)
_messages.add(chatMessage)
Expand Down

0 comments on commit 7e69d9e

Please sign in to comment.