From eb661337d50323f7f58ab16ac36022cd52f5076f Mon Sep 17 00:00:00 2001 From: jmccrosky <42929912+jmccrosky@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:47:21 +0300 Subject: [PATCH] Add support for 0 temperature --- model.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/model.py b/model.py index c698f8b601..4b498babff 100644 --- a/model.py +++ b/model.py @@ -315,15 +315,19 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): # forward the model to get the logits for the index in the sequence logits, _ = self(idx_cond) # pluck the logits at the final step and scale by desired temperature - logits = logits[:, -1, :] / temperature - # optionally crop the logits to only the top k options - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('Inf') - # apply softmax to convert logits to (normalized) probabilities - probs = F.softmax(logits, dim=-1) - # sample from the distribution - idx_next = torch.multinomial(probs, num_samples=1) + if temperature == 0: + logits = logits[:, -1, :] + idx_next = torch.argmax(logits, dim=-1, keepdim=True) + else: + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next), dim=1)