Skip to content

Commit

Permalink
Update models for completion to support single message
Browse files Browse the repository at this point in the history
  • Loading branch information
daviddenton committed Sep 7, 2024
1 parent 7fc8147 commit 8148adb
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 23 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ changes with their rationale when appropriate. Given version `A.B.C.D`, breaking
- **http4k-connect-*** - Upgrade dependencies including Kotlin to 2.0.20
- **http4k-connect-ai-openai*** - [Breaking] Tightened up types for completion requests.
- **http4k-connect-ai-azure*** - [Breaking] Tightened up types for completion requests.
- **http4k-connect-ai-lmstudio*** - [Breaking] Tightened up types for completion requests.
- **http4k-connect-ai-ollama*** - [Breaking] Tightened up types for completion requests.

### v5.23.0.0
- **http4k-connect-*** - Upgrade dependencies including Kotlin to 2.0.20
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,18 @@ sealed class Content {
}

@JsonSerializable
data class Message(val role: Role, val content: List<Content>)
data class Message(val role: Role, val content: List<Content>) {
companion object {
fun User(content: Content) = Message(Role.User, listOf(content))
fun User(content: List<Content>) = Message(Role.User, content)
fun System(content: Content) = Message(Role.System, listOf(content))
fun System(content: List<Content>) = Message(Role.System, content)
fun Assistant(content: Content) = Message(Role.Assistant, listOf(content))
fun Assistant(content: List<Content>) = Message(Role.Assistant, content)
fun Tool(content: Content) = Message(Role.Tool, listOf(content))
fun Tool(content: List<Content>) = Message(Role.Tool, content)
}
}

@JsonSerializable
data class Tool(val name: ToolName, val description: String, val inputSchema: Schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import org.http4k.connect.anthropic.action.MessageGenerationEvent
import org.http4k.connect.anthropic.action.Source
import org.http4k.connect.model.Base64Blob
import org.http4k.connect.model.ModelName
import org.http4k.connect.model.Role
import org.http4k.connect.successValue
import org.http4k.testing.ApprovalTest
import org.junit.jupiter.api.Test
Expand All @@ -28,8 +27,8 @@ interface AnthropicAIContract {
val responses = anthropicAi.messageCompletion(
ModelName.of("claude-3-5-sonnet-20240620"),
listOf(
Message(
Role.User, listOf(
Message.User(
listOf(
Content.Image(
Source(
Base64Blob.encode(resourceLoader.stream("dog.png")),
Expand All @@ -51,9 +50,7 @@ interface AnthropicAIContract {
val responses = anthropicAi.messageCompletionStream(
ModelName.of("claude-3-5-sonnet-20240620"),
listOf(
Message(
Role.User, listOf(Content.Text("You are Leonardo Da Vinci"))
)
Message.User(listOf(Content.Text("You are Leonardo Da Vinci")))
),
100,
).successValue().toList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ data class ChatCompletion(
val tool_choice: Any? = null,
val parallel_tool_calls: Boolean? = null,
) : LmStudioAction<Sequence<CompletionResponse>> {
constructor(model: ModelName, message: Message, max_tokens: Int = 16, stream: Boolean = true)
: this(model, listOf(message), max_tokens, stream)

constructor(model: ModelName, messages: List<Message>, max_tokens: Int = 16, stream: Boolean = true) : this(
model,
messages,
Expand Down Expand Up @@ -79,17 +82,39 @@ data class ResponseFormat(
@JsonSerializable
data class Message(
val role: Role?,
val content: List<MessageContent>,
val content: List<MessageContent>? = null,
val name: User? = null,
val refusal: String? = null,
val tool_calls: List<ToolCall>? = null
) {
@Deprecated("Use relevant companion constructor instead")
constructor(
role: Role,
text: String,
name: User? = null,
tool_calls: List<ToolCall>? = null
) :
this(role, listOf(MessageContent(ContentType.text, text)), name, tool_calls)
this(role, listOf(MessageContent(ContentType.text, text)), name, null, tool_calls)

companion object {
fun User(content: String, name: User? = null) = User(listOf(MessageContent(ContentType.text, content)), name)
fun User(content: List<MessageContent>, name: User? = null) = Message(Role.User, content, name, null)

fun System(content: String, name: User? = null) =
System(listOf(MessageContent(ContentType.text, content)), name)

fun System(content: List<MessageContent>, name: User? = null) = Message(Role.System, content, name)

fun Assistant(content: String, name: User? = null, refusal: String? = null) =
Assistant(listOf(MessageContent(ContentType.text, content)), name, refusal)

fun Assistant(content: List<MessageContent>, name: User? = null, refusal: String? = null) =
Message(Role.Assistant, content, name, refusal)

@JvmName("AssistantToolCalls")
fun Assistant(tool_calls: List<ToolCall>, name: User? = null, refusal: String? = null) =
Message(Role.Assistant, null, name, refusal, tool_calls)
}
}

@JsonSerializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ interface LmStudioContract {
val responses = lmStudio.chatCompletion(
ModelName.CHAT_MODEL,
listOf(
Message(System, "You are Leonardo Da Vinci"),
Message(Companion.User, "What is your favourite colour?")
Message.System("You are Leonardo Da Vinci"),
Message.User("What is your favourite colour?")
),
1000,
stream = false
Expand All @@ -52,11 +52,8 @@ interface LmStudioContract {
val responses = lmStudio.chatCompletion(
ModelName.CHAT_MODEL,
listOf(
Message(System, "You are Leonardo Da Vinci"),
Message(
Role.User,
"What is your favourite colour?"
)
Message.System("You are Leonardo Da Vinci"),
Message.User("What is your favourite colour?")
),
1000,
stream = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ fun interface ChatCompletionGenerator : (ChatCompletion) -> List<Choice> {
val ChatCompletionGenerator.Companion.ReverseInput
get() = ChatCompletionGenerator { req ->
req.messages.flatMap { m ->
m.content.mapIndexed { i, content ->
m.content?.mapIndexed { i, content ->
Choice(i, ChoiceDetail(System, content.text?.reversed() ?: "", null), null, stop)
}
} ?: emptyList()
}
}

Expand All @@ -40,7 +40,7 @@ fun ChatCompletionGenerator.Companion.LoremIpsum(random: Random = Random(0)) = C
*/
val ChatCompletionGenerator.Companion.Echo
get() = ChatCompletionGenerator { req ->
req.choices(req.messages.first { it.role == User }.content.first().text ?: "")
req.choices(req.messages.first { it.role == User }.content?.first()?.text ?: "")
}

private fun ChatCompletion.choices(msg: String) = (if (stream) msg.split(" ").map { "$it " } else listOf(msg))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,20 @@ data class ChatCompletion(
val keep_alive: String? = null,
val options: ModelOptions? = null
) : OllamaAction<Sequence<ChatCompletionResponse>> {
constructor(
model: ModelName,
messages: Message,
stream: Boolean? = false,
format: ResponseFormat? = null,
keep_alive: String? = null,
options: ModelOptions? = null
) : this(model, listOf(messages), stream, format, keep_alive, options)

override fun toRequest() = Request(POST, "/api/chat")
.with(autoBody<ChatCompletion>().toLens() of this)

override fun toResult(response: Response) = toCompletionSequence(response, OllamaMoshi, "", "__FAKE_HTTP4k_STOP_SIGNAL__")
override fun toResult(response: Response) =
toCompletionSequence(response, OllamaMoshi, "", "__FAKE_HTTP4k_STOP_SIGNAL__")
}

@JsonSerializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,10 @@ enum class ResponseFormat {
}

@JsonSerializable
data class Message(val role: Role, val content: String, val images: List<Base64Blob>? = null)
data class Message(val role: Role, val content: String, val images: List<Base64Blob>? = null) {
companion object {
fun User(content: String, images: List<Base64Blob>? = null) = Message(Role.User, content, images)
fun System(content: String, images: List<Base64Blob>? = null) = Message(Role.System, content, images)
fun Assistant(content: String, images: List<Base64Blob>? = null) = Message(Role.Assistant, content, images)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ interface OllamaContract {
fun `get chat response non-stream`() {
val responses = ollama.chatCompletion(
modelName,
listOf(Message(User, "count to five", null)),
listOf(Message.User("count to five")),
false,
null,
null,
Expand All @@ -89,7 +89,7 @@ interface OllamaContract {
fun `get chat response stream`() {
val responses = ollama.chatCompletion(
modelName,
listOf(Message(User, "count to five", null)),
listOf(Message.User("count to five")),
true,
null,
null,
Expand Down

0 comments on commit 8148adb

Please sign in to comment.