Skip to content

Commit

Permalink
able to remove reduction for ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 8, 2025
1 parent 40bdb75 commit 4d03c9f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions hl_gauss_pytorch/hl_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def forward_kl_div(
def forward(
self,
logits,
target = None
target = None,
reduction = 'mean'
):
return_loss = exists(target)

Expand All @@ -140,7 +141,7 @@ def forward(

if return_loss:
target_probs = self.transform_to_probs(target)
return F.cross_entropy(logits, target_probs)
return F.cross_entropy(logits, target_probs, reduction = reduction)

# if targets are not given, return the predicted value

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "hl-gauss-pytorch"
version = "0.1.12"
version = "0.1.14"
description = "HL Gauss - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 4d03c9f

Please sign in to comment.