Skip to content

Commit

Permalink
Add perfect derive semantics to #[zeroable(bound = "...")].
Browse files Browse the repository at this point in the history
  • Loading branch information
zachs18 committed Jul 17, 2023
1 parent 74aa251 commit e4d7b31
Showing 1 changed file with 51 additions and 15 deletions.
66 changes: 51 additions & 15 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ pub fn derive_anybitpattern(
///
/// ```rust
/// # use bytemuck_derive::{Zeroable};
///
/// #[derive(Copy, Clone, Zeroable)]
/// #[repr(C)]
/// struct Test {
Expand All @@ -91,12 +90,15 @@ pub fn derive_anybitpattern(
/// Custom bounds for the derived `Zeroable` impl can be given using the
/// `#[zeroable(bound = "")]` helper attribute.
///
/// Using this attribute additionally opts-in to "perfect derive" semantics,
/// where instead of adding bounds for each generic type parameter, bounds are
/// added for each field's type.
///
/// ## Examples
///
/// ```rust
/// # use bytemuck::Zeroable;
/// # use std::marker::PhantomData;
///
/// #[derive(Clone, Zeroable)]
/// #[zeroable(bound = "")]
/// struct AlwaysZeroable<T> {
Expand All @@ -109,7 +111,6 @@ pub fn derive_anybitpattern(
/// ```rust,compile_fail
/// # use bytemuck::Zeroable;
/// # use std::marker::PhantomData;
///
/// #[derive(Clone, Zeroable)]
/// #[zeroable(bound = "T: Copy")]
/// struct ZeroableWhenTIsCopy<T> {
Expand All @@ -119,17 +120,27 @@ pub fn derive_anybitpattern(
/// ZeroableWhenTIsCopy::<String>::zeroed();
/// ```
///
/// The restriction that all fields must be Zeroable is still applied, and an
/// error will be produced if the custom bounds do not guarantee this.
///
/// ```rust,compile_fail
/// # use bytemuck_derive::{Zeroable};
/// The restriction that all fields must be Zeroable is still applied, and this
/// is enforced using the mentioned "perfect derive" semantics.
///
/// ```rust
/// # use bytemuck::Zeroable;
/// #[derive(Clone, Zeroable)]
/// #[zeroable(bound = "")]
/// struct AlwaysZeroable<T> {
/// struct ZeroableWhenTIsZeroable<T> {
/// a: T,
/// }
/// ZeroableWhenTIsZeroable::<u32>::zeroed();
/// ```
///
/// ```rust,compile_fail
/// # use bytemuck::Zeroable;
/// # #[derive(Clone, Zeroable)]
/// # #[zeroable(bound = "")]
/// # struct ZeroableWhenTIsZeroable<T> {
/// # a: T,
/// # }
/// ZeroableWhenTIsZeroable::<String>::zeroed();
/// ```
#[proc_macro_derive(Zeroable, attributes(zeroable))]
pub fn derive_zeroable(
Expand Down Expand Up @@ -366,7 +377,7 @@ fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
/// Find `#[name(key = "value")]` helper attributes on the struct, and return
/// their `"value"`s parsed with `parser`.
///
/// Returns an error any attributes with the given name do not match the
/// Returns an error if any attributes with the given `name` do not match the
/// expected format. Returns `Ok([])` if no attributes with `name` are found.
fn find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>(
attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
Expand Down Expand Up @@ -441,22 +452,47 @@ fn derive_marker_trait_inner<Trait: Derivable>(

if !explicit_bounds.is_empty() {
// Explicit bounds were given.
// Only enforce explicitly given bounds (the asserts emitted should ensure
// soundness).
// Enforce explicitly given bounds, and emit "perfect derive" (i.e. add
// bounds for each field's type).
let explicit_bounds = explicit_bounds
.into_iter()
.flatten()
.collect::<Vec<syn::WherePredicate>>();

input.generics.make_where_clause().predicates.extend(explicit_bounds);
let predicates = &mut input.generics.make_where_clause().predicates;

predicates.extend(explicit_bounds);

let fields = match &input.data {
syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
syn::Data::Union(_) => {
return Err(syn::Error::new_spanned(
trait_,
&"perfect derive is not supported for unions",
));
}
syn::Data::Enum(_) => {
return Err(syn::Error::new_spanned(
trait_,
&"perfect derive is not supported for enums",
));
}
};

for field in fields {
let ty = field.ty;
predicates.push(syn::parse_quote!(
#ty: #trait_
));
}
} else {
// No explicit bounds were given.
// Enforce trait bound on all generic fields.
// Enforce trait bound on all type generics.
add_trait_marker(&mut input.generics, &trait_);
}
} else {
// This trait does not allow explicit bounds.
// Enforce trait bound on all generic fields.
// Enforce trait bound on all type generics.
add_trait_marker(&mut input.generics, &trait_);
}

Expand Down

0 comments on commit e4d7b31

Please sign in to comment.