Skip to content

Commit

Permalink
Fix get_batch_loss_metrics comments (#413)
Browse files Browse the repository at this point in the history
## Summary

Remove misleading docstring in `get_batch_loss_metrics()` of
`test/utils.py`.

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: A10G
- [ ] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 authored Nov 28, 2024
1 parent 7e3683e commit 0137757
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def get_batch_loss_metrics(
ref_bias: torch.FloatTensor = None,
average_log_prob: bool = True,
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
"""Compute the loss metrics for the given batch of inputs for train or test."""

forward_output = self.concatenated_forward(
_input, weight, target, bias, average_log_prob
Expand Down

0 comments on commit 0137757

Please sign in to comment.