Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom http headers #93

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
76 changes: 33 additions & 43 deletions Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@ final public class OpenAI: OpenAIProtocol {
/// Default request timeout
public let timeoutInterval: TimeInterval

public init(token: String, organizationIdentifier: String? = nil, host: String = "api.openai.com", timeoutInterval: TimeInterval = 60.0) {
/// A dictionary of HTTP headers to include in requests to OpenAI (or any proxy server the request may be sent to)
public var additionalHeaders: [String: String]

public init(token: String, organizationIdentifier: String? = nil, host: String = "api.openai.com", timeoutInterval: TimeInterval = 60.0, additionalHeaders: [String: String]? = nil) {
self.token = token
self.organizationIdentifier = organizationIdentifier
self.host = host
self.timeoutInterval = timeoutInterval
self.additionalHeaders = additionalHeaders ?? [:]
}
}

private let session: URLSessionProtocol
private let dataTaskManager: URLSessionDataTaskManager
private var streamingSessions: [NSObject] = []

public let configuration: Configuration
Expand All @@ -50,102 +55,87 @@ final public class OpenAI: OpenAIProtocol {
init(configuration: Configuration, session: URLSessionProtocol) {
self.configuration = configuration
self.session = session
self.dataTaskManager = URLSessionDataTaskManager(session: session)
}

public convenience init(configuration: Configuration, session: URLSession = URLSession.shared) {
self.init(configuration: configuration, session: session as URLSessionProtocol)
}

public func completions(query: CompletionsQuery, completion: @escaping (Result<CompletionsResult, Error>) -> Void) {
performRequest(request: JSONRequest<CompletionsResult>(body: query, url: buildURL(path: .completions)), completion: completion)
performRequest(request: JSONRequest<CompletionsQuery, CompletionsResult>(body: query, url: buildURL(path: .completions), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}

public func completionsStream(query: CompletionsQuery, onResult: @escaping (Result<CompletionsResult, Error>) -> Void, completion: ((Error?) -> Void)?) {
performSteamingRequest(request: JSONRequest<CompletionsResult>(body: query.makeStreamable(), url: buildURL(path: .completions)), onResult: onResult, completion: completion)
performSteamingRequest(request: JSONRequest<CompletionsQuery, CompletionsResult>(body: query.makeStreamable(), url: buildURL(path: .completions), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), onResult: onResult, completion: completion)
}

public func images(query: ImagesQuery, completion: @escaping (Result<ImagesResult, Error>) -> Void) {
performRequest(request: JSONRequest<ImagesResult>(body: query, url: buildURL(path: .images)), completion: completion)
performRequest(request: JSONRequest<ImagesQuery, ImagesResult>(body: query, url: buildURL(path: .images), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}

public func embeddings(query: EmbeddingsQuery, completion: @escaping (Result<EmbeddingsResult, Error>) -> Void) {
performRequest(request: JSONRequest<EmbeddingsResult>(body: query, url: buildURL(path: .embeddings)), completion: completion)
performRequest(request: JSONRequest<EmbeddingsQuery, EmbeddingsResult>(body: query, url: buildURL(path: .embeddings), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}

public func chats(query: ChatQuery, completion: @escaping (Result<ChatResult, Error>) -> Void) {
performRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
performRequest(request: JSONRequest<ChatQuery, ChatResult>(body: query, url: buildURL(path: .chats), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}

public func chatsStream(query: ChatQuery, onResult: @escaping (Result<ChatStreamResult, Error>) -> Void, completion: ((Error?) -> Void)?) {
performSteamingRequest(request: JSONRequest<ChatResult>(body: query.makeStreamable(), url: buildURL(path: .chats)), onResult: onResult, completion: completion)
performSteamingRequest(request: JSONRequest<ChatQuery, ChatResult>(body: query.makeStreamable(), url: buildURL(path: .chats), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), onResult: onResult, completion: completion)
}

public func edits(query: EditsQuery, completion: @escaping (Result<EditsResult, Error>) -> Void) {
performRequest(request: JSONRequest<EditsResult>(body: query, url: buildURL(path: .edits)), completion: completion)
performRequest(request: JSONRequest<EditsQuery, EditsResult>(body: query, url: buildURL(path: .edits), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}

public func model(query: ModelQuery, completion: @escaping (Result<ModelResult, Error>) -> Void) {
performRequest(request: JSONRequest<ModelResult>(url: buildURL(path: .models.withPath(query.model)), method: "GET"), completion: completion)
performRequest(request: JSONRequest<ModelQuery, ModelResult>(url: buildURL(path: .models.withPath(query.model)), method: "GET", headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}

public func models(completion: @escaping (Result<ModelsResult, Error>) -> Void) {
performRequest(request: JSONRequest<ModelsResult>(url: buildURL(path: .models), method: "GET"), completion: completion)
performRequest(request: JSONRequest<Empty, ModelsResult>(url: buildURL(path: .models), method: "GET", headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}

public func moderations(query: ModerationsQuery, completion: @escaping (Result<ModerationsResult, Error>) -> Void) {
performRequest(request: JSONRequest<ModerationsResult>(body: query, url: buildURL(path: .moderations)), completion: completion)
performRequest(request: JSONRequest<ModerationsQuery, ModerationsResult>(body: query, url: buildURL(path: .moderations), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}

public func audioTranscriptions(query: AudioTranscriptionQuery, completion: @escaping (Result<AudioTranscriptionResult, Error>) -> Void) {
performRequest(request: MultipartFormDataRequest<AudioTranscriptionResult>(body: query, url: buildURL(path: .audioTranscriptions)), completion: completion)
performRequest(request: MultipartFormDataRequest<AudioTranscriptionQuery, AudioTranscriptionResult>(body: query, url: buildURL(path: .audioTranscriptions), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}

public func audioTranslations(query: AudioTranslationQuery, completion: @escaping (Result<AudioTranslationResult, Error>) -> Void) {
performRequest(request: MultipartFormDataRequest<AudioTranslationResult>(body: query, url: buildURL(path: .audioTranslations)), completion: completion)
performRequest(request: MultipartFormDataRequest<AudioTranslationQuery, AudioTranslationResult>(body: query, url: buildURL(path: .audioTranslations), headers: generateHeaders(), timeoutInterval: configuration.timeoutInterval), completion: completion)
}
}

extension OpenAI {
internal func generateHeaders() -> [String: String] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would change the naming to defaultHeaders to add the knowledge that it is going to be appended further.
What do you think?

var headers = configuration.additionalHeaders
headers["Authorization"] = "Bearer \(configuration.token)"
if let organizationIdentifier = configuration.organizationIdentifier {
headers["OpenAI-Organization"] = organizationIdentifier
}
return headers
}
}

extension OpenAI {

func performRequest<ResultType: Codable>(request: any URLRequestBuildable, completion: @escaping (Result<ResultType, Error>) -> Void) {
do {
let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)
let task = session.dataTask(with: request) { data, _, error in
if let error = error {
completion(.failure(error))
return
}
guard let data = data else {
completion(.failure(OpenAIError.emptyData))
return
}

var apiError: Error? = nil
do {
let decoded = try JSONDecoder().decode(ResultType.self, from: data)
completion(.success(decoded))
} catch {
apiError = error
}

if let apiError = apiError {
do {
let decoded = try JSONDecoder().decode(APIErrorResponse.self, from: data)
completion(.failure(decoded))
} catch {
completion(.failure(apiError))
}
}
}
task.resume()
let request = try request.build()
dataTaskManager.performDataTask(with: request, completion: completion)
} catch {
completion(.failure(error))
}
}

func performSteamingRequest<ResultType: Codable>(request: any URLRequestBuildable, onResult: @escaping (Result<ResultType, Error>) -> Void, completion: ((Error?) -> Void)?) {
do {
let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)
let request = try request.build()
let session = StreamingSession<ResultType>(urlRequest: request)
session.onReceiveContent = {_, object in
onResult(.success(object))
Expand Down
36 changes: 36 additions & 0 deletions Sources/OpenAI/Private/BaseRequest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//
// BaseRequest.swift
//
//
// Created by Benjamin Truitt on 7/28/23.
//

import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

protocol RequestBuildable {
var url: URL { get }
var headers: [String: String] { get }
var method: String { get }
var timeoutInterval: TimeInterval { get }

func getBody() throws -> Data?
}

protocol BaseRequest: RequestBuildable {
func build() throws -> URLRequest
}

extension BaseRequest {
func build() throws -> URLRequest {
var request = URLRequest(url: url, timeoutInterval: timeoutInterval)

request.httpMethod = method
request.httpBody = try getBody()
request.allHTTPHeaderFields = headers

return request
}
}
14 changes: 14 additions & 0 deletions Sources/OpenAI/Private/Empty.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//
// Empty.swift
//
//
// Created by Benjamin Truitt on 7/29/23.
//

import Foundation

struct Empty: Encodable {
func encode(to encoder: Encoder) throws {
// Do nothing
}
}
39 changes: 14 additions & 25 deletions Sources/OpenAI/Private/JSONRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,26 @@
// Created by Sergii Kryvoblotskyi on 12/19/22.
//


import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

final class JSONRequest<ResultType> {

let body: Codable?
let url: URL
let method: String

init(body: Codable? = nil, url: URL, method: String = "POST") {
struct JSONRequest<BodyType: Encodable, ResultType>: BaseRequest, URLRequestBuildable {
var body: BodyType?
var url: URL
var method: String = "POST"
var headers: [String: String]
var timeoutInterval: TimeInterval

init(body: BodyType? = nil, url: URL, method: String = "POST", headers: [String: String]?, timeoutInterval: TimeInterval) {
self.body = body
self.url = url
self.method = method
self.headers = headers ?? [:]
self.headers["Content-Type"] = "application/json"
self.timeoutInterval = timeoutInterval
}
}

extension JSONRequest: URLRequestBuildable {

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
var request = URLRequest(url: url, timeoutInterval: timeoutInterval)
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
if let organizationIdentifier {
request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization")
}
request.httpMethod = method
if let body = body {
request.httpBody = try JSONEncoder().encode(body)
}
return request
func getBody() throws -> Data? {
return try body.map { try JSONEncoder().encode($0) }
}
}
39 changes: 15 additions & 24 deletions Sources/OpenAI/Private/MultipartFormDataRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,27 @@
//

import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

final class MultipartFormDataRequest<ResultType> {
struct MultipartFormDataRequest<BodyType: MultipartFormDataBodyEncodable, ResultType>: BaseRequest, URLRequestBuildable {
var body: MultipartFormDataBodyEncodable?
var url: URL
var headers: [String: String]
var method: String = "POST"
var timeoutInterval: TimeInterval
var boundary: String = UUID().uuidString

let body: MultipartFormDataBodyEncodable
let url: URL
let method: String

init(body: MultipartFormDataBodyEncodable, url: URL, method: String = "POST") {
init(body: BodyType?, url: URL, method: String = "POST", headers: [String: String]?, timeoutInterval: TimeInterval) {
self.body = body
self.url = url
self.method = method
self.headers = headers ?? [:]
self.headers["Content-Type"] = "multipart/form-data; boundary=\(boundary)"
self.timeoutInterval = timeoutInterval
}
}

extension MultipartFormDataRequest: URLRequestBuildable {

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
var request = URLRequest(url: url)
let boundary: String = UUID().uuidString
request.timeoutInterval = timeoutInterval
request.httpMethod = method
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
request.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type")
if let organizationIdentifier {
request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization")
}
request.httpBody = body.encode(boundary: boundary)
return request
func getBody() throws -> Data? {
return body?.encode(boundary: boundary)
}
}


Loading