Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support reverse proxy (includes demo) #176

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion Demo/App/APIKeyModalView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@ struct APIKeyModalView: View {
let isMandatory: Bool

@Binding private var apiKey: String
@Binding private var proxy: String
@State private var internalAPIKey: String
@State private var internalProxy: String

public init(
apiKey: Binding<String>,
proxy: Binding<String>,
isMandatory: Bool = true
) {
self._apiKey = apiKey
self._proxy = proxy
self._internalAPIKey = State(initialValue: apiKey.wrappedValue)
self._internalProxy = State(initialValue: proxy.wrappedValue)
self.isMandatory = isMandatory
}

Expand Down Expand Up @@ -68,12 +73,41 @@ struct APIKeyModalView: View {
.background(Color.white)
.clipShape(RoundedRectangle(cornerRadius: 8))


VStack(alignment: .leading, spacing: 8) {
Text(
"Set your proxy."
)
.font(.caption)
}

TextEditor(
text: $internalProxy
)
.frame(height: 120)
.font(.caption)
.padding(8)
.background(
RoundedRectangle(
cornerRadius: 8
)
.stroke(
strokeColor,
lineWidth: 1
)
)
.padding(4)
.background(Color.white)
.clipShape(RoundedRectangle(cornerRadius: 8))


if isMandatory {
HStack {
Spacer()

Button {
apiKey = internalAPIKey
proxy = internalProxy
dismiss()
} label: {
Text(
Expand All @@ -82,7 +116,7 @@ struct APIKeyModalView: View {
.padding(8)
}
.buttonStyle(.borderedProminent)
.disabled(internalAPIKey.isEmpty)
.disabled(internalAPIKey.isEmpty && internalProxy.isEmpty || !internalAPIKey.isEmpty && !internalProxy.isEmpty)

Spacer()
}
Expand All @@ -102,18 +136,34 @@ struct APIKeyModalView: View {
}
}
}
.padding()
.navigationTitle("OpenAI Reverse Proxy")
.toolbar {
ToolbarItem(placement: .primaryAction) {
if isMandatory {
EmptyView()
} else {
Button("Close") {
proxy = internalProxy
dismiss()
}
}
}
}
}
}
}

struct APIKeyModalView_Previews: PreviewProvider {
struct APIKeyModalView_PreviewsContainerView: View {
@State var apiKey = ""
@State var proxy = ""
let isMandatory: Bool

var body: some View {
APIKeyModalView(
apiKey: $apiKey,
proxy: $proxy,
isMandatory: isMandatory
)
}
Expand Down
22 changes: 19 additions & 3 deletions Demo/App/APIProvidedView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import SwiftUI

struct APIProvidedView: View {
@Binding var apiKey: String
@Binding var proxy: String
@StateObject var chatStore: ChatStore
@StateObject var imageStore: ImageStore
@StateObject var miscStore: MiscStore
Expand All @@ -21,23 +22,32 @@ struct APIProvidedView: View {

init(
apiKey: Binding<String>,
proxy: Binding<String>,
idProvider: @escaping () -> String
) {
self._apiKey = apiKey
self._proxy = proxy
var client: OpenAI? = nil
if apiKey.wrappedValue.isEmpty && !proxy.wrappedValue.isEmpty {
client = OpenAI(proxy: proxy.wrappedValue)
} else {
client = OpenAI(apiToken: apiKey.wrappedValue)
}

self._chatStore = StateObject(
wrappedValue: ChatStore(
openAIClient: OpenAI(apiToken: apiKey.wrappedValue),
openAIClient: client!,
idProvider: idProvider
)
)
self._imageStore = StateObject(
wrappedValue: ImageStore(
openAIClient: OpenAI(apiToken: apiKey.wrappedValue)
openAIClient: client!
)
)
self._miscStore = StateObject(
wrappedValue: MiscStore(
openAIClient: OpenAI(apiToken: apiKey.wrappedValue)
openAIClient: client!
)
)
}
Expand All @@ -54,5 +64,11 @@ struct APIProvidedView: View {
imageStore.openAIClient = client
miscStore.openAIClient = client
}
.onChange(of: proxy) { newProxy in
let client = OpenAI(apiToken: newProxy)
chatStore.openAIClient = client
imageStore.openAIClient = client
miscStore.openAIClient = client
}
}
}
6 changes: 4 additions & 2 deletions Demo/App/DemoApp.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import SwiftUI
@main
struct DemoApp: App {
@AppStorage("apiKey") var apiKey: String = ""
@AppStorage("proxy") var proxy: String = ""
@State var isShowingAPIConfigModal: Bool = true

let idProvider: () -> String
Expand All @@ -29,16 +30,17 @@ struct DemoApp: App {
Group {
APIProvidedView(
apiKey: $apiKey,
proxy: $proxy,
idProvider: idProvider
)
}
#if os(iOS)
.fullScreenCover(isPresented: $isShowingAPIConfigModal) {
APIKeyModalView(apiKey: $apiKey)
APIKeyModalView(apiKey: $apiKey, proxy: $proxy)
}
#elseif os(macOS)
.popover(isPresented: $isShowingAPIConfigModal) {
APIKeyModalView(apiKey: $apiKey)
APIKeyModalView(apiKey: $apiKey, proxy: $proxy)
}
#endif
}
Expand Down
19 changes: 15 additions & 4 deletions Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ final public class OpenAI: OpenAIProtocol {
public struct Configuration {

/// OpenAI API token. See https://platform.openai.com/docs/api-reference/authentication
public let token: String
public let token: String?

/// Optional OpenAI organization identifier. See https://platform.openai.com/docs/api-reference/authentication
public let organizationIdentifier: String?

Expand All @@ -26,19 +26,30 @@ final public class OpenAI: OpenAIProtocol {
/// Default request timeout
public let timeoutInterval: TimeInterval

public init(token: String, organizationIdentifier: String? = nil, host: String = "api.openai.com", timeoutInterval: TimeInterval = 60.0) {
self.token = token
public init(organizationIdentifier: String? = nil, host: String, timeoutInterval: TimeInterval = 60.0) {
self.token = nil
self.organizationIdentifier = organizationIdentifier
self.host = host
self.timeoutInterval = timeoutInterval
}

public init(token: String, organizationIdentifier: String? = nil, timeoutInterval: TimeInterval = 60.0) {
self.token = token
self.organizationIdentifier = organizationIdentifier
self.host = "api.openai.com"
self.timeoutInterval = timeoutInterval
}
}

private let session: URLSessionProtocol
private var streamingSessions = ArrayWithThreadSafety<NSObject>()

public let configuration: Configuration

public convenience init(proxy: String) {
self.init(configuration: Configuration(host: proxy), session: URLSession.shared)
}

public convenience init(apiToken: String) {
self.init(configuration: Configuration(token: apiToken), session: URLSession.shared)
}
Expand Down
6 changes: 4 additions & 2 deletions Sources/OpenAI/Private/JSONRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ final class JSONRequest<ResultType> {

extension JSONRequest: URLRequestBuildable {

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
func build(token: String?, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
var request = URLRequest(url: url, timeoutInterval: timeoutInterval)
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
if let token {
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
}
if let organizationIdentifier {
request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization")
}
Expand Down
6 changes: 4 additions & 2 deletions Sources/OpenAI/Private/MultipartFormDataRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ final class MultipartFormDataRequest<ResultType> {

extension MultipartFormDataRequest: URLRequestBuildable {

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
func build(token: String?, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest {
var request = URLRequest(url: url)
let boundary: String = UUID().uuidString
request.timeoutInterval = timeoutInterval
request.httpMethod = method
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
if let token {
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
}
request.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type")
if let organizationIdentifier {
request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization")
Expand Down
2 changes: 1 addition & 1 deletion Sources/OpenAI/Private/URLRequestBuildable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ protocol URLRequestBuildable {

associatedtype ResultType

func build(token: String, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest
func build(token: String?, organizationIdentifier: String?, timeoutInterval: TimeInterval) throws -> URLRequest
}
6 changes: 3 additions & 3 deletions Tests/OpenAITests/OpenAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class OpenAITests: XCTestCase {
let jsonRequest = JSONRequest<ChatResult>(body: completionQuery, url: URL(string: "http://google.com")!)
let urlRequest = try jsonRequest.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)

XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(configuration.token)")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(configuration.token!)")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Content-Type"), "application/json")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "OpenAI-Organization"), configuration.organizationIdentifier)
XCTAssertEqual(urlRequest.timeoutInterval, configuration.timeoutInterval)
Expand All @@ -369,7 +369,7 @@ class OpenAITests: XCTestCase {
let jsonRequest = MultipartFormDataRequest<ChatResult>(body: completionQuery, url: URL(string: "http://google.com")!)
let urlRequest = try jsonRequest.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)

XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(configuration.token)")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "Authorization"), "Bearer \(configuration.token!)")
XCTAssertEqual(urlRequest.value(forHTTPHeaderField: "OpenAI-Organization"), configuration.organizationIdentifier)
XCTAssertEqual(urlRequest.timeoutInterval, configuration.timeoutInterval)
}
Expand All @@ -382,7 +382,7 @@ class OpenAITests: XCTestCase {
}

func testCustomURLBuilt() {
let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14)
let configuration = OpenAI.Configuration(organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14)
let openAI = OpenAI(configuration: configuration, session: self.urlSession)
let chatsURL = openAI.buildURL(path: .chats)
XCTAssertEqual(chatsURL, URL(string: "https://my.host.com/v1/chat/completions"))
Expand Down
Loading