Skip to content

Commit

Permalink
Working
Browse files Browse the repository at this point in the history
  • Loading branch information
DePasqualeOrg committed Jan 14, 2025
1 parent c1deeb4 commit 806c7f2
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions Libraries/MLXVLM/Models/Qwen2VL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -686,37 +686,22 @@ public class Qwen2VLProcessor: UserInputProcessor {
return (flattenedPatches, .init(gridT, gridH, gridW))
}

private func prepareMessages(_ messages: [Message], imageTHW: [THW]?) -> [Message] {
private func prepareMessages(_ messages: [Message]) -> [Message] {
var messages = messages
print(messages)
// Add system message if not present
if let role = messages[0]["role"] as? String, role != "system" {
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
}

// // Add image markers to last message if needed
// if let imageTHW {
// let lastIndex = messages.count - 1
// var content = messages[lastIndex]["content"] as? String ?? ""
// let mergeLength = config.mergeSize * config.mergeSize
// for thw in imageTHW {
// content += "<|vision_start|>"
// content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength).joined()
// content += "<|vision_end|>"
// }
// messages[lastIndex]["content"] = content
// }

// TODO: Instead of the above, replace the single `<|image_pad|>` with repeated padding, using the same logic as above to determine the number of repeats.

return messages
}

public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) throws -> String {
let messages = prepareMessages(prompt.asMessages(), imageTHW: imageTHW)
let tokens = try tokenizer.applyChatTemplate(messages: messages)
return tokenizer.decode(tokens: tokens)
}
// public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) throws -> String {
// let messages = prepareMessages(prompt.asMessages())
// let tokens = try tokenizer.applyChatTemplate(messages: messages)
// return tokenizer.decode(tokens: tokens)
// }

public func prepare(input: UserInput) throws -> LMInput {
// Text-only input
Expand All @@ -725,15 +710,34 @@ public class Qwen2VLProcessor: UserInputProcessor {
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
return LMInput(tokens: MLXArray(promptTokens))
}

// Input with images
let images = try input.images.map {
try preprocess(images: [$0.asCIImage()], processing: input.processing)
}
let pixels = concatenated(images.map { $0.0 })
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 })
// Prepare messages with image markers
let messages = prepareMessages(input.prompt.asMessages(), imageTHW: image.imageGridThw)
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)

// Get tokens from messages
let messages = prepareMessages(input.prompt.asMessages())
var promptTokens = try tokenizer.applyChatTemplate(messages: messages)

// Replace single image pad token with correct number for each image
let imagePadToken = try tokenizer.encode(text: "<|image_pad|>").first!
let mergeLength = config.mergeSize * config.mergeSize

// TODO: This assumes that there is only one image. A better solution is needed for the case when multiple images are included.
if let imageGridThw = image.imageGridThw {
for thw in imageGridThw {
if let padIndex = promptTokens.firstIndex(of: imagePadToken) {
let paddingCount = thw.product / mergeLength
promptTokens.replaceSubrange(
padIndex ... (padIndex),
with: Array(repeating: imagePadToken, count: paddingCount)
)
}
}
}

// TODO: For debugging. Remove later.
let promptTokensDecoded = try tokenizer.decode(tokens: promptTokens)
Expand Down

0 comments on commit 806c7f2

Please sign in to comment.