Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional zerocopy support #62

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,18 @@ version = "^1.0.0"
default-features = false
optional = true

[dependencies.zerocopy]
version = "0.8.9"
default-features = false
optional = true
[dependencies.zerocopy-derive]
version = "0.8.9"
default-features = false
optional = true

[features]
std = []
zerocopy = ["dep:zerocopy", "dep:zerocopy-derive"]

[workspace]
members = [
Expand Down
122 changes: 120 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
//!
//! - [`serde`](https://serde.rs/) implements `Serialize` and `Deserialize`
//! for `BitFlags<T>`.
//! - [`zerocopy`](https://github.com/google/zerocopy/) implements `Immutable`, `IntoBytes`,
//! `FromZeros`, `TryFromBytes`, and `KnownLayout` for all `BitFlags<T>` and `Unaligned` if the value type is unaligned.
//! - `std` implements `std::error::Error` for `FromBitsError`.
//!
//! ## `const fn`-compatible APIs
Expand Down Expand Up @@ -253,7 +255,7 @@ pub trait BitFlag: Copy + Clone + 'static + _internal::RawBitFlags {
///
/// All bits set in `val` must correspond to a value of the enum.
///
/// # Example
/// # Example
///
/// This is a convenience reexport of [`BitFlags::from_bits_unchecked`]. It can be
/// called with `MyFlag::from_bits_unchecked(bits)`, thus bypassing the need for
Expand Down Expand Up @@ -322,8 +324,8 @@ pub mod _internal {
}

use ::core::fmt;
use ::core::ops::{BitAnd, BitOr, BitXor, Not, Sub};
use ::core::hash::Hash;
use ::core::ops::{BitAnd, BitOr, BitXor, Not, Sub};

pub trait BitFlagNum:
Default
Expand Down Expand Up @@ -523,6 +525,14 @@ pub use crate::const_api::ConstToken;
/// `BitFlags` value where that isn't the case is only possible with
/// incorrect unsafe code.
#[derive(Copy, Clone)]
#[cfg_attr(
feature = "zerocopy",
derive(
zerocopy_derive::Immutable,
zerocopy_derive::KnownLayout,
zerocopy_derive::IntoBytes,
)
)]
#[repr(transparent)]
pub struct BitFlags<T, N = <T as _internal::RawBitFlags>::Numeric> {
val: N,
Expand Down Expand Up @@ -657,6 +667,33 @@ where
unsafe { BitFlags::from_bits_unchecked(bits & T::ALL_BITS) }
}

/// Validate if an underlying bitwise value can safely be converted to `BitFlags`.
/// Returns false if any invalid bits are set.
///
/// ```
/// # use enumflags2::{bitflags, BitFlags};
/// #[bitflags]
/// #[repr(u8)]
/// #[derive(Clone, Copy, PartialEq, Eq)]
/// enum MyFlag {
/// One = 0b0001,
/// Two = 0b0010,
/// Three = 0b1000,
/// }
///
/// assert_eq!(BitFlags::<MyFlag>::validate_bits(0b1011), true);
/// assert_eq!(BitFlags::<MyFlag>::validate_bits(0b0000), true);
/// assert_eq!(BitFlags::<MyFlag>::validate_bits(0b0100), false);
/// assert_eq!(BitFlags::<MyFlag>::validate_bits(0b1111), false);
/// ```
#[must_use]
#[inline(always)]
pub fn validate_bits(bits: T::Numeric) -> bool {
// SAFETY: We're truncating out all the invalid bits so it will
// only be different if there are invalid bits set.
(bits & T::ALL_BITS) == bits
}
Comment on lines +689 to +695
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I'm not sure if this makes sense as a public API.


/// Create a new BitFlags unsafely, without checking if the bits form
/// a valid bit pattern for the type.
///
Expand Down Expand Up @@ -1032,3 +1069,84 @@ mod impl_serde {
}
}
}

#[cfg(feature = "zerocopy")]
mod impl_zerocopy {
use super::{BitFlag, BitFlags};
use zerocopy::{FromZeros, Immutable, TryFromBytes, Unaligned};

// All zeros is always valid
unsafe impl<T> FromZeros for BitFlags<T>
where
T: BitFlag,
T::Numeric: Immutable,
T::Numeric: FromZeros,
{
// We are actually allowed to implement this trait. The scary name is just meant
// to convey that "this is dangerous and you'd better know what you're doing and
// be sure that you need to do this and can't just use the derives". (https://github.com/google/zerocopy/issues/287)
// We can not use the derives for this, because they dont support validation.
fn only_derive_is_allowed_to_implement_this_trait() {}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could use a comment linking to google/zerocopy#287, so that this doesn't scream "crimes! damn crimes!" at anyone who looks at this later.

}

// Mark all BitFlags as Unaligned if the underlying number type is unaligned
unsafe impl<T> Unaligned for BitFlags<T>
where
T: BitFlag,
T::Numeric: Unaligned,
{
// We are actually allowed to implement this trait. The scary name is just meant
// to convey that "this is dangerous and you'd better know what you're doing and
// be sure that you need to do this and can't just use the derives". (https://github.com/google/zerocopy/issues/287)
// We can not use the derives for this, because they dont support validation.
fn only_derive_is_allowed_to_implement_this_trait() {}
}

// Assert that there are no invalid bytes set
unsafe impl<T> TryFromBytes for BitFlags<T>
where
T: BitFlag,
T::Numeric: Immutable,
T::Numeric: TryFromBytes,
{
// We are actually allowed to implement this trait. The scary name is just meant
// to convey that "this is dangerous and you'd better know what you're doing and
// be sure that you need to do this and can't just use the derives". (https://github.com/google/zerocopy/issues/287)
// We can not use the derives for this, because they dont support validation.
fn only_derive_is_allowed_to_implement_this_trait()
where
Self: Sized,
{
}

#[inline]
fn is_bit_valid<
ZerocopyAliasing: zerocopy::pointer::invariant::Aliasing
+ zerocopy::pointer::invariant::AtLeast<zerocopy::pointer::invariant::Shared>,
>(
candidate: zerocopy::Maybe<'_, Self, ZerocopyAliasing>,
) -> bool {
// SAFETY:
// - The cast preserves address. The caller has promised that the
// cast results in an object of equal or lesser size, and so the
// cast returns a pointer which references a subset of the bytes
// of `p`.
// - The cast preserves provenance.
// - The caller has promised that the destination type has
// `UnsafeCell`s at the same byte ranges as the source type.
let candidate = unsafe { candidate.cast_unsized::<T::Numeric, _>(|p| p as *mut _) };

// SAFETY: The caller has promised that the referenced memory region
// will contain a valid `$repr`.
let my_candidate =
unsafe { candidate.assume_validity::<zerocopy::pointer::invariant::Valid>() };
{
// TODO: Currently this assumes that the candidate is aligned. We actually need to check this beforehand
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is something I would've caught. Do you have any links I could reference for example zerocopy implementations for similar types?

// Dereference the pointer to the candidate
let candidate =
my_candidate.read_unaligned::<zerocopy::pointer::BecauseImmutable>();
return BitFlags::<T>::validate_bits(candidate);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return BitFlags::<T>::validate_bits(candidate);
BitFlags::<T>::validate_bits(candidate)

}
}
}
}
11 changes: 10 additions & 1 deletion test_suite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ edition = "2018"

[dependencies.enumflags2]
path = "../"
features = ["serde"]
features = ["serde", "zerocopy"]

[dependencies.serde]
version = "1"
features = ["derive"]

[dependencies.zerocopy]
version = "0.8.9"
features = ["derive"]

[dev-dependencies]
trybuild = "1.0"
glob = "0.3"
Expand Down Expand Up @@ -65,3 +69,8 @@ edition = "2018"
name = "not_literal"
path = "tests/not_literal.rs"
edition = "2018"

[[test]]
name = "zerocopy"
path = "tests/zerocopy.rs"
edition = "2018"
32 changes: 32 additions & 0 deletions test_suite/tests/zerocopy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use enumflags2::{bitflags, BitFlags};
use zerocopy::{Immutable, IntoBytes, KnownLayout, TryFromBytes};

#[test]
fn zerocopy_compile() {
#[bitflags]
#[derive(Copy, Clone, Debug, KnownLayout)]
#[repr(u8)]
enum TestU8 {
A,
B,
C,
D,
}

#[bitflags]
#[derive(Copy, Clone, Debug, KnownLayout)]
#[repr(u16)]
enum TestU16 {
A,
B,
C,
D,
}

#[derive(Clone, Debug, Immutable, TryFromBytes, IntoBytes, KnownLayout)]
#[repr(packed)]
struct Other {
flags2: BitFlags<TestU8>,
flags: BitFlags<TestU16>,
}
}
Loading