diff --git a/hl_gauss_pytorch/hl_gauss.py b/hl_gauss_pytorch/hl_gauss.py index 4e6725c..d1af4f6 100644 --- a/hl_gauss_pytorch/hl_gauss.py +++ b/hl_gauss_pytorch/hl_gauss.py @@ -73,16 +73,34 @@ def transform_to_probs(self, target): def transform_from_probs(self, probs): return (probs * self.centers).sum(dim = -1) + def forward_kl_div( + self, + pred, + target + ): + """ + allow for predicted value to be passed in, in which it will also be binned to probs and kl div used with target + """ + assert pred.shape == target.shape + + logprob_pred = self.transform_to_logprobs(pred) + logprob_target = self.transform_to_logprobs(target) + + return F.kl_div(logprob_pred, logprob_target, log_target = True, reduction = 'mean') + @torch.autocast('cuda', enabled = False) def forward( self, logits, target = None ): - assert logits.shape[-1] == self.num_bins - return_loss = exists(target) + if return_loss and logits.shape == target.shape: + return self.forward_kl_div(logits, target) + + assert logits.shape[-1] == self.num_bins + if return_loss: target_probs = self.transform_to_probs(target) return F.cross_entropy(logits, target_probs) diff --git a/pyproject.toml b/pyproject.toml index af2d294..b198491 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "hl-gauss-pytorch" -version = "0.1.4" +version = "0.1.5" description = "HL Gauss - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }