Skip to content

Commit

Permalink
last commit for the day
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 7, 2025
1 parent cfa2f05 commit 78f3f94
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
22 changes: 20 additions & 2 deletions hl_gauss_pytorch/hl_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,34 @@ def transform_to_probs(self, target):
def transform_from_probs(self, probs):
return (probs * self.centers).sum(dim = -1)

def forward_kl_div(
self,
pred,
target
):
"""
allow for predicted value to be passed in, in which it will also be binned to probs and kl div used with target
"""
assert pred.shape == target.shape

logprob_pred = self.transform_to_logprobs(pred)
logprob_target = self.transform_to_logprobs(target)

return F.kl_div(logprob_pred, logprob_target, log_target = True, reduction = 'mean')

@torch.autocast('cuda', enabled = False)
def forward(
self,
logits,
target = None
):
assert logits.shape[-1] == self.num_bins

return_loss = exists(target)

if return_loss and logits.shape == target.shape:
return self.forward_kl_div(logits, target)

assert logits.shape[-1] == self.num_bins

if return_loss:
target_probs = self.transform_to_probs(target)
return F.cross_entropy(logits, target_probs)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "hl-gauss-pytorch"
version = "0.1.4"
version = "0.1.5"
description = "HL Gauss - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 78f3f94

Please sign in to comment.