Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for more Assistants API features #3

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions Demo/DemoChat/Sources/AssistantStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ public final class AssistantStore: ObservableObject {
// MARK: Models

@MainActor
func createAssistant(name: String, description: String, instructions: String, codeInterpreter: Bool, retrievel: Bool, fileIds: [String]? = nil) async -> String? {
func createAssistant(name: String, description: String, instructions: String, codeInterpreter: Bool, retrieval: Bool, functions: [FunctionDeclaration], fileIds: [String]? = nil) async -> String? {
do {
let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrievel)
let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrieval, functions: functions)
let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools:tools, fileIds: fileIds)
let response = try await openAIClient.assistants(query: query, method: "POST", after: nil)
let response = try await openAIClient.assistantCreate(query: query)

// Refresh assistants with one just created (or modified)
let _ = await getAssistants()
Expand All @@ -47,11 +47,11 @@ public final class AssistantStore: ObservableObject {
}

@MainActor
func modifyAssistant(asstId: String, name: String, description: String, instructions: String, codeInterpreter: Bool, retrievel: Bool, fileIds: [String]? = nil) async -> String? {
func modifyAssistant(asstId: String, name: String, description: String, instructions: String, codeInterpreter: Bool, retrieval: Bool, functions: [FunctionDeclaration], fileIds: [String]? = nil) async -> String? {
do {
let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrievel)
let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrieval, functions: functions)
let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools:tools, fileIds: fileIds)
let response = try await openAIClient.assistantModify(query: query, asstId: asstId)
let response = try await openAIClient.assistantModify(query: query, assistantId: asstId)

// Returns assistantId
return response.id
Expand All @@ -66,15 +66,24 @@ public final class AssistantStore: ObservableObject {
@MainActor
func getAssistants(limit: Int = 20, after: String? = nil) async -> [Assistant] {
do {
let response = try await openAIClient.assistants(query: nil, method: "GET", after: after)
let response = try await openAIClient.assistants(after: after)

var assistants = [Assistant]()
for result in response.data ?? [] {
let codeInterpreter = result.tools?.filter { $0.toolType == "code_interpreter" }.first != nil
let retrieval = result.tools?.filter { $0.toolType == "retrieval" }.first != nil
let tools = result.tools ?? []
let codeInterpreter = tools.contains { $0 == .codeInterpreter }
let retrieval = tools.contains { $0 == .retrieval }
let functions = tools.compactMap {
switch $0 {
case let .function(declaration):
return declaration
default:
return nil
}
}
let fileIds = result.fileIds ?? []

assistants.append(Assistant(id: result.id, name: result.name, description: result.description, instructions: result.instructions, codeInterpreter: codeInterpreter, retrieval: retrieval, fileIds: fileIds))
assistants.append(Assistant(id: result.id, name: result.name ?? "", description: result.description, instructions: result.instructions, codeInterpreter: codeInterpreter, retrieval: retrieval, fileIds: fileIds, functions: functions))
}
if after == nil {
availableAssistants = assistants
Expand Down Expand Up @@ -112,14 +121,14 @@ public final class AssistantStore: ObservableObject {
}
}

func createToolsArray(codeInterpreter: Bool, retrieval: Bool) -> [Tool] {
func createToolsArray(codeInterpreter: Bool, retrieval: Bool, functions: [FunctionDeclaration]) -> [Tool] {
var tools = [Tool]()
if codeInterpreter {
tools.append(Tool(toolType: "code_interpreter"))
tools.append(.codeInterpreter)
}
if retrieval {
tools.append(Tool(toolType: "retrieval"))
tools.append(.retrieval)
}
return tools
return tools + functions.map { .function($0) }
}
}
125 changes: 80 additions & 45 deletions Demo/DemoChat/Sources/ChatStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public final class ChatStore: ObservableObject {
conversations[conversationIndex].messages.append(localMessage)

do {
let threadsQuery = ThreadsQuery(messages: [Chat(role: message.role, content: message.content)])
let threadsQuery = ThreadsQuery(messages: [MessageQuery(role: .user, content: message.content)])
let threadsResult = try await openAIClient.threads(query: threadsQuery)

guard let currentAssistantId = conversations[conversationIndex].assistantId else { return print("No assistant selected.")}
Expand All @@ -117,7 +117,7 @@ public final class ChatStore: ObservableObject {
guard let currentThreadId else { return print("No thread to add message to.")}

let _ = try await openAIClient.threadsAddMessage(threadId: currentThreadId,
query: ThreadAddMessageQuery(role: message.role.rawValue, content: message.content))
query: MessageQuery(role: message.role, content: message.content))

guard let currentAssistantId = conversations[conversationIndex].assistantId else { return print("No assistant selected.")}

Expand Down Expand Up @@ -150,7 +150,7 @@ public final class ChatStore: ObservableObject {
return
}

let weatherFunction = ChatFunctionDeclaration(
let weatherFunction = FunctionDeclaration(
name: "getWeatherData",
description: "Get the current weather in a given location",
parameters: .init(
Expand Down Expand Up @@ -243,19 +243,19 @@ public final class ChatStore: ObservableObject {
let result = try await openAIClient.runRetrieve(threadId: currentThreadId ?? "", runId: currentRunId ?? "")

// TESTING RETRIEVAL OF RUN STEPS
handleRunRetrieveSteps()
try await handleRunRetrieveSteps()

switch result.status {
// Get threadsMesages.
case "completed":
case .completed:
handleCompleted()
break
case "failed":
case .failed:
// Handle more gracefully with a popup dialog or failure indicator
await MainActor.run {
self.stopPolling()
}
break
case .requiresAction:
try await handleRequiresAction(result)
default:
// Handle additional statuses "requires_action", "queued" ?, "expired", "cancelled"
// https://platform.openai.com/docs/assistants/how-it-works/runs-and-run-steps
Expand Down Expand Up @@ -287,7 +287,7 @@ public final class ChatStore: ObservableObject {
for innerItem in item.content {
let message = Message(
id: item.id,
role: Chat.Role(rawValue: role) ?? .user,
role: role,
content: innerItem.text?.value ?? "",
createdAt: Date(),
isLocal: false // Messages from the server are not local
Expand All @@ -308,54 +308,89 @@ public final class ChatStore: ObservableObject {
}
}

// Store the function call as a message and submit tool outputs with a simple done message.
private func handleRequiresAction(_ result: RunResult) async throws {
guard let currentThreadId, let currentRunId else {
return
}

guard let toolCalls = result.requiredAction?.submitToolOutputs.toolCalls else {
return
}

var toolOutputs = [RunToolOutputsQuery.ToolOutput]()

for toolCall in toolCalls {
let msgContent = "function\nname: \(toolCall.function.name ?? "")\nargs: \(toolCall.function.arguments ?? "{}")"

let runStepMessage = Message(
id: toolCall.id,
role: .assistant,
content: msgContent,
createdAt: Date(),
isRunStep: true
)
await addOrUpdateRunStepMessage(runStepMessage)

// Just return a generic "Done" output for now
toolOutputs.append(.init(toolCallId: toolCall.id, output: "Done"))
}

let query = RunToolOutputsQuery(toolOutputs: toolOutputs)
_ = try await openAIClient.runSubmitToolOutputs(threadId: currentThreadId, runId: currentRunId, query: query)
}

// The run retrieval steps are fetched in a separate task. This request is fetched, checking for new run steps, each time the run is fetched.
private func handleRunRetrieveSteps() {
Task {
guard let conversationIndex = conversations.firstIndex(where: { $0.id == currentConversationId }) else {
return
}
var before: String?
private func handleRunRetrieveSteps() async throws {
var before: String?
// if let lastRunStepMessage = self.conversations[conversationIndex].messages.last(where: { $0.isRunStep == true }) {
// before = lastRunStepMessage.id
// }

let stepsResult = try await openAIClient.runRetrieveSteps(threadId: currentThreadId ?? "", runId: currentRunId ?? "", before: before)
let stepsResult = try await openAIClient.runRetrieveSteps(threadId: currentThreadId ?? "", runId: currentRunId ?? "", before: before)

for item in stepsResult.data.reversed() {
let toolCalls = item.stepDetails.toolCalls?.reversed() ?? []
for item in stepsResult.data.reversed() {
let toolCalls = item.stepDetails.toolCalls?.reversed() ?? []

for step in toolCalls {
// TODO: Depending on the type of tool tha is used we can add additional information here
// ie: if its a retrieval: add file information, code_interpreter: add inputs and outputs info, or function: add arguemts and additional info.
let msgContent: String
switch step.type {
case "retrieval":
msgContent = "RUN STEP: \(step.type)"
for step in toolCalls {
// TODO: Depending on the type of tool tha is used we can add additional information here
// ie: if its a retrieval: add file information, code_interpreter: add inputs and outputs info, or function: add arguemts and additional info.
let msgContent: String
switch step.type {
case .retrieval:
msgContent = "RUN STEP: \(step.type)"

case "code_interpreter":
msgContent = "code_interpreter\ninput:\n\(step.code?.input ?? "")\noutputs: \(step.code?.outputs?.first?.logs ?? "")"
case .codeInterpreter:
let code = step.codeInterpreter
msgContent = "code_interpreter\ninput:\n\(code?.input ?? "")\noutputs: \(code?.outputs?.first?.logs ?? "")"

default:
msgContent = "RUN STEP: \(step.type)"
case .function:
msgContent = "function\nname: \(step.function?.name ?? "")\nargs: \(step.function?.arguments ?? "{}")"

}
let runStepMessage = Message(
id: step.id,
role: .assistant,
content: msgContent,
createdAt: Date(),
isRunStep: true
)
await MainActor.run {
if let localMessageIndex = self.conversations[conversationIndex].messages.firstIndex(where: { $0.isRunStep == true && $0.id == step.id }) {
self.conversations[conversationIndex].messages[localMessageIndex] = runStepMessage
}
else {
self.conversations[conversationIndex].messages.append(runStepMessage)
}
}
}
let runStepMessage = Message(
id: step.id,
role: .assistant,
content: msgContent,
createdAt: Date(),
isRunStep: true
)
await addOrUpdateRunStepMessage(runStepMessage)
}
}
}

@MainActor
private func addOrUpdateRunStepMessage(_ message: Message) async {
guard let conversationIndex = conversations.firstIndex(where: { $0.id == currentConversationId }) else {
return
}

if let localMessageIndex = conversations[conversationIndex].messages.firstIndex(where: { $0.isRunStep == true && $0.id == message.id }) {
conversations[conversationIndex].messages[localMessageIndex] = message
}
else {
conversations[conversationIndex].messages.append(message)
}
}
}
13 changes: 12 additions & 1 deletion Demo/DemoChat/Sources/Models/Assistant.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
//

import Foundation
import OpenAI

struct Assistant: Hashable {
init(id: String, name: String, description: String? = nil, instructions: String? = nil, codeInterpreter: Bool, retrieval: Bool, fileIds: [String]? = nil) {
init(id: String, name: String, description: String? = nil, instructions: String? = nil, codeInterpreter: Bool, retrieval: Bool, fileIds: [String]? = nil, functions: [FunctionDeclaration] = []) {
self.id = id
self.name = name
self.description = description
self.instructions = instructions
self.codeInterpreter = codeInterpreter
self.retrieval = retrieval
self.fileIds = fileIds
self.functions = functions
}

typealias ID = String
Expand All @@ -27,7 +29,16 @@ struct Assistant: Hashable {
let fileIds: [String]?
var codeInterpreter: Bool
var retrieval: Bool
var functions: [FunctionDeclaration]
}


extension Assistant: Equatable, Identifiable {}

extension FunctionDeclaration: Hashable {
public func hash(into hasher: inout Hasher) {
hasher.combine(name)
hasher.combine(description)
hasher.combine(parameters)
}
}
Loading