Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument for extra EOS token to llm-tool #217

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Libraries/MLXLMCommon/ModelConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ public struct ModelConfiguration: Sendable {
public let overrideTokenizer: String?

/// A reasonable default prompt for the model
public let defaultPrompt: String
public var defaultPrompt: String

/// Additional tokens to use for end of string
public let extraEOSTokens: Set<String>
public var extraEOSTokens: Set<String>

public init(
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
Expand Down
10 changes: 7 additions & 3 deletions Libraries/MLXLMCommon/ModelContainer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ import Tokenizers
/// }
/// ```
public actor ModelContainer {
let context: ModelContext
nonisolated public let configuration: ModelConfiguration
var context: ModelContext
public var configuration: ModelConfiguration { context.configuration }

public init(context: ModelContext) {
self.context = context
self.configuration = context.configuration
}

/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
Expand Down Expand Up @@ -75,4 +74,9 @@ public actor ModelContainer {
try await action(context, values)
}

/// Update the owned `ModelContext`.
/// - Parameter action: update action
public func update(_ action: @Sendable (inout ModelContext) -> Void) {
action(&context)
}
}
8 changes: 4 additions & 4 deletions Libraries/MLXLMCommon/ModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ public enum ModelFactoryError: Error {
/// See also ``ModelFactory/loadContainer(hub:configuration:progressHandler:)`` and
/// ``ModelContainer``.
public struct ModelContext {
public let configuration: ModelConfiguration
public let model: any LanguageModel
public let processor: any UserInputProcessor
public let tokenizer: Tokenizer
public var configuration: ModelConfiguration
public var model: any LanguageModel
public var processor: any UserInputProcessor
public var tokenizer: Tokenizer

public init(
configuration: ModelConfiguration, model: any LanguageModel,
Expand Down
27 changes: 21 additions & 6 deletions Tools/llm-tool/LLMTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ struct ModelArguments: ParsableArguments, Sendable {

let modelName = self.model ?? defaultModel

print("Loading \(modelName)...")

if modelName.hasPrefix("/") {
// path
modelConfiguration = ModelConfiguration(directory: URL(filePath: modelName))
Expand Down Expand Up @@ -67,6 +69,9 @@ struct GenerateArguments: ParsableArguments, Sendable {
@Option(name: .long, help: "The number of tokens to consider for repetition penalty")
var repetitionContextSize: Int = 20

@Option(name: .long, help: "Additional end-of-sequence token to stop generation")
var extraEosToken: String?

@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0

Expand All @@ -89,17 +94,22 @@ struct GenerateArguments: ParsableArguments, Sendable {
}
}

func prepare(
_ context: inout ModelContext
) {
if let extraEosToken {
context.configuration.extraEOSTokens.insert(extraEosToken)
}
}

func generate(
input: LMInput, context: ModelContext
)
throws -> GenerateResult
{
) throws -> GenerateResult {
var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer)

return try MLXLMCommon.generate(
input: input, parameters: generateParameters, context: context
) { tokens in

if let last = tokens.last {
detokenizer.append(token: last)
}
Expand Down Expand Up @@ -276,11 +286,16 @@ struct EvaluateCommand: AsyncParsableCommand {
try await args.load(defaultModel: defaultModel.name, modelFactory: modelFactory)
}

// update the context/configuration with any command line parameters
await modelContainer.update { [generate] context in
generate.prepare(&context)
}

// Get the resolved configuration (this has the default prompt)
let modelConfiguration = modelContainer.configuration
let modelConfiguration = await modelContainer.configuration

if !generate.quiet {
print("Model loaded -> \(modelConfiguration.id)")
print("Loaded \(modelConfiguration.name)")
}

let userInput = self.userInput(modelConfiguration: modelConfiguration)
Expand Down
9 changes: 5 additions & 4 deletions Tools/llm-tool/LoraCommands.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct LoRAModelArguments: ParsableArguments, Sendable {
// convert some of the Linear layers to LoRALinear
await modelContainer.perform { context in
guard let lora = context.model as? LoRAModel else {
fatalError("Model \(modelContainer.configuration.name) is not a LoRAModel")
fatalError("Model \(context.configuration.name) is not a LoRAModel")
}
LoRATrain.convert(model: context.model, layers: lora.loraLinearLayers(loraLayers))
}
Expand Down Expand Up @@ -197,7 +197,7 @@ struct LoRAFuseCommand: AsyncParsableCommand {
// fuse them back into Linear/QuantizedLinear
await modelContainer.perform { [args, deQuantize] context in
guard let lora = context.model as? LoRAModel else {
fatalError("Model \(modelContainer.configuration.name) is not a LoRAModel")
fatalError("Model \(context.configuration.name) is not a LoRAModel")
}

LoRATrain.fuse(
Expand All @@ -207,7 +207,7 @@ struct LoRAFuseCommand: AsyncParsableCommand {

// make the new directory and copy files from source model
try FileManager.default.createDirectory(at: outputURL, withIntermediateDirectories: true)
let inputURL = modelContainer.configuration.modelDirectory()
let inputURL = await modelContainer.configuration.modelDirectory()
let enumerator = FileManager.default.enumerator(
at: inputURL, includingPropertiesForKeys: nil)!
for case let url as URL in enumerator {
Expand Down Expand Up @@ -296,7 +296,8 @@ struct LoRAEvalCommand: AsyncParsableCommand {

memory.start()

let prompt = generate.prompt ?? modelContainer.configuration.defaultPrompt
let defaultPrompt = await modelContainer.configuration.defaultPrompt
let prompt = generate.prompt ?? defaultPrompt

if !generate.quiet {
print("Starting generation ...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--prompt &apos;Describe the image in English.&apos; --image https://www.gstatic.com/webp/gallery/1.webp"
argument = "--model microsoft/Phi-4-mini-instruct --prompt &quot;Why is the sky blue?&quot; --extra-eos-token &quot;&lt;|end|&gt;&quot;"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--model mlx-community/Qwen2-VL-2B-Instruct-4bit"
isEnabled = "NO">
argument = "--model mlx-community/Qwen2-VL-2B-Instruct-4bit --prompt &apos;Describe the image in English.&apos; --image https://www.gstatic.com/webp/gallery/1.webp"
isEnabled = "YES">
</CommandLineArgument>
<CommandLineArgument
argument = "--repetition-penalty 1.2"
Expand Down Expand Up @@ -89,15 +89,15 @@
</CommandLineArgument>
<CommandLineArgument
argument = "--prompt &apos;Why is the sky blue?&apos;"
isEnabled = "YES">
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--model mlx-community/Llama-3.2-1B-Instruct-4bit"
isEnabled = "YES">
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--model mlx-community/phi-2-hf-4bit-mlx"
Expand Down