Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt Lookup Decoding - merged under Speculative example #237

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

LeonEricsson
Copy link
Contributor

@LeonEricsson LeonEricsson commented Jan 5, 2024

Continuation of #202. Decided to merge the Prompt Lookup Decoding under the Speculative Decoding example.

This PR implements a example for the "Prompt Lookup Decoding" technique:

https://github.com/apoorvumang/prompt-lookup-decoding

  • This approach works best on input-grounded tasks such as summarization, document QA, code editing etc where there is a high overlap in input and output.
  • It replaces the draft model in speculative decoding with a simple ngram search of the input tokens. Similar to speculative decoding it has no impact on output.

TODO

  • Debug: Model output doesn't have any spaces
  • Add --color flag to SpeculativeDecoder
    Ended up being quite a messy implementation; need to deal with the fact that output can be truncated and hence only come from draft model
  • Update README
  • Code cleaning

@LeonEricsson LeonEricsson mentioned this pull request Jan 5, 2024
1 task
@LeonEricsson LeonEricsson marked this pull request as ready for review January 7, 2024 11:02
@LeonEricsson
Copy link
Contributor Author

@awni perhaps we can leave this as T5 and then make an attempt at swapping to Llama in a new PR? I was thinking we could adopt the model format / conversion from hf_llm as well while we're still at it? I could take a first crack at it.

@awni
Copy link
Member

awni commented Jan 8, 2024

@awni perhaps we can leave this as T5 and then make an attempt at swapping to Llama in a new PR?

Yea that sounds like a great plan to me! Sorry for the delay in the review here, I will get to it shortly!

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really nice! I think we can get this in soon. I didn't look yet at the core of the prompt decoder but left a few comments.

Thanks a ton for refactoring them together, I think it makes a lot of sense this way.

@LeonEricsson
Copy link
Contributor Author

Looks really nice! I think we can get this in soon. I didn't look yet at the core of the prompt decoder but left a few comments.

Thanks a ton for refactoring them together, I think it makes a lot of sense this way.

Thanks! Addressed all your comments

@cmcmaster1
Copy link
Contributor

I've been poking around your code @LeonEricsson because I have some long summarization tasks that I'd like to speed up, but noticed a significant bottleneck from the loop. This is probably still not perfect, but I've had a go at speeding it up. This implementation is about 500x faster:

def find_draft(self, input_ids):
            # Convert MLX array to NumPy for vectorized operations
            input_ids_np = np.array(input_ids)
            
            # Create a sliding window of the last ngram_max tokens
            ngram = input_ids_np[-self.ngram_max:]
            
            # Vectorized comparison of ngram with all possible sub-arrays of input_ids
            matches = np.lib.stride_tricks.sliding_window_view(input_ids_np, self.ngram_max) == ngram
            
            # Check if all elements in ngram match for each sub-array
            match_indices = np.all(matches, axis=1).nonzero()[0]
            
            # Filter out matches that are too short or overlap with the ngram itself
            match_indices = match_indices[(match_indices + self.ngram_max <= input_ids_np.size - self.ngram_max) & (match_indices >= self.ngram_min)]
            
            # Find the largest match
            if match_indices.size > 0:
                largest_match_idx = match_indices[-1]  # Assuming the last match is the largest
                start_idx = largest_match_idx + self.ngram_max
                end_idx = min(start_idx + self.n_draft, input_ids_np.size)
                candidate = input_ids_np[start_idx:end_idx]
                
                # Convert the candidate back to MLX array
                return mx.array(candidate, dtype=mx.uint32)
            
            return mx.array([], dtype=mx.uint32)

@LeonEricsson
Copy link
Contributor Author

LeonEricsson commented Jan 22, 2024

I've been poking around your code @LeonEricsson because I have some long summarization tasks that I'd like to speed up, but noticed a significant bottleneck from the loop. This is probably still not perfect, but I've had a go at speeding it up. This implementation is about 500x faster:

nice 🚀 the original implementation employed numpy's sliding windows, but I chose to maintain a purely mlx approach. However, as these are user examples, we should prioritize what is most beneficial for the user. A performance bottleneck like this is indeed a significant issue, and I concur that it warrants a change.

sidenote: perhaps we can attain comparable speed improvements using mlx.core.vmap?

@cmcmaster1
Copy link
Contributor

cmcmaster1 commented Jan 22, 2024

That makes sense. And I'm guessing you didn't notice a huge performance gap, because you didn't try it on long texts? I'm shaving ~30 seconds off inference time. I was thinking about trying a vmap version next.

Edit: I should clarify, without vectorization prompt lookup is slower than generate for anything but the most trivial task (e.g. repetition). So I think this change is necessary to really justify its existence as a useful example for the community.

@LeonEricsson
Copy link
Contributor Author

LeonEricsson commented Jan 27, 2024

@cmcmaster1 finally implemented a pure MLX version that should be comparable in performance to the numpy one. Would be great if you could confirm this on your end. However, before you do so note that your current implementation does not consider ngram matches other than of size self.ngram_max, which is not aligned with how Prompt Lookup was originally proposed. The idea of prompt lookup is to iteratively check for smaller ngram keys until you get to self.ngram_min; note the for loop here. I spent a lot of time trying to do away with this for loop and letting mlx.core do the work but couldn't find a way I was happy with. The further distance between self.ngram_max and self.ngram_min the more of a python overhead you're going to have, you could set self.ngram_max = self.ngram_min if you don't want this behaviour.

@awni imo this is ready to be merged, sorry for the delay.

@cmcmaster1
Copy link
Contributor

@LeonEricsson oops, you're right. I somehow missed that and just tested on examples where it made no difference! Still much faster than the original and definitely comparable to the (flawed) numpy implementation.

@LeonEricsson
Copy link
Contributor Author

ping @awni

@awni
Copy link
Member

awni commented Feb 23, 2024

Sorry for the delay!! I will review and get this in early next week

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants