diff --git a/Core/Package.swift b/Core/Package.swift index 184ee54..e33da87 100644 --- a/Core/Package.swift +++ b/Core/Package.swift @@ -66,6 +66,10 @@ let package = Package( ] ), + .testTarget( + name: "CodeCompletionServiceTests", + dependencies: ["CodeCompletionService"] + ), .testTarget( name: "FundamentalTests", dependencies: ["Fundamental"] diff --git a/Core/Sources/CodeCompletionService/AzureOpenAIService.swift b/Core/Sources/CodeCompletionService/API/AzureOpenAIService.swift similarity index 100% rename from Core/Sources/CodeCompletionService/AzureOpenAIService.swift rename to Core/Sources/CodeCompletionService/API/AzureOpenAIService.swift diff --git a/Core/Sources/CodeCompletionService/GoogleGeminiService.swift b/Core/Sources/CodeCompletionService/API/GoogleGeminiService.swift similarity index 100% rename from Core/Sources/CodeCompletionService/GoogleGeminiService.swift rename to Core/Sources/CodeCompletionService/API/GoogleGeminiService.swift diff --git a/Core/Sources/CodeCompletionService/OllamaService.swift b/Core/Sources/CodeCompletionService/API/OllamaService.swift similarity index 100% rename from Core/Sources/CodeCompletionService/OllamaService.swift rename to Core/Sources/CodeCompletionService/API/OllamaService.swift diff --git a/Core/Sources/CodeCompletionService/OpenAIService.swift b/Core/Sources/CodeCompletionService/API/OpenAIService.swift similarity index 100% rename from Core/Sources/CodeCompletionService/OpenAIService.swift rename to Core/Sources/CodeCompletionService/API/OpenAIService.swift diff --git a/Core/Sources/CodeCompletionService/TabbyService.swift b/Core/Sources/CodeCompletionService/API/TabbyService.swift similarity index 100% rename from Core/Sources/CodeCompletionService/TabbyService.swift rename to Core/Sources/CodeCompletionService/API/TabbyService.swift diff --git a/Core/Sources/CodeCompletionService/CodeCompletionService.swift b/Core/Sources/CodeCompletionService/CodeCompletionService.swift index ad8c69b..60cbd83 100644 --- a/Core/Sources/CodeCompletionService/CodeCompletionService.swift +++ b/Core/Sources/CodeCompletionService/CodeCompletionService.swift @@ -11,17 +11,20 @@ protocol CodeCompletionServiceType { extension CodeCompletionServiceType { func getCompletions( _ request: PromptStrategy, + streamStopStrategy: StreamStopStrategy, count: Int ) async throws -> [String] { try await withThrowingTaskGroup(of: String.self) { group in for _ in 0.. [String] { @@ -71,13 +75,18 @@ public struct CodeCompletionService { } }()) - let result = try await service.getCompletions(prompt, count: count) + let result = try await service.getCompletions( + prompt, + streamStopStrategy: streamStopStrategy, + count: count + ) try Task.checkCancellation() return result } public func getCompletions( _ prompt: PromptStrategy, + streamStopStrategy: StreamStopStrategy, model: ChatModel, count: Int ) async throws -> [String] { @@ -92,7 +101,11 @@ public struct CodeCompletionService { stopWords: prompt.stopWords, apiKey: apiKey ) - let result = try await service.getCompletions(prompt, count: count) + let result = try await service.getCompletions( + prompt, + streamStopStrategy: streamStopStrategy, + count: count + ) try Task.checkCancellation() return result case .azureOpenAI: @@ -103,7 +116,11 @@ public struct CodeCompletionService { stopWords: prompt.stopWords, apiKey: apiKey ) - let result = try await service.getCompletions(prompt, count: count) + let result = try await service.getCompletions( + prompt, + streamStopStrategy: streamStopStrategy, + count: count + ) try Task.checkCancellation() return result case .googleAI: @@ -112,7 +129,11 @@ public struct CodeCompletionService { maxToken: model.info.maxTokens, apiKey: apiKey ) - let result = try await service.getCompletions(prompt, count: count) + let result = try await service.getCompletions( + prompt, + streamStopStrategy: streamStopStrategy, + count: count + ) try Task.checkCancellation() return result case .ollama: @@ -124,7 +145,11 @@ public struct CodeCompletionService { keepAlive: model.info.ollamaInfo.keepAlive, format: .none ) - let result = try await service.getCompletions(prompt, count: count) + let result = try await service.getCompletions( + prompt, + streamStopStrategy: streamStopStrategy, + count: count + ) try Task.checkCancellation() return result case .unknown: @@ -134,6 +159,7 @@ public struct CodeCompletionService { public func getCompletions( _ prompt: PromptStrategy, + streamStopStrategy: StreamStopStrategy, model: CompletionModel, count: Int ) async throws -> [String] { @@ -148,7 +174,11 @@ public struct CodeCompletionService { stopWords: prompt.stopWords, apiKey: apiKey ) - let result = try await service.getCompletions(prompt, count: count) + let result = try await service.getCompletions( + prompt, + streamStopStrategy: streamStopStrategy, + count: count + ) try Task.checkCancellation() return result case .azureOpenAI: @@ -159,7 +189,11 @@ public struct CodeCompletionService { stopWords: prompt.stopWords, apiKey: apiKey ) - let result = try await service.getCompletions(prompt, count: count) + let result = try await service.getCompletions( + prompt, + streamStopStrategy: streamStopStrategy, + count: count + ) try Task.checkCancellation() return result case .ollama: @@ -171,7 +205,11 @@ public struct CodeCompletionService { keepAlive: model.info.ollamaInfo.keepAlive, format: .none ) - let result = try await service.getCompletions(prompt, count: count) + let result = try await service.getCompletions( + prompt, + streamStopStrategy: streamStopStrategy, + count: count + ) try Task.checkCancellation() return result case .unknown: diff --git a/Core/Sources/CodeCompletionService/StreamLineLimiter.swift b/Core/Sources/CodeCompletionService/StreamLineLimiter.swift new file mode 100644 index 0000000..c8e78db --- /dev/null +++ b/Core/Sources/CodeCompletionService/StreamLineLimiter.swift @@ -0,0 +1,62 @@ +import Foundation + +final class StreamLineLimiter { + public private(set) var result = "" + private var currentLine = "" + private var existedLines = [String]() + private let lineLimit: Int + private let strategy: any StreamStopStrategy + + enum PushResult: Equatable { + case `continue` + case finish(String) + } + + init( + lineLimit: Int = UserDefaults.shared.value(for: \.maxNumberOfLinesOfSuggestion), + strategy: any StreamStopStrategy + ) { + self.lineLimit = lineLimit + self.strategy = strategy + } + + func push(_ token: String) -> PushResult { + currentLine.append(token) + if let newLine = currentLine.last(where: { $0.isNewline }) { + let lines = currentLine + .breakLines(proposedLineEnding: String(newLine), appendLineBreakToLastLine: false) + let (newLines, lastLine) = lines.headAndTail + existedLines.append(contentsOf: newLines) + currentLine = lastLine ?? "" + } + + let stopResult = if lineLimit <= 0 { + StreamStopStrategyResult.continue + } else { + strategy.shouldStop( + existedLines: existedLines, + currentLine: currentLine, + proposedLineLimit: lineLimit + ) + } + + switch stopResult { + case .continue: + result.append(token) + return .continue + case let .stop(appendingNewContent): + if appendingNewContent { + result.append(token) + } + return .finish(result) + } + } +} + +extension Array { + var headAndTail: ([Element], Element?) { + guard let tail = last else { return ([], nil) } + return (Array(dropLast()), tail) + } +} + diff --git a/Core/Sources/CodeCompletionService/StreamStopStrategy/DefaultStreamStopStrategy.swift b/Core/Sources/CodeCompletionService/StreamStopStrategy/DefaultStreamStopStrategy.swift new file mode 100644 index 0000000..8e9c359 --- /dev/null +++ b/Core/Sources/CodeCompletionService/StreamStopStrategy/DefaultStreamStopStrategy.swift @@ -0,0 +1,15 @@ +public struct DefaultStreamStopStrategy: StreamStopStrategy { + public init() {} + + public func shouldStop( + existedLines: [String], + currentLine: String, + proposedLineLimit: Int + ) -> StreamStopStrategyResult { + if existedLines.count >= proposedLineLimit { + return .stop(appendingNewContent: true) + } + return .continue + } +} + diff --git a/Core/Sources/CodeCompletionService/StreamStopStrategy/NeverStreamStopStrategy.swift b/Core/Sources/CodeCompletionService/StreamStopStrategy/NeverStreamStopStrategy.swift new file mode 100644 index 0000000..fa924dd --- /dev/null +++ b/Core/Sources/CodeCompletionService/StreamStopStrategy/NeverStreamStopStrategy.swift @@ -0,0 +1,12 @@ +public struct NeverStreamStopStrategy: StreamStopStrategy { + public init() {} + + public func shouldStop( + existedLines: [String], + currentLine: String, + proposedLineLimit: Int + ) -> StreamStopStrategyResult { + .continue + } +} + diff --git a/Core/Sources/CodeCompletionService/StreamStopStrategy/OpeningTagBasedStreamStopStrategy.swift b/Core/Sources/CodeCompletionService/StreamStopStrategy/OpeningTagBasedStreamStopStrategy.swift new file mode 100644 index 0000000..02cd381 --- /dev/null +++ b/Core/Sources/CodeCompletionService/StreamStopStrategy/OpeningTagBasedStreamStopStrategy.swift @@ -0,0 +1,31 @@ +import Foundation + +public struct OpeningTagBasedStreamStopStrategy: StreamStopStrategy { + public let openingTag: String + public let toleranceIfNoOpeningTagFound: Int + + public init(openingTag: String, toleranceIfNoOpeningTagFound: Int) { + self.openingTag = openingTag + self.toleranceIfNoOpeningTagFound = toleranceIfNoOpeningTagFound + } + + public func shouldStop( + existedLines: [String], + currentLine: String, + proposedLineLimit: Int + ) -> StreamStopStrategyResult { + if let index = existedLines.firstIndex(where: { $0.contains(openingTag) }) { + if existedLines.count - index - 1 >= proposedLineLimit { + return .stop(appendingNewContent: true) + } + return .continue + } else { + if existedLines.count >= proposedLineLimit + toleranceIfNoOpeningTagFound { + return .stop(appendingNewContent: true) + } else { + return .continue + } + } + } +} + diff --git a/Core/Sources/CodeCompletionService/StreamStopStrategy/StreamStopStrategy.swift b/Core/Sources/CodeCompletionService/StreamStopStrategy/StreamStopStrategy.swift new file mode 100644 index 0000000..b493d55 --- /dev/null +++ b/Core/Sources/CodeCompletionService/StreamStopStrategy/StreamStopStrategy.swift @@ -0,0 +1,12 @@ +import Foundation + +public enum StreamStopStrategyResult { + case `continue` + case stop(appendingNewContent: Bool) +} + +public protocol StreamStopStrategy { + func shouldStop(existedLines: [String], currentLine: String, proposedLineLimit: Int) + -> StreamStopStrategyResult +} + diff --git a/Core/Sources/Storage/Preferences.swift b/Core/Sources/Storage/Preferences.swift index 2096aee..9337aec 100644 --- a/Core/Sources/Storage/Preferences.swift +++ b/Core/Sources/Storage/Preferences.swift @@ -80,6 +80,13 @@ public extension UserDefaultPreferenceKeys { key: "CustomSuggestionService-TabbyModel" ) } + + var maxNumberOfLinesOfSuggestion: PreferenceKey { + .init( + defaultValue: 0, + key: "CustomSuggestionService-MaxNumberOfLinesOfSuggestion" + ) + } var installBetaBuild: PreferenceKey { .init(defaultValue: false, key: "CustomSuggestionService-InstallBetaBuild") diff --git a/Core/Sources/SuggestionService/RawSuggestionPostProcessing/DefaultRawSuggestionPostProcessingStrategy.swift b/Core/Sources/SuggestionService/RawSuggestionPostProcessing/DefaultRawSuggestionPostProcessingStrategy.swift index e7bbfd3..d62360d 100644 --- a/Core/Sources/SuggestionService/RawSuggestionPostProcessing/DefaultRawSuggestionPostProcessingStrategy.swift +++ b/Core/Sources/SuggestionService/RawSuggestionPostProcessing/DefaultRawSuggestionPostProcessingStrategy.swift @@ -5,36 +5,28 @@ protocol RawSuggestionPostProcessingStrategy { func postProcess(rawSuggestion: String, infillPrefix: String, suffix: [String]) -> String } -extension RawSuggestionPostProcessingStrategy { - func removeTrailingNewlinesAndWhitespace(from string: String) -> String { - var text = string[...] - while let last = text.last, last.isNewline || last.isWhitespace { - text = text.dropLast(1) - } - return String(text) - } -} - struct DefaultRawSuggestionPostProcessingStrategy: RawSuggestionPostProcessingStrategy { - let openingCodeTag: String - let closingCodeTag: String + let codeWrappingTags: (opening: String, closing: String)? func postProcess(rawSuggestion: String, infillPrefix: String, suffix: [String]) -> String { var suggestion = extractSuggestion(from: rawSuggestion) removePrefix(from: &suggestion, infillPrefix: infillPrefix) removeSuffix(from: &suggestion, suffix: suffix) - return removeTrailingNewlinesAndWhitespace(from: infillPrefix + suggestion) + return infillPrefix + suggestion } func extractSuggestion(from response: String) -> String { let escapedMarkdownCodeBlock = removeLeadingAndTrailingMarkdownCodeBlockMark(from: response) - let escapedTags = extractEnclosingSuggestion( - from: escapedMarkdownCodeBlock, - openingTag: openingCodeTag, - closingTag: closingCodeTag - ) - - return escapedTags + if let tags = codeWrappingTags { + let escapedTags = extractEnclosingSuggestion( + from: escapedMarkdownCodeBlock, + openingTag: tags.opening, + closingTag: tags.closing + ) + return escapedTags + } else { + return escapedMarkdownCodeBlock + } } func removePrefix(from suggestion: inout String, infillPrefix: String) { diff --git a/Core/Sources/SuggestionService/RawSuggestionPostProcessing/NoOpRawSuggestionPostProcessingStrategy.swift b/Core/Sources/SuggestionService/RawSuggestionPostProcessing/NoOpRawSuggestionPostProcessingStrategy.swift index c8b047d..89b7eac 100644 --- a/Core/Sources/SuggestionService/RawSuggestionPostProcessing/NoOpRawSuggestionPostProcessingStrategy.swift +++ b/Core/Sources/SuggestionService/RawSuggestionPostProcessing/NoOpRawSuggestionPostProcessingStrategy.swift @@ -2,7 +2,7 @@ import Foundation struct NoOpRawSuggestionPostProcessingStrategy: RawSuggestionPostProcessingStrategy { func postProcess(rawSuggestion: String, infillPrefix: String, suffix: [String]) -> String { - removeTrailingNewlinesAndWhitespace(from: infillPrefix + rawSuggestion) + infillPrefix + rawSuggestion } } diff --git a/Core/Sources/SuggestionService/RequestStrategies/CodeLlamaFillInTheMiddleRequestStrategy.swift b/Core/Sources/SuggestionService/RequestStrategies/CodeLlamaFillInTheMiddleRequestStrategy.swift index 72bd498..2c7cef0 100644 --- a/Core/Sources/SuggestionService/RequestStrategies/CodeLlamaFillInTheMiddleRequestStrategy.swift +++ b/Core/Sources/SuggestionService/RequestStrategies/CodeLlamaFillInTheMiddleRequestStrategy.swift @@ -1,3 +1,4 @@ +import CodeCompletionService import CopilotForXcodeKit import Foundation import Fundamental @@ -20,8 +21,12 @@ struct CodeLlamaFillInTheMiddleRequestStrategy: RequestStrategy { ) } - func createRawSuggestionPostProcessor() -> NoOpRawSuggestionPostProcessingStrategy { - NoOpRawSuggestionPostProcessingStrategy() + func createStreamStopStrategy() -> some StreamStopStrategy { + FIMStreamStopStrategy(prefix: prefix) + } + + func createRawSuggestionPostProcessor() -> some RawSuggestionPostProcessingStrategy { + DefaultRawSuggestionPostProcessingStrategy(codeWrappingTags: nil) } enum Tag { @@ -92,8 +97,41 @@ struct CodeLlamaFillInTheMiddleWithSystemPromptRequestStrategy: RequestStrategy return prompt } + func createStreamStopStrategy() -> some StreamStopStrategy { + strategy.createStreamStopStrategy() + } + func createRawSuggestionPostProcessor() -> some RawSuggestionPostProcessingStrategy { - DefaultRawSuggestionPostProcessingStrategy(openingCodeTag: "", closingCodeTag: "") + strategy.createRawSuggestionPostProcessor() + } +} + +struct FIMStreamStopStrategy: StreamStopStrategy { + let prefix: [String] + + func shouldStop( + existedLines: [String], + currentLine: String, + proposedLineLimit: Int + ) -> StreamStopStrategyResult { + if let prefixLastLine = prefix.last { + if let lastLineIndex = existedLines.lastIndex(of: prefixLastLine) { + if existedLines.count >= lastLineIndex + 1 + proposedLineLimit { + return .stop(appendingNewContent: true) + } + return .continue + } else { + if existedLines.count >= proposedLineLimit { + return .stop(appendingNewContent: true) + } + return .continue + } + } else { + if existedLines.count >= proposedLineLimit { + return .stop(appendingNewContent: true) + } + return .continue + } } } diff --git a/Core/Sources/SuggestionService/RequestStrategies/ContinueRequestStrategy.swift b/Core/Sources/SuggestionService/RequestStrategies/ContinueRequestStrategy.swift index c136da8..48bc292 100644 --- a/Core/Sources/SuggestionService/RequestStrategies/ContinueRequestStrategy.swift +++ b/Core/Sources/SuggestionService/RequestStrategies/ContinueRequestStrategy.swift @@ -1,3 +1,4 @@ +import CodeCompletionService import CopilotForXcodeKit import Foundation import Fundamental @@ -21,11 +22,18 @@ struct ContinueRequestStrategy: RequestStrategy { suffix: suffix ) } - + func createRawSuggestionPostProcessor() -> DefaultRawSuggestionPostProcessingStrategy { - DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: Tag.openingCode, - closingCodeTag: Tag.closingCode + DefaultRawSuggestionPostProcessingStrategy(codeWrappingTags: ( + Tag.openingCode, + Tag.closingCode + )) + } + + func createStreamStopStrategy() -> some StreamStopStrategy { + OpeningTagBasedStreamStopStrategy( + openingTag: Tag.openingCode, + toleranceIfNoOpeningTagFound: 4 ) } diff --git a/Core/Sources/SuggestionService/RequestStrategies/DefaultRequestStrategy.swift b/Core/Sources/SuggestionService/RequestStrategies/DefaultRequestStrategy.swift index 448bf20..e065e11 100644 --- a/Core/Sources/SuggestionService/RequestStrategies/DefaultRequestStrategy.swift +++ b/Core/Sources/SuggestionService/RequestStrategies/DefaultRequestStrategy.swift @@ -1,3 +1,4 @@ +import CodeCompletionService import CopilotForXcodeKit import Foundation import Fundamental @@ -10,7 +11,7 @@ struct DefaultRequestStrategy: RequestStrategy { var sourceRequest: SuggestionRequest var prefix: [String] var suffix: [String] - + var shouldSkip: Bool { prefix.last?.trimmingCharacters(in: .whitespaces) == "}" } @@ -22,14 +23,21 @@ struct DefaultRequestStrategy: RequestStrategy { suffix: suffix ) } - - func createRawSuggestionPostProcessor() -> DefaultRawSuggestionPostProcessingStrategy { - DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: Tag.openingCode, - closingCodeTag: Tag.closingCode + + func createStreamStopStrategy() -> some StreamStopStrategy { + OpeningTagBasedStreamStopStrategy( + openingTag: Tag.openingCode, + toleranceIfNoOpeningTagFound: 4 ) } + func createRawSuggestionPostProcessor() -> DefaultRawSuggestionPostProcessingStrategy { + DefaultRawSuggestionPostProcessingStrategy(codeWrappingTags: ( + Tag.openingCode, + Tag.closingCode + )) + } + enum Tag { public static let openingCode = "" public static let closingCode = "" @@ -65,7 +73,7 @@ struct DefaultRequestStrategy: RequestStrategy { var relevantCodeSnippets: [RelevantCodeSnippet] { sourceRequest.relevantCodeSnippets } var stopWords: [String] { [Tag.closingCode, "\n\n"] } var language: CodeLanguage? { sourceRequest.language } - + var suggestionPrefix: SuggestionPrefix { guard let prefix = prefix.last else { return .empty } return .unchanged(prefix).curlyBracesLineBreak() @@ -103,7 +111,7 @@ struct DefaultRequestStrategy: RequestStrategy { File Path: \(filePath) Indentation: \ \(sourceRequest.indentSize) \(sourceRequest.usesTabsForIndentation ? "tab" : "space") - + --- Here is the code: diff --git a/Core/Sources/SuggestionService/RequestStrategies/NaiveRequestStrategy.swift b/Core/Sources/SuggestionService/RequestStrategies/NaiveRequestStrategy.swift index 3d21111..4d4f7aa 100644 --- a/Core/Sources/SuggestionService/RequestStrategies/NaiveRequestStrategy.swift +++ b/Core/Sources/SuggestionService/RequestStrategies/NaiveRequestStrategy.swift @@ -1,3 +1,4 @@ +import CodeCompletionService import CopilotForXcodeKit import Foundation import Fundamental @@ -8,7 +9,7 @@ struct NaiveRequestStrategy: RequestStrategy { var sourceRequest: SuggestionRequest var prefix: [String] var suffix: [String] - + var shouldSkip: Bool { prefix.last?.trimmingCharacters(in: .whitespaces) == "}" } @@ -20,11 +21,15 @@ struct NaiveRequestStrategy: RequestStrategy { suffix: suffix ) } - + func createRawSuggestionPostProcessor() -> some RawSuggestionPostProcessingStrategy { NoOpRawSuggestionPostProcessingStrategy() } + func createStreamStopStrategy() -> some StreamStopStrategy { + DefaultStreamStopStrategy() + } + struct Request: PromptStrategy { let systemPrompt: String = "" var sourceRequest: SuggestionRequest @@ -34,7 +39,7 @@ struct NaiveRequestStrategy: RequestStrategy { var relevantCodeSnippets: [RelevantCodeSnippet] { sourceRequest.relevantCodeSnippets } var stopWords: [String] { ["\n\n"] } var language: CodeLanguage? { sourceRequest.language } - + var suggestionPrefix: SuggestionPrefix { guard let prefix = prefix.last else { return .empty } return .unchanged(prefix).curlyBracesLineBreak() @@ -71,9 +76,9 @@ struct NaiveRequestStrategy: RequestStrategy { return [.init(role: .user, content: """ File path: \(filePath) - + --- - + \(code) """.trimmingCharacters(in: .whitespacesAndNewlines))] } diff --git a/Core/Sources/SuggestionService/RequestStrategies/TabbyRequestStrategy.swift b/Core/Sources/SuggestionService/RequestStrategies/TabbyRequestStrategy.swift index 0570779..bd94fe7 100644 --- a/Core/Sources/SuggestionService/RequestStrategies/TabbyRequestStrategy.swift +++ b/Core/Sources/SuggestionService/RequestStrategies/TabbyRequestStrategy.swift @@ -1,3 +1,4 @@ +import CodeCompletionService import CopilotForXcodeKit import Foundation import Fundamental @@ -24,6 +25,10 @@ struct TabbyRequestStrategy: RequestStrategy { NoOpRawSuggestionPostProcessingStrategy() } + func createStreamStopStrategy() -> some StreamStopStrategy { + NeverStreamStopStrategy() + } + struct Prompt: PromptStrategy { let systemPrompt: String = "" var sourceRequest: SuggestionRequest diff --git a/Core/Sources/SuggestionService/RequestStrategy.swift b/Core/Sources/SuggestionService/RequestStrategy.swift index 5f6060a..4256626 100644 --- a/Core/Sources/SuggestionService/RequestStrategy.swift +++ b/Core/Sources/SuggestionService/RequestStrategy.swift @@ -1,3 +1,4 @@ +import CodeCompletionService import CopilotForXcodeKit import Foundation import Fundamental @@ -8,6 +9,7 @@ import Parsing protocol RequestStrategy { associatedtype Prompt: PromptStrategy associatedtype RawSuggestionPostProcessor: RawSuggestionPostProcessingStrategy + associatedtype SomeStreamStopStrategy: StreamStopStrategy init(sourceRequest: SuggestionRequest, prefix: [String], suffix: [String]) @@ -17,6 +19,9 @@ protocol RequestStrategy { /// Create a prompt to generate code completion. func createPrompt() -> Prompt + /// Control how a stream should stop early. + func createStreamStopStrategy() -> SomeStreamStopStrategy + /// The AI model may not return a suggestion in a ideal format. You can use it to reformat the /// suggestions. func createRawSuggestionPostProcessor() -> RawSuggestionPostProcessor diff --git a/Core/Sources/SuggestionService/Service.swift b/Core/Sources/SuggestionService/Service.swift index 9e37ae0..dab853b 100644 --- a/Core/Sources/SuggestionService/Service.swift +++ b/Core/Sources/SuggestionService/Service.swift @@ -46,6 +46,7 @@ actor Service { let prompt = strategy.createPrompt() let postProcessor = strategy.createRawSuggestionPostProcessor() + let stopStream = strategy.createStreamStopStrategy() let suggestedCodeSnippets: [String] @@ -54,6 +55,7 @@ actor Service { CodeCompletionLogger.logger.logModel(model) suggestedCodeSnippets = try await service.getCompletions( prompt, + streamStopStrategy: stopStream, model: model, count: 1 ) @@ -61,6 +63,7 @@ actor Service { CodeCompletionLogger.logger.logModel(model) suggestedCodeSnippets = try await service.getCompletions( prompt, + streamStopStrategy: stopStream, model: model, count: 1 ) @@ -68,6 +71,7 @@ actor Service { CodeCompletionLogger.logger.logModel(model) suggestedCodeSnippets = try await service.getCompletions( prompt, + streamStopStrategy: stopStream, model: model, count: 1 ) @@ -78,13 +82,21 @@ actor Service { return suggestedCodeSnippets .filter { !$0.allSatisfy { $0.isWhitespace || $0.isNewline } } .map { - CodeSuggestion( - id: UUID().uuidString, - text: postProcessor.postProcess( + let suggestionText = postProcessor + .postProcess( rawSuggestion: $0, infillPrefix: prompt.suggestionPrefix.prependingValue, suffix: prompt.suffix - ), + ) + .keepLines( + count: UserDefaults.shared + .value(for: \.maxNumberOfLinesOfSuggestion) + ) + .removeTrailingNewlinesAndWhitespace() + + return CodeSuggestion( + id: UUID().uuidString, + text: suggestionText, position: request.cursorPosition, range: .init( start: .init( @@ -182,3 +194,19 @@ actor Service { } } +extension String { + func keepLines(count: Int) -> String { + if count <= 0 { return self } + let lines = breakLines() + return lines.prefix(count).joined() + } + + func removeTrailingNewlinesAndWhitespace() -> String { + var text = self[...] + while let last = text.last, last.isNewline || last.isWhitespace { + text = text.dropLast(1) + } + return String(text) + } +} + diff --git a/Core/Tests/CodeCompletionServiceTests/OpeningTagBasedStreamStopStrategyTests.swift b/Core/Tests/CodeCompletionServiceTests/OpeningTagBasedStreamStopStrategyTests.swift new file mode 100644 index 0000000..910bec3 --- /dev/null +++ b/Core/Tests/CodeCompletionServiceTests/OpeningTagBasedStreamStopStrategyTests.swift @@ -0,0 +1,106 @@ +import Foundation +import XCTest + +@testable import CodeCompletionService + +class OpeningTagBasedStreamStopStrategyTests: XCTestCase { + func test_no_opening_tag_found_and_not_hitting_limit() { + let strategy = OpeningTagBasedStreamStopStrategy( + openingTag: "", + toleranceIfNoOpeningTagFound: 3 + ) + let limiter = StreamLineLimiter(lineLimit: 1, strategy: strategy) + let content = """ + Hello World + My Friend + """ + for character in content { + let result = limiter.push(String(character)) + XCTAssertEqual(result, .continue) + } + XCTAssertEqual(limiter.result, content) + } + + func test_no_opening_tag_found_hitting_limit() { + let strategy = OpeningTagBasedStreamStopStrategy( + openingTag: "", + toleranceIfNoOpeningTagFound: 3 + ) + let limiter = StreamLineLimiter(lineLimit: 1, strategy: strategy) + let content = """ + Hello World + My Friend + How Are You + I Am Fine + Thank You + """ + + let expected = """ + Hello World + My Friend + How Are You + I Am Fine + + """ + + for character in content { + let result = limiter.push(String(character)) + if result == .finish(expected) { + XCTAssertEqual(limiter.result, expected) + return + } + } + XCTFail("Should return in the loop\n\n\(limiter.result)") + } + + func test_opening_tag_found_not_hitting_limit() { + let strategy = OpeningTagBasedStreamStopStrategy( + openingTag: "", + toleranceIfNoOpeningTagFound: 3 + ) + let limiter = StreamLineLimiter(lineLimit: 2, strategy: strategy) + let content = """ + Hello World + + How Are You + """ + for character in content { + let result = limiter.push(String(character)) + XCTAssertEqual(result, .continue) + } + XCTAssertEqual(limiter.result, content) + } + + func test_opening_tag_found_hitting_limit() { + let strategy = OpeningTagBasedStreamStopStrategy( + openingTag: "", + toleranceIfNoOpeningTagFound: 3 + ) + let limiter = StreamLineLimiter(lineLimit: 2, strategy: strategy) + let content = """ + Hello World + + How Are You + I Am Fine + Thank You + """ + + let expected = """ + Hello World + + How Are You + I Am Fine + + """ + + for character in content { + let result = limiter.push(String(character)) + if result == .finish(expected) { + XCTAssertEqual(limiter.result, expected) + return + } + } + XCTFail("Should return in the loop\n\n\(limiter.result)") + } +} + diff --git a/Core/Tests/CodeCompletionServiceTests/StreamLineLimiterTests.swift b/Core/Tests/CodeCompletionServiceTests/StreamLineLimiterTests.swift new file mode 100644 index 0000000..2153781 --- /dev/null +++ b/Core/Tests/CodeCompletionServiceTests/StreamLineLimiterTests.swift @@ -0,0 +1,69 @@ +import Foundation +import XCTest + +@testable import CodeCompletionService + +class StreamLineLimiterTests: XCTestCase { + func test_pushing_characters_without_hitting_limit() { + let limiter = StreamLineLimiter(lineLimit: 2, strategy: DefaultStreamStopStrategy()) + let content = "hello world\n" + for character in content { + let result = limiter.push(String(character)) + XCTAssertEqual(result, .continue) + } + XCTAssertEqual(limiter.result, content) + } + + func test_pushing_characters_hitting_limit() { + let limiter = StreamLineLimiter(lineLimit: 2, strategy: DefaultStreamStopStrategy()) + let content = "hello world\nhello world\nhello world" + for character in content { + let result = limiter.push(String(character)) + if result == .finish("hello world\nhello world\n") { + XCTAssertEqual(limiter.result, "hello world\nhello world\n") + return + } + } + XCTFail("Should return in the loop\n\(limiter.result)") + } + + func test_pushing_characters_with_early_exit_strategy() { + struct Strategy: StreamStopStrategy { + func shouldStop( + existedLines: [String], + currentLine: String, + proposedLineLimit: Int + ) -> StreamStopStrategyResult { + let hasPrefixP = currentLine.hasPrefix("p") + let hasNewLine = existedLines.first?.hasSuffix("\n") ?? false + if hasPrefixP && hasNewLine { + return .stop(appendingNewContent: false) + } + return .continue + } + } + + let limiter = StreamLineLimiter(lineLimit: 10, strategy: Strategy()) + let content = "hello world\npikachu\n" + for character in content { + let result = limiter.push(String(character)) + if result == .finish("hello world\n") { + XCTAssertEqual(limiter.result, "hello world\n") + return + } + } + XCTFail("Should return in the loop\n\(limiter.result)") + } + + func test_receiving_multiple_line_ending_as_a_single_token() { + let limiter = StreamLineLimiter(lineLimit: 4, strategy: DefaultStreamStopStrategy()) + let content = "hello world" + for character in content { + let result = limiter.push(String(character)) + XCTAssertEqual(result, .continue) + } + XCTAssertEqual(limiter.push("\n\n\n"), .continue) + XCTAssertEqual(limiter.push("\n"), .finish("hello world\n\n\n\n")) + } +} + diff --git a/Core/Tests/SuggestionServiceTests/DefaultRawSuggestionPostProcessingStrategyTests.swift b/Core/Tests/SuggestionServiceTests/DefaultRawSuggestionPostProcessingStrategyTests.swift index f8e92d2..a615402 100644 --- a/Core/Tests/SuggestionServiceTests/DefaultRawSuggestionPostProcessingStrategyTests.swift +++ b/Core/Tests/SuggestionServiceTests/DefaultRawSuggestionPostProcessingStrategyTests.swift @@ -6,8 +6,7 @@ import XCTest class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenSuggestionHasCodeTagAtTheFirstLine_shouldExtractCodeInside() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: """ @@ -21,8 +20,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenSuggestionHasCodeTagAtTheFirstLine_closingTagInOtherLines_shouldExtractCodeInside( ) { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: """ @@ -36,8 +34,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenSuggestionHasCodeTag_butNoClosingTag_shouldExtractCodeAfterTheTag() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: """ @@ -51,8 +48,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenMultipleOpeningTagFound_shouldTreatTheNextOneAsClosing() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: """ @@ -64,8 +60,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenMarkdownCodeBlockFound_shouldExtractCodeInside() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: """ @@ -80,8 +75,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenOnlyLinebreaksOrSpacesBeforeMarkdownCodeBlock_shouldExtractCodeInside() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: """ @@ -120,8 +114,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenMarkdownCodeBlockAndCodeTagFound_firstlyExtractCodeTag_thenCodeTag() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: """ @@ -137,8 +130,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenMarkdownCodeBlockAndCodeTagFound_butNoClosingTag_firstlyExtractCodeTag_thenCodeTag( ) { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: """ @@ -153,8 +145,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenSuggestionHasTheSamePrefix_removeThePrefix() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: "suggestion" @@ -165,8 +156,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenSuggestionLooksLikeAMessage_parseItCorrectly() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.extractSuggestion( from: """ @@ -182,8 +172,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenSuggestionHasTheSamePrefix_inTags_removeThePrefix() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) var suggestion = "prefix suggestion" strategy.removePrefix(from: &suggestion, infillPrefix: "prefix") @@ -193,8 +182,7 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { func test_whenSuggestionHasTheSameSuffix_removeTheSuffix() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) var suggestion = "suggestion\na\nb" strategy.removeSuffix(from: &suggestion, suffix: [ @@ -214,11 +202,10 @@ class DefaultRawSuggestionPostProcessingStrategyTests: XCTestCase { XCTAssertEqual(suggestion3, "suggestion\na\n") } - + func test_case_1() { let strategy = DefaultRawSuggestionPostProcessingStrategy( - openingCodeTag: "", - closingCodeTag: "" + codeWrappingTags: ("", "") ) let result = strategy.postProcess( rawSuggestion: """ diff --git a/CustomSuggestionService/ContentView.swift b/CustomSuggestionService/ContentView.swift index fa1fe18..38afa4b 100644 --- a/CustomSuggestionService/ContentView.swift +++ b/CustomSuggestionService/ContentView.swift @@ -14,6 +14,7 @@ struct ContentView: View { @AppStorage(\.chatModelId) var chatModelId @AppStorage(\.installBetaBuild) var installBetaBuild @AppStorage(\.verboseLog) var verboseLog + @AppStorage(\.maxNumberOfLinesOfSuggestion) var maxNumberOfLinesOfSuggestion } @StateObject var settings = Settings() @@ -26,16 +27,26 @@ struct ContentView: View { WithPerceptionTracking { VStack { Form { - HStack { - ExistedChatModelPicker() - if CustomModelType(rawValue: settings.chatModelId) != nil { - Button("Edit Model") { - isEditingCustomModel = true + Section { + HStack { + ExistedChatModelPicker() + if CustomModelType(rawValue: settings.chatModelId) != nil { + Button("Edit Model") { + isEditingCustomModel = true + } } } - } - RequestStrategyPicker() + RequestStrategyPicker() + + NumberInput( + value: settings.$maxNumberOfLinesOfSuggestion, + range: 0...Int.max, + step: 1 + ) { + Text("Suggestion Line Limit (0 for unlimited)") + } + } Section { HStack { @@ -192,6 +203,47 @@ struct RequestStrategyPicker: View { } } +struct NumberInput: View { + @Binding var value: V + let formatter = NumberFormatter() + let range: ClosedRange + let step: V.Stride + @ViewBuilder var label: () -> Label + + var body: some View { + TextField(value: .init(get: { + if value > range.upperBound { + return range.upperBound + } else if value < range.lowerBound { + return range.lowerBound + } else { + return value + } + }, set: { newValue in + if newValue > range.upperBound { + value = range.upperBound + } else if newValue < range.lowerBound { + value = range.lowerBound + } else { + value = newValue + } + }), formatter: formatter, prompt: nil) { + label() + } + .padding(.trailing) + .overlay(alignment: .trailing) { + Stepper( + value: $value, + in: range, + step: step + ) { + EmptyView() + } + } + .padding(.trailing, 4) + } +} + #Preview { ContentView() .frame(width: 800, height: 1000) diff --git a/TestPlan.xctestplan b/TestPlan.xctestplan index 70bfd2e..5b2e13d 100644 --- a/TestPlan.xctestplan +++ b/TestPlan.xctestplan @@ -15,8 +15,8 @@ { "target" : { "containerPath" : "container:Core", - "identifier" : "SuggestionServiceTests", - "name" : "SuggestionServiceTests" + "identifier" : "CodeCompletionServiceTests", + "name" : "CodeCompletionServiceTests" } }, { @@ -25,6 +25,13 @@ "identifier" : "FundamentalTests", "name" : "FundamentalTests" } + }, + { + "target" : { + "containerPath" : "container:Core", + "identifier" : "SuggestionServiceTests", + "name" : "SuggestionServiceTests" + } } ], "version" : 1 diff --git a/Version.xcconfig b/Version.xcconfig index 166c909..917fd82 100644 --- a/Version.xcconfig +++ b/Version.xcconfig @@ -1,3 +1,3 @@ -APP_VERSION = 0.2.0 -APP_BUILD = 20 +APP_VERSION = 0.3.0 +APP_BUILD = 30 diff --git a/appcast.xml b/appcast.xml index aff937d..8f3f886 100644 --- a/appcast.xml +++ b/appcast.xml @@ -2,6 +2,18 @@ Custom Suggestion Service + + 0.3.0 + Tue, 16 Apr 2024 00:16:25 +0800 + 30 + 0.3.0 + 13.0 + + https://github.com/intitni/CustomSuggestionServiceForCopilotForXcode/releases/tag/0.3.0 + + + + 0.2.0 Mon, 04 Mar 2024 12:50:34 +0800