From 806c7f263e7ea290bbd851dfa21ac1e4cf367350 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Tue, 14 Jan 2025 21:55:29 +0100 Subject: [PATCH] Working --- Libraries/MLXVLM/Models/Qwen2VL.swift | 52 ++++++++++++++------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index 2f90a7b..6e0fba1 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -686,7 +686,7 @@ public class Qwen2VLProcessor: UserInputProcessor { return (flattenedPatches, .init(gridT, gridH, gridW)) } - private func prepareMessages(_ messages: [Message], imageTHW: [THW]?) -> [Message] { + private func prepareMessages(_ messages: [Message]) -> [Message] { var messages = messages print(messages) // Add system message if not present @@ -694,29 +694,14 @@ public class Qwen2VLProcessor: UserInputProcessor { messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0) } - // // Add image markers to last message if needed - // if let imageTHW { - // let lastIndex = messages.count - 1 - // var content = messages[lastIndex]["content"] as? String ?? "" - // let mergeLength = config.mergeSize * config.mergeSize - // for thw in imageTHW { - // content += "<|vision_start|>" - // content += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength).joined() - // content += "<|vision_end|>" - // } - // messages[lastIndex]["content"] = content - // } - - // TODO: Instead of the above, replace the single `<|image_pad|>` with repeated padding, using the same logic as above to determine the number of repeats. - return messages } - public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) throws -> String { - let messages = prepareMessages(prompt.asMessages(), imageTHW: imageTHW) - let tokens = try tokenizer.applyChatTemplate(messages: messages) - return tokenizer.decode(tokens: tokens) - } + // public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) throws -> String { + // let messages = prepareMessages(prompt.asMessages()) + // let tokens = try tokenizer.applyChatTemplate(messages: messages) + // return tokenizer.decode(tokens: tokens) + // } public func prepare(input: UserInput) throws -> LMInput { // Text-only input @@ -725,15 +710,34 @@ public class Qwen2VLProcessor: UserInputProcessor { let promptTokens = try tokenizer.applyChatTemplate(messages: messages) return LMInput(tokens: MLXArray(promptTokens)) } + // Input with images let images = try input.images.map { try preprocess(images: [$0.asCIImage()], processing: input.processing) } let pixels = concatenated(images.map { $0.0 }) let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 }) - // Prepare messages with image markers - let messages = prepareMessages(input.prompt.asMessages(), imageTHW: image.imageGridThw) - let promptTokens = try tokenizer.applyChatTemplate(messages: messages) + + // Get tokens from messages + let messages = prepareMessages(input.prompt.asMessages()) + var promptTokens = try tokenizer.applyChatTemplate(messages: messages) + + // Replace single image pad token with correct number for each image + let imagePadToken = try tokenizer.encode(text: "<|image_pad|>").first! + let mergeLength = config.mergeSize * config.mergeSize + + // TODO: This assumes that there is only one image. A better solution is needed for the case when multiple images are included. + if let imageGridThw = image.imageGridThw { + for thw in imageGridThw { + if let padIndex = promptTokens.firstIndex(of: imagePadToken) { + let paddingCount = thw.product / mergeLength + promptTokens.replaceSubrange( + padIndex ... (padIndex), + with: Array(repeating: imagePadToken, count: paddingCount) + ) + } + } + } // TODO: For debugging. Remove later. let promptTokensDecoded = try tokenizer.decode(tokens: promptTokens)