Skip to content

Commit

Permalink
llama : add DRY sampler (#9702)
Browse files Browse the repository at this point in the history
* sampling : add DRY sampler (post-refactor)

* DRY: Trying to fix coauthors, removed unneeded line

* DRY: Fixed redundant code

* DRY: Fixed crash issue due to DRY being in chain but uninitialized

---------

Co-authored-by: l3utterfly <[email protected]>
Co-authored-by: pi6am <[email protected]>
  • Loading branch information
3 people authored Oct 25, 2024
1 parent d80fb71 commit ff252ea
Show file tree
Hide file tree
Showing 17 changed files with 713 additions and 63 deletions.
61 changes: 61 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & antiprompt : params.antiprompt) {
string_process_escapes(antiprompt);
}
for (auto & seq_breaker : params.sparams.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
}

if (!params.kv_overrides.empty()) {
Expand Down Expand Up @@ -997,6 +1000,64 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sparams.penalty_freq = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dry-multiplier"}, "N",
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
[](common_params & params, const std::string & value) {
params.sparams.dry_multiplier = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dry-base"}, "N",
string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
[](common_params & params, const std::string & value) {
float potential_base = std::stof(value);
if (potential_base >= 1.0f)
{
params.sparams.dry_base = potential_base;
}
}
).set_sparam());
add_opt(common_arg(
{"--dry-allowed-length"}, "N",
string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
[](common_params & params, int value) {
params.sparams.dry_allowed_length = value;
}
).set_sparam());
add_opt(common_arg(
{"--dry-penalty-last-n"}, "N",
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
[](common_params & params, int value) {
params.sparams.dry_penalty_last_n = value;
}
).set_sparam());
add_opt(common_arg(
{"--dry-sequence-breaker"}, "STRING",
string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n",
params.sparams.dry_sequence_breakers.empty() ? "none" :
std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()),
params.sparams.dry_sequence_breakers.end(),
std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'",
[](const std::string& a, const std::string& b) {
std::string formatted_b = (b == "\n") ? "\\n" : b;
return a + ", '" + formatted_b + "'";
}).c_str()),
[](common_params & params, const std::string & value) {
static bool defaults_cleared = false;

if (!defaults_cleared) {
params.sparams.dry_sequence_breakers.clear();
defaults_cleared = true;
}

if (value == "none") {
params.sparams.dry_sequence_breakers.clear();
} else {
params.sparams.dry_sequence_breakers.emplace_back(value);
}
}
).set_sparam());
add_opt(common_arg(
{"--dynatemp-range"}, "N",
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
Expand Down
4 changes: 4 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2006,6 +2006,10 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
Expand Down
70 changes: 39 additions & 31 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,15 @@ enum llama_example {

enum common_sampler_type {
COMMON_SAMPLER_TYPE_NONE = 0,
COMMON_SAMPLER_TYPE_TOP_K = 1,
COMMON_SAMPLER_TYPE_TOP_P = 2,
COMMON_SAMPLER_TYPE_MIN_P = 3,
COMMON_SAMPLER_TYPE_TFS_Z = 4,
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
COMMON_SAMPLER_TYPE_XTC = 7,
COMMON_SAMPLER_TYPE_INFILL = 8,
COMMON_SAMPLER_TYPE_DRY = 1,
COMMON_SAMPLER_TYPE_TOP_K = 2,
COMMON_SAMPLER_TYPE_TOP_P = 3,
COMMON_SAMPLER_TYPE_MIN_P = 4,
COMMON_SAMPLER_TYPE_TFS_Z = 5,
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9,
};

// dimensionality reduction methods, used by cvector-generator
Expand All @@ -104,32 +105,39 @@ enum dimre_method {
struct common_sampler_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler

int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float tfs_z = 1.00f; // 1.0 = disabled
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float tfs_z = 1.00f; // 1.0 = disabled
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics

std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY


std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TFS_Z,
COMMON_SAMPLER_TYPE_TYPICAL_P,
Expand Down
17 changes: 17 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ std::string common_sampler_params::print() const {

snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
mirostat, mirostat_eta, mirostat_tau);

Expand Down Expand Up @@ -174,6 +176,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
std::vector<const char*> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto& str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}

llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break;
Expand Down Expand Up @@ -358,6 +371,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_

char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: return 'd';
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TFS_Z: return 'f';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
Expand All @@ -372,6 +386,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {

std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: return "dry";
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
Expand All @@ -386,6 +401,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {

std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
Expand Down Expand Up @@ -434,6 +450,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect

std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
std::unordered_map<char, common_sampler_type> sampler_name_map = {
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
Expand Down
24 changes: 24 additions & 0 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,30 @@ Use the `--no-penalize-nl` option to disable newline penalization when applying

Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`

### DRY Repetition Penalty

DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).

- `--dry-multiplier N`: Set the DRY sampling multiplier (default: 0.0, 0.0 = disabled).
- `--dry-base N`: Set the DRY sampling base value (default: 1.75).
- `--dry-allowed-length N`: Set the allowed length for DRY sampling (default: 2).
- `--dry-penalty-last-n N`: Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size).
- `--dry-sequence-breaker STRING`: Add a sequence breaker for DRY sampling. Can be used more than once to add multiple sequence breakers. Using this clears out the default breakers, which consist of: `['\n', ':', '"', '*']`. If the string `"none"` is supplied, no sequence breakers are used.

The `dry-multiplier` option controls the strength of the DRY sampling effect. A value of 0.0 disables DRY sampling, while higher values increase its influence. A typical recommended value is 0.8.

The `dry-base` option sets the base value for the exponential penalty calculation in DRY sampling. Higher values lead to more aggressive penalization of repetitions.

The `dry-allowed-length` option sets the maximum length of repeated sequences that will not be penalized. Repetitions shorter than or equal to this length are not penalized, allowing for natural repetitions of short phrases or common words.

The `dry-penalty-last-n` option controls how many recent tokens to consider when applying the DRY penalty. A value of -1 considers the entire context. Use a positive value to limit the consideration to a specific number of recent tokens.

The `dry-sequence-breaker` option adds a single sequence breaker and can be used more than once to specify multiple sequence breakers. Sequence breakers interrupt sequence matching and break the input into parts where matching can be applied.

DRY sampling provides more nuanced control over text generation, particularly for reducing long-range repetitions and maintaining global coherence.

Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"`

### Top-K Sampling

- `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40).
Expand Down
15 changes: 15 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ The project is under active development, and we are [looking for feedback and co
| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) |
| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) |
| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) |
| `--dry-multiplier N` | DRY sampling multiplier (default: 0.0, 0.0 = disabled) |
| `--dry-base N` | DRY sampling base value (default: 1.75) |
| `--dry-allowed-length N` | allowed length for DRY sampling (default: 2) |
| `--dry-penalty-last-n N` | DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) |
| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers (`['\n', ':', '"', '*']`) in the process; use `"none"` to not use any sequence breakers
| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) |
| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) |
| `--mirostat N` | use Mirostat sampling.<br/>Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.<br/>(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) |
Expand Down Expand Up @@ -369,6 +374,16 @@ node index.js

`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.

`dry_multiplier`: Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled.

`dry_base`: Set the DRY repetition penalty base value. Default: `1.75`

`dry_allowed_length`: Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2`

`dry_penalty_last_n`: How many tokens to scan for repetitions. Default: `-1`, where `0` is disabled and `-1` is context size.

`dry_sequence_breakers`: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']`

`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.

`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
Expand Down
Loading

0 comments on commit ff252ea

Please sign in to comment.