Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIx idefics3 do-image-split #192

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 19 additions & 30 deletions Libraries/MLXVLM/Models/Idefics3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -664,46 +664,35 @@ public class Idefics3: Module, VLMModel, KVCacheDimensionProvider {
}

private func prepareInputsForMultimodal(
imageFeatures: MLXArray, inputs_embeds: MLXArray, inputIds: MLXArray
imageFeatures: MLXArray,
inputs_embeds: MLXArray,
inputIds: MLXArray
) -> MLXArray {
// inputIds shape: (1, seq_len)
// asArray(Int.self) -> [[Int]], take [0] to get [Int]
let ids: [[Int]] = [inputIds.asArray(Int.self)]
let imageTokenIndex = config.imageTokenIndex

// Get input IDs as array and find image positions
let ids: [[Int]] = [inputIds.asArray(Int.self)]
let inputIdArray: [Int] = ids[0]

let imageTokenIndex = config.imageTokenIndex
let imagePositions = inputIdArray.enumerated().compactMap {
$1 == imageTokenIndex ? $0 : nil
}

var segments = [MLXArray]()
var start_idx = 0
// Get image feature dimensions and reshape
let (numImages, _, visionHiddenSize) = (
imageFeatures.dim(0),
imageFeatures.dim(1),
imageFeatures.dim(2)
)

for pos in imagePositions {
if pos > start_idx {
let textSegment = inputs_embeds[0..., start_idx ..< pos, 0...]
if textSegment.dim(1) > 0 {
segments.append(textSegment)
}
}
start_idx = pos + 1
segments.append(imageFeatures)
}
let reshapedImageFeatures = imageFeatures.reshaped(-1, visionHiddenSize)

if start_idx < inputs_embeds.dim(1) {
let remain = inputs_embeds[0..., start_idx..., 0...]
if remain.dim(1) > 0 {
segments.append(remain)
}
}
// Convert to same dtype as inputs_embeds (handling quantized models)
let typedImageFeatures = reshapedImageFeatures.asType(inputs_embeds.dtype)

var finalEmbeds = segments[0]
for seg in segments.dropFirst() {
finalEmbeds = concatenated([finalEmbeds, seg], axis: 1)
}
// Place image features at image token positions
inputs_embeds[0..., imagePositions as! MLXArrayIndex, 0...] = typedImageFeatures

return finalEmbeds
return inputs_embeds
}

public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws
Expand Down Expand Up @@ -794,7 +783,7 @@ public class Idefics3Processor: UserInputProcessor {

// From the Python code and default config, we know image_token_id is usually 49153.
// Hardcode this since we can't pass it in or rely on it from the processor config.
private let imageTokenId = 49153
private let imageTokenId = 49190
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This no longer matches the comment. Also, might want to change this line:

        self.imageTokenId = (try? container.decode(Int.self, forKey: .imageTokenId)) ?? 49153

That said, perhaps we should address the "Hardcode this since we can't pass it in or rely on it from the processor config" part of it -- if UserInputProcessor needs it then it should get passed in (VLMModelFactory).

Comparing this to the other VLMs it looks like the difference is here:

            // Encode only the text part of the prompt, without <image>
            var promptTokens = try tokenizer.encode(text: prompt)

            let imageTokenIndex = promptTokens.count / 2
            promptTokens.insert(imageTokenId, at: imageTokenIndex)

this deals in numeric tokens after tokenization while the others are injecting text.

I think we can keep this hard coded for this specific issue but it would be good to decide which approach to take here:

  • inject text into the prompt (no token id needed)
  • we need the model config


public init(
_ config: Idefics3ProcessorConfiguration,
Expand Down