Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement lexicographical ordering for slices of arbitrary types #116

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 115 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ pub trait ConstantTimeGreater {
macro_rules! generate_unsigned_integer_greater {
($t_u: ty, $bit_width: expr) => {
impl ConstantTimeGreater for $t_u {
/// Returns Choice::from(1) iff x > y, and Choice::from(0) iff x <= y.
/// Returns Choice::from(1) if x > y, and Choice::from(0) if x <= y.
///
/// # Note
///
Expand Down Expand Up @@ -914,6 +914,106 @@ impl ConstantTimeGreater for cmp::Ordering {
}
}

/// Compares two slices lexicographically in constant time whose
/// elements can be ordered in constant time.
///
/// Returns:
///
/// - `Ordering::Less` if `lhs < rhs`
/// - `Ordering::Equal` if `lhs == rhs`
/// - `Ordering::Greater` if `lhs > rhs`
#[inline]
fn ct_slice_lex_cmp<T>(lhs: &[T], rhs: &[T]) -> cmp::Ordering
where
T: ConstantTimeEq + ConstantTimeGreater,
{
let mut whole_slice_is_eq = Choice(1);
let mut whole_slice_is_gt = Choice(0);

// Zip automatically stops iterating once one of the zipped
// iterators has been exhausted.
for (v1, v2) in lhs.iter().zip(rhs.iter()) {
// If the previous elements in the array were all equal, but `v1 > v2` in this
// position, then `lhs` is deemed to be greater than `rhs`.
//
// We want `whole_slice_is_gt` to remain true if we ever found this condition,
// but since we're aiming for constant-time, we cannot break the loop.
whole_slice_is_gt |= whole_slice_is_eq & v1.ct_gt(&v2);

// Track whether all elements in the slices up to this point are equal.
whole_slice_is_eq &= v1.ct_eq(&v2);
}

let l_len = lhs.len() as u64;
let r_len = rhs.len() as u64;
let lhs_is_longer = l_len.ct_gt(&r_len);
let rhs_is_longer = r_len.ct_gt(&l_len);

// Fallback: lhs < rhs
let mut order = cmp::Ordering::Less;

// both slices up to `min(l_len, r_len)` were equal.
order.conditional_assign(&cmp::Ordering::Equal, whole_slice_is_eq);

// `rhs` is a prefix of `lhs`. `lhs` is lexicographically greater.
order.conditional_assign(&cmp::Ordering::Greater, whole_slice_is_eq & lhs_is_longer);

// `lhs` is a prefix of `rhs`. `rhs` is lexicographically greater.
order.conditional_assign(&cmp::Ordering::Less, whole_slice_is_eq & rhs_is_longer);

// `lhs` contains the earliest strictly-greater element.
order.conditional_assign(&cmp::Ordering::Greater, whole_slice_is_gt);

order
}

/// A slice is greater than another slice lexicographically if it contains the earliest
/// element which is strictly greater than its counterpart element at the same index. If
/// one slice is a prefix of the other, then the longer of the two slices is deemed to be
/// greater. We stop performing pairwise comparisons on slice elements after
/// `min(lhs.len(), rhs.len())` loop iterations.
///
/// This blanket slice implementation requires the element type `T` to implement
/// [`ConstantTimeEq`](ConstantTimeEq) so that we can compare slices lexicographically.
///
/// # Example
///
/// It is easy to see in slices of integer-like types.
///
/// ```
/// use subtle::ConstantTimeGreater;
///
/// assert_eq!((&[1u32, 2, 4]).ct_gt(&[1, 2, 3]).unwrap_u8(), 1);
/// assert_eq!((&[0u32, 1, 2, 4]).ct_gt(&[1, 2, 3, 4]).unwrap_u8(), 0);
/// assert_eq!((&[5u8, 5, 5, 5]).ct_gt(&[5, 5, 5]).unwrap_u8(), 1);
/// assert_eq!((&[5u8, 5, 5, 5]).ct_gt(&[5, 5, 5, 5]).unwrap_u8(), 0);
/// ```
///
/// Strings can also be ordered this way.
///
/// ```
/// use subtle::ConstantTimeGreater;
///
/// assert_eq!("aab".as_bytes().ct_gt("aaa".as_bytes()).unwrap_u8(), 1);
/// assert_eq!("aaa".as_bytes().ct_gt("aa".as_bytes()).unwrap_u8(), 1);
/// assert_eq!("aaac".as_bytes().ct_gt("aaa".as_bytes()).unwrap_u8(), 1);
///
/// assert_eq!("aaa".as_bytes().ct_gt("aaa".as_bytes()).unwrap_u8(), 0);
/// assert_eq!("aaa".as_bytes().ct_gt("aab".as_bytes()).unwrap_u8(), 0);
/// assert_eq!("aaac".as_bytes().ct_gt("aab".as_bytes()).unwrap_u8(), 0);
/// ```
impl<T> ConstantTimeGreater for [T]
where
T: ConstantTimeGreater + ConstantTimeEq,
{
/// Returns `Choice::from(1)` if `self > other` using lexicographical sorting rules.
/// Returns `Choice::from(0)` if `self <= other`.
#[inline]
fn ct_gt(&self, other: &[T]) -> Choice {
ct_slice_lex_cmp(self, other).ct_eq(&cmp::Ordering::Greater)
}
}

/// A type which can be compared in some manner and be determined to be less
/// than another of the same type.
pub trait ConstantTimeLess: ConstantTimeEq + ConstantTimeGreater {
Expand Down Expand Up @@ -974,3 +1074,17 @@ impl ConstantTimeLess for cmp::Ordering {
(a as u8).ct_lt(&(b as u8))
}
}

/// Lexicographical comparison of slices of constant-time ordered elements.
///
/// See the blanket implementation of [`ConstantTimeGreater`](ConstantTimeGreater)
/// on `[T]` for details.
impl<T> ConstantTimeLess for [T]
where
T: ConstantTimeGreater + ConstantTimeEq,
{
#[inline]
fn ct_lt(&self, other: &[T]) -> Choice {
ct_slice_lex_cmp(self, other).ct_eq(&cmp::Ordering::Less)
}
}
205 changes: 188 additions & 17 deletions tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,66 @@ fn test_ctoption() {
));

// Test (in)equality
assert!(CtOption::new(1, Choice::from(0)).ct_eq(&CtOption::new(1, Choice::from(1))).unwrap_u8() == 0);
assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(1, Choice::from(0))).unwrap_u8() == 0);
assert!(CtOption::new(1, Choice::from(0)).ct_eq(&CtOption::new(2, Choice::from(1))).unwrap_u8() == 0);
assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(2, Choice::from(0))).unwrap_u8() == 0);
assert!(CtOption::new(1, Choice::from(0)).ct_eq(&CtOption::new(1, Choice::from(0))).unwrap_u8() == 1);
assert!(CtOption::new(1, Choice::from(0)).ct_eq(&CtOption::new(2, Choice::from(0))).unwrap_u8() == 1);
assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(2, Choice::from(1))).unwrap_u8() == 0);
assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(2, Choice::from(1))).unwrap_u8() == 0);
assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(1, Choice::from(1))).unwrap_u8() == 1);
assert!(CtOption::new(1, Choice::from(1)).ct_eq(&CtOption::new(1, Choice::from(1))).unwrap_u8() == 1);
assert!(
CtOption::new(1, Choice::from(0))
.ct_eq(&CtOption::new(1, Choice::from(1)))
.unwrap_u8()
== 0
);
assert!(
CtOption::new(1, Choice::from(1))
.ct_eq(&CtOption::new(1, Choice::from(0)))
.unwrap_u8()
== 0
);
assert!(
CtOption::new(1, Choice::from(0))
.ct_eq(&CtOption::new(2, Choice::from(1)))
.unwrap_u8()
== 0
);
assert!(
CtOption::new(1, Choice::from(1))
.ct_eq(&CtOption::new(2, Choice::from(0)))
.unwrap_u8()
== 0
);
assert!(
CtOption::new(1, Choice::from(0))
.ct_eq(&CtOption::new(1, Choice::from(0)))
.unwrap_u8()
== 1
);
assert!(
CtOption::new(1, Choice::from(0))
.ct_eq(&CtOption::new(2, Choice::from(0)))
.unwrap_u8()
== 1
);
assert!(
CtOption::new(1, Choice::from(1))
.ct_eq(&CtOption::new(2, Choice::from(1)))
.unwrap_u8()
== 0
);
assert!(
CtOption::new(1, Choice::from(1))
.ct_eq(&CtOption::new(2, Choice::from(1)))
.unwrap_u8()
== 0
);
assert!(
CtOption::new(1, Choice::from(1))
.ct_eq(&CtOption::new(1, Choice::from(1)))
.unwrap_u8()
== 1
);
assert!(
CtOption::new(1, Choice::from(1))
.ct_eq(&CtOption::new(1, Choice::from(1)))
.unwrap_u8()
== 1
);
}

#[test]
Expand Down Expand Up @@ -327,7 +377,7 @@ macro_rules! generate_greater_than_test {
assert!(z.unwrap_u8() == 1);
}
}
}
};
}

#[test]
Expand Down Expand Up @@ -358,16 +408,26 @@ fn greater_than_u128() {

#[test]
fn greater_than_ordering() {
assert_eq!(cmp::Ordering::Less.ct_gt(&cmp::Ordering::Greater).unwrap_u8(), 0);
assert_eq!(cmp::Ordering::Greater.ct_gt(&cmp::Ordering::Less).unwrap_u8(), 1);
assert_eq!(
cmp::Ordering::Less
.ct_gt(&cmp::Ordering::Greater)
.unwrap_u8(),
0
);
assert_eq!(
cmp::Ordering::Greater
.ct_gt(&cmp::Ordering::Less)
.unwrap_u8(),
1
);
}

#[test]
/// Test that the two's compliment min and max, i.e. 0000...0001 < 1111...1110,
/// gives the correct result. (This fails using the bit-twiddling algorithm that
/// go/crypto/subtle uses.)
fn less_than_twos_compliment_minmax() {
let z = 1u32.ct_lt(&(2u32.pow(31)-1));
let z = 1u32.ct_lt(&(2u32.pow(31) - 1));

assert!(z.unwrap_u8() == 1);
}
Expand All @@ -389,7 +449,7 @@ macro_rules! generate_less_than_test {
assert!(z.unwrap_u8() == 1);
}
}
}
};
}

#[test]
Expand Down Expand Up @@ -420,6 +480,117 @@ fn less_than_u128() {

#[test]
fn less_than_ordering() {
assert_eq!(cmp::Ordering::Greater.ct_lt(&cmp::Ordering::Less).unwrap_u8(), 0);
assert_eq!(cmp::Ordering::Less.ct_lt(&cmp::Ordering::Greater).unwrap_u8(), 1);
assert_eq!(
cmp::Ordering::Greater
.ct_lt(&cmp::Ordering::Less)
.unwrap_u8(),
0
);
assert_eq!(
cmp::Ordering::Less
.ct_lt(&cmp::Ordering::Greater)
.unwrap_u8(),
1
);
}

#[test]
fn slices_ordering() {
let mut buf = [0u8; 50];

for _ in 0..1000 {
OsRng.fill_bytes(&mut buf);

let l_start = (OsRng.next_u64() as usize) % buf.len();
let l_end = (OsRng.next_u64() as usize) % buf.len();

let r_start = (OsRng.next_u64() as usize) % buf.len();
let r_end = (OsRng.next_u64() as usize) % buf.len();

let lhs = &buf[l_start..=l_end.max(l_start)];
let rhs = &buf[r_start..=r_end.max(r_start)];

let is_lt = lhs.ct_lt(rhs);
let is_gt = lhs.ct_gt(rhs);

println!(
"lhs={:?} rhs={:?} is_lt={:?} is_gt={:?}",
lhs, rhs, is_lt, is_gt
);

if lhs < rhs {
assert_eq!(lhs.ct_lt(rhs).unwrap_u8(), 1);
assert_eq!(lhs.ct_gt(rhs).unwrap_u8(), 0);

// Comparison should be commutative
assert_eq!(rhs.ct_lt(lhs).unwrap_u8(), 0);
assert_eq!(rhs.ct_gt(lhs).unwrap_u8(), 1);
} else if lhs == rhs {
assert_eq!(lhs.ct_lt(rhs).unwrap_u8(), 0);
assert_eq!(lhs.ct_gt(rhs).unwrap_u8(), 0);

// Comparison should be commutative
assert_eq!(rhs.ct_lt(lhs).unwrap_u8(), 0);
assert_eq!(rhs.ct_gt(lhs).unwrap_u8(), 0);
} else if lhs > rhs {
assert_eq!(lhs.ct_lt(rhs).unwrap_u8(), 0);
assert_eq!(lhs.ct_gt(rhs).unwrap_u8(), 1);

// Comparison should be commutative
assert_eq!(rhs.ct_lt(lhs).unwrap_u8(), 1);
assert_eq!(rhs.ct_gt(lhs).unwrap_u8(), 0);
}
}

// Strings
let tests = [
("www", "www", cmp::Ordering::Equal),
("wwwa", "www", cmp::Ordering::Greater),
("wwwb", "www", cmp::Ordering::Greater),
("wwwc", "www", cmp::Ordering::Greater),
(" www", "www", cmp::Ordering::Less),
("www", "", cmp::Ordering::Greater),
(".", "", cmp::Ordering::Greater),
("", ".", cmp::Ordering::Less),
("", "", cmp::Ordering::Equal),
];

for (lhs, rhs, order) in tests {
let lhs_bytes = lhs.as_bytes();
let rhs_bytes = rhs.as_bytes();

let msg = format!(
"'{}' {} '{}'",
lhs,
match order {
cmp::Ordering::Less => "<",
cmp::Ordering::Equal => "==",
cmp::Ordering::Greater => ">",
},
rhs
);

if lhs < rhs {
assert_eq!(lhs_bytes.ct_lt(rhs_bytes).unwrap_u8(), 1, "{}", &msg);
assert_eq!(lhs_bytes.ct_gt(rhs_bytes).unwrap_u8(), 0, "{}", &msg);

// Comparison should be commutative.
assert_eq!(rhs_bytes.ct_lt(lhs_bytes).unwrap_u8(), 0, "{}", &msg);
assert_eq!(rhs_bytes.ct_gt(lhs_bytes).unwrap_u8(), 1, "{}", &msg);
} else if lhs == rhs {
assert_eq!(lhs_bytes.ct_lt(rhs_bytes).unwrap_u8(), 0, "{}", &msg);
assert_eq!(lhs_bytes.ct_gt(rhs_bytes).unwrap_u8(), 0, "{}", &msg);

// Comparison should be commutative.
assert_eq!(rhs_bytes.ct_lt(lhs_bytes).unwrap_u8(), 0, "{}", &msg);
assert_eq!(rhs_bytes.ct_gt(lhs_bytes).unwrap_u8(), 0, "{}", &msg);
} else if lhs > rhs {
assert_eq!(lhs_bytes.ct_lt(rhs_bytes).unwrap_u8(), 0, "{}", &msg);
assert_eq!(lhs_bytes.ct_gt(rhs_bytes).unwrap_u8(), 1, "{}", &msg);

// Comparison should be commutative.
assert_eq!(rhs_bytes.ct_lt(lhs_bytes).unwrap_u8(), 1, "{}", &msg);
assert_eq!(rhs_bytes.ct_gt(lhs_bytes).unwrap_u8(), 0, "{}", &msg);
}
}
}