Skip to content

Commit

Permalink
Debug and add to-do for next steps
Browse files Browse the repository at this point in the history
  • Loading branch information
DePasqualeOrg committed Jan 14, 2025
1 parent 96cdc14 commit c1deeb4
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions Libraries/MLXVLM/Models/Qwen2VL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -693,18 +693,22 @@ public class Qwen2VLProcessor: UserInputProcessor {
if let role = messages[0]["role"] as? String, role != "system" {
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
}
// // Add image markers to last message if needed
// if let imageTHW {
// let lastIndex = messages.count - 1
// var content = messages[lastIndex]["content"] ?? ""
// let mergeLength = config.mergeSize * config.mergeSize
// for thw in imageTHW {
// content += "<|vision_start|>"
// content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength).joined()
// content += "<|vision_end|>"
// }
// messages[lastIndex]["content"] = content
// }

// // Add image markers to last message if needed
// if let imageTHW {
// let lastIndex = messages.count - 1
// var content = messages[lastIndex]["content"] as? String ?? ""
// let mergeLength = config.mergeSize * config.mergeSize
// for thw in imageTHW {
// content += "<|vision_start|>"
// content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength).joined()
// content += "<|vision_end|>"
// }
// messages[lastIndex]["content"] = content
// }

// TODO: Instead of the above, replace the single `<|image_pad|>` with repeated padding, using the same logic as above to determine the number of repeats.

return messages
}

Expand All @@ -730,6 +734,11 @@ public class Qwen2VLProcessor: UserInputProcessor {
// Prepare messages with image markers
let messages = prepareMessages(input.prompt.asMessages(), imageTHW: image.imageGridThw)
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)

// TODO: For debugging. Remove later.
let promptTokensDecoded = try tokenizer.decode(tokens: promptTokens)
print(promptTokensDecoded)

let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
let mask = ones(like: promptArray).asType(.int8)
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image)
Expand Down

0 comments on commit c1deeb4

Please sign in to comment.