Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Average Marginal Ranking Loss Function #742

Merged
merged 22 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
0dc6f25
Docs: added documentation to marginal ranking loss function
jkauerl May 31, 2024
d771a44
Feat: created function signature
jkauerl May 31, 2024
6c152f5
Doc: added return comment
jkauerl May 31, 2024
d94e3ef
Feat: finished implementation and changed type of input
jkauerl May 31, 2024
9a29646
Test: created a test case
jkauerl May 31, 2024
30f8b05
Feat: added the correct exports
jkauerl May 31, 2024
7f5d3e4
Feat: now using option as a return type
jkauerl Jun 10, 2024
b0e552a
Test: added a macro for testing purposes as suggested
jkauerl Jun 10, 2024
2fe3de7
Test: macro tests took too long
jkauerl Jun 11, 2024
44a0813
Docs: added the algorithm to directory with the correct link
jkauerl Jun 19, 2024
090b50c
Feat: algorithm now returns Result and updated tests
jkauerl Jul 11, 2024
7dae972
Test: added a macro for testing purposes and more tests
jkauerl Jul 11, 2024
0bb2dc1
Docs: added more documentation to the file
jkauerl Jul 11, 2024
361d64f
Refcator: changed the name of the function
jkauerl Jul 18, 2024
7dd2fbf
Feat: added 1 more possible error message
jkauerl Jul 18, 2024
455fa8a
Test: added symmetric error handling
jkauerl Jul 18, 2024
94b9e81
Refactoring: added more rust-like syntaxis
jkauerl Jul 18, 2024
af0efde
Feat: fixed with the correct export
jkauerl Jul 18, 2024
b9f26cc
Feat: added suggested check_input function
jkauerl Jul 29, 2024
c7cb14b
Refactoring: changed the name to margin ranking loss
jkauerl Jul 30, 2024
90f1be4
docs: update dead link
vil02 Jul 30, 2024
1bad31e
Merge branch 'master' into feat/ml/loss/marginal_ranking
vil02 Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DIRECTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
* [Hinge Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/hinge_loss.rs)
* [Huber Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/huber_loss.rs)
* [Kl Divergence Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/kl_divergence_loss.rs)
* [Marginal Ranking Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/average_margin_ranking_loss.rs)
* [Mean Absolute Error Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/mean_absolute_error_loss.rs)
* [Mean Squared Error Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/mean_squared_error_loss.rs)
* [Negative Log Likelihood](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/negative_log_likelihood.rs)
Expand Down
113 changes: 113 additions & 0 deletions src/machine_learning/loss_function/average_margin_ranking_loss.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/// Marginal Ranking
///
/// The 'average_margin_ranking_loss' function calculates the Margin Ranking loss, which is a
/// loss function used for ranking problems in machine learning.
///
/// ## Formula
///
/// For a pair of values `x_first` and `x_second`, `margin`, and `y_true`,
/// the Margin Ranking loss is calculated as:
///
/// - loss = `max(0, -y_true * (x_first - x_second) + margin)`.
///
/// It returns the average loss by dividing the `total_loss` by total no. of
/// elements.
///
/// Pytorch implementation:
/// https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html
/// https://gombru.github.io/2019/04/03/ranking_loss/
/// https://vinija.ai/concepts/loss/#pairwise-ranking-loss
///

pub fn average_margin_ranking_loss(
x_first: &[f64],
x_second: &[f64],
margin: f64,
y_true: f64,
) -> Result<f64, MarginalRankingLossError> {
check_input(x_first, x_second, margin, y_true)?;

let total_loss: f64 = x_first
.iter()
.zip(x_second.iter())
.map(|(f, s)| (margin - y_true * (f - s)).max(0.0))
.sum();
Ok(total_loss / (x_first.len() as f64))
}

fn check_input(
x_first: &[f64],
x_second: &[f64],
margin: f64,
y_true: f64,
) -> Result<(), MarginalRankingLossError> {
if x_first.len() != x_second.len() {
return Err(MarginalRankingLossError::InputsHaveDifferentLength);
}
if x_first.is_empty() {
return Err(MarginalRankingLossError::EmptyInputs);
}
if margin < 0.0 {
return Err(MarginalRankingLossError::NegativeMargin);
}
if y_true != 1.0 && y_true != -1.0 {
return Err(MarginalRankingLossError::InvalidValues);
}

Ok(())
}

#[derive(Debug, PartialEq, Eq)]
pub enum MarginalRankingLossError {
InputsHaveDifferentLength,
EmptyInputs,
InvalidValues,
NegativeMargin,
}

#[cfg(test)]
mod tests {
use super::*;

macro_rules! test_with_wrong_inputs {
($($name:ident: $inputs:expr,)*) => {
$(
#[test]
fn $name() {
let (vec_a, vec_b, margin, y_true, expected) = $inputs;
assert_eq!(average_margin_ranking_loss(&vec_a, &vec_b, margin, y_true), expected);
assert_eq!(average_margin_ranking_loss(&vec_b, &vec_a, margin, y_true), expected);
}
)*
}
}

test_with_wrong_inputs! {
invalid_length0: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)),
invalid_length1: (vec![1.0, 2.0], vec![2.0, 3.0, 4.0], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)),
invalid_length2: (vec![], vec![1.0, 2.0, 3.0], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)),
invalid_length3: (vec![1.0, 2.0, 3.0], vec![], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)),
invalid_values: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], -1.0, 1.0, Err(MarginalRankingLossError::NegativeMargin)),
invalid_y_true: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], 1.0, 2.0, Err(MarginalRankingLossError::InvalidValues)),
empty_inputs: (vec![], vec![], 1.0, 1.0, Err(MarginalRankingLossError::EmptyInputs)),
}

macro_rules! test_average_margin_ranking_loss {
($($name:ident: $inputs:expr,)*) => {
$(
#[test]
fn $name() {
let (x_first, x_second, margin, y_true, expected) = $inputs;
assert_eq!(average_margin_ranking_loss(&x_first, &x_second, margin, y_true), Ok(expected));
}
)*
}
}

test_average_margin_ranking_loss! {
set_0: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], 1.0, -1.0, 0.0),
set_1: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], 1.0, 1.0, 2.0),
set_2: (vec![1.0, 2.0, 3.0], vec![1.0, 2.0, 3.0], 0.0, 1.0, 0.0),
set_3: (vec![4.0, 5.0, 6.0], vec![1.0, 2.0, 3.0], 1.0, -1.0, 4.0),
}
}
2 changes: 2 additions & 0 deletions src/machine_learning/loss_function/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod average_margin_ranking_loss;
mod hinge_loss;
mod huber_loss;
mod kl_divergence_loss;
mod mean_absolute_error_loss;
mod mean_squared_error_loss;
mod negative_log_likelihood;

pub use self::average_margin_ranking_loss::average_margin_ranking_loss;
pub use self::hinge_loss::hng_loss;
pub use self::huber_loss::huber_loss;
pub use self::kl_divergence_loss::kld_loss;
Expand Down
1 change: 1 addition & 0 deletions src/machine_learning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod optimization;
pub use self::cholesky::cholesky;
pub use self::k_means::k_means;
pub use self::linear_regression::linear_regression;
pub use self::loss_function::average_margin_ranking_loss;
pub use self::loss_function::hng_loss;
pub use self::loss_function::huber_loss;
pub use self::loss_function::kld_loss;
Expand Down