diff --git a/hl_gauss_pytorch/hl_gauss.py b/hl_gauss_pytorch/hl_gauss.py index 7995a04..464a41e 100644 --- a/hl_gauss_pytorch/hl_gauss.py +++ b/hl_gauss_pytorch/hl_gauss.py @@ -126,7 +126,8 @@ def forward_kl_div( def forward( self, logits, - target = None + target = None, + reduction = 'mean' ): return_loss = exists(target) @@ -140,7 +141,7 @@ def forward( if return_loss: target_probs = self.transform_to_probs(target) - return F.cross_entropy(logits, target_probs) + return F.cross_entropy(logits, target_probs, reduction = reduction) # if targets are not given, return the predicted value diff --git a/pyproject.toml b/pyproject.toml index c2c547e..4e6a27b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "hl-gauss-pytorch" -version = "0.1.12" +version = "0.1.14" description = "HL Gauss - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }