diff --git a/exllamav2/generator/base.py b/exllamav2/generator/base.py index 3a8989f5..6904e8f8 100644 --- a/exllamav2/generator/base.py +++ b/exllamav2/generator/base.py @@ -52,7 +52,8 @@ def generate_simple(self, prompt: str or list, encode_special_tokens = False, decode_special_tokens = False, loras = None, - stop_token = -1): + stop_token = -1, + return_lowest_perplexity = False): # Default stop token @@ -68,14 +69,20 @@ def generate_simple(self, prompt: str or list, # Tokenize input and produce padding mask if needed - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if return_lowest_perplexity: + batch_size = self.cache.batch_size + else: + batch_size = 1 if isinstance(prompt, str) else len(prompt) + assert batch_size > 1 or not return_lowest_perplexity, "When return_lowest_perplexity is set, batch_size should be greater than 1" + assert isinstance(prompt, str) or not return_lowest_perplexity, "When return_lowest_perplexity is set, the prompt should be a single string" ids, position_offsets = self.tokenizer.encode(prompt, encode_special_tokens = encode_special_tokens, return_offsets = True) - if batch_size == 1: position_offsets = None + if batch_size == 1 or return_lowest_perplexity: position_offsets = None + if return_lowest_perplexity: ids = ids.repeat(batch_size, 1) overflow = ids.shape[-1] + num_tokens - self.model.config.max_seq_len if overflow > 0: ids = ids[:, overflow:] - mask = self.tokenizer.padding_mask(ids) if batch_size > 1 else None + mask = self.tokenizer.padding_mask(ids) if batch_size > 1 and not return_lowest_perplexity else None # Prepare for healing @@ -102,16 +109,21 @@ def generate_simple(self, prompt: str or list, # Generate tokens batch_eos = [False] * batch_size + if return_lowest_perplexity: + logprob_sum = torch.zeros(batch_size) + sequence_length = torch.zeros(batch_size) for i in range(num_tokens): logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, input_mask = mask, loras = loras, position_offsets = position_offsets).float().cpu() - token, _, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token) + token, output_probs, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token) eos = False if stop_token is not None: for b in range(batch_size): if token[b, 0].item() == stop_token: + if return_lowest_perplexity and not batch_eos[b]: + sequence_length[b] = i batch_eos[b] = True if all(batch_eos): eos = True if batch_eos[b]: @@ -120,14 +132,23 @@ def generate_simple(self, prompt: str or list, self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1) gen_settings.feed_filters(token) + if return_lowest_perplexity: + logprob_sum = torch.add(logprob_sum, + torch.log(torch.squeeze(output_probs, -1))) + unhealed_token = None if eos: break # Decode - text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens = decode_special_tokens) + if return_lowest_perplexity: + mean_log_prob = torch.div(logprob_sum, sequence_length) + lowest_perplexed_index = torch.argmin(mean_log_prob).item() + text = self.tokenizer.decode(self.sequence_ids[lowest_perplexed_index], decode_special_tokens = decode_special_tokens) + else: + text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens = decode_special_tokens) + text = text[0] if isinstance(prompt, str) else text - if isinstance(prompt, str): return text[0] return text