Skip to content

Commit

Permalink
gpt-4-vision-preview support fix and test
Browse files Browse the repository at this point in the history
  • Loading branch information
James J Kalafus authored and tisfeng committed May 11, 2024
1 parent c5f1147 commit a373eae
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 24 deletions.
67 changes: 52 additions & 15 deletions Sources/OpenAI/Public/Models/ChatQuery.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ public struct ChatQuery: Equatable, Codable, Streamable {
case assistant(Self.ChatCompletionAssistantMessageParam)
case tool(Self.ChatCompletionToolMessageParam)

public var content: Self.ChatCompletionUserMessageParam.Content? { get { // TODO: String type except for .user
public var content: Self.ChatCompletionUserMessageParam.Content? { get {
switch self {
case .system(let systemMessage):
return Self.ChatCompletionUserMessageParam.Content.string(systemMessage.content)
case .user(let userMessage):
return userMessage.content // TODO: Content type
return userMessage.content
case .assistant(let assistantMessage):
if let content = assistantMessage.content {
return Self.ChatCompletionUserMessageParam.Content.string(content)
Expand Down Expand Up @@ -178,7 +178,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
public init?(
role: Role,
content: String? = nil,
imageUrl: URL? = nil,
name: String? = nil,
toolCalls: [Self.ChatCompletionAssistantMessageParam.ChatCompletionMessageToolCallParam]? = nil,
toolCallId: String? = nil
Expand All @@ -193,8 +192,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
case .user:
if let content {
self = .user(.init(content: .init(string: content), name: name))
} else if let imageUrl {
self = .user(.init(content: .init(chatCompletionContentPartImageParam: .init(imageUrl: .init(url: imageUrl.absoluteString, detail: .auto))), name: name))
} else {
return nil
}
Expand All @@ -209,6 +206,20 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}
}

public init?(
role: Role,
content: [ChatCompletionUserMessageParam.Content.VisionContent],
name: String? = nil
) {
switch role {
case .user:
self = .user(.init(content: .vision(content), name: name))
default:
return nil
}

}

private init?(
content: String,
role: Role,
Expand Down Expand Up @@ -330,8 +341,7 @@ public struct ChatQuery: Equatable, Codable, Streamable {

public enum Content: Codable, Equatable {
case string(String)
case chatCompletionContentPartTextParam(ChatCompletionContentPartTextParam)
case chatCompletionContentPartImageParam(ChatCompletionContentPartImageParam)
case vision([VisionContent])

public var string: String? { get {
switch self {
Expand All @@ -342,6 +352,33 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}
}}

public init(string: String) {
self = .string(string)
}

public init(vision: [VisionContent]) {
self = .vision(vision)
}

public enum CodingKeys: CodingKey {
case string
case vision
}

public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .string(let a0):
try container.encode(a0)
case .vision(let a0):
try container.encode(a0)
}
}

public enum VisionContent: Codable, Equatable {
case chatCompletionContentPartTextParam(ChatCompletionContentPartTextParam)
case chatCompletionContentPartImageParam(ChatCompletionContentPartImageParam)

public var text: String? { get {
switch self {
case .chatCompletionContentPartTextParam(let text):
Expand All @@ -360,10 +397,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}
}}

public init(string: String) {
self = .string(string)
}

public init(chatCompletionContentPartTextParam: ChatCompletionContentPartTextParam) {
self = .chatCompletionContentPartTextParam(chatCompletionContentPartTextParam)
}
Expand All @@ -375,8 +408,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .string(let a0):
try container.encode(a0)
case .chatCompletionContentPartTextParam(let a0):
try container.encode(a0)
case .chatCompletionContentPartImageParam(let a0):
Expand All @@ -385,7 +416,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}

enum CodingKeys: CodingKey {
case string
case chatCompletionContentPartTextParam
case chatCompletionContentPartImageParam
}
Expand All @@ -409,7 +439,7 @@ public struct ChatQuery: Equatable, Codable, Streamable {

public init(imageUrl: ImageURL) {
self.imageUrl = imageUrl
self.type = "imageUrl"
self.type = "image_url"
}

public struct ImageURL: Codable, Equatable {
Expand All @@ -424,6 +454,12 @@ public struct ChatQuery: Equatable, Codable, Streamable {
self.detail = detail
}

public init(url: Data, detail: Detail) {
self.init(
url: "data:image/jpeg;base64,\(url.base64EncodedString())",
detail: detail)
}

public enum Detail: String, Codable, Equatable, CaseIterable {
case auto
case low
Expand All @@ -438,6 +474,7 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}
}
}
}

internal struct ChatCompletionMessageParam: Codable, Equatable {
typealias Role = ChatQuery.ChatCompletionMessageParam.Role
Expand Down
11 changes: 3 additions & 8 deletions Sources/OpenAI/Public/Models/ChatResult.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,10 @@ extension ChatQuery.ChatCompletionMessageParam.ChatCompletionUserMessageParam.Co
return
} catch {}
do {
let text = try container.decode(ChatCompletionContentPartTextParam.self)
self = .chatCompletionContentPartTextParam(text)
let vision = try container.decode([VisionContent].self)
self = .vision(vision)
return
} catch {}
do {
let image = try container.decode(ChatCompletionContentPartImageParam.self)
self = .chatCompletionContentPartImageParam(image)
return
} catch {}
throw DecodingError.typeMismatch(Self.self, .init(codingPath: [Self.CodingKeys.string, CodingKeys.chatCompletionContentPartTextParam, CodingKeys.chatCompletionContentPartImageParam], debugDescription: "Content: expected String, ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam"))
throw DecodingError.typeMismatch(Self.self, .init(codingPath: [Self.CodingKeys.string, Self.CodingKeys.vision], debugDescription: "Content: expected String || Vision"))
}
}
46 changes: 45 additions & 1 deletion Tests/OpenAITests/OpenAITestsDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,51 @@ class OpenAITestsDecoder: XCTestCase {

XCTAssertEqual(imageQueryAsDict, expectedValueAsDict)
}


func testChatQueryWithVision() async throws {
let chatQuery = ChatQuery(messages: [
// .init(role: .user, content: [
// .chatCompletionContentPartTextParam(.init(text: "What's in this image?")),
// .chatCompletionContentPartImageParam(.init(imageUrl: .init(url: "https://some.url/image.jpeg", detail: .auto)))
// ])!
.user(.init(content: .vision([
.chatCompletionContentPartTextParam(.init(text: "What's in this image?")),
.chatCompletionContentPartImageParam(.init(imageUrl: .init(url: "https://some.url/image.jpeg", detail: .auto)))
])))
], model: Model.gpt4_vision_preview, maxTokens: 300)
let expectedValue = """
{
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://some.url/image.jpeg",
"detail": "auto"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}
"""

// To compare serialized JSONs we first convert them both into NSDictionary which are comparable (unline native swift dictionaries)
let chatQueryAsDict = try jsonDataAsNSDictionary(JSONEncoder().encode(chatQuery))
let expectedValueAsDict = try jsonDataAsNSDictionary(expectedValue.data(using: .utf8)!)

XCTAssertEqual(chatQueryAsDict, expectedValueAsDict)
}

func testChatQueryWithFunctionCall() async throws {
let chatQuery = ChatQuery(
messages: [
Expand Down

0 comments on commit a373eae

Please sign in to comment.