Skip to content

Commit

Permalink
Merge pull request #3 from StanfordBDHG/upstream/0.2.9
Browse files Browse the repository at this point in the history
Lift to upstream 0.2.9 (support for GPT-4o)
  • Loading branch information
vishnuravi authored May 15, 2024
2 parents 29316ba + 9fc0277 commit 1ad95dd
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 36 deletions.
4 changes: 2 additions & 2 deletions Demo/Demo.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
IPHONEOS_DEPLOYMENT_TARGET = 16.4;
IPHONEOS_DEPLOYMENT_TARGET = 17.0;
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 13.3;
Expand Down Expand Up @@ -354,7 +354,7 @@
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
IPHONEOS_DEPLOYMENT_TARGET = 16.4;
IPHONEOS_DEPLOYMENT_TARGET = 17.0;
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 13.3;
Expand Down
4 changes: 2 additions & 2 deletions Demo/DemoChat/Package.swift
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// swift-tools-version: 5.8
// swift-tools-version: 5.9
// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription

let package = Package(
name: "DemoChat",
platforms: [.macOS(.v13), .iOS(.v16)],
platforms: [.macOS(.v13), .iOS(.v17)],
products: [
.library(
name: "DemoChat",
Expand Down
2 changes: 1 addition & 1 deletion Demo/DemoChat/Sources/UI/DetailView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct DetailView: View {
@State private var showsModelSelectionSheet = false
@State private var selectedChatModel: Model = .gpt4_0613

private static let availableChatModels: [Model] = [.gpt3_5Turbo, .gpt4]
private static let availableChatModels: [Model] = [.gpt3_5Turbo, .gpt4, .gpt4_o]

let conversation: Conversation
let error: Error?
Expand Down
4 changes: 0 additions & 4 deletions Sources/OpenAI/Public/Models/AudioTranscriptionQuery.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ public enum ResponseFormat: String, Codable, Equatable, CaseIterable {
switch self {
case .mpga:
fileName += Self.mp3.rawValue
case .m4a:
fileName += Self.mp4.rawValue
default:
fileName += self.rawValue
}
Expand All @@ -72,8 +70,6 @@ public enum ResponseFormat: String, Codable, Equatable, CaseIterable {
switch self {
case .mpga:
contentType += Self.mp3.rawValue
case .m4a:
contentType += Self.mp4.rawValue
default:
contentType += self.rawValue
}
Expand Down
67 changes: 52 additions & 15 deletions Sources/OpenAI/Public/Models/ChatQuery.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ public struct ChatQuery: Equatable, Codable, Streamable {
case assistant(Self.ChatCompletionAssistantMessageParam)
case tool(Self.ChatCompletionToolMessageParam)

public var content: Self.ChatCompletionUserMessageParam.Content? { get { // TODO: String type except for .user
public var content: Self.ChatCompletionUserMessageParam.Content? { get {
switch self {
case .system(let systemMessage):
return Self.ChatCompletionUserMessageParam.Content.string(systemMessage.content)
case .user(let userMessage):
return userMessage.content // TODO: Content type
return userMessage.content
case .assistant(let assistantMessage):
if let content = assistantMessage.content {
return Self.ChatCompletionUserMessageParam.Content.string(content)
Expand Down Expand Up @@ -178,7 +178,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
public init?(
role: Role,
content: String? = nil,
imageUrl: URL? = nil,
name: String? = nil,
toolCalls: [Self.ChatCompletionAssistantMessageParam.ChatCompletionMessageToolCallParam]? = nil,
toolCallId: String? = nil
Expand All @@ -193,8 +192,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
case .user:
if let content {
self = .user(.init(content: .init(string: content), name: name))
} else if let imageUrl {
self = .user(.init(content: .init(chatCompletionContentPartImageParam: .init(imageUrl: .init(url: imageUrl.absoluteString, detail: .auto))), name: name))
} else {
return nil
}
Expand All @@ -209,6 +206,20 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}
}

public init?(
role: Role,
content: [ChatCompletionUserMessageParam.Content.VisionContent],
name: String? = nil
) {
switch role {
case .user:
self = .user(.init(content: .vision(content), name: name))
default:
return nil
}

}

private init?(
content: String,
role: Role,
Expand Down Expand Up @@ -330,8 +341,7 @@ public struct ChatQuery: Equatable, Codable, Streamable {

public enum Content: Codable, Equatable {
case string(String)
case chatCompletionContentPartTextParam(ChatCompletionContentPartTextParam)
case chatCompletionContentPartImageParam(ChatCompletionContentPartImageParam)
case vision([VisionContent])

public var string: String? { get {
switch self {
Expand All @@ -342,6 +352,33 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}
}}

public init(string: String) {
self = .string(string)
}

public init(vision: [VisionContent]) {
self = .vision(vision)
}

public enum CodingKeys: CodingKey {
case string
case vision
}

public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .string(let a0):
try container.encode(a0)
case .vision(let a0):
try container.encode(a0)
}
}

public enum VisionContent: Codable, Equatable {
case chatCompletionContentPartTextParam(ChatCompletionContentPartTextParam)
case chatCompletionContentPartImageParam(ChatCompletionContentPartImageParam)

public var text: String? { get {
switch self {
case .chatCompletionContentPartTextParam(let text):
Expand All @@ -360,10 +397,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}
}}

public init(string: String) {
self = .string(string)
}

public init(chatCompletionContentPartTextParam: ChatCompletionContentPartTextParam) {
self = .chatCompletionContentPartTextParam(chatCompletionContentPartTextParam)
}
Expand All @@ -375,8 +408,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .string(let a0):
try container.encode(a0)
case .chatCompletionContentPartTextParam(let a0):
try container.encode(a0)
case .chatCompletionContentPartImageParam(let a0):
Expand All @@ -385,7 +416,6 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}

enum CodingKeys: CodingKey {
case string
case chatCompletionContentPartTextParam
case chatCompletionContentPartImageParam
}
Expand All @@ -409,7 +439,7 @@ public struct ChatQuery: Equatable, Codable, Streamable {

public init(imageUrl: ImageURL) {
self.imageUrl = imageUrl
self.type = "imageUrl"
self.type = "image_url"
}

public struct ImageURL: Codable, Equatable {
Expand All @@ -424,6 +454,12 @@ public struct ChatQuery: Equatable, Codable, Streamable {
self.detail = detail
}

public init(url: Data, detail: Detail) {
self.init(
url: "data:image/jpeg;base64,\(url.base64EncodedString())",
detail: detail)
}

public enum Detail: String, Codable, Equatable, CaseIterable {
case auto
case low
Expand All @@ -438,6 +474,7 @@ public struct ChatQuery: Equatable, Codable, Streamable {
}
}
}
}

internal struct ChatCompletionMessageParam: Codable, Equatable {
typealias Role = ChatQuery.ChatCompletionMessageParam.Role
Expand Down
11 changes: 3 additions & 8 deletions Sources/OpenAI/Public/Models/ChatResult.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,10 @@ extension ChatQuery.ChatCompletionMessageParam.ChatCompletionUserMessageParam.Co
return
} catch {}
do {
let text = try container.decode(ChatCompletionContentPartTextParam.self)
self = .chatCompletionContentPartTextParam(text)
let vision = try container.decode([VisionContent].self)
self = .vision(vision)
return
} catch {}
do {
let image = try container.decode(ChatCompletionContentPartImageParam.self)
self = .chatCompletionContentPartImageParam(image)
return
} catch {}
throw DecodingError.typeMismatch(Self.self, .init(codingPath: [Self.CodingKeys.string, CodingKeys.chatCompletionContentPartTextParam, CodingKeys.chatCompletionContentPartImageParam], debugDescription: "Content: expected String, ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam"))
throw DecodingError.typeMismatch(Self.self, .init(codingPath: [Self.CodingKeys.string, Self.CodingKeys.vision], debugDescription: "Content: expected String || Vision"))
}
}
9 changes: 8 additions & 1 deletion Sources/OpenAI/Public/Models/Models/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@ public extension Model {
// Chat Completion
// GPT-4

/// `gpt-4-turbo`, the latest gpt-4 model with improved instruction following, JSON mode, reproducible outputs, parallel function calling and more. Maximum of 4096 output tokens
/// `gpt-4o`, currently the most advanced, multimodal flagship model that's cheaper and faster than GPT-4 Turbo.
static let gpt4_o = "gpt-4o"

/// `gpt-4-turbo`, The latest GPT-4 Turbo model with vision capabilities. Vision requests can now use JSON mode and function calling and more. Context window: 128,000 tokens
static let gpt4_turbo = "gpt-4-turbo"

/// `gpt-4-turbo`, gpt-4 model with improved instruction following, JSON mode, reproducible outputs, parallel function calling and more. Maximum of 4096 output tokens
@available(*, deprecated, message: "Please upgrade to the newer model")
static let gpt4_turbo_preview = "gpt-4-turbo-preview"

/// `gpt-4-vision-preview`, able to understand images, in addition to all other GPT-4 Turbo capabilities.
Expand Down
8 changes: 6 additions & 2 deletions Tests/OpenAITests/OpenAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,9 @@ class OpenAITests: XCTestCase {
let jsonRequest = JSONRequest<ChatResult>(body: completionQuery, url: URL(string: "http://google.com")!)
let urlRequest = try jsonRequest.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)

XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(configuration.token)")
let unwrappedToken = try XCTUnwrap(configuration.token)

XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(unwrappedToken)")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Content-Type"), "application/json")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "OpenAI-Organization"), configuration.organizationIdentifier)
XCTAssertEqual(urlRequest.timeoutInterval, configuration.timeoutInterval)
Expand All @@ -385,7 +387,9 @@ class OpenAITests: XCTestCase {
let jsonRequest = MultipartFormDataRequest<ChatResult>(body: completionQuery, url: URL(string: "http://google.com")!)
let urlRequest = try jsonRequest.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)

XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(configuration.token)")
let unwrappedToken = try XCTUnwrap(configuration.token)

XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(unwrappedToken)")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "OpenAI-Organization"), configuration.organizationIdentifier)
XCTAssertEqual(urlRequest.timeoutInterval, configuration.timeoutInterval)
}
Expand Down
46 changes: 45 additions & 1 deletion Tests/OpenAITests/OpenAITestsDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,51 @@ class OpenAITestsDecoder: XCTestCase {

XCTAssertEqual(imageQueryAsDict, expectedValueAsDict)
}


func testChatQueryWithVision() async throws {
let chatQuery = ChatQuery(messages: [
// .init(role: .user, content: [
// .chatCompletionContentPartTextParam(.init(text: "What's in this image?")),
// .chatCompletionContentPartImageParam(.init(imageUrl: .init(url: "https://some.url/image.jpeg", detail: .auto)))
// ])!
.user(.init(content: .vision([
.chatCompletionContentPartTextParam(.init(text: "What's in this image?")),
.chatCompletionContentPartImageParam(.init(imageUrl: .init(url: "https://some.url/image.jpeg", detail: .auto)))
])))
], model: Model.gpt4_vision_preview, maxTokens: 300)
let expectedValue = """
{
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://some.url/image.jpeg",
"detail": "auto"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}
"""

// To compare serialized JSONs we first convert them both into NSDictionary which are comparable (unline native swift dictionaries)
let chatQueryAsDict = try jsonDataAsNSDictionary(JSONEncoder().encode(chatQuery))
let expectedValueAsDict = try jsonDataAsNSDictionary(expectedValue.data(using: .utf8)!)

XCTAssertEqual(chatQueryAsDict, expectedValueAsDict)
}

func testChatQueryWithFunctionCall() async throws {
let chatQuery = ChatQuery(
messages: [
Expand Down

0 comments on commit 1ad95dd

Please sign in to comment.