Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt lookup decoding #202

Closed
wants to merge 11 commits into from

Conversation

LeonEricsson
Copy link
Contributor

@LeonEricsson LeonEricsson commented Dec 28, 2023

This PR implements a example for the "Prompt Lookup Decoding" technique:

https://github.com/apoorvumang/prompt-lookup-decoding

using a local Mistral model.

  • This approach works best on input-grounded tasks such as summarization, document QA, code editing etc where there is a high overlap in input and output.
  • It replaces the draft model in speculative decoding with a simple ngram search of the input tokens. Similar to speculative decoding it has no impact on output.

There is an ongoing PR for speculative decoding (#149) and I imagine we'll want to sync our implementations before merging. I've mostly followed #149 but some things differ mainly because #149 uses a Llama model (with loading from HF) and I load a local Mistral model. The Mistral and Llama example implementations are inherently different which also leads to some minor differences.

I have confirmed speedups through a simple repetition test:

with prompt lookup (accepted draft tokens in blue)
test_pld

without prompt lookup
test_non_pld

There are a couple of key considerations when implementing the prompt lookup search:

  1. If a match is found but we can't draft n_draft tokens, do we draft as many as we can
    or move on to look for a match with a smaller ngram? We risk not finding a match if we
    move on but on the other hand we might draft just a single token.
  2. How do we choose if there are multiple matches of the same ngram size? What ranking
    scheme should be used?

This implementation:

  1. Ignores a match if we can't draft n_draft tokens.
  2. Exits upon the first match

TODO:

  • Perform find_draft in a single pass over the input_ids instead of iterating for every ngram_size. Should improve speeds considering python loop overhead.

@andersonbcdefg
Copy link
Contributor

Haha awesome I was actually planning this after finishing my speculative decoding but didn't get to it, glad someone else did! :D

@LeonEricsson
Copy link
Contributor Author

Haha awesome I was actually planning this after finishing my speculative decoding but didn't get to it, glad someone else did! :D

thanks for laying the groundwork! I had a review comment on your PR that was ignored but I think its quite important, any thoughts on that?

@awni
Copy link
Member

awni commented Dec 29, 2023

I had a review comment on your PR that was ignored but I think its quite important, any thoughts on that?

What was your comment, I couldn't find it?

@LeonEricsson
Copy link
Contributor Author

You've since changed that part of the code but it was about this line in mydecoder.py file:

new_tokens = sampled[: max(1, num_to_accept + 1)]

I'm fairly confident the accepted tokens should be num_to_accept + 1, meaning we accept the num_to_accept tokens that match in draft and sampled but also the first mismatch because that token was correctly sampled using a accepted token history.

@awni
Copy link
Member

awni commented Dec 30, 2023

Yea I think you are correct and it should have that behavior now. Let me know if you still think it's an issue.

@LeonEricsson
Copy link
Contributor Author

LeonEricsson commented Dec 31, 2023

Yea I think you are correct and it should have that behavior now. Let me know if you still think it's an issue.

Nice, looks good! I like that you avoided main model sampling and instead just compare draft / main probs directly. Do you have a reference for this code, I haven't seen this before? note: model_probs is overwritten on line 97 not sure what the intention was

@awni
Copy link
Member

awni commented Jan 1, 2024

Nice, looks good! I like that you avoided main model sampling and instead just compare draft / main probs directly. Do you have a reference for this code, I haven't seen this before?

Everything there except the inclusion of the delta parameter came from the original paper. See Algorithm 1 and section 2.3.

note: model_probs is overwritten on line 97 not sure what the intention was

The intention is to normalize them since the raw output of the model is not normalized.

@LeonEricsson
Copy link
Contributor Author

Everything there except the inclusion of the delta parameter came from the original paper. See Algorithm 1 and section 2.3.

Thanks!

The intention is to normalize them since the raw output of the model is not normalized.

I missed the - in model_probs -= on line 100. I was reading it as model_probs =, hence my confusion.

@awni
Copy link
Member

awni commented Jan 3, 2024

Hey @LeonEricsson I like this PR a lot but TBH I'm not entirely sure what to do with it.

The prompt look up decoding is a bit niche to dedicate a whole example to it. So one thing we could do is try and merge it into the speculative decoding example (and support different draft model strategies). The only challenge there is that I changed it to be a T5 example instead of a more traditonal causal LM.

So I see a couple options for moving forward (both of which involve integrating into speculative decoding).

  1. Do this prompt decoding but using T5 instead
  2. Change the speculative decoding example to use an LLM and see if we can actually get something faster. IMO if we tried something like code gen (where the output space is a little more regular), it might work a lot better.

What do you think?

@LeonEricsson
Copy link
Contributor Author

LeonEricsson commented Jan 3, 2024

Gotchu, it could probably slide into the speculative example; just want to make sure things remain modular as to not strain users attempting to understand and reimplement the examples

I haven't looked through the speculative example thoroughly since the change to T5 but I'll give it a look and try to decide what's most appropriate between 1) and 2). Does T5 vs LLM really change that much in decoder.py?

Would you like a new PR for 1) and/or 2) or keep it all here?

@awni
Copy link
Member

awni commented Jan 3, 2024

I haven't looked through the speculative example thoroughly since the change to T5 but I'll give it a look and try to decide what's most appropriate between 1) and 2). Does T5 vs LLM really change that much in decoder.py?

Not really that much TBH

Would you like a new PR for 1) and/or 2) or keep it all here?

Whatever is easier for you.

just want to make sure things remain modular as to not strain users attempting to understand and reimplement the examples

I very much agree with that goal. These examples are meant to be instructive and hackable (hence simple). With that in mind, I would say worry a bit less about code duplication and a bit more about keeping the actual implementation simple and modular.

For example if you need to make a different class in decoder.py or a different file all together that may be ok.

@LeonEricsson
Copy link
Contributor Author

I haven't looked through the speculative example thoroughly since the change to T5 but I'll give it a look and try to decide what's most appropriate between 1) and 2). Does T5 vs LLM really change that much in decoder.py?

Not really that much TBH

Would you like a new PR for 1) and/or 2) or keep it all here?

Whatever is easier for you.

just want to make sure things remain modular as to not strain users attempting to understand and reimplement the examples

I very much agree with that goal. These examples are meant to be instructive and hackable (hence simple). With that in mind, I would say worry a bit less about code duplication and a bit more about keeping the actual implementation simple and modular.

For example if you need to make a different class in decoder.py or a different file all together that may be ok.

I've begun implementing a PromptLookupDecoder class inside the speculative decoder.py file. I think this is the best way to keep things readable. I'm trying to make things work with T5, I don't have the time to rewrite SpeculativeDecoder for LLM right now (think the goal should still be to change this to LLM, not sure how valuable a T5 example is for most practitioners but maybe it's more popular than I think it is?). The only crux is that T5 outputs a mask token <extra_id_0> at the beginning of it's response which will never happen again; it ruins a lot of the early drafts generated through prompt lookup (although the effects of this diminish with larger prompts).

@awni
Copy link
Member

awni commented Jan 5, 2024

I'm trying to make things work with T5, I don't have the time to rewrite SpeculativeDecoder for LLM right now (think the goal should still be to change this to LLM, not sure how valuable a T5 example is for most practitioners but maybe it's more popular than I think it is?).

I don't disagree. I would prefer the example to work with an LLM instead of T5. If you prefer to wait, I might have some time to look into it a bit more this weekend / next week.

@LeonEricsson
Copy link
Contributor Author

LeonEricsson commented Jan 5, 2024

I'm trying to make things work with T5, I don't have the time to rewrite SpeculativeDecoder for LLM right now (think the goal should still be to change this to LLM, not sure how valuable a T5 example is for most practitioners but maybe it's more popular than I think it is?).

I don't disagree. I would prefer the example to work with an LLM instead of T5. If you prefer to wait, I might have some time to look into it a bit more this weekend / next week.

I have done most of the work now, just got to debug this one error then it should be ready (started a draft PR #237). That being said I don't mind updating PromptLookupDecoder if you have some time over to fix Speculative.

Please double check and make sure you're satisfied with the division of speculative // prompt lookup in #237, I'd be happy to change things around if needed.

@LeonEricsson
Copy link
Contributor Author

Implementation moved to #237.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants