Skip to content

Commit

Permalink
Add lmstudio to langchain support
Browse files Browse the repository at this point in the history
  • Loading branch information
daviddenton committed Jul 9, 2024
1 parent 35bdb43 commit b7a8303
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 27 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ changes with their rationale when appropriate. Given version `A.B.C.D`, breaking
- **http4k-connect-amazon-cognito*** - [Breaking] AWS Cognito: Add support for server side authentication (
AdminInitiateAuth and AdminRespondToAuthChallenge). H/T @markth0mas
- **http4k-connect-ai-**** - [Breaking] Repackaged `ModelName` to common location. Just update imports!
- **http4k-connect-ai-langchain** - [Breaking] Added support for LmStudio chat and embedding models. Break is
renamed: `ChatModelOptions` to `OpenAiChatModelOptions`.
- **http4k-connect-ai-lmstudio*** - [New module!] LmStudio adapter module and fake so you can connect to a locally
running LLM server running any model.

Expand Down
2 changes: 2 additions & 0 deletions ai/langchain/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ dependencies {
api(project(":http4k-connect-amazon-s3"))
api(project(":http4k-connect-ai-openai"))
api(project(":http4k-connect-ai-ollama"))
api(project(":http4k-connect-ai-lmstudio"))
api("dev.langchain4j:langchain4j-core:_")

testImplementation("dev.langchain4j:langchain4j:_")
testImplementation(project(":http4k-connect-ai-openai-fake"))
testImplementation(project(":http4k-connect-ai-ollama-fake"))
testImplementation(project(":http4k-connect-ai-lmstudio-fake"))
testImplementation(project(":http4k-connect-amazon-s3-fake"))
testImplementation(project(path = ":http4k-connect-amazon-core", configuration = "testArtifacts"))
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package org.http4k.connect.langchain.chat

import dev.forkhandles.result4k.map
import dev.langchain4j.agent.tool.ToolExecutionRequest
import dev.langchain4j.agent.tool.ToolSpecification
import dev.langchain4j.data.message.AiMessage
import dev.langchain4j.data.message.ChatMessage
import dev.langchain4j.data.message.ImageContent
import dev.langchain4j.data.message.ImageContent.DetailLevel.AUTO
import dev.langchain4j.data.message.ImageContent.DetailLevel.HIGH
import dev.langchain4j.data.message.ImageContent.DetailLevel.LOW
import dev.langchain4j.data.message.SystemMessage
import dev.langchain4j.data.message.TextContent
import dev.langchain4j.data.message.UserMessage
import dev.langchain4j.model.chat.ChatLanguageModel
import dev.langchain4j.model.output.FinishReason
import dev.langchain4j.model.output.Response
import dev.langchain4j.model.output.TokenUsage
import org.http4k.connect.lmstudio.LmStudio
import org.http4k.connect.lmstudio.Role
import org.http4k.connect.lmstudio.TokenId
import org.http4k.connect.lmstudio.User
import org.http4k.connect.lmstudio.action.ContentType
import org.http4k.connect.lmstudio.action.Detail.auto
import org.http4k.connect.lmstudio.action.Detail.high
import org.http4k.connect.lmstudio.action.Detail.low
import org.http4k.connect.lmstudio.action.FinishReason.content_filter
import org.http4k.connect.lmstudio.action.FinishReason.length
import org.http4k.connect.lmstudio.action.FinishReason.stop
import org.http4k.connect.lmstudio.action.FinishReason.tool_calls
import org.http4k.connect.lmstudio.action.FunctionCall
import org.http4k.connect.lmstudio.action.FunctionSpec
import org.http4k.connect.lmstudio.action.ImageUrl
import org.http4k.connect.lmstudio.action.Message
import org.http4k.connect.lmstudio.action.MessageContent
import org.http4k.connect.lmstudio.action.ResponseFormat
import org.http4k.connect.lmstudio.action.Tool
import org.http4k.connect.lmstudio.action.ToolCall
import org.http4k.connect.lmstudio.chatCompletion
import org.http4k.connect.model.ModelName
import org.http4k.connect.orThrow
import org.http4k.core.Uri


data class LmStudioChatModelOptions(
val model: ModelName,
val stream: Boolean? = null,
val maxTokens: Int? = null,
val temperature: Double = 1.0,
val top_p: Double = 1.0,
val n: Int = 1,
val stop: Any? = null,
val presencePenalty: Double = 0.0,
val frequencyPenalty: Double = 0.0,
val logitBias: Map<TokenId, Double>? = null,
val user: User? = null,
val responseFormat: ResponseFormat? = null,
val toolChoice: Any? = null,
val parallelToolCalls: Boolean? = null,
)

fun LmStudioChatLanguageModel(
lmStudio: LmStudio,
options: LmStudioChatModelOptions
) =
object : ChatLanguageModel {
override fun generate(p0: List<ChatMessage>) = generate(p0, emptyList())

override fun generate(messages: List<ChatMessage>, toolSpecifications: List<ToolSpecification>?)
: Response<AiMessage> = with(options) {
lmStudio.chatCompletion(
model,
messages.map {
when (it) {
is UserMessage -> it.toHttp4k()
is SystemMessage -> it.toHttp4k()
is AiMessage -> it.toHttp4k()
else -> error("unknown message type")
}
},
maxTokens,
temperature,
top_p,
n,
stop,
presencePenalty,
frequencyPenalty,
logitBias,
user,
false,
responseFormat,
toolSpecifications?.takeIf { it.isNotEmpty() }?.map { it.toHttp4k() },
toolChoice,
parallelToolCalls
)
}
.map {
it.map {
Response(
AiMessage(it.choices?.mapNotNull { it.message?.content }?.joinToString("") ?: ""),
it.usage?.let { TokenUsage(it.prompt_tokens, it.completion_tokens, it.total_tokens) },
when (it.choices?.last()?.finish_reason) {
stop -> FinishReason.STOP
length -> FinishReason.LENGTH
content_filter -> FinishReason.CONTENT_FILTER
tool_calls -> FinishReason.TOOL_EXECUTION
else -> FinishReason.OTHER
}
)
}.toList()
}.orThrow().first()
}

private fun UserMessage.toHttp4k() = Message(
Role.User,
contents().map {
when (it) {
is TextContent -> it.toHttp4k()
is ImageContent -> it.toHttp4k()
else -> error("unknown content type")
}
}, name()?.let { User.of(it) },
null
)

private fun SystemMessage.toHttp4k() = Message(
Role.System,
listOf(MessageContent(ContentType.text, text()))
)

private fun AiMessage.toHttp4k(): Message {
val toolCalls = toolExecutionRequests()?.map { it.toHttp4k() }?.takeIf { it.isNotEmpty() }
return Message(Role.Assistant, listOf(MessageContent(ContentType.text, text())), tool_calls = toolCalls)
}

private fun ToolExecutionRequest.toHttp4k() = ToolCall(id(), "function", FunctionCall(name(), arguments()))

private fun TextContent.toHttp4k() = MessageContent(ContentType.text, this@toHttp4k.text())

private fun ImageContent.toHttp4k() =
MessageContent(
ContentType.image_url, null, ImageUrl(
Uri.of(this@toHttp4k.image().url().toString()),
when (this@toHttp4k.detailLevel()) {
LOW -> low
HIGH -> high
AUTO -> auto
}
)
)

private fun ToolSpecification.toHttp4k() = Tool(
FunctionSpec(
this@toHttp4k.name(),
this@toHttp4k.parameters(),
this@toHttp4k.description()
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ import dev.langchain4j.model.chat.ChatLanguageModel
import dev.langchain4j.model.output.FinishReason
import dev.langchain4j.model.output.Response
import dev.langchain4j.model.output.TokenUsage
import org.http4k.connect.model.ModelName
import org.http4k.connect.openai.GPT3_5
import org.http4k.connect.openai.OpenAI
import org.http4k.connect.openai.Role
import org.http4k.connect.openai.TokenId
import org.http4k.connect.openai.User
import org.http4k.connect.openai.action.ContentType
import org.http4k.connect.openai.action.Detail.auto
Expand All @@ -32,15 +35,34 @@ import org.http4k.connect.openai.action.FunctionSpec
import org.http4k.connect.openai.action.ImageUrl
import org.http4k.connect.openai.action.Message
import org.http4k.connect.openai.action.MessageContent
import org.http4k.connect.openai.action.ResponseFormat
import org.http4k.connect.openai.action.Tool
import org.http4k.connect.openai.action.ToolCall
import org.http4k.connect.openai.chatCompletion
import org.http4k.connect.orThrow
import org.http4k.core.Uri


data class OpenAiChatModelOptions(
val model: ModelName = ModelName.GPT3_5,
val stream: Boolean? = null,
val maxTokens: Int? = null,
val temperature: Double = 1.0,
val top_p: Double = 1.0,
val n: Int = 1,
val stop: Any? = null,
val presencePenalty: Double = 0.0,
val frequencyPenalty: Double = 0.0,
val logitBias: Map<TokenId, Double>? = null,
val user: User? = null,
val responseFormat: ResponseFormat? = null,
val toolChoice: Any? = null,
val parallelToolCalls: Boolean? = null,
)

fun OpenAiChatLanguageModel(
openAi: OpenAI,
options: ChatModelOptions = ChatModelOptions()
options: OpenAiChatModelOptions = OpenAiChatModelOptions()
) =
object : ChatLanguageModel {
override fun generate(p0: List<ChatMessage>) = generate(p0, emptyList())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.http4k.connect.langchain.embedding

import dev.forkhandles.result4k.map
import dev.langchain4j.data.embedding.Embedding
import dev.langchain4j.model.embedding.EmbeddingModel
import dev.langchain4j.model.output.Response
import org.http4k.connect.lmstudio.LmStudio
import org.http4k.connect.lmstudio.createEmbeddings
import org.http4k.connect.model.ModelName
import org.http4k.connect.orThrow

fun LmStudioEmbeddingModel(lmStudio: LmStudio, model: ModelName) = EmbeddingModel {
lmStudio.createEmbeddings(model, it?.map { it.text() } ?: emptyList())
.map { Response(it.data.map { Embedding(it.embedding) }) }
.orThrow()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.http4k.connect.langchain.chat

import org.http4k.connect.lmstudio.CHAT_MODEL
import org.http4k.connect.lmstudio.FakeLmStudio
import org.http4k.connect.lmstudio.Http
import org.http4k.connect.lmstudio.LmStudio
import org.http4k.connect.model.ModelName

class LmStudioChatLanguageModelTest : ChatLanguageModelContract {
override val model by lazy {
LmStudioChatLanguageModel(
LmStudio.Http(FakeLmStudio()),
LmStudioChatModelOptions(ModelName.CHAT_MODEL, temperature = 0.0)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class OpenAiChatLanguageModelTest : ChatLanguageModelContract {
override val model by lazy {
OpenAiChatLanguageModel(
OpenAI.Http(OpenAIToken.of("hello"), FakeOpenAI()),
ChatModelOptions(temperature = 0.0)
OpenAiChatModelOptions(temperature = 0.0)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class RealOpenAiChatLanguageModelTest : ChatLanguageModelContract {
override val model by lazy {
OpenAiChatLanguageModel(
OpenAI.Http(apiKey(ENV)!!, JavaHttpClient().debug()),
ChatModelOptions(temperature = 0.0)
OpenAiChatModelOptions(temperature = 0.0)
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.http4k.connect.langchain.embedding

import org.http4k.connect.lmstudio.CHAT_MODEL
import org.http4k.connect.lmstudio.FakeLmStudio
import org.http4k.connect.lmstudio.Http
import org.http4k.connect.lmstudio.LmStudio
import org.http4k.connect.model.ModelName

class LmStudioEmbeddingModelTest : EmbeddingModelContract {
override val model = LmStudioEmbeddingModel(LmStudio.Http(FakeLmStudio()), ModelName.CHAT_MODEL)
}

0 comments on commit b7a8303

Please sign in to comment.