Skip to content

Commit

Permalink
Merge pull request #1 from StanfordBDHG/feat/fog-support
Browse files Browse the repository at this point in the history
Add custom TLS verification support for streaming queries
  • Loading branch information
philippzagar authored Mar 26, 2024
2 parents 5861c59 + a6a9437 commit 4866368
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 85 deletions.
38 changes: 0 additions & 38 deletions .github/ISSUE_TEMPLATE/bug_report.md

This file was deleted.

20 changes: 0 additions & 20 deletions .github/ISSUE_TEMPLATE/feature_request.md

This file was deleted.

13 changes: 0 additions & 13 deletions .github/PULL_REQUEST_TEMPLATE.md

This file was deleted.

10 changes: 9 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
// swift-tools-version: 5.7
// swift-tools-version: 5.9
// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription

let package = Package(
name: "OpenAI",
defaultLocalization: "en",
platforms: [
.iOS(.v17),
.visionOS(.v1),
.tvOS(.v17),
.watchOS(.v10),
.macOS(.v14)
],
products: [
.library(
name: "OpenAI",
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

![logo](https://user-images.githubusercontent.com/1411778/218319355-f56b6bd4-961a-4d8f-82cd-6dbd43111d7f.png)

> [!NOTE]
> This repository is a [StanfordBDHG](https://github.com/StanfordBDHG) fork of the [MacPaw OpenAI project](https://github.com/MacPaw/OpenAI), adding support for:
> - Custom CA certificate verification for HTTPS requests.
> - Support for the visionOS and macOS SDKs.
> - Minor bugfixes and adjustments throughout the package.
___

![Swift Workflow](https://github.com/MacPaw/OpenAI/actions/workflows/swift.yml/badge.svg)
Expand Down
59 changes: 51 additions & 8 deletions Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

//
// OpenAI.swift
//
Expand All @@ -22,18 +23,47 @@ final public class OpenAI: OpenAIProtocol {

/// API host. Set this property if you use some kind of proxy or your own server. Default is api.openai.com
public let host: String
/// API port. Set this property if you use some kind of proxy or your own server. Default is 443
public let port: Int
/// API scheme. Set this property if you use some kind of proxy or your own server. Default is 443
public let scheme: String

/// Default request timeout
public let timeoutInterval: TimeInterval

public init(token: String, organizationIdentifier: String? = nil, host: String = "api.openai.com", port: Int = 443, scheme: String = "https", timeoutInterval: TimeInterval = 60.0) {
/// Custom CA certificate that should be used for the TLS verification.
public let caCertificate: SecCertificate?

/// Expected API domain hostname, set in the HTTP header "Host". Useful if the to be connected IP should verify against a TLS token issued for a certain host domain.
public let expectedHost: String?

/// Create a new ``OpenAI`` instance.
///
/// - Parameters:
/// - token: OpenAI API token passed via the `Bearer` authentication HTTP header.
/// - organizationIdentifier: Optional OpenAI organization identifier.
/// - host: API host that the ``OpenAI`` client should connect to, defaults to `api.openai.com`.
/// - timeoutInterval: The maximum interval that a request is allowed to take until timeout.
/// - caCertificate: Optional custom CA certificate that should be used for the TLS verification. Useful if using a custom root CA certificate to sign the API host.
/// - expectedHost: Optional expected hostname to verify the received TLS token against. Useful for network requests to another domain or IP than the host issued the TLS token.
public init(
token: String,
organizationIdentifier: String? = nil,
host: String = "api.openai.com",
port: Int = 443,
scheme: String = "https",
timeoutInterval: TimeInterval = 60.0,
caCertificate: SecCertificate? = nil,
expectedHost: String? = nil
) {
self.token = token
self.organizationIdentifier = organizationIdentifier
self.host = host
self.port = port
self.scheme = scheme
self.timeoutInterval = timeoutInterval
self.caCertificate = caCertificate
self.expectedHost = expectedHost
}
}

Expand Down Expand Up @@ -123,12 +153,15 @@ final public class OpenAI: OpenAIProtocol {
}

extension OpenAI {

// As non-streaming inference requests are not currently supported by SpeziLLM,
// no need to adjust this function for custom TLS verification (required for Fog LLM functionality)
func performRequest<ResultType: Codable>(request: any URLRequestBuildable, completion: @escaping (Result<ResultType, Error>) -> Void) {
do {
let request = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
timeoutInterval: configuration.timeoutInterval,
expectedHost: configuration.expectedHost)

let task = session.dataTask(with: request) { data, _, error in
if let error = error {
return completion(.failure(error))
Expand All @@ -153,8 +186,9 @@ extension OpenAI {
do {
let request = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
let session = StreamingSession<ResultType>(urlRequest: request)
timeoutInterval: configuration.timeoutInterval,
expectedHost: configuration.expectedHost)
let session = StreamingSession<ResultType>(urlRequest: request, caCertificate: configuration.caCertificate, expectedHost: configuration.expectedHost)
session.onReceiveContent = {_, object in
onResult(.success(object))
}
Expand All @@ -172,11 +206,14 @@ extension OpenAI {
}
}

// As non-streaming inference requests are not currently supported by SpeziLLM,
// no need to adjust this function for custom TLS verification (required for Fog LLM functionality)
func performSpeechRequest(request: any URLRequestBuildable, completion: @escaping (Result<AudioSpeechResult, Error>) -> Void) {
do {
let request = try request.build(token: configuration.token,
organizationIdentifier: configuration.organizationIdentifier,
timeoutInterval: configuration.timeoutInterval)
timeoutInterval: configuration.timeoutInterval,
expectedHost: configuration.expectedHost)

let task = session.dataTask(with: request) { data, _, error in
if let error = error {
Expand All @@ -196,12 +233,18 @@ extension OpenAI {
}

extension OpenAI {

func buildURL(path: String) -> URL {
var components = URLComponents()
components.scheme = configuration.scheme
components.host = configuration.host
components.port = configuration.port

// If IPv6 address, wrap the IP with '[' and ']' as required by RFC 3986: https://datatracker.ietf.org/doc/html/rfc3986
if configuration.host.contains(":") && !configuration.host.hasPrefix("[") && !configuration.host.hasSuffix("]") {
components.host = "[\(configuration.host)]"
} else {
components.host = configuration.host
}

components.path = path
return components.url!
}
Expand Down
6 changes: 5 additions & 1 deletion Sources/OpenAI/Private/JSONRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ final class JSONRequest<ResultType> {

extension JSONRequest: URLRequestBuildable {

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval, expectedHost: String? = nil) throws -> URLRequest {
var request = URLRequest(url: url, timeoutInterval: timeoutInterval)
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
if let organizationIdentifier {
request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization")
}
// Set the expected host on the HTTP request within the `Host` header field, if present
if let expectedHost {
request.setValue(expectedHost, forHTTPHeaderField: "Host")
}
request.httpMethod = method
if let body = body {
request.httpBody = try JSONEncoder().encode(body)
Expand Down
6 changes: 5 additions & 1 deletion Sources/OpenAI/Private/MultipartFormDataRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ final class MultipartFormDataRequest<ResultType> {

extension MultipartFormDataRequest: URLRequestBuildable {

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval, expectedHost: String? = nil) throws -> URLRequest {
var request = URLRequest(url: url)
let boundary: String = UUID().uuidString
request.timeoutInterval = timeoutInterval
Expand All @@ -35,6 +35,10 @@ extension MultipartFormDataRequest: URLRequestBuildable {
if let organizationIdentifier {
request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization")
}
// Set the expected host on the HTTP request within the `Host` header field, if present
if let expectedHost {
request.setValue(expectedHost, forHTTPHeaderField: "Host")
}
request.httpBody = body.encode(boundary: boundary)
return request
}
Expand Down
87 changes: 86 additions & 1 deletion Sources/OpenAI/Private/StreamingSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif
import Security

final class StreamingSession<ResultType: Codable>: NSObject, Identifiable, URLSessionDelegate, URLSessionDataDelegate {

Expand All @@ -29,9 +30,19 @@ final class StreamingSession<ResultType: Codable>: NSObject, Identifiable, URLSe
}()

private var previousChunkBuffer = ""
private let caCertificate: SecCertificate?
private let expectedHost: String?

init(urlRequest: URLRequest) {
/// Create an instance of the `StreamingSession`
///
/// - Parameters:
/// - urlRequest: Base `URLRequest`
/// - caCertificate: The optional, to-be-trusted custom CA certificate.
/// - expectedHost: The optional expected hostname to verify the received TLS token against. Useful for network requests to another domain or IP than the host issued the TLS token (e.g. within a local network with non-public hostnames and requests via IPs)
init(urlRequest: URLRequest, caCertificate: SecCertificate? = nil, expectedHost: String? = nil) {
self.urlRequest = urlRequest
self.caCertificate = caCertificate
self.expectedHost = expectedHost
}

func perform() {
Expand All @@ -44,6 +55,46 @@ final class StreamingSession<ResultType: Codable>: NSObject, Identifiable, URLSe
onComplete?(self, error)
}

/// Handle HTTP 401 and 403 status codes returned by OpenAI API implementations such as Ollama and completes the current request with an error.
func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) {
if let httpResponse = response as? HTTPURLResponse {
// Handle negative HTTP status code returned by various OpenAI API implementations
if httpResponse.statusCode == 401 {
// Propagate the HTTP error up the call stack
onComplete?(
self,
APIErrorResponse(
error: .init(
message: "HTTP 401: Unauthorized",
type: "unauthorized",
param: nil,
code: "401"
)
)
)
completionHandler(.cancel)
return
} else if httpResponse.statusCode == 403 {
// Propagate the HTTP error up the call stack
onComplete?(
self,
APIErrorResponse(
error: .init(
message: "HTTP 403: Forbidden",
type: "forbidden",
param: nil,
code: "403"
)
)
)
completionHandler(.cancel)
return
}
}

completionHandler(.allow)
}

func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
guard let stringContent = String(data: data, encoding: .utf8) else {
onProcessingError?(self, StreamingError.unknownContent)
Expand All @@ -52,6 +103,40 @@ final class StreamingSession<ResultType: Codable>: NSObject, Identifiable, URLSe
processJSON(from: stringContent)
}

/// Handle custom TLS certificate verification of `StreamingSession` requests.
///
/// Uses the `caCertificate` and `expectedHost` parameters of the `StreamingSession` to verify the server's authenticity and establish a secure SSL connection.
func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
) {
guard challenge.protectionSpace.authenticationMethod == NSURLAuthenticationMethodServerTrust,
let serverTrust = challenge.protectionSpace.serverTrust,
let caCertificate, let expectedHost else {
completionHandler(.performDefaultHandling, nil)
return
}

// Set the anchor certificate
let anchorCertificates: [SecCertificate] = [caCertificate]
SecTrustSetAnchorCertificates(serverTrust, anchorCertificates as CFArray)

SecTrustSetAnchorCertificatesOnly(serverTrust, true)

let policy = SecPolicyCreateSSL(true, expectedHost as CFString)
SecTrustSetPolicies(serverTrust, policy)

var error: CFError?
if SecTrustEvaluateWithError(serverTrust, &error) {
// Trust evaluation succeeded, proceed with the connection
completionHandler(.useCredential, URLCredential(trust: serverTrust))
} else {
// Trust evaluation failed, handle the error
print("OpenAI: Trust evaluation failed with error: \(error?.localizedDescription)")
completionHandler(.cancelAuthenticationChallenge, nil)
}
}
}

extension StreamingSession {
Expand Down
2 changes: 1 addition & 1 deletion Sources/OpenAI/Private/URLRequestBuildable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ protocol URLRequestBuildable {

associatedtype ResultType

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest
func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval, expectedHost: String?) throws -> URLRequest
}
Loading

0 comments on commit 4866368

Please sign in to comment.