From 4f0342b481f1e0abc3258413f4eb5080210016d7 Mon Sep 17 00:00:00 2001 From: Benjamin Truitt Date: Sat, 29 Jul 2023 20:46:15 -0500 Subject: [PATCH 1/8] Add support for arbitrary HTTP headers to enable proxies like Helicone. Refactor URLRequestBuildable and its implementations to separate configuration concerns from URL building concerns. --- Sources/OpenAI/OpenAI.swift | 39 ++++++++++++------- Sources/OpenAI/Private/BaseRequest.swift | 36 +++++++++++++++++ Sources/OpenAI/Private/Empty.swift | 14 +++++++ Sources/OpenAI/Private/JSONRequest.swift | 39 +++++++------------ .../Private/MultipartFormDataRequest.swift | 39 +++++++------------ Sources/OpenAI/Private/StreamingSession.swift | 20 +++++++++- .../OpenAI/Private/URLRequestBuildable.swift | 2 +- 7 files changed, 123 insertions(+), 66 deletions(-) create mode 100644 Sources/OpenAI/Private/BaseRequest.swift create mode 100644 Sources/OpenAI/Private/Empty.swift diff --git a/Sources/OpenAI/OpenAI.swift b/Sources/OpenAI/OpenAI.swift index 5e88ffcf..26484d13 100644 --- a/Sources/OpenAI/OpenAI.swift +++ b/Sources/OpenAI/OpenAI.swift @@ -26,11 +26,20 @@ final public class OpenAI: OpenAIProtocol { /// Default request timeout public let timeoutInterval: TimeInterval - public init(token: String, organizationIdentifier: String? = nil, host: String = "api.openai.com", timeoutInterval: TimeInterval = 60.0) { + /// A dictionary of HTTP headers to include in requests to OpenAI (or any proxy server the request may be sent to) + public var additionalHeaders: [String: String] + + public init(token: String, organizationIdentifier: String? = nil, host: String = "api.openai.com", timeoutInterval: TimeInterval = 60.0, additionalHeaders: [String: String]? = nil) { self.token = token self.organizationIdentifier = organizationIdentifier self.host = host self.timeoutInterval = timeoutInterval + + self.additionalHeaders = additionalHeaders ?? [:] + self.additionalHeaders["Authorization"] = "Bearer \(token)" + if let organizationIdentifier = organizationIdentifier { + self.additionalHeaders["OpenAI-Organization"] = organizationIdentifier + } } } @@ -57,51 +66,51 @@ final public class OpenAI: OpenAIProtocol { } public func completions(query: CompletionsQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .completions)), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .completions), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } public func completionsStream(query: CompletionsQuery, onResult: @escaping (Result) -> Void, completion: ((Error?) -> Void)?) { - performSteamingRequest(request: JSONRequest(body: query.makeStreamable(), url: buildURL(path: .completions)), onResult: onResult, completion: completion) + performSteamingRequest(request: JSONRequest(body: query.makeStreamable(), url: buildURL(path: .completions), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), onResult: onResult, completion: completion) } public func images(query: ImagesQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .images)), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .images), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } public func embeddings(query: EmbeddingsQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .embeddings)), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .embeddings), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } public func chats(query: ChatQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .chats)), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .chats), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } public func chatsStream(query: ChatQuery, onResult: @escaping (Result) -> Void, completion: ((Error?) -> Void)?) { - performSteamingRequest(request: JSONRequest(body: query.makeStreamable(), url: buildURL(path: .chats)), onResult: onResult, completion: completion) + performSteamingRequest(request: JSONRequest(body: query.makeStreamable(), url: buildURL(path: .chats), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), onResult: onResult, completion: completion) } public func edits(query: EditsQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .edits)), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .edits), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } public func model(query: ModelQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(url: buildURL(path: .models.withPath(query.model)), method: "GET"), completion: completion) + performRequest(request: JSONRequest(url: buildURL(path: .models.withPath(query.model)), method: "GET", headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } public func models(completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(url: buildURL(path: .models), method: "GET"), completion: completion) + performRequest(request: JSONRequest(url: buildURL(path: .models), method: "GET", headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } public func moderations(query: ModerationsQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .moderations)), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .moderations), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } public func audioTranscriptions(query: AudioTranscriptionQuery, completion: @escaping (Result) -> Void) { - performRequest(request: MultipartFormDataRequest(body: query, url: buildURL(path: .audioTranscriptions)), completion: completion) + performRequest(request: MultipartFormDataRequest(body: query, url: buildURL(path: .audioTranscriptions), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } public func audioTranslations(query: AudioTranslationQuery, completion: @escaping (Result) -> Void) { - performRequest(request: MultipartFormDataRequest(body: query, url: buildURL(path: .audioTranslations)), completion: completion) + performRequest(request: MultipartFormDataRequest(body: query, url: buildURL(path: .audioTranslations), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) } } @@ -109,7 +118,7 @@ extension OpenAI { func performRequest(request: any URLRequestBuildable, completion: @escaping (Result) -> Void) { do { - let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval) + let request = try request.build() let task = session.dataTask(with: request) { data, _, error in if let error = error { completion(.failure(error)) @@ -145,7 +154,7 @@ extension OpenAI { func performSteamingRequest(request: any URLRequestBuildable, onResult: @escaping (Result) -> Void, completion: ((Error?) -> Void)?) { do { - let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval) + let request = try request.build() let session = StreamingSession(urlRequest: request) session.onReceiveContent = {_, object in onResult(.success(object)) diff --git a/Sources/OpenAI/Private/BaseRequest.swift b/Sources/OpenAI/Private/BaseRequest.swift new file mode 100644 index 00000000..bc04fc15 --- /dev/null +++ b/Sources/OpenAI/Private/BaseRequest.swift @@ -0,0 +1,36 @@ +// +// BaseRequest.swift +// +// +// Created by Benjamin Truitt on 7/28/23. +// + +import Foundation +#if canImport(FoundationNetworking) +import FoundationNetworking +#endif + +protocol RequestBuildable { + var url: URL { get } + var headers: [String: String] { get } + var method: String { get } + var timeoutInterval: TimeInterval { get } + + func getBody() throws -> Data? +} + +protocol BaseRequest: RequestBuildable { + func build() throws -> URLRequest +} + +extension BaseRequest { + func build() throws -> URLRequest { + var request = URLRequest(url: url, timeoutInterval: timeoutInterval) + + request.httpMethod = method + request.httpBody = try getBody() + request.allHTTPHeaderFields = headers + + return request + } +} diff --git a/Sources/OpenAI/Private/Empty.swift b/Sources/OpenAI/Private/Empty.swift new file mode 100644 index 00000000..1dcdc608 --- /dev/null +++ b/Sources/OpenAI/Private/Empty.swift @@ -0,0 +1,14 @@ +// +// Empty.swift +// +// +// Created by Benjamin Truitt on 7/29/23. +// + +import Foundation + +struct Empty: Encodable { + func encode(to encoder: Encoder) throws { + // Do nothing + } +} diff --git a/Sources/OpenAI/Private/JSONRequest.swift b/Sources/OpenAI/Private/JSONRequest.swift index 526f95c9..8ac344d6 100644 --- a/Sources/OpenAI/Private/JSONRequest.swift +++ b/Sources/OpenAI/Private/JSONRequest.swift @@ -5,37 +5,26 @@ // Created by Sergii Kryvoblotskyi on 12/19/22. // + import Foundation -#if canImport(FoundationNetworking) -import FoundationNetworking -#endif -final class JSONRequest { - - let body: Codable? - let url: URL - let method: String - - init(body: Codable? = nil, url: URL, method: String = "POST") { +struct JSONRequest: BaseRequest, URLRequestBuildable { + var body: BodyType? + var url: URL + var method: String = "POST" + var headers: [String: String] + var timeoutInterval: TimeInterval + + init(body: BodyType? = nil, url: URL, method: String = "POST", headers: [String: String]?, timeoutInterval: TimeInterval) { self.body = body self.url = url self.method = method + self.headers = headers ?? [:] + self.headers["Content-Type"] = "application/json" + self.timeoutInterval = timeoutInterval } -} - -extension JSONRequest: URLRequestBuildable { - func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest { - var request = URLRequest(url: url, timeoutInterval: timeoutInterval) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization") - if let organizationIdentifier { - request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization") - } - request.httpMethod = method - if let body = body { - request.httpBody = try JSONEncoder().encode(body) - } - return request + func getBody() throws -> Data? { + return try body.map { try JSONEncoder().encode($0) } } } diff --git a/Sources/OpenAI/Private/MultipartFormDataRequest.swift b/Sources/OpenAI/Private/MultipartFormDataRequest.swift index 13764a58..cbbae3d9 100644 --- a/Sources/OpenAI/Private/MultipartFormDataRequest.swift +++ b/Sources/OpenAI/Private/MultipartFormDataRequest.swift @@ -6,36 +6,27 @@ // import Foundation -#if canImport(FoundationNetworking) -import FoundationNetworking -#endif -final class MultipartFormDataRequest { +struct MultipartFormDataRequest: BaseRequest, URLRequestBuildable { + var body: MultipartFormDataBodyEncodable? + var url: URL + var headers: [String: String] + var method: String = "POST" + var timeoutInterval: TimeInterval + var boundary: String = UUID().uuidString - let body: MultipartFormDataBodyEncodable - let url: URL - let method: String - - init(body: MultipartFormDataBodyEncodable, url: URL, method: String = "POST") { + init(body: BodyType?, url: URL, method: String = "POST", headers: [String: String]?, timeoutInterval: TimeInterval) { self.body = body self.url = url self.method = method + self.headers = headers ?? [:] + self.headers["Content-Type"] = "multipart/form-data; boundary=\(boundary)" + self.timeoutInterval = timeoutInterval } -} - -extension MultipartFormDataRequest: URLRequestBuildable { - func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest { - var request = URLRequest(url: url) - let boundary: String = UUID().uuidString - request.timeoutInterval = timeoutInterval - request.httpMethod = method - request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization") - request.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type") - if let organizationIdentifier { - request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization") - } - request.httpBody = body.encode(boundary: boundary) - return request + func getBody() throws -> Data? { + return body?.encode(boundary: boundary) } } + + diff --git a/Sources/OpenAI/Private/StreamingSession.swift b/Sources/OpenAI/Private/StreamingSession.swift index 77b7ba0e..1a0f90f9 100644 --- a/Sources/OpenAI/Private/StreamingSession.swift +++ b/Sources/OpenAI/Private/StreamingSession.swift @@ -28,6 +28,8 @@ final class StreamingSession: NSObject, Identifiable, URLSe return session }() + private var accumulatedData = "" // This will hold incomplete JSON data + init(urlRequest: URLRequest) { self.urlRequest = urlRequest } @@ -47,7 +49,12 @@ final class StreamingSession: NSObject, Identifiable, URLSe onProcessingError?(self, StreamingError.unknownContent) return } - let jsonObjects = stringContent + + print(stringContent) + + accumulatedData.append(stringContent) + + let jsonObjects = accumulatedData .components(separatedBy: "data:") .filter { $0.isEmpty == false } .map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } @@ -67,8 +74,19 @@ final class StreamingSession: NSObject, Identifiable, URLSe do { let decoder = JSONDecoder() let object = try decoder.decode(ResultType.self, from: jsonData) + accumulatedData = "" // Reset accumulatedData, since the jsonData is complete and valid. onReceiveContent?(self, object) + } catch let error as DecodingError { + if case .dataCorrupted = error { + // Invalid JSON - this isn't an error condition, we simply don't have all the data yet, so we'll wait for + // this function to be called again, and will append the data we subsequently receive to the accumulatedData + // variable so that we can try to process that longer, more complete string at that point. + } else { + // Handle other decoding errors + apiError = error + } } catch { + // Handle non-DecodingErrors apiError = error } diff --git a/Sources/OpenAI/Private/URLRequestBuildable.swift b/Sources/OpenAI/Private/URLRequestBuildable.swift index a10f3109..d362eefe 100644 --- a/Sources/OpenAI/Private/URLRequestBuildable.swift +++ b/Sources/OpenAI/Private/URLRequestBuildable.swift @@ -14,5 +14,5 @@ protocol URLRequestBuildable { associatedtype ResultType - func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest + func build() throws -> URLRequest } From b882637a07194a5777f5004fec1cd71d65c24eae Mon Sep 17 00:00:00 2001 From: Benjamin Truitt Date: Sat, 29 Jul 2023 22:16:18 -0500 Subject: [PATCH 2/8] Separate the client configuration and request headers construction so that the code better adheres to the Single Responsibility Principle. --- Sources/OpenAI/OpenAI.swift | 42 +++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/Sources/OpenAI/OpenAI.swift b/Sources/OpenAI/OpenAI.swift index 26484d13..c5a902a5 100644 --- a/Sources/OpenAI/OpenAI.swift +++ b/Sources/OpenAI/OpenAI.swift @@ -33,13 +33,8 @@ final public class OpenAI: OpenAIProtocol { self.token = token self.organizationIdentifier = organizationIdentifier self.host = host - self.timeoutInterval = timeoutInterval - + self.timeoutInterval = timeoutInterval self.additionalHeaders = additionalHeaders ?? [:] - self.additionalHeaders["Authorization"] = "Bearer \(token)" - if let organizationIdentifier = organizationIdentifier { - self.additionalHeaders["OpenAI-Organization"] = organizationIdentifier - } } } @@ -66,51 +61,62 @@ final public class OpenAI: OpenAIProtocol { } public func completions(query: CompletionsQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .completions), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .completions), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) } public func completionsStream(query: CompletionsQuery, onResult: @escaping (Result) -> Void, completion: ((Error?) -> Void)?) { - performSteamingRequest(request: JSONRequest(body: query.makeStreamable(), url: buildURL(path: .completions), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), onResult: onResult, completion: completion) + performSteamingRequest(request: JSONRequest(body: query.makeStreamable(), url: buildURL(path: .completions), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), onResult: onResult, completion: completion) } public func images(query: ImagesQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .images), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .images), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) } public func embeddings(query: EmbeddingsQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .embeddings), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .embeddings), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) } public func chats(query: ChatQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .chats), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .chats), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) } public func chatsStream(query: ChatQuery, onResult: @escaping (Result) -> Void, completion: ((Error?) -> Void)?) { - performSteamingRequest(request: JSONRequest(body: query.makeStreamable(), url: buildURL(path: .chats), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), onResult: onResult, completion: completion) + performSteamingRequest(request: JSONRequest(body: query.makeStreamable(), url: buildURL(path: .chats), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), onResult: onResult, completion: completion) } public func edits(query: EditsQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .edits), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .edits), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) } public func model(query: ModelQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(url: buildURL(path: .models.withPath(query.model)), method: "GET", headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: JSONRequest(url: buildURL(path: .models.withPath(query.model)), method: "GET", headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) } public func models(completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(url: buildURL(path: .models), method: "GET", headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: JSONRequest(url: buildURL(path: .models), method: "GET", headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) } public func moderations(query: ModerationsQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .moderations), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: JSONRequest(body: query, url: buildURL(path: .moderations), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) } public func audioTranscriptions(query: AudioTranscriptionQuery, completion: @escaping (Result) -> Void) { - performRequest(request: MultipartFormDataRequest(body: query, url: buildURL(path: .audioTranscriptions), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: MultipartFormDataRequest(body: query, url: buildURL(path: .audioTranscriptions), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) } public func audioTranslations(query: AudioTranslationQuery, completion: @escaping (Result) -> Void) { - performRequest(request: MultipartFormDataRequest(body: query, url: buildURL(path: .audioTranslations), headers: configuration.additionalHeaders, timeoutInterval: configuration.timeoutInterval), completion: completion) + performRequest(request: MultipartFormDataRequest(body: query, url: buildURL(path: .audioTranslations), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion) + } +} + +extension OpenAI { + func generateHeaders() -> [String: String] { + var headers = configuration.additionalHeaders + headers["Authorization"] = "Bearer \(configuration.token)" + if let organizationIdentifier = configuration.organizationIdentifier { + headers["OpenAI-Organization"] = organizationIdentifier + } + return headers } } From 2334b8cc30a1511cad4e02ea23cf5b058f3a23de Mon Sep 17 00:00:00 2001 From: Benjamin Truitt Date: Sat, 29 Jul 2023 22:24:32 -0500 Subject: [PATCH 3/8] Mark generateHeaders() private to limit visibility, since it should only be used by the OpenAI class. --- Sources/OpenAI/OpenAI.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/OpenAI/OpenAI.swift b/Sources/OpenAI/OpenAI.swift index c5a902a5..8d707454 100644 --- a/Sources/OpenAI/OpenAI.swift +++ b/Sources/OpenAI/OpenAI.swift @@ -33,7 +33,7 @@ final public class OpenAI: OpenAIProtocol { self.token = token self.organizationIdentifier = organizationIdentifier self.host = host - self.timeoutInterval = timeoutInterval + self.timeoutInterval = timeoutInterval self.additionalHeaders = additionalHeaders ?? [:] } } @@ -110,7 +110,7 @@ final public class OpenAI: OpenAIProtocol { } extension OpenAI { - func generateHeaders() -> [String: String] { + private func generateHeaders() -> [String: String] { var headers = configuration.additionalHeaders headers["Authorization"] = "Bearer \(configuration.token)" if let organizationIdentifier = configuration.organizationIdentifier { From 31532ccf1f309254bb8b686849647aa08fd9c890 Mon Sep 17 00:00:00 2001 From: Benjamin Truitt Date: Sat, 29 Jul 2023 22:50:16 -0500 Subject: [PATCH 4/8] Change private to internal for generateHeaders() visbility modifier to enable future extension and/or testing. --- Sources/OpenAI/OpenAI.swift | 33 ++------------- .../Private/URLSessionDataTaskManager.swift | 40 +++++++++++++++++++ 2 files changed, 44 insertions(+), 29 deletions(-) create mode 100644 Sources/OpenAI/Private/URLSessionDataTaskManager.swift diff --git a/Sources/OpenAI/OpenAI.swift b/Sources/OpenAI/OpenAI.swift index 8d707454..c96bcb9f 100644 --- a/Sources/OpenAI/OpenAI.swift +++ b/Sources/OpenAI/OpenAI.swift @@ -39,6 +39,7 @@ final public class OpenAI: OpenAIProtocol { } private let session: URLSessionProtocol + private let dataTaskManager: URLSessionDataTaskManager private var streamingSessions: [NSObject] = [] public let configuration: Configuration @@ -54,6 +55,7 @@ final public class OpenAI: OpenAIProtocol { init(configuration: Configuration, session: URLSessionProtocol) { self.configuration = configuration self.session = session + self.dataTaskManager = URLSessionDataTaskManager(session: session) } public convenience init(configuration: Configuration, session: URLSession = URLSession.shared) { @@ -110,7 +112,7 @@ final public class OpenAI: OpenAIProtocol { } extension OpenAI { - private func generateHeaders() -> [String: String] { + internal func generateHeaders() -> [String: String] { var headers = configuration.additionalHeaders headers["Authorization"] = "Bearer \(configuration.token)" if let organizationIdentifier = configuration.organizationIdentifier { @@ -125,34 +127,7 @@ extension OpenAI { func performRequest(request: any URLRequestBuildable, completion: @escaping (Result) -> Void) { do { let request = try request.build() - let task = session.dataTask(with: request) { data, _, error in - if let error = error { - completion(.failure(error)) - return - } - guard let data = data else { - completion(.failure(OpenAIError.emptyData)) - return - } - - var apiError: Error? = nil - do { - let decoded = try JSONDecoder().decode(ResultType.self, from: data) - completion(.success(decoded)) - } catch { - apiError = error - } - - if let apiError = apiError { - do { - let decoded = try JSONDecoder().decode(APIErrorResponse.self, from: data) - completion(.failure(decoded)) - } catch { - completion(.failure(apiError)) - } - } - } - task.resume() + dataTaskManager.performDataTask(with: request, completion: completion) } catch { completion(.failure(error)) } diff --git a/Sources/OpenAI/Private/URLSessionDataTaskManager.swift b/Sources/OpenAI/Private/URLSessionDataTaskManager.swift new file mode 100644 index 00000000..d13d8662 --- /dev/null +++ b/Sources/OpenAI/Private/URLSessionDataTaskManager.swift @@ -0,0 +1,40 @@ +// +// URLSessionDataTaskManager.swift +// +// +// Created by Benjamin Truitt on 7/29/23. +// + +import Foundation + +public class URLSessionDataTaskManager { + + private var session: URLSessionProtocol + + init(session: URLSessionProtocol) { + self.session = session + } + + func performDataTask(with request: URLRequest, completion: @escaping (Result) -> Void) { + let task = session.dataTask(with: request) { (data: Data?, response: URLResponse?, error: Error?) in + if let error = error { + completion(.failure(error)) + } else if let data = data { + do { + let decoded = try JSONDecoder().decode(ResultType.self, from: data) + completion(.success(decoded)) + } catch { + do { + let decoded = try JSONDecoder().decode(APIErrorResponse.self, from: data) + completion(.failure(decoded)) + } catch let decodingError { + completion(.failure(decodingError)) + } + } + } else { + completion(.failure(OpenAIError.emptyData)) + } + } + task.resume() + } +} From bafdc5390ce01dd36af452b71612bce25c51b487 Mon Sep 17 00:00:00 2001 From: Benjamin Truitt Date: Sun, 30 Jul 2023 00:03:57 -0500 Subject: [PATCH 5/8] Remove print() statement accidentally left in from debugging. --- Sources/OpenAI/Private/StreamingSession.swift | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Sources/OpenAI/Private/StreamingSession.swift b/Sources/OpenAI/Private/StreamingSession.swift index 1a0f90f9..1c919724 100644 --- a/Sources/OpenAI/Private/StreamingSession.swift +++ b/Sources/OpenAI/Private/StreamingSession.swift @@ -48,9 +48,7 @@ final class StreamingSession: NSObject, Identifiable, URLSe guard let stringContent = String(data: data, encoding: .utf8) else { onProcessingError?(self, StreamingError.unknownContent) return - } - - print(stringContent) + } accumulatedData.append(stringContent) From e2d5dfadf6108d27d8222bfc208c74071677b14b Mon Sep 17 00:00:00 2001 From: Benjamin Truitt Date: Sun, 6 Aug 2023 22:34:14 -0500 Subject: [PATCH 6/8] Refactor to reduce cognitive complexity. --- Sources/OpenAI/Private/StreamingSession.swift | 93 ++++++++++++------- 1 file changed, 57 insertions(+), 36 deletions(-) diff --git a/Sources/OpenAI/Private/StreamingSession.swift b/Sources/OpenAI/Private/StreamingSession.swift index 1c919724..57408610 100644 --- a/Sources/OpenAI/Private/StreamingSession.swift +++ b/Sources/OpenAI/Private/StreamingSession.swift @@ -50,52 +50,73 @@ final class StreamingSession: NSObject, Identifiable, URLSe return } + // As data comes in, we process it by trying to decode to JSON. Since we may not have yet received a fully valid + // JSON string yet, we keep appending content if we're unable to decode due to a "dataCorrupted" error, trying again + // on the next pass in case that new data has completed the string to create valid (parsable) JSON, and so on. + // Note that while OpenAI's API doesn't appear to return partial string fragments, proxy service such as Helicone + // do seem to do so, making this logic necessary. accumulatedData.append(stringContent) - + processAccumulatedData() + } + + func processAccumulatedData() { let jsonObjects = accumulatedData .components(separatedBy: "data:") .filter { $0.isEmpty == false } .map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } - guard jsonObjects.isEmpty == false, jsonObjects.first != streamingCompletionMarker else { + guard !jsonObjects.isEmpty, jsonObjects.first != streamingCompletionMarker else { return } - jsonObjects.forEach { jsonContent in - guard jsonContent != streamingCompletionMarker else { - return - } - guard let jsonData = jsonContent.data(using: .utf8) else { - onProcessingError?(self, StreamingError.unknownContent) - return - } - - var apiError: Error? = nil - do { - let decoder = JSONDecoder() - let object = try decoder.decode(ResultType.self, from: jsonData) - accumulatedData = "" // Reset accumulatedData, since the jsonData is complete and valid. + jsonObjects.forEach { jsonContent in + processJsonContent(jsonContent: jsonContent) + } + } + + private func processJsonContent(jsonContent: String) { + guard jsonContent != streamingCompletionMarker else { + return + } + + guard let jsonData = jsonContent.data(using: .utf8) else { + onProcessingError?(self, StreamingError.unknownContent) + return + } + + do { + if let object = try decodeResultType(from: jsonData) { + accumulatedData = "" // Successfully decoded, so reset accumulatedData. onReceiveContent?(self, object) - } catch let error as DecodingError { - if case .dataCorrupted = error { - // Invalid JSON - this isn't an error condition, we simply don't have all the data yet, so we'll wait for - // this function to be called again, and will append the data we subsequently receive to the accumulatedData - // variable so that we can try to process that longer, more complete string at that point. - } else { - // Handle other decoding errors - apiError = error - } - } catch { - // Handle non-DecodingErrors - apiError = error } - - if let apiError = apiError { - do { - let decoded = try JSONDecoder().decode(APIErrorResponse.self, from: data) - onProcessingError?(self, decoded) - } catch { - onProcessingError?(self, apiError) - } + } catch let apiError { + handleApiError(apiError, data: jsonData) + } + } + + private func decodeResultType(from data: Data) throws -> ResultType? { + let decoder = JSONDecoder() + do { + return try decoder.decode(ResultType.self, from: data) + } catch let error as DecodingError { + if case .dataCorrupted = error { + // Invalid JSON - this isn't an error condition, we simply don't have all the data yet. + // We'll wait for this function to be called again, and will append the data we subsequently + // receive to the accumulatedData variable, so that we can try to process that longer, + // more complete string at that point. + + return nil // Return nil to indicate that the data is not yet complete. + } else { + // Handle other decoding errors + throw error } } } + + private func handleApiError(_ error: Error, data: Data) { + do { + let decoded = try JSONDecoder().decode(APIErrorResponse.self, from: data) + onProcessingError?(self, decoded) + } catch { + onProcessingError?(self, error) + } + } } From 9413cce79728813c095f3a472c83bfd98187f16f Mon Sep 17 00:00:00 2001 From: SDimka Date: Sun, 19 May 2024 13:46:29 +0500 Subject: [PATCH 7/8] Feat: Get token usage data for streamed chat completion response. --- Sources/OpenAI/Public/Models/ChatQuery.swift | 21 ++++++++++++++++++- .../Public/Models/ChatStreamResult.swift | 15 +++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/Sources/OpenAI/Public/Models/ChatQuery.swift b/Sources/OpenAI/Public/Models/ChatQuery.swift index c7a88649..a36e0f6f 100644 --- a/Sources/OpenAI/Public/Models/ChatQuery.swift +++ b/Sources/OpenAI/Public/Models/ChatQuery.swift @@ -67,6 +67,10 @@ public struct ChatQuery: Equatable, Codable, Streamable { /// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. /// https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format public var stream: Bool + /// If set, `in stream mode`, last chunk will contain usage data for the entire request. Can be obtained from ChatStreamResult.usage. + /// Official doc: + /// https://cookbook.openai.com/examples/how_to_stream_completions#4-how-to-get-token-usage-data-for-streamed-chat-completion-response + public let streamOptions: Self.StreamOptions? public init( messages: [Self.ChatCompletionMessageParam], @@ -86,7 +90,8 @@ public struct ChatQuery: Equatable, Codable, Streamable { topLogprobs: Int? = nil, topP: Double? = nil, user: String? = nil, - stream: Bool = false + stream: Bool = false, + streamOptions: Self.StreamOptions? = nil ) { self.messages = messages self.model = model @@ -106,6 +111,19 @@ public struct ChatQuery: Equatable, Codable, Streamable { self.topP = topP self.user = user self.stream = stream + self.streamOptions = streamOptions + } + + public struct StreamOptions: Codable, Equatable { + public var includeUsage: Bool + + public init(includeUsage: Bool) { + self.includeUsage = includeUsage + } + + public enum CodingKeys: String, CodingKey { + case includeUsage = "include_usage" + } } public enum ChatCompletionMessageParam: Codable, Equatable { @@ -851,5 +869,6 @@ public struct ChatQuery: Equatable, Codable, Streamable { case topP = "top_p" case user case stream + case streamOptions = "stream_options" } } diff --git a/Sources/OpenAI/Public/Models/ChatStreamResult.swift b/Sources/OpenAI/Public/Models/ChatStreamResult.swift index 3457b089..977b1ba7 100644 --- a/Sources/OpenAI/Public/Models/ChatStreamResult.swift +++ b/Sources/OpenAI/Public/Models/ChatStreamResult.swift @@ -115,6 +115,18 @@ public struct ChatStreamResult: Codable, Equatable { case logprobs } } + + public struct Usage: Codable, Equatable { + public let completionTokens: Int + public let promptTokens: Int + public let totalTokens: Int + + public enum CodingKeys: String, CodingKey { + case completionTokens = "completion_tokens" + case promptTokens = "prompt_tokens" + case totalTokens = "total_tokens" + } + } /// A unique identifier for the chat completion. Each chunk has the same ID. public let id: String @@ -130,6 +142,8 @@ public struct ChatStreamResult: Codable, Equatable { public let choices: [Choice] /// This fingerprint represents the backend configuration that the model runs with. Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. public let systemFingerprint: String? + /// In stream mode, if `streamOptions` are set in the request, return token usage data. + public let usage: Self.Usage? public enum CodingKeys: String, CodingKey { case id @@ -138,5 +152,6 @@ public struct ChatStreamResult: Codable, Equatable { case model case choices case systemFingerprint = "system_fingerprint" + case usage } } From 52fe7922a940a687add059e2452ecd5cf2e8e4ca Mon Sep 17 00:00:00 2001 From: Benjamin Truitt Date: Wed, 17 Jul 2024 20:09:45 -0500 Subject: [PATCH 8/8] Add new chat models. --- .../OpenAI/Public/Models/Models/Models.swift | 93 +++++++++++++++---- .../Public/Models/StreamableQuery.swift | 2 +- 2 files changed, 74 insertions(+), 21 deletions(-) diff --git a/Sources/OpenAI/Public/Models/Models/Models.swift b/Sources/OpenAI/Public/Models/Models/Models.swift index 2b356889..2d5d0031 100644 --- a/Sources/OpenAI/Public/Models/Models/Models.swift +++ b/Sources/OpenAI/Public/Models/Models/Models.swift @@ -1,40 +1,80 @@ // // Models.swift -// +// // // Created by Sergii Kryvoblotskyi on 12/19/22. // -import Foundation - +/// Defines all available OpenAI models supported by the library. public typealias Model = String + public extension Model { + // Chat Completion + // GPT-4 + + /// `gpt-4o`, currently the most advanced, multimodal flagship model that's cheaper and faster than GPT-4 Turbo. + static let gpt4_o = "gpt-4o" + + /// `gpt-4-turbo`, the latest gpt-4 model with improved instruction following, JSON mode, reproducible outputs, parallel function calling and more. Maximum of 4096 output tokens + static let gpt4_turbo_preview = "gpt-4-turbo-preview" + + /// `gpt-4-vision-preview`, able to understand images, in addition to all other GPT-4 Turbo capabilities. + static let gpt4_vision_preview = "gpt-4-vision-preview" + + /// `gpt-4-turbo-2024-04-09`, GPT-4 Turbo with Vision model. Vision requests can now use JSON mode and function calling. gpt-4-turbo currently points to this version. + static let gpt4_turbo_2024_04_09 = "gpt-4-turbo-2024-04-09" + + /// Snapshot of `gpt-4-turbo-preview` from January 25th 2024. This model reduces cases of “laziness” where the model doesn’t complete a task. Also fixes the bug impacting non-English UTF-8 generations. Maximum of 4096 output tokens + static let gpt4_0125_preview = "gpt-4-0125-preview" - // Chat Completions + /// Snapshot of `gpt-4-turbo-preview` from November 6th 2023. Improved instruction following, JSON mode, reproducible outputs, parallel function calling and more. Maximum of 4096 output tokens + @available(*, deprecated, message: "Please upgrade to the newer model") + static let gpt4_1106_preview = "gpt-4-1106-preview" - /// More capable than any GPT-3.5 model, able to do more complex tasks, and optimized for chat. Will be updated with our latest model iteration 2 weeks after it is released. + /// Most capable `gpt-4` model, outperforms any GPT-3.5 model, able to do more complex tasks, and optimized for chat. static let gpt4 = "gpt-4" - /// Snapshot of gpt-4 from March 14th 2023. Unlike gpt-4, this model will not receive updates, and will only be supported for a three month period ending on June 14th 2023. + + /// Snapshot of `gpt-4` from June 13th 2023 with function calling data. Unlike `gpt-4`, this model will not receive updates, and will be deprecated 3 months after a new version is released. + static let gpt4_0613 = "gpt-4-0613" + + /// Snapshot of `gpt-4` from March 14th 2023. Unlike gpt-4, this model will not receive updates, and will only be supported for a three month period ending on June 14th 2023. @available(*, deprecated, message: "Please upgrade to the newer model") static let gpt4_0314 = "gpt-4-0314" - /// Snapshot of gpt-4 from June 13th 2023 with function calling data. Unlike gpt-4, this model will not receive updates, and will be deprecated 3 months after a new version is released. - static let gpt4_0613 = "gpt-4-0613" - /// Same capabilities as the base gpt-4 mode but with 4x the context length. Will be updated with our latest model iteration. + + /// Same capabilities as the base `gpt-4` model but with 4x the context length. Will be updated with our latest model iteration. static let gpt4_32k = "gpt-4-32k" - /// Snapshot of gpt-4-32 from March 14th 2023. Unlike gpt-4-32k, this model will not receive updates, and will only be supported for a three month period ending on June 14th 2023. - static let gpt4_32k_0314 = "gpt-4-32k-0314" - /// Snapshot of gpt-4-32 from June 13th 2023. Unlike gpt-4-32k, this model will not receive updates, and will be deprecated 3 months after a new version is released. + + /// Snapshot of `gpt-4-32k` from June 13th 2023. Unlike `gpt-4-32k`, this model will not receive updates, and will be deprecated 3 months after a new version is released. static let gpt4_32k_0613 = "gpt-4-32k-0613" - /// Most capable GPT-3.5 model and optimized for chat at 1/10th the cost of text-davinci-003. Will be updated with our latest model iteration. + + /// Snapshot of `gpt-4-32k` from March 14th 2023. Unlike `gpt-4-32k`, this model will not receive updates, and will only be supported for a three month period ending on June 14th 2023. + @available(*, deprecated, message: "Please upgrade to the newer model") + static let gpt4_32k_0314 = "gpt-4-32k-0314" + + // GPT-3.5 + + /// Most capable `gpt-3.5-turbo` model and optimized for chat. Will be updated with our latest model iteration. static let gpt3_5Turbo = "gpt-3.5-turbo" - /// Snapshot of gpt-3.5-turbo from March 1st 2023. Unlike gpt-3.5-turbo, this model will not receive updates, and will only be supported for a three month period ending on June 1st 2023. + + /// Snapshot of `gpt-3.5-turbo` from January 25th 2024. Decreased prices by 50%. Various improvements including higher accuracy at responding in requested formats and a fix for a bug which caused a text encoding issue for non-English language function calls. + static let gpt3_5Turbo_0125 = "gpt-3.5-turbo-0125" + + /// Snapshot of `gpt-3.5-turbo` from November 6th 2023. The latest `gpt-3.5-turbo` model with improved instruction following, JSON mode, reproducible outputs, parallel function calling and more. + @available(*, deprecated, message: "Please upgrade to the newer model") + static let gpt3_5Turbo_1106 = "gpt-3.5-turbo-1106" + + /// Snapshot of `gpt-3.5-turbo` from June 13th 2023 with function calling data. Unlike `gpt-3.5-turbo`, this model will not receive updates, and will be deprecated 3 months after a new version is released. + @available(*, deprecated, message: "Please upgrade to the newer model") + static let gpt3_5Turbo_0613 = "gpt-3.5-turbo-0613" + + /// Snapshot of `gpt-3.5-turbo` from March 1st 2023. Unlike `gpt-3.5-turbo`, this model will not receive updates, and will only be supported for a three month period ending on June 1st 2023. @available(*, deprecated, message: "Please upgrade to the newer model") - static let gpt3_5Turbo0301 = "gpt-3.5-turbo-0301" - /// Snapshot of gpt-3.5-turbo from June 13th 2023 with function calling data. Unlike gpt-3.5-turbo, this model will not receive updates, and will be deprecated 3 months after a new version is released. - static let gpt3_5Turbo0613 = "gpt-3.5-turbo-0613" - /// Same capabilities as the standard gpt-3.5-turbo model but with 4 times the context. + static let gpt3_5Turbo_0301 = "gpt-3.5-turbo-0301" + + /// Same capabilities as the standard `gpt-3.5-turbo` model but with 4 times the context. static let gpt3_5Turbo_16k = "gpt-3.5-turbo-16k" - /// Snapshot of gpt-3.5-turbo-16k from June 13th 2023. Unlike gpt-3.5-turbo-16k, this model will not receive updates, and will be deprecated 3 months after a new version is released. + + /// Snapshot of `gpt-3.5-turbo-16k` from June 13th 2023. Unlike `gpt-3.5-turbo-16k`, this model will not receive updates, and will be deprecated 3 months after a new version is released. static let gpt3_5Turbo_16k_0613 = "gpt-3.5-turbo-16k-0613" // Completions @@ -55,9 +95,20 @@ public extension Model { static let textDavinci_001 = "text-davinci-001" static let codeDavinciEdit_001 = "code-davinci-edit-001" + // Speech + + /// The latest text to speech model, optimized for speed. + static let tts_1 = "tts-1" + /// The latest text to speech model, optimized for quality. + static let tts_1_hd = "tts-1-hd" + // Transcriptions / Translations static let whisper_1 = "whisper-1" + + // Image Generation + static let dall_e_2 = "dall-e-2" + static let dall_e_3 = "dall-e-3" // Fine Tunes @@ -76,6 +127,8 @@ public extension Model { static let textSearchAda = "text-search-ada-doc-001" static let textSearchBabbageDoc = "text-search-babbage-doc-001" static let textSearchBabbageQuery001 = "text-search-babbage-query-001" + static let textEmbedding3 = "text-embedding-3-small" + static let textEmbedding3Large = "text-embedding-3-large" // Moderations @@ -83,5 +136,5 @@ public extension Model { static let textModerationStable = "text-moderation-stable" /// Most capable moderation model. Accuracy will be slightly higher than the stable model. static let textModerationLatest = "text-moderation-latest" - static let moderation = "text-moderation-001" + static let moderation = "text-moderation-007" } diff --git a/Sources/OpenAI/Public/Models/StreamableQuery.swift b/Sources/OpenAI/Public/Models/StreamableQuery.swift index 1210432f..829cbf7f 100644 --- a/Sources/OpenAI/Public/Models/StreamableQuery.swift +++ b/Sources/OpenAI/Public/Models/StreamableQuery.swift @@ -17,7 +17,7 @@ extension Streamable { func makeStreamable() -> Self { var copy = self - copy.stream = true + copy.stream = true return copy } }