Skip to content

Commit

Permalink
Replace SSE function to library native function
Browse files Browse the repository at this point in the history
  • Loading branch information
Taewan-P committed Nov 23, 2024
1 parent 9ae36c3 commit 0dffff5
Showing 1 changed file with 16 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@ import dev.chungjungsoo.gptmobile.data.dto.anthropic.response.ErrorResponseChunk
import dev.chungjungsoo.gptmobile.data.dto.anthropic.response.MessageResponseChunk
import io.ktor.client.call.body
import io.ktor.client.plugins.sse.sse
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.accept
import io.ktor.client.request.headers
import io.ktor.client.request.setBody
import io.ktor.client.request.url
import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.HttpStatement
import io.ktor.http.ContentType
import io.ktor.http.HttpMethod
import io.ktor.http.contentType
import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.cancel
import io.ktor.utils.io.readUTF8Line
Expand All @@ -26,7 +22,6 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.isActive
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.encodeToJsonElement

Expand All @@ -48,39 +43,23 @@ class AnthropicAPIImpl @Inject constructor(
override fun streamChatMessage(messageRequest: MessageRequest): Flow<MessageResponseChunk> {
val body = Json.encodeToJsonElement(messageRequest)

val builder = HttpRequestBuilder().apply {
method = HttpMethod.Post
if (apiUrl.endsWith("/")) url("${apiUrl}v1/messages") else url("$apiUrl/v1/messages")
contentType(ContentType.Application.Json)
setBody(body)
accept(ContentType.Text.EventStream)
headers {
append(API_KEY_HEADER, token ?: "")
append(VERSION_HEADER, ANTHROPIC_VERSION)
}
}

runBlocking {
networkClient().sse(
host = apiUrl,
path = if (apiUrl.endsWith("/")) "v1/messages" else "/v1/messages"
) {
incoming.collect { event ->
val line = event.data
val value = when {
line?.startsWith(STREAM_END_TOKEN) == true -> break
line?.startsWith(STREAM_PREFIX) == true -> Json.decodeFromString(line.removePrefix(STREAM_PREFIX))
else -> continue
}
}
}
}

return flow {
return flow<MessageResponseChunk> {
try {
HttpStatement(builder = builder, client = networkClient()).execute {
streamEventsFrom(it)
}
networkClient()
.sse(
urlString = if (apiUrl.endsWith("/")) "${apiUrl}v1/messages" else "$apiUrl/v1/messages",
request = {
method = HttpMethod.Post
setBody(body)
accept(ContentType.Text.EventStream)
headers {
append(API_KEY_HEADER, token ?: "")
append(VERSION_HEADER, ANTHROPIC_VERSION)
}
}
) {
incoming.collect { event -> event.data?.let { line -> emit(Json.decodeFromString(line)) } }
}
} catch (e: Exception) {
emit(ErrorResponseChunk(error = ErrorDetail(type = "network_error", message = e.message ?: "")))
}
Expand Down

0 comments on commit 0dffff5

Please sign in to comment.