diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index 603e918..190974b 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -22,25 +22,42 @@ private func create( /// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``. public class ModelTypeRegistry: @unchecked Sendable { + /// Creates an empty registry. + public init() { + self.creators = [:] + } + + /// Creates a registry with given creators. + public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) { + self.creators = creators + } + + /// Shared instance with default model types. + public static let shared: ModelTypeRegistry = .init(creators: all()) + + /// All predefined model types. + private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] { + [ + "mistral": create(LlamaConfiguration.self, LlamaModel.init), + "llama": create(LlamaConfiguration.self, LlamaModel.init), + "phi": create(PhiConfiguration.self, PhiModel.init), + "phi3": create(Phi3Configuration.self, Phi3Model.init), + "phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init), + "gemma": create(GemmaConfiguration.self, GemmaModel.init), + "gemma2": create(Gemma2Configuration.self, Gemma2Model.init), + "qwen2": create(Qwen2Configuration.self, Qwen2Model.init), + "starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init), + "cohere": create(CohereConfiguration.self, CohereModel.init), + "openelm": create(OpenElmConfiguration.self, OpenELMModel.init), + "internlm2": create(InternLM2Configuration.self, InternLM2Model.init), + ] + } + // Note: using NSLock as we have very small (just dictionary get/set) // critical sections and expect no contention. this allows the methods // to remain synchronous. private let lock = NSLock() - - private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ - "mistral": create(LlamaConfiguration.self, LlamaModel.init), - "llama": create(LlamaConfiguration.self, LlamaModel.init), - "phi": create(PhiConfiguration.self, PhiModel.init), - "phi3": create(Phi3Configuration.self, Phi3Model.init), - "phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init), - "gemma": create(GemmaConfiguration.self, GemmaModel.init), - "gemma2": create(Gemma2Configuration.self, Gemma2Model.init), - "qwen2": create(Qwen2Configuration.self, Qwen2Model.init), - "starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init), - "cohere": create(CohereConfiguration.self, CohereModel.init), - "openelm": create(OpenElmConfiguration.self, OpenELMModel.init), - "internlm2": create(InternLM2Configuration.self, InternLM2Model.init), - ] + private var creators: [String: @Sendable (URL) throws -> any LanguageModel] /// Add a new model to the type registry. public func registerModelType( @@ -72,8 +89,21 @@ public class ModelTypeRegistry: @unchecked Sendable { /// implementation, if needed. public class ModelRegistry: @unchecked Sendable { + /// Creates an empty registry. + public init() { + self.registry = Dictionary() + } + + /// Creates a new registry with from given model configurations. + public init(modelConfigurations: [ModelConfiguration]) { + self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) }) + } + + /// Shared instance with default model configurations. + public static let shared = ModelRegistry(modelConfigurations: all()) + private let lock = NSLock() - private var registry = Dictionary(uniqueKeysWithValues: all().map { ($0.name, $0) }) + private var registry: [String: ModelConfiguration] static public let smolLM_135M_4bit = ModelConfiguration( id: "mlx-community/SmolLM-135M-Instruct-4bit", @@ -274,13 +304,19 @@ private struct LLMUserInputProcessor: UserInputProcessor { /// ``` public class LLMModelFactory: ModelFactory { - public static let shared = LLMModelFactory() + public init(typeRegistry: ModelTypeRegistry, modelRegistry: ModelRegistry) { + self.typeRegistry = typeRegistry + self.modelRegistry = modelRegistry + } + + /// Shared instance with default behavior. + public static let shared = LLMModelFactory(typeRegistry: .shared, modelRegistry: .shared) /// registry of model type, e.g. configuration value `llama` -> configuration and init methods - public let typeRegistry = ModelTypeRegistry() + public let typeRegistry: ModelTypeRegistry /// registry of model id to configuration, e.g. `mlx-community/Llama-3.2-3B-Instruct-4bit` - public let modelRegistry = ModelRegistry() + public let modelRegistry: ModelRegistry public func configuration(id: String) -> ModelConfiguration { modelRegistry.configuration(id: id) diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index 17ba1f1..ee7d15b 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -52,16 +52,34 @@ private func create( /// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``. public class ModelTypeRegistry: @unchecked Sendable { + /// Creates an empty registry. + public init() { + self.creators = [:] + } + + /// Creates a registry with given creators. + public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) { + self.creators = creators + } + + /// Shared instance with default model types. + public static let shared: ModelTypeRegistry = .init(creators: all()) + + /// All predefined model types + private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] { + [ + "paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init), + "qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init), + "idefics3": create(Idefics3Configuration.self, Idefics3.init), + ] + } + // Note: using NSLock as we have very small (just dictionary get/set) // critical sections and expect no contention. this allows the methods // to remain synchronous. private let lock = NSLock() - private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ - "paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init), - "qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init), - "idefics3": create(Idefics3Configuration.self, Idefics3.init), - ] + private var creators: [String: @Sendable (URL) throws -> any LanguageModel] /// Add a new model to the type registry. public func registerModelType( @@ -90,20 +108,39 @@ public class ModelTypeRegistry: @unchecked Sendable { public class ProcessorTypeRegistry: @unchecked Sendable { - // Note: using NSLock as we have very small (just dictionary get/set) - // critical sections and expect no contention. this allows the methods - // to remain synchronous. - private let lock = NSLock() + /// Creates an empty registry. + public init() { + self.creators = [:] + } + + /// Creates a registry with given creators. + public init(creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor]) + { + self.creators = creators + } - private var creators: - [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor] = [ + /// Shared instance with default processor types. + public static let shared: ProcessorTypeRegistry = .init(creators: all()) + + /// All predefined processor types. + private static func all() -> [String: @Sendable (URL, any Tokenizer) throws -> + any UserInputProcessor] + { + [ "PaliGemmaProcessor": create( PaliGemmaProcessorConfiguration.self, PaligGemmaProcessor.init), - "Qwen2VLProcessor": create( - Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init), + "Qwen2VLProcessor": create(Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init), "Idefics3Processor": create( Idefics3ProcessorConfiguration.self, Idefics3Processor.init), ] + } + + // Note: using NSLock as we have very small (just dictionary get/set) + // critical sections and expect no contention. this allows the methods + // to remain synchronous. + private let lock = NSLock() + + private var creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor] /// Add a new model to the type registry. public func registerProcessorType( @@ -140,12 +177,21 @@ public class ProcessorTypeRegistry: @unchecked Sendable { /// swift-tokenizers code handles a good chunk of that and this is a place to augment that /// implementation, if needed. public class ModelRegistry: @unchecked Sendable { + /// Creates an empty registry. + public init() { + registry = Dictionary() + } + + /// Creates a new registry with from given model configurations. + public init(modelConfigurations: [ModelConfiguration]) { + registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) }) + } + + /// Shared instance with default model configurations. + public static let shared = ModelRegistry(modelConfigurations: all()) private let lock = NSLock() - private var registry = Dictionary( - uniqueKeysWithValues: all().map { - ($0.name, $0) - }) + private var registry: [String: ModelConfiguration] static public let paligemma3bMix448_8bit = ModelConfiguration( id: "mlx-community/paligemma-3b-mix-448-8bit", @@ -166,6 +212,7 @@ public class ModelRegistry: @unchecked Sendable { [ paligemma3bMix448_8bit, qwen2VL2BInstruct4Bit, + smolvlminstruct4bit, ] } @@ -205,16 +252,27 @@ public class ModelRegistry: @unchecked Sendable { /// ``` public class VLMModelFactory: ModelFactory { - public static let shared = VLMModelFactory() + public init( + typeRegistry: ModelTypeRegistry, processorRegistry: ProcessorTypeRegistry, + modelRegistry: ModelRegistry + ) { + self.typeRegistry = typeRegistry + self.processorRegistry = processorRegistry + self.modelRegistry = modelRegistry + } + + /// Shared instance with default behavior. + public static let shared = VLMModelFactory( + typeRegistry: .shared, processorRegistry: .shared, modelRegistry: .shared) /// registry of model type, e.g. configuration value `paligemma` -> configuration and init methods - public let typeRegistry = ModelTypeRegistry() + public let typeRegistry: ModelTypeRegistry /// registry of input processor type, e.g. configuration value `PaliGemmaProcessor` -> configuration and init methods - public let processorRegistry = ProcessorTypeRegistry() + public let processorRegistry: ProcessorTypeRegistry /// registry of model id to configuration, e.g. `mlx-community/paligemma-3b-mix-448-8bit` - public let modelRegistry = ModelRegistry() + public let modelRegistry: ModelRegistry public func configuration(id: String) -> ModelConfiguration { modelRegistry.configuration(id: id)