From a634fde471478b2040c4a111e0a8ecaff2744cfb Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Thu, 27 Feb 2025 10:27:08 +0100 Subject: [PATCH 1/2] Improve llm-tool --- Tools/llm-tool/LLMTool.swift | 50 ++++++++++++++++--- .../xcshareddata/xcschemes/llm-tool.xcscheme | 10 ++-- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index 13b470e..f6acf33 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -31,6 +31,8 @@ struct ModelArguments: ParsableArguments, Sendable { let modelName = self.model ?? defaultModel + print("Loading \(modelName)...") + if modelName.hasPrefix("/") { // path modelConfiguration = ModelConfiguration(directory: URL(filePath: modelName)) @@ -67,6 +69,9 @@ struct GenerateArguments: ParsableArguments, Sendable { @Option(name: .long, help: "The number of tokens to consider for repetition penalty") var repetitionContextSize: Int = 20 + @Option(name: .long, help: "Additional end-of-sequence token to stop generation") + var extraEosToken: String? + @Option(name: .long, help: "The PRNG seed") var seed: UInt64 = 0 @@ -91,15 +96,48 @@ struct GenerateArguments: ParsableArguments, Sendable { func generate( input: LMInput, context: ModelContext - ) - throws -> GenerateResult - { + ) throws -> GenerateResult { var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer) + // If an extra EOS token is provided, create a new context with the updated configuration + let contextToUse: ModelContext + if let extraToken = extraEosToken { + // Create a new configuration with the extra EOS token + var extraTokens = context.configuration.extraEOSTokens + extraTokens.insert(extraToken) + // Create a new configuration based on the existing one + let newConfig: ModelConfiguration + switch context.configuration.id { + case .id(let id): + newConfig = ModelConfiguration( + id: id, + tokenizerId: context.configuration.tokenizerId, + overrideTokenizer: context.configuration.overrideTokenizer, + defaultPrompt: context.configuration.defaultPrompt, + extraEOSTokens: extraTokens + ) + case .directory(let url): + newConfig = ModelConfiguration( + directory: url, + tokenizerId: context.configuration.tokenizerId, + overrideTokenizer: context.configuration.overrideTokenizer, + defaultPrompt: context.configuration.defaultPrompt, + extraEOSTokens: extraTokens + ) + } + // Create a new context with the updated configuration + contextToUse = ModelContext( + configuration: newConfig, + model: context.model, + processor: context.processor, + tokenizer: context.tokenizer + ) + } else { + contextToUse = context + } return try MLXLMCommon.generate( - input: input, parameters: generateParameters, context: context + input: input, parameters: generateParameters, context: contextToUse ) { tokens in - if let last = tokens.last { detokenizer.append(token: last) } @@ -280,7 +318,7 @@ struct EvaluateCommand: AsyncParsableCommand { let modelConfiguration = modelContainer.configuration if !generate.quiet { - print("Model loaded -> \(modelConfiguration.id)") + print("Loaded \(modelConfiguration.name)") } let userInput = self.userInput(modelConfiguration: modelConfiguration) diff --git a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme index eb1ceff..095ac3e 100644 --- a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme +++ b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme @@ -56,12 +56,12 @@ isEnabled = "NO"> + argument = "--model mlx-community/Qwen2-VL-2B-Instruct-4bit --prompt 'Describe the image in English.' --image https://www.gstatic.com/webp/gallery/1.webp" + isEnabled = "YES"> + isEnabled = "NO"> + isEnabled = "NO"> Date: Tue, 4 Mar 2025 11:27:50 -0800 Subject: [PATCH 2/2] - make ModelConfiguration and ModelContext properties mutable - update context/configuration with extra EOS tokens --- .../MLXLMCommon/ModelConfiguration.swift | 4 +- Libraries/MLXLMCommon/ModelContainer.swift | 10 ++-- Libraries/MLXLMCommon/ModelFactory.swift | 8 +-- Tools/llm-tool/LLMTool.swift | 53 ++++++------------- Tools/llm-tool/LoraCommands.swift | 9 ++-- 5 files changed, 33 insertions(+), 51 deletions(-) diff --git a/Libraries/MLXLMCommon/ModelConfiguration.swift b/Libraries/MLXLMCommon/ModelConfiguration.swift index bad4c8f..bd60893 100644 --- a/Libraries/MLXLMCommon/ModelConfiguration.swift +++ b/Libraries/MLXLMCommon/ModelConfiguration.swift @@ -31,10 +31,10 @@ public struct ModelConfiguration: Sendable { public let overrideTokenizer: String? /// A reasonable default prompt for the model - public let defaultPrompt: String + public var defaultPrompt: String /// Additional tokens to use for end of string - public let extraEOSTokens: Set + public var extraEOSTokens: Set public init( id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil, diff --git a/Libraries/MLXLMCommon/ModelContainer.swift b/Libraries/MLXLMCommon/ModelContainer.swift index b969378..5b4e462 100644 --- a/Libraries/MLXLMCommon/ModelContainer.swift +++ b/Libraries/MLXLMCommon/ModelContainer.swift @@ -32,12 +32,11 @@ import Tokenizers /// } /// ``` public actor ModelContainer { - let context: ModelContext - nonisolated public let configuration: ModelConfiguration + var context: ModelContext + public var configuration: ModelConfiguration { context.configuration } public init(context: ModelContext) { self.context = context - self.configuration = context.configuration } /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as @@ -75,4 +74,9 @@ public actor ModelContainer { try await action(context, values) } + /// Update the owned `ModelContext`. + /// - Parameter action: update action + public func update(_ action: @Sendable (inout ModelContext) -> Void) { + action(&context) + } } diff --git a/Libraries/MLXLMCommon/ModelFactory.swift b/Libraries/MLXLMCommon/ModelFactory.swift index 2035b7e..fcc5bbf 100644 --- a/Libraries/MLXLMCommon/ModelFactory.swift +++ b/Libraries/MLXLMCommon/ModelFactory.swift @@ -22,10 +22,10 @@ public enum ModelFactoryError: Error { /// See also ``ModelFactory/loadContainer(hub:configuration:progressHandler:)`` and /// ``ModelContainer``. public struct ModelContext { - public let configuration: ModelConfiguration - public let model: any LanguageModel - public let processor: any UserInputProcessor - public let tokenizer: Tokenizer + public var configuration: ModelConfiguration + public var model: any LanguageModel + public var processor: any UserInputProcessor + public var tokenizer: Tokenizer public init( configuration: ModelConfiguration, model: any LanguageModel, diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index f6acf33..30fe1fd 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -94,49 +94,21 @@ struct GenerateArguments: ParsableArguments, Sendable { } } + func prepare( + _ context: inout ModelContext + ) { + if let extraEosToken { + context.configuration.extraEOSTokens.insert(extraEosToken) + } + } + func generate( input: LMInput, context: ModelContext ) throws -> GenerateResult { var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer) - // If an extra EOS token is provided, create a new context with the updated configuration - let contextToUse: ModelContext - if let extraToken = extraEosToken { - // Create a new configuration with the extra EOS token - var extraTokens = context.configuration.extraEOSTokens - extraTokens.insert(extraToken) - // Create a new configuration based on the existing one - let newConfig: ModelConfiguration - switch context.configuration.id { - case .id(let id): - newConfig = ModelConfiguration( - id: id, - tokenizerId: context.configuration.tokenizerId, - overrideTokenizer: context.configuration.overrideTokenizer, - defaultPrompt: context.configuration.defaultPrompt, - extraEOSTokens: extraTokens - ) - case .directory(let url): - newConfig = ModelConfiguration( - directory: url, - tokenizerId: context.configuration.tokenizerId, - overrideTokenizer: context.configuration.overrideTokenizer, - defaultPrompt: context.configuration.defaultPrompt, - extraEOSTokens: extraTokens - ) - } - // Create a new context with the updated configuration - contextToUse = ModelContext( - configuration: newConfig, - model: context.model, - processor: context.processor, - tokenizer: context.tokenizer - ) - } else { - contextToUse = context - } return try MLXLMCommon.generate( - input: input, parameters: generateParameters, context: contextToUse + input: input, parameters: generateParameters, context: context ) { tokens in if let last = tokens.last { detokenizer.append(token: last) @@ -314,8 +286,13 @@ struct EvaluateCommand: AsyncParsableCommand { try await args.load(defaultModel: defaultModel.name, modelFactory: modelFactory) } + // update the context/configuration with any command line parameters + await modelContainer.update { [generate] context in + generate.prepare(&context) + } + // Get the resolved configuration (this has the default prompt) - let modelConfiguration = modelContainer.configuration + let modelConfiguration = await modelContainer.configuration if !generate.quiet { print("Loaded \(modelConfiguration.name)") diff --git a/Tools/llm-tool/LoraCommands.swift b/Tools/llm-tool/LoraCommands.swift index 4e47604..60c1621 100644 --- a/Tools/llm-tool/LoraCommands.swift +++ b/Tools/llm-tool/LoraCommands.swift @@ -48,7 +48,7 @@ struct LoRAModelArguments: ParsableArguments, Sendable { // convert some of the Linear layers to LoRALinear await modelContainer.perform { context in guard let lora = context.model as? LoRAModel else { - fatalError("Model \(modelContainer.configuration.name) is not a LoRAModel") + fatalError("Model \(context.configuration.name) is not a LoRAModel") } LoRATrain.convert(model: context.model, layers: lora.loraLinearLayers(loraLayers)) } @@ -197,7 +197,7 @@ struct LoRAFuseCommand: AsyncParsableCommand { // fuse them back into Linear/QuantizedLinear await modelContainer.perform { [args, deQuantize] context in guard let lora = context.model as? LoRAModel else { - fatalError("Model \(modelContainer.configuration.name) is not a LoRAModel") + fatalError("Model \(context.configuration.name) is not a LoRAModel") } LoRATrain.fuse( @@ -207,7 +207,7 @@ struct LoRAFuseCommand: AsyncParsableCommand { // make the new directory and copy files from source model try FileManager.default.createDirectory(at: outputURL, withIntermediateDirectories: true) - let inputURL = modelContainer.configuration.modelDirectory() + let inputURL = await modelContainer.configuration.modelDirectory() let enumerator = FileManager.default.enumerator( at: inputURL, includingPropertiesForKeys: nil)! for case let url as URL in enumerator { @@ -296,7 +296,8 @@ struct LoRAEvalCommand: AsyncParsableCommand { memory.start() - let prompt = generate.prompt ?? modelContainer.configuration.defaultPrompt + let defaultPrompt = await modelContainer.configuration.defaultPrompt + let prompt = generate.prompt ?? defaultPrompt if !generate.quiet { print("Starting generation ...")