diff --git a/src/lib.rs b/src/lib.rs index e8b6524..f0bbee4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,6 +96,9 @@ use core::cmp; use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Neg, Not}; use core::option::Option; +#[cfg(feature = "core_hint_black_box")] +use core::hint::black_box; + /// The `Choice` struct represents a choice for use in conditional assignment. /// /// It is a wrapper around a `u8`, which should have the value either `1` (true) @@ -217,25 +220,23 @@ impl Not for Choice { /// code may break in a non-destructive way in the future, “constant-time” code /// is a continually moving target, and this is better than doing nothing. #[inline(never)] -fn black_box(input: u8) -> u8 { - debug_assert!((input == 0u8) | (input == 1u8)); - +fn black_box(input: T) -> T { unsafe { // Optimization barrier // - // Unsafe is ok, because: - // - &input is not NULL; - // - size of input is not zero; - // - u8 is neither Sync, nor Send; - // - u8 is Copy, so input is always live; - // - u8 type is always properly aligned. - core::ptr::read_volatile(&input as *const u8) + // SAFETY: + // - &input is not NULL because we own input; + // - input is Copy and always live; + // - input is always properly aligned. + core::ptr::read_volatile(&input) } } impl From for Choice { #[inline] fn from(input: u8) -> Choice { + debug_assert!((input == 0u8) | (input == 1u8)); + // Our goal is to prevent the compiler from inferring that the value held inside the // resulting `Choice` struct is really an `i1` instead of an `i8`. Choice(black_box(input)) @@ -986,3 +987,21 @@ impl ConstantTimeLess for cmp::Ordering { (a as u8).ct_lt(&(b as u8)) } } + +/// Wrapper type which implements an optimization barrier for all accesses. +#[derive(Clone, Copy, Debug)] +pub struct BlackBox(T); + +impl BlackBox { + /// Constructs a new instance of `BlackBox` which will wrap the specified value. + /// + /// All access to the inner value will be mediated by a `black_box` optimization barrier. + pub fn new(value: T) -> Self { + Self(value) + } + + /// Read the inner value, applying an optimization barrier on access. + pub fn get(self) -> T { + black_box(self.0) + } +} diff --git a/tests/mod.rs b/tests/mod.rs index 51f9fd8..888b9d0 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -423,3 +423,10 @@ fn less_than_ordering() { assert_eq!(cmp::Ordering::Greater.ct_lt(&cmp::Ordering::Less).unwrap_u8(), 0); assert_eq!(cmp::Ordering::Less.ct_lt(&cmp::Ordering::Greater).unwrap_u8(), 1); } + +#[test] +fn black_box_round_trip() { + let n = 42u64; + let black_box = BlackBox::new(n); + assert_eq!(n, black_box.get()); +}