Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LLMModelFactory and VLMModelFactory inits public #226

Merged
merged 8 commits into from
Mar 8, 2025
74 changes: 55 additions & 19 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,42 @@ private func create<C: Codable, M>(
/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
public class ModelTypeRegistry: @unchecked Sendable {

/// Creates an empty registry.
public init() {
self.creators = [:]
}

/// Creates a registry with given creators.
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
self.creators = creators
}

/// Shared instance with default model types.
public static let shared: ModelTypeRegistry = .init(creators: all())

/// All predefined model types.
private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] {
[
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
"llama": create(LlamaConfiguration.self, LlamaModel.init),
"phi": create(PhiConfiguration.self, PhiModel.init),
"phi3": create(Phi3Configuration.self, Phi3Model.init),
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
"cohere": create(CohereConfiguration.self, CohereModel.init),
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
]
}

// Note: using NSLock as we have very small (just dictionary get/set)
// critical sections and expect no contention. this allows the methods
// to remain synchronous.
private let lock = NSLock()

private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
"llama": create(LlamaConfiguration.self, LlamaModel.init),
"phi": create(PhiConfiguration.self, PhiModel.init),
"phi3": create(Phi3Configuration.self, Phi3Model.init),
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
"cohere": create(CohereConfiguration.self, CohereModel.init),
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
]
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]

/// Add a new model to the type registry.
public func registerModelType(
Expand Down Expand Up @@ -72,8 +89,21 @@ public class ModelTypeRegistry: @unchecked Sendable {
/// implementation, if needed.
public class ModelRegistry: @unchecked Sendable {

/// Creates an empty registry.
public init() {
self.registry = Dictionary()
}

/// Creates a new registry with from given model configurations.
public init(modelConfigurations: [ModelConfiguration]) {
self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
}

/// Shared instance with default model configurations.
public static let shared = ModelRegistry(modelConfigurations: all())

private let lock = NSLock()
private var registry = Dictionary(uniqueKeysWithValues: all().map { ($0.name, $0) })
private var registry: [String: ModelConfiguration]

static public let smolLM_135M_4bit = ModelConfiguration(
id: "mlx-community/SmolLM-135M-Instruct-4bit",
Expand Down Expand Up @@ -274,13 +304,19 @@ private struct LLMUserInputProcessor: UserInputProcessor {
/// ```
public class LLMModelFactory: ModelFactory {

public static let shared = LLMModelFactory()
public init(typeRegistry: ModelTypeRegistry, modelRegistry: ModelRegistry) {
self.typeRegistry = typeRegistry
self.modelRegistry = modelRegistry
}

/// Shared instance with default behavior.
public static let shared = LLMModelFactory(typeRegistry: .shared, modelRegistry: .shared)

/// registry of model type, e.g. configuration value `llama` -> configuration and init methods
public let typeRegistry = ModelTypeRegistry()
public let typeRegistry: ModelTypeRegistry

/// registry of model id to configuration, e.g. `mlx-community/Llama-3.2-3B-Instruct-4bit`
public let modelRegistry = ModelRegistry()
public let modelRegistry: ModelRegistry

public func configuration(id: String) -> ModelConfiguration {
modelRegistry.configuration(id: id)
Expand Down
100 changes: 79 additions & 21 deletions Libraries/MLXVLM/VLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,34 @@ private func create<C: Codable, P>(
/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
public class ModelTypeRegistry: @unchecked Sendable {

/// Creates an empty registry.
public init() {
self.creators = [:]
}

/// Creates a registry with given creators.
public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
self.creators = creators
}

/// Shared instance with default model types.
public static let shared: ModelTypeRegistry = .init(creators: all())

/// All predefined model types
private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] {
[
"paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init),
"qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init),
"idefics3": create(Idefics3Configuration.self, Idefics3.init),
]
}

// Note: using NSLock as we have very small (just dictionary get/set)
// critical sections and expect no contention. this allows the methods
// to remain synchronous.
private let lock = NSLock()

private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [
"paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init),
"qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init),
"idefics3": create(Idefics3Configuration.self, Idefics3.init),
]
private var creators: [String: @Sendable (URL) throws -> any LanguageModel]

/// Add a new model to the type registry.
public func registerModelType(
Expand Down Expand Up @@ -90,20 +108,39 @@ public class ModelTypeRegistry: @unchecked Sendable {

public class ProcessorTypeRegistry: @unchecked Sendable {

// Note: using NSLock as we have very small (just dictionary get/set)
// critical sections and expect no contention. this allows the methods
// to remain synchronous.
private let lock = NSLock()
/// Creates an empty registry.
public init() {
self.creators = [:]
}

/// Creates a registry with given creators.
public init(creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor])
{
self.creators = creators
}

private var creators:
[String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor] = [
/// Shared instance with default processor types.
public static let shared: ProcessorTypeRegistry = .init(creators: all())

/// All predefined processor types.
private static func all() -> [String: @Sendable (URL, any Tokenizer) throws ->
any UserInputProcessor]
{
[
"PaliGemmaProcessor": create(
PaliGemmaProcessorConfiguration.self, PaligGemmaProcessor.init),
"Qwen2VLProcessor": create(
Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init),
"Qwen2VLProcessor": create(Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init),
"Idefics3Processor": create(
Idefics3ProcessorConfiguration.self, Idefics3Processor.init),
]
}

// Note: using NSLock as we have very small (just dictionary get/set)
// critical sections and expect no contention. this allows the methods
// to remain synchronous.
private let lock = NSLock()

private var creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor]

/// Add a new model to the type registry.
public func registerProcessorType(
Expand Down Expand Up @@ -140,12 +177,21 @@ public class ProcessorTypeRegistry: @unchecked Sendable {
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
/// implementation, if needed.
public class ModelRegistry: @unchecked Sendable {
/// Creates an empty registry.
public init() {
registry = Dictionary()
}

/// Creates a new registry with from given model configurations.
public init(modelConfigurations: [ModelConfiguration]) {
registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
}

/// Shared instance with default model configurations.
public static let shared = ModelRegistry(modelConfigurations: all())

private let lock = NSLock()
private var registry = Dictionary(
uniqueKeysWithValues: all().map {
($0.name, $0)
})
private var registry: [String: ModelConfiguration]

static public let paligemma3bMix448_8bit = ModelConfiguration(
id: "mlx-community/paligemma-3b-mix-448-8bit",
Expand All @@ -166,6 +212,7 @@ public class ModelRegistry: @unchecked Sendable {
[
paligemma3bMix448_8bit,
qwen2VL2BInstruct4Bit,
smolvlminstruct4bit,
]
}

Expand Down Expand Up @@ -205,16 +252,27 @@ public class ModelRegistry: @unchecked Sendable {
/// ```
public class VLMModelFactory: ModelFactory {

public static let shared = VLMModelFactory()
public init(
typeRegistry: ModelTypeRegistry, processorRegistry: ProcessorTypeRegistry,
modelRegistry: ModelRegistry
) {
self.typeRegistry = typeRegistry
self.processorRegistry = processorRegistry
self.modelRegistry = modelRegistry
}

/// Shared instance with default behavior.
public static let shared = VLMModelFactory(
typeRegistry: .shared, processorRegistry: .shared, modelRegistry: .shared)

/// registry of model type, e.g. configuration value `paligemma` -> configuration and init methods
public let typeRegistry = ModelTypeRegistry()
public let typeRegistry: ModelTypeRegistry

/// registry of input processor type, e.g. configuration value `PaliGemmaProcessor` -> configuration and init methods
public let processorRegistry = ProcessorTypeRegistry()
public let processorRegistry: ProcessorTypeRegistry

/// registry of model id to configuration, e.g. `mlx-community/paligemma-3b-mix-448-8bit`
public let modelRegistry = ModelRegistry()
public let modelRegistry: ModelRegistry

public func configuration(id: String) -> ModelConfiguration {
modelRegistry.configuration(id: id)
Expand Down