Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefer chat_template.json for chat template #184

Merged
merged 2 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -177,28 +177,39 @@ public class LanguageModelConfigurationFromHub {
modelName: String,
hubApi: HubApi = .shared
) async throws -> Configurations {
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
let filesToDownload = ["config.json", "tokenizer_config.json", "chat_template.json", "tokenizer.json"]
let repo = Hub.Repo(id: modelName)
let downloadedModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload)

return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi)
}

func loadConfig(
modelFolder: URL,
hubApi: HubApi = .shared
) async throws -> Configurations {
// Note tokenizerConfig may be nil (does not exist in all models)
// Load required configurations
let modelConfig = try hubApi.configuration(fileURL: modelFolder.appending(path: "config.json"))
let tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json"))
let tokenizerVocab = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json"))

let configs = Configurations(
let tokenizerData = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json"))
// Load tokenizer config
var tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json"))
// Check for chat template and merge if available
if let chatTemplateConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "chat_template.json")),
let chatTemplate = chatTemplateConfig.chatTemplate?.stringValue {
// The value of chat_template could also be an array of strings, but we're not handling that case here, since it's discouraged.
// Create or update tokenizer config with chat template
if var configDict = tokenizerConfig?.dictionary {
configDict["chat_template"] = chatTemplate
tokenizerConfig = Config(configDict)
} else {
tokenizerConfig = Config(["chat_template": chatTemplate])
}
}
return Configurations(
modelConfig: modelConfig,
tokenizerConfig: tokenizerConfig,
tokenizerData: tokenizerVocab
tokenizerData: tokenizerData
)
return configs
}

static func fallbackTokenizerConfig(for modelType: String) -> Config? {
Expand Down
38 changes: 38 additions & 0 deletions Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,44 @@ What is the weather in Paris today?<|im_end|>
XCTAssertTrue(tokenizer.hasChatTemplate)
}

// Test for vision models with a vision chat template in chat_template.json
func testChatTemplateFromChatTemplateJson() async throws {
let visionMessages = [
[
"role": "user",
"content": [
[
"type": "text",
"text": "What's in this image?",
] as [String: String],
[
"type": "image",
"image_url": "example.jpg",
] as [String: String],
] as [[String: String]],
] as [String: Any]
] as [[String: Any]]
// Qwen 2 VL does not have a chat_template.json file. The chat template is in tokenizer_config.json.
let qwen2VLTokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2-VL-7B-Instruct-4bit")
// Qwen 2.5 VL has a chat_template.json file with a different chat template than the one in tokenizer_config.json.
let qwen2_5VLTokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2.5-VL-7B-Instruct-4bit")
let qwen2VLEncoded = try qwen2VLTokenizer.applyChatTemplate(messages: visionMessages)
let qwen2VLDecoded = qwen2VLTokenizer.decode(tokens: qwen2VLEncoded)
let qwen2_5VLEncoded = try qwen2_5VLTokenizer.applyChatTemplate(messages: visionMessages)
let qwen2_5VLDecoded = qwen2_5VLTokenizer.decode(tokens: qwen2_5VLEncoded)
let expectedOutput = """
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What's in this image?<|vision_start|><|image_pad|><|vision_end|><|im_end|>
<|im_start|>assistant

"""
XCTAssertEqual(qwen2VLEncoded, qwen2_5VLEncoded, "Encoded sequences should be equal")
XCTAssertEqual(qwen2VLDecoded, qwen2_5VLDecoded, "Decoded sequences should be equal")
XCTAssertEqual(qwen2_5VLDecoded, expectedOutput, "Decoded sequence should match expected output")
}

func testApplyTemplateError() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "google-bert/bert-base-uncased")
XCTAssertFalse(tokenizer.hasChatTemplate)
Expand Down