Skip to content

Commit

Permalink
Feat/update to FSRS-5
Browse files Browse the repository at this point in the history
  • Loading branch information
miroim committed Sep 27, 2024
1 parent cba31ce commit a3bccc7
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 22 deletions.
53 changes: 40 additions & 13 deletions src/algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ impl FSRS {
self.set_due(&mut output_cards, Easy, Duration::days(easy_interval));
}
Learning | Relearning => {
self.next_stability(&mut output_cards, card.state);
self.next_difficulty(&mut output_cards);

self.set_scheduled_days(&mut output_cards, Again, 0);
self.set_due(&mut output_cards, Again, Duration::minutes(5));

Expand All @@ -58,7 +61,7 @@ impl FSRS {
self.set_due(&mut output_cards, Easy, Duration::days(easy_interval));
}
Review => {
self.next_stability(&mut output_cards);
self.next_stability(&mut output_cards, card.state);
self.next_difficulty(&mut output_cards);

let mut hard_interval = self.next_interval(&mut output_cards, Hard).unwrap();
Expand Down Expand Up @@ -111,17 +114,26 @@ impl FSRS {

fn init_difficulty_stability(&self, output_cards: &mut ScheduledCards) {
for rating in Rating::iter() {
let rating_int: i32 = *rating as i32;
let Some(card) = output_cards.cards.get_mut(rating) else {
continue;
};
card.difficulty = self.params.w[5]
.mul_add(-(rating_int as f32 - 3.0), self.params.w[4])
.clamp(1.0, 10.0);
card.stability = self.params.w[(rating_int - 1) as usize].max(0.1);
card.difficulty = self.init_difficulty(*rating);
card.stability = self.init_stability(*rating);
}
}

fn init_difficulty(&self, rating: Rating) -> f32 {
let rating_int: i32 = rating as i32;

(self.params.w[4] - f32::exp(self.params.w[5] * (rating_int as f32 - 1.0)) + 1.0)
.clamp(1.0, 10.0)
}

fn init_stability(&self, rating: Rating) -> f32 {
let rating_int: i32 = rating as i32;
self.params.w[(rating_int - 1) as usize].max(0.1)
}

#[allow(clippy::suboptimal_flops)]
fn next_interval(
&self,
Expand All @@ -136,12 +148,16 @@ impl FSRS {
Ok((new_interval.round() as i64).clamp(1, self.params.maximum_interval as i64))
}

fn next_stability(&self, output_cards: &mut ScheduledCards) {
for rating in Rating::iter() {
if rating == &Again {
self.next_forget_stability(output_cards);
} else {
self.next_recall_stability(output_cards, *rating);
fn next_stability(&self, output_cards: &mut ScheduledCards, state: State) {
if state == Learning || state == Relearning {
self.short_term_stability(output_cards)
} else if state == Review {
for rating in Rating::iter() {
if rating == &Again {
self.next_forget_stability(output_cards);
} else {
self.next_recall_stability(output_cards, *rating);
}
}
}
}
Expand Down Expand Up @@ -184,7 +200,7 @@ impl FSRS {
};
let next_difficulty =
self.params.w[6].mul_add(-(rating_int as f32 - 3.0), card.difficulty);
let mean_reversion = self.mean_reversion(self.params.w[4], next_difficulty);
let mean_reversion = self.mean_reversion(self.init_difficulty(Easy), next_difficulty);
card.difficulty = mean_reversion.clamp(1.0, 10.0);
output_cards.cards.insert(*rating, card);
}
Expand All @@ -193,4 +209,15 @@ impl FSRS {
fn mean_reversion(&self, initial: f32, current: f32) -> f32 {
self.params.w[7].mul_add(initial, (1.0 - self.params.w[7]) * current)
}

fn short_term_stability(&self, output_cards: &mut ScheduledCards) {
for rating in Rating::iter() {
let rating_int = *rating as i32;
let Some(card) = output_cards.cards.get_mut(rating) else {
continue;
};
card.stability *=
f32::exp(self.params.w[17] * (rating_int as f32 - 3.0 + self.params.w[18]));
}
}
}
6 changes: 3 additions & 3 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub struct ReviewLog {
pub struct Parameters {
pub request_retention: f32,
pub maximum_interval: i32,
pub w: [f32; 17],
pub w: [f32; 19],
}

pub const DECAY: f32 = -0.5;
Expand All @@ -81,8 +81,8 @@ impl Default for Parameters {
request_retention: 0.9,
maximum_interval: 36500,
w: [
0.5701, 1.4436, 4.1386, 10.9355, 5.1443, 1.2006, 0.8627, 0.0362, 1.629, 0.1342,
1.0166, 2.1174, 0.0839, 0.3204, 1.4676, 0.219, 2.8237,
0.4197, 1.1869, 3.0412, 15.2441, 7.1434, 0.6477, 1.0007, 0.0674, 1.6597, 0.1712,
1.1178, 2.0225, 0.0904, 0.3025, 2.1214, 0.2498, 2.9466, 0.4891, 0.6468,
],
}
}
Expand Down
13 changes: 7 additions & 6 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ static TEST_RATINGS: [Rating; 13] = [
];

#[cfg(test)]
static WEIGHTS: [f32; 17] = [
1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321, 2.1866,
0.0661, 0.336, 1.7766, 0.1693, 2.9244,
static WEIGHTS: [f32; 19] = [
0.4197, 1.1869, 3.0412, 15.2441, 7.1434, 0.6477, 1.0007, 0.0674, 1.6597, 0.1712, 1.1178,
2.0225, 0.0904, 0.3025, 2.1214, 0.2498, 2.9466, 0.4891, 0.6468,
];

#[cfg(test)]
Expand Down Expand Up @@ -55,7 +55,7 @@ fn test_interval() {
interval_history.push(card.scheduled_days);
now = card.due;
}
let expected = [0, 4, 15, 49, 143, 379, 0, 0, 15, 37, 85, 184, 376];
let expected = [0, 4, 17, 62, 198, 563, 0, 0, 9, 27, 74, 190, 457];
assert_eq!(interval_history, expected);
}

Expand Down Expand Up @@ -115,6 +115,7 @@ fn test_memo_state() {
},
);
card = scheduled_cards.select_card(Rating::Good);
assert!((card.stability - 43.05542).abs() < f32::EPSILON * 100f32);
assert_eq!(card.difficulty, 7.7609);
assert_eq!(card.stability, 71.4554);
// card.difficulty = 5.0976353
assert!((card.difficulty - 5.0976).abs() < f32::EPSILON * 1000f32)
}

0 comments on commit a3bccc7

Please sign in to comment.