Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speculative : refactor and add a simpler example #10362

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ add_library(${TARGET} STATIC
ngram-cache.h
sampling.cpp
sampling.h
speculative.cpp
speculative.h
)

if (BUILD_SHARED_LIBS)
Expand Down
22 changes: 22 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,28 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return cur_p.data[cur_p.selected].id;
}

std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

std::vector<llama_token> result;
result.reserve(idxs.size());

size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

if (draft[i] != id) {
break;
}

result.push_back(id);
}

result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first));

return result;
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}
Expand Down
13 changes: 13 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);

// generalized version of common_sampler_sample
//
// will cross-reference the sampled tokens with a batch of draft tokens
// if the sampler disagrees at some point, we stop and return the sampled tokens up to now
//
// `common_sampler_sample_n(gsmpl, ctx, { idx }, {})` is equivalent to `common_sampler_sample(gsmpl, ctx, idx)`
//
// requires: idxs.size() == draft.size() + 1
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

// helpers
Expand Down
154 changes: 154 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#include "speculative.h"

#include "log.h"
#include "common.h"
#include "sampling.h"

struct common_speculative {
struct common_speculative_params params;

llama_batch batch_dft;

struct common_sampler * smpl;

std::vector<int> i_batch_tgt;

std::vector<llama_token> tokens;
};

struct common_speculative * common_speculative_init(struct common_speculative_params params) {
auto * result = new common_speculative {
/* .params = */ params,
/* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1),
/* .smpl = */ nullptr,
/* .i_batch_tgt = */ {},
/* .tokens = */ {},
};

// TODO: optimize or pass from outside?
#if 0
{
common_sampler_params sparams;
sparams.no_perf = false;

sparams.top_k = 40;
sparams.top_p = 0.9;

sparams.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_INFILL,
};

result->smpl = common_sampler_init(params.model_dft, sparams);
}
#else
{
common_sampler_params sparams;
sparams.no_perf = false;

sparams.top_k = 10;

sparams.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};

result->smpl = common_sampler_init(params.model_dft, sparams);
}
#endif

result->batch_dft = llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1);

return result;
}

void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);

llama_batch_free(spec->batch_dft);

delete spec;
}

void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) {
llama_kv_cache_clear(spec->params.ctx_dft);

// TODO: error handling
llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens));
}

void common_speculative_add_draft(
struct common_speculative * spec,
struct llama_batch & batch_tgt,
llama_token id_last,
int n_past) {
spec->tokens.clear();

spec->i_batch_tgt.clear();
spec->i_batch_tgt.push_back(0);

common_sampler_reset(spec->smpl);

common_batch_clear(spec->batch_dft);
common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true);

llama_decode(spec->params.ctx_dft, spec->batch_dft);

// sample n_draft tokens from the draft model
for (int i = 0; i < spec->params.n_draft; ++i) {
common_batch_clear(spec->batch_dft);

common_sampler_sample(spec->smpl, spec->params.ctx_dft, 0, true);

const auto * cur_p = common_sampler_get_candidates(spec->smpl);

for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(spec->params.ctx_dft, cur_p->data[k].id).c_str());
}

// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;

// only collect very high-confidence draft tokens
if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) {
break;
}

common_sampler_accept(spec->smpl, id, true);

spec->tokens.push_back(id);

// add unique drafted tokens to the target batch
spec->i_batch_tgt.push_back(batch_tgt.n_tokens);

common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true);

if (batch_tgt.n_tokens > spec->params.n_draft) {
break;
}

common_batch_add(spec->batch_dft, id, n_past + i + 1, { 0 }, true);

// evaluate the drafted tokens on the draft model
llama_decode(spec->params.ctx_dft, spec->batch_dft);
}

// don't waste time on small batches
// TODO: do not evaluate the draft model for that many rounds
if (batch_tgt.n_tokens < spec->params.n_min) {
batch_tgt.n_tokens = 1;
spec->tokens.resize(0);
spec->i_batch_tgt.resize(1);
}

// print current draft sequences
LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str());
}

std::vector<llama_token> common_speculative_sample(
struct common_speculative * spec,
struct common_sampler * smpl,
struct llama_context * ctx_tgt) {
return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens);
}
46 changes: 46 additions & 0 deletions common/speculative.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "llama.h"

#include <vector>

struct common_speculative;

struct common_speculative_params {
int n_draft = 16;
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?

struct llama_model * model_dft = nullptr;

struct llama_context * ctx_dft = nullptr;
};

struct common_speculative * common_speculative_init(struct common_speculative_params params);

void common_speculative_free(struct common_speculative * spec);

// TODO: remove
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens);

// sample up to n_draft tokens and add them to the batch using the draft model
//
// TODO: change to:
//
// void common_speculative_add_draft(
// struct common_speculative * spec,
// struct llama_batch & batch_tgt,
// llama_token * tokens,
// int32_t n_tokens);
//
// and update the internal logic to compute only the new tokens
//
void common_speculative_add_draft(
struct common_speculative * spec,
struct llama_batch & batch_tgt,
llama_token id_last,
int n_past);

std::vector<llama_token> common_speculative_sample(
struct common_speculative * spec,
struct common_sampler * smpl,
struct llama_context * ctx_tgt);
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,6 @@ else()
add_subdirectory(simple)
add_subdirectory(simple-chat)
add_subdirectory(speculative)
add_subdirectory(speculative-simple)
add_subdirectory(tokenize)
endif()
5 changes: 5 additions & 0 deletions examples/speculative-simple/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET llama-speculative-simple)
add_executable(${TARGET} speculative-simple.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
3 changes: 3 additions & 0 deletions examples/speculative-simple/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# llama.cpp/examples/speculative-simple

Demonstration of basic greedy speculative decoding
Loading