From ca9079474f39353fca47040c869e3c0843f4479a Mon Sep 17 00:00:00 2001 From: conduition Date: Wed, 18 Oct 2023 01:14:13 +0000 Subject: [PATCH 1/3] cargo fmt --- tests/mod.rs | 104 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 87 insertions(+), 17 deletions(-) diff --git a/tests/mod.rs b/tests/mod.rs index f6b3982..a5130b6 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -290,16 +290,66 @@ fn test_ctoption() { )); // Test (in)equality - assert!(CtOption::new(1, Choice::from(0)).ct_eq(&CtOption::new(1, Choice::from(1))).unwrap_u8() == 0); - assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(1, Choice::from(0))).unwrap_u8() == 0); - assert!(CtOption::new(1, Choice::from(0)).ct_eq(&CtOption::new(2, Choice::from(1))).unwrap_u8() == 0); - assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(2, Choice::from(0))).unwrap_u8() == 0); - assert!(CtOption::new(1, Choice::from(0)).ct_eq(&CtOption::new(1, Choice::from(0))).unwrap_u8() == 1); - assert!(CtOption::new(1, Choice::from(0)).ct_eq(&CtOption::new(2, Choice::from(0))).unwrap_u8() == 1); - assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(2, Choice::from(1))).unwrap_u8() == 0); - assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(2, Choice::from(1))).unwrap_u8() == 0); - assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(1, Choice::from(1))).unwrap_u8() == 1); - assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(1, Choice::from(1))).unwrap_u8() == 1); + assert!( + CtOption::new(1, Choice::from(0)) + .ct_eq(&CtOption::new(1, Choice::from(1))) + .unwrap_u8() + == 0 + ); + assert!( + CtOption::new(1, Choice::from(1)) + .ct_eq(&CtOption::new(1, Choice::from(0))) + .unwrap_u8() + == 0 + ); + assert!( + CtOption::new(1, Choice::from(0)) + .ct_eq(&CtOption::new(2, Choice::from(1))) + .unwrap_u8() + == 0 + ); + assert!( + CtOption::new(1, Choice::from(1)) + .ct_eq(&CtOption::new(2, Choice::from(0))) + .unwrap_u8() + == 0 + ); + assert!( + CtOption::new(1, Choice::from(0)) + .ct_eq(&CtOption::new(1, Choice::from(0))) + .unwrap_u8() + == 1 + ); + assert!( + CtOption::new(1, Choice::from(0)) + .ct_eq(&CtOption::new(2, Choice::from(0))) + .unwrap_u8() + == 1 + ); + assert!( + CtOption::new(1, Choice::from(1)) + .ct_eq(&CtOption::new(2, Choice::from(1))) + .unwrap_u8() + == 0 + ); + assert!( + CtOption::new(1, Choice::from(1)) + .ct_eq(&CtOption::new(2, Choice::from(1))) + .unwrap_u8() + == 0 + ); + assert!( + CtOption::new(1, Choice::from(1)) + .ct_eq(&CtOption::new(1, Choice::from(1))) + .unwrap_u8() + == 1 + ); + assert!( + CtOption::new(1, Choice::from(1)) + .ct_eq(&CtOption::new(1, Choice::from(1))) + .unwrap_u8() + == 1 + ); } #[test] @@ -327,7 +377,7 @@ macro_rules! generate_greater_than_test { assert!(z.unwrap_u8() == 1); } } - } + }; } #[test] @@ -358,8 +408,18 @@ fn greater_than_u128() { #[test] fn greater_than_ordering() { - assert_eq!(cmp::Ordering::Less.ct_gt(&cmp::Ordering::Greater).unwrap_u8(), 0); - assert_eq!(cmp::Ordering::Greater.ct_gt(&cmp::Ordering::Less).unwrap_u8(), 1); + assert_eq!( + cmp::Ordering::Less + .ct_gt(&cmp::Ordering::Greater) + .unwrap_u8(), + 0 + ); + assert_eq!( + cmp::Ordering::Greater + .ct_gt(&cmp::Ordering::Less) + .unwrap_u8(), + 1 + ); } #[test] @@ -367,7 +427,7 @@ fn greater_than_ordering() { /// gives the correct result. (This fails using the bit-twiddling algorithm that /// go/crypto/subtle uses.) fn less_than_twos_compliment_minmax() { - let z = 1u32.ct_lt(&(2u32.pow(31)-1)); + let z = 1u32.ct_lt(&(2u32.pow(31) - 1)); assert!(z.unwrap_u8() == 1); } @@ -389,7 +449,7 @@ macro_rules! generate_less_than_test { assert!(z.unwrap_u8() == 1); } } - } + }; } #[test] @@ -420,6 +480,16 @@ fn less_than_u128() { #[test] fn less_than_ordering() { - assert_eq!(cmp::Ordering::Greater.ct_lt(&cmp::Ordering::Less).unwrap_u8(), 0); - assert_eq!(cmp::Ordering::Less.ct_lt(&cmp::Ordering::Greater).unwrap_u8(), 1); + assert_eq!( + cmp::Ordering::Greater + .ct_lt(&cmp::Ordering::Less) + .unwrap_u8(), + 0 + ); + assert_eq!( + cmp::Ordering::Less + .ct_lt(&cmp::Ordering::Greater) + .unwrap_u8(), + 1 + ); } From 09f5b702a23aa8d14a5b0a984f20b0922318b67e Mon Sep 17 00:00:00 2001 From: conduition Date: Wed, 18 Oct 2023 01:14:22 +0000 Subject: [PATCH 2/3] spellcheck comment in generate_unsigned_integer_greater --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 795eade..bf861ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -865,7 +865,7 @@ pub trait ConstantTimeGreater { macro_rules! generate_unsigned_integer_greater { ($t_u: ty, $bit_width: expr) => { impl ConstantTimeGreater for $t_u { - /// Returns Choice::from(1) iff x > y, and Choice::from(0) iff x <= y. + /// Returns Choice::from(1) if x > y, and Choice::from(0) if x <= y. /// /// # Note /// From 618100c56395268570f0d50b058495754ce3ebe6 Mon Sep 17 00:00:00 2001 From: conduition Date: Wed, 18 Oct 2023 02:11:36 +0000 Subject: [PATCH 3/3] implement lexicographical ordering for slices of arbitrary types --- src/lib.rs | 114 +++++++++++++++++++++++++++++++++++++++++++++++++++ tests/mod.rs | 101 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index bf861ce..4eeef9d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -914,6 +914,106 @@ impl ConstantTimeGreater for cmp::Ordering { } } +/// Compares two slices lexicographically in constant time whose +/// elements can be ordered in constant time. +/// +/// Returns: +/// +/// - `Ordering::Less` if `lhs < rhs` +/// - `Ordering::Equal` if `lhs == rhs` +/// - `Ordering::Greater` if `lhs > rhs` +#[inline] +fn ct_slice_lex_cmp(lhs: &[T], rhs: &[T]) -> cmp::Ordering +where + T: ConstantTimeEq + ConstantTimeGreater, +{ + let mut whole_slice_is_eq = Choice(1); + let mut whole_slice_is_gt = Choice(0); + + // Zip automatically stops iterating once one of the zipped + // iterators has been exhausted. + for (v1, v2) in lhs.iter().zip(rhs.iter()) { + // If the previous elements in the array were all equal, but `v1 > v2` in this + // position, then `lhs` is deemed to be greater than `rhs`. + // + // We want `whole_slice_is_gt` to remain true if we ever found this condition, + // but since we're aiming for constant-time, we cannot break the loop. + whole_slice_is_gt |= whole_slice_is_eq & v1.ct_gt(&v2); + + // Track whether all elements in the slices up to this point are equal. + whole_slice_is_eq &= v1.ct_eq(&v2); + } + + let l_len = lhs.len() as u64; + let r_len = rhs.len() as u64; + let lhs_is_longer = l_len.ct_gt(&r_len); + let rhs_is_longer = r_len.ct_gt(&l_len); + + // Fallback: lhs < rhs + let mut order = cmp::Ordering::Less; + + // both slices up to `min(l_len, r_len)` were equal. + order.conditional_assign(&cmp::Ordering::Equal, whole_slice_is_eq); + + // `rhs` is a prefix of `lhs`. `lhs` is lexicographically greater. + order.conditional_assign(&cmp::Ordering::Greater, whole_slice_is_eq & lhs_is_longer); + + // `lhs` is a prefix of `rhs`. `rhs` is lexicographically greater. + order.conditional_assign(&cmp::Ordering::Less, whole_slice_is_eq & rhs_is_longer); + + // `lhs` contains the earliest strictly-greater element. + order.conditional_assign(&cmp::Ordering::Greater, whole_slice_is_gt); + + order +} + +/// A slice is greater than another slice lexicographically if it contains the earliest +/// element which is strictly greater than its counterpart element at the same index. If +/// one slice is a prefix of the other, then the longer of the two slices is deemed to be +/// greater. We stop performing pairwise comparisons on slice elements after +/// `min(lhs.len(), rhs.len())` loop iterations. +/// +/// This blanket slice implementation requires the element type `T` to implement +/// [`ConstantTimeEq`](ConstantTimeEq) so that we can compare slices lexicographically. +/// +/// # Example +/// +/// It is easy to see in slices of integer-like types. +/// +/// ``` +/// use subtle::ConstantTimeGreater; +/// +/// assert_eq!((&[1u32, 2, 4]).ct_gt(&[1, 2, 3]).unwrap_u8(), 1); +/// assert_eq!((&[0u32, 1, 2, 4]).ct_gt(&[1, 2, 3, 4]).unwrap_u8(), 0); +/// assert_eq!((&[5u8, 5, 5, 5]).ct_gt(&[5, 5, 5]).unwrap_u8(), 1); +/// assert_eq!((&[5u8, 5, 5, 5]).ct_gt(&[5, 5, 5, 5]).unwrap_u8(), 0); +/// ``` +/// +/// Strings can also be ordered this way. +/// +/// ``` +/// use subtle::ConstantTimeGreater; +/// +/// assert_eq!("aab".as_bytes().ct_gt("aaa".as_bytes()).unwrap_u8(), 1); +/// assert_eq!("aaa".as_bytes().ct_gt("aa".as_bytes()).unwrap_u8(), 1); +/// assert_eq!("aaac".as_bytes().ct_gt("aaa".as_bytes()).unwrap_u8(), 1); +/// +/// assert_eq!("aaa".as_bytes().ct_gt("aaa".as_bytes()).unwrap_u8(), 0); +/// assert_eq!("aaa".as_bytes().ct_gt("aab".as_bytes()).unwrap_u8(), 0); +/// assert_eq!("aaac".as_bytes().ct_gt("aab".as_bytes()).unwrap_u8(), 0); +/// ``` +impl ConstantTimeGreater for [T] +where + T: ConstantTimeGreater + ConstantTimeEq, +{ + /// Returns `Choice::from(1)` if `self > other` using lexicographical sorting rules. + /// Returns `Choice::from(0)` if `self <= other`. + #[inline] + fn ct_gt(&self, other: &[T]) -> Choice { + ct_slice_lex_cmp(self, other).ct_eq(&cmp::Ordering::Greater) + } +} + /// A type which can be compared in some manner and be determined to be less /// than another of the same type. pub trait ConstantTimeLess: ConstantTimeEq + ConstantTimeGreater { @@ -974,3 +1074,17 @@ impl ConstantTimeLess for cmp::Ordering { (a as u8).ct_lt(&(b as u8)) } } + +/// Lexicographical comparison of slices of constant-time ordered elements. +/// +/// See the blanket implementation of [`ConstantTimeGreater`](ConstantTimeGreater) +/// on `[T]` for details. +impl ConstantTimeLess for [T] +where + T: ConstantTimeGreater + ConstantTimeEq, +{ + #[inline] + fn ct_lt(&self, other: &[T]) -> Choice { + ct_slice_lex_cmp(self, other).ct_eq(&cmp::Ordering::Less) + } +} diff --git a/tests/mod.rs b/tests/mod.rs index a5130b6..7b62761 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -493,3 +493,104 @@ fn less_than_ordering() { 1 ); } + +#[test] +fn slices_ordering() { + let mut buf = [0u8; 50]; + + for _ in 0..1000 { + OsRng.fill_bytes(&mut buf); + + let l_start = (OsRng.next_u64() as usize) % buf.len(); + let l_end = (OsRng.next_u64() as usize) % buf.len(); + + let r_start = (OsRng.next_u64() as usize) % buf.len(); + let r_end = (OsRng.next_u64() as usize) % buf.len(); + + let lhs = &buf[l_start..=l_end.max(l_start)]; + let rhs = &buf[r_start..=r_end.max(r_start)]; + + let is_lt = lhs.ct_lt(rhs); + let is_gt = lhs.ct_gt(rhs); + + println!( + "lhs={:?} rhs={:?} is_lt={:?} is_gt={:?}", + lhs, rhs, is_lt, is_gt + ); + + if lhs < rhs { + assert_eq!(lhs.ct_lt(rhs).unwrap_u8(), 1); + assert_eq!(lhs.ct_gt(rhs).unwrap_u8(), 0); + + // Comparison should be commutative + assert_eq!(rhs.ct_lt(lhs).unwrap_u8(), 0); + assert_eq!(rhs.ct_gt(lhs).unwrap_u8(), 1); + } else if lhs == rhs { + assert_eq!(lhs.ct_lt(rhs).unwrap_u8(), 0); + assert_eq!(lhs.ct_gt(rhs).unwrap_u8(), 0); + + // Comparison should be commutative + assert_eq!(rhs.ct_lt(lhs).unwrap_u8(), 0); + assert_eq!(rhs.ct_gt(lhs).unwrap_u8(), 0); + } else if lhs > rhs { + assert_eq!(lhs.ct_lt(rhs).unwrap_u8(), 0); + assert_eq!(lhs.ct_gt(rhs).unwrap_u8(), 1); + + // Comparison should be commutative + assert_eq!(rhs.ct_lt(lhs).unwrap_u8(), 1); + assert_eq!(rhs.ct_gt(lhs).unwrap_u8(), 0); + } + } + + // Strings + let tests = [ + ("www", "www", cmp::Ordering::Equal), + ("wwwa", "www", cmp::Ordering::Greater), + ("wwwb", "www", cmp::Ordering::Greater), + ("wwwc", "www", cmp::Ordering::Greater), + (" www", "www", cmp::Ordering::Less), + ("www", "", cmp::Ordering::Greater), + (".", "", cmp::Ordering::Greater), + ("", ".", cmp::Ordering::Less), + ("", "", cmp::Ordering::Equal), + ]; + + for (lhs, rhs, order) in tests { + let lhs_bytes = lhs.as_bytes(); + let rhs_bytes = rhs.as_bytes(); + + let msg = format!( + "'{}' {} '{}'", + lhs, + match order { + cmp::Ordering::Less => "<", + cmp::Ordering::Equal => "==", + cmp::Ordering::Greater => ">", + }, + rhs + ); + + if lhs < rhs { + assert_eq!(lhs_bytes.ct_lt(rhs_bytes).unwrap_u8(), 1, "{}", &msg); + assert_eq!(lhs_bytes.ct_gt(rhs_bytes).unwrap_u8(), 0, "{}", &msg); + + // Comparison should be commutative. + assert_eq!(rhs_bytes.ct_lt(lhs_bytes).unwrap_u8(), 0, "{}", &msg); + assert_eq!(rhs_bytes.ct_gt(lhs_bytes).unwrap_u8(), 1, "{}", &msg); + } else if lhs == rhs { + assert_eq!(lhs_bytes.ct_lt(rhs_bytes).unwrap_u8(), 0, "{}", &msg); + assert_eq!(lhs_bytes.ct_gt(rhs_bytes).unwrap_u8(), 0, "{}", &msg); + + // Comparison should be commutative. + assert_eq!(rhs_bytes.ct_lt(lhs_bytes).unwrap_u8(), 0, "{}", &msg); + assert_eq!(rhs_bytes.ct_gt(lhs_bytes).unwrap_u8(), 0, "{}", &msg); + } else if lhs > rhs { + assert_eq!(lhs_bytes.ct_lt(rhs_bytes).unwrap_u8(), 0, "{}", &msg); + assert_eq!(lhs_bytes.ct_gt(rhs_bytes).unwrap_u8(), 1, "{}", &msg); + + // Comparison should be commutative. + assert_eq!(rhs_bytes.ct_lt(lhs_bytes).unwrap_u8(), 1, "{}", &msg); + assert_eq!(rhs_bytes.ct_gt(lhs_bytes).unwrap_u8(), 0, "{}", &msg); + } + } +}