diff --git a/src/util/macro_util.rs b/src/util/macro_util.rs index 9d3f6ed77c..fcc7ca2397 100644 --- a/src/util/macro_util.rs +++ b/src/util/macro_util.rs @@ -29,6 +29,20 @@ use crate::{ Immutable, IntoBytes, Ptr, TryFromBytes, Unalign, ValidityError, }; +/// Projects the type of the field at `Index` in `Self`. +/// +/// The `Index` parameter is any sort of handle that identifies the field; its +/// definition is the obligation of the implementer. +/// +/// # Safety +/// +/// Unsafe code may assume that this accurately reflects the definition of +/// `Self`. +pub unsafe trait Field { + /// The type of the field at `Index`. + type Type: ?Sized; +} + #[cfg_attr( zerocopy_diagnostic_on_unimplemented, diagnostic::on_unimplemented( diff --git a/zerocopy-derive/src/ext.rs b/zerocopy-derive/src/ext.rs index 9b446f53a9..d1be8cf373 100644 --- a/zerocopy-derive/src/ext.rs +++ b/zerocopy-derive/src/ext.rs @@ -8,7 +8,7 @@ use proc_macro2::{Span, TokenStream}; use quote::ToTokens; -use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Ident, Index, Type}; +use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Ident, Index, Type, Visibility}; pub(crate) trait DataExt { /// Extracts the names and types of all fields. For enums, extracts the names @@ -19,15 +19,15 @@ pub(crate) trait DataExt { /// makes sense because we don't care about where they live - we just care /// about transitive ownership. But for field names, we'd only use them when /// generating is_bit_valid, which cares about where they live. - fn fields(&self) -> Vec<(TokenStream, &Type)>; + fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)>; - fn variants(&self) -> Vec>; + fn variants(&self) -> Vec>; fn tag(&self) -> Option; } impl DataExt for Data { - fn fields(&self) -> Vec<(TokenStream, &Type)> { + fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> { match self { Data::Struct(strc) => strc.fields(), Data::Enum(enm) => enm.fields(), @@ -35,7 +35,7 @@ impl DataExt for Data { } } - fn variants(&self) -> Vec> { + fn variants(&self) -> Vec> { match self { Data::Struct(strc) => strc.variants(), Data::Enum(enm) => enm.variants(), @@ -53,11 +53,11 @@ impl DataExt for Data { } impl DataExt for DataStruct { - fn fields(&self) -> Vec<(TokenStream, &Type)> { + fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> { map_fields(&self.fields) } - fn variants(&self) -> Vec> { + fn variants(&self) -> Vec> { vec![self.fields()] } @@ -67,11 +67,11 @@ impl DataExt for DataStruct { } impl DataExt for DataEnum { - fn fields(&self) -> Vec<(TokenStream, &Type)> { + fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> { map_fields(self.variants.iter().flat_map(|var| &var.fields)) } - fn variants(&self) -> Vec> { + fn variants(&self) -> Vec> { self.variants.iter().map(|var| map_fields(&var.fields)).collect() } @@ -81,11 +81,11 @@ impl DataExt for DataEnum { } impl DataExt for DataUnion { - fn fields(&self) -> Vec<(TokenStream, &Type)> { + fn fields(&self) -> Vec<(&Visibility, TokenStream, &Type)> { map_fields(&self.fields.named) } - fn variants(&self) -> Vec> { + fn variants(&self) -> Vec> { vec![self.fields()] } @@ -96,12 +96,13 @@ impl DataExt for DataUnion { fn map_fields<'a>( fields: impl 'a + IntoIterator, -) -> Vec<(TokenStream, &'a Type)> { +) -> Vec<(&'a Visibility, TokenStream, &'a Type)> { fields .into_iter() .enumerate() .map(|(idx, f)| { ( + &f.vis, f.ident .as_ref() .map(ToTokens::to_token_stream) diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index 4b717c1494..ccef0496d2 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -144,8 +144,8 @@ fn derive_known_layout_inner(ast: &DeriveInput, _top_level: Trait) -> Result Result = + fields.iter().map(|(_vis, name, _ty)| field_index(name)).collect(); + + // Define the collection of type-level field handles. + let field_defs = field_indices.iter().zip(&fields).map(|(idx, (vis, _, _))| { + quote! { + #[allow(non_camel_case_types)] + #vis struct #idx; + } + }); + + let field_impls = field_indices.iter().zip(&fields).map(|(idx, (_, _, ty))| quote! { + // SAFETY: `#ty` is the type of `#ident`'s field at `#idx`. + unsafe impl #impl_generics ::zerocopy::util::macro_util::Field<#idx> for #ident #ty_generics + where + #predicates + { + type Type = #ty; + } + }); + + let trailing_field_index = field_index(trailing_field_name); + let leading_field_indices = + leading_fields.iter().map(|(_vis, name, _ty)| field_index(name)); + + let trailing_field_ty = quote! { + <#ident #ty_generics as + ::zerocopy::util::macro_util::Field<#trailing_field_index> + >::Type + }; + let methods = make_methods(&parse_quote! { <#trailing_field_ty as ::zerocopy::KnownLayout>::MaybeUninit }); quote! { + #(#field_defs)* + + #(#field_impls)* + // SAFETY: This has the same layout as the derive target type, - // except that it admits uninit bytes. This is ensured by using the - // same repr as the target type, and by using field types which have - // the same layout as the target type's fields, except that they - // admit uninit bytes. + // except that it admits uninit bytes. This is ensured by using + // the same repr as the target type, and by using field types + // which have the same layout as the target type's fields, + // except that they admit uninit bytes. We indirect through + // `Field` to ensure that occurrences of `Self` resolve to + // `#ty`, not `__ZerocopyKnownLayoutMaybeUninit` (see #2116). #repr #[doc(hidden)] #vis struct __ZerocopyKnownLayoutMaybeUninit<#params> ( - #(::zerocopy::util::macro_util::core_reexport::mem::MaybeUninit<#leading_fields_tys>,)* + #(::zerocopy::util::macro_util::core_reexport::mem::MaybeUninit< + <#ident #ty_generics as + ::zerocopy::util::macro_util::Field<#leading_field_indices> + >::Type + >,)* <#trailing_field_ty as ::zerocopy::KnownLayout>::MaybeUninit ) where @@ -295,9 +341,6 @@ fn derive_known_layout_inner(ast: &DeriveInput, _top_level: Trait) -> Result::MaybeUninit: ::zerocopy::KnownLayout, #predicates { #[allow(clippy::missing_inline_in_public_items)] @@ -494,8 +537,8 @@ fn derive_try_from_bytes_struct( ) -> Result { let extras = try_gen_trivial_is_bit_valid(ast, top_level).unwrap_or_else(|| { let fields = strct.fields(); - let field_names = fields.iter().map(|(name, _ty)| name); - let field_tys = fields.iter().map(|(_name, ty)| ty); + let field_names = fields.iter().map(|(_vis, name, _ty)| name); + let field_tys = fields.iter().map(|(_vis, _name, ty)| ty); quote!( // SAFETY: We use `is_bit_valid` to validate that each field is // bit-valid, and only return `true` if all of them are. The bit @@ -552,8 +595,8 @@ fn derive_try_from_bytes_union( FieldBounds::All(&[TraitBound::Slf, TraitBound::Other(Trait::Immutable)]); let extras = try_gen_trivial_is_bit_valid(ast, top_level).unwrap_or_else(|| { let fields = unn.fields(); - let field_names = fields.iter().map(|(name, _ty)| name); - let field_tys = fields.iter().map(|(_name, ty)| ty); + let field_names = fields.iter().map(|(_vis, name, _ty)| name); + let field_tys = fields.iter().map(|(_vis, _name, ty)| ty); quote!( // SAFETY: We use `is_bit_valid` to validate that any field is // bit-valid; we only return `true` if at least one of them is. The @@ -1419,12 +1462,13 @@ fn impl_block( parse_quote!(#ty: #(#traits)+*) } let field_type_bounds: Vec<_> = match (field_type_trait_bounds, &fields[..]) { - (FieldBounds::All(traits), _) => { - fields.iter().map(|(_name, ty)| bound_tt(ty, normalize_bounds(trt, traits))).collect() - } + (FieldBounds::All(traits), _) => fields + .iter() + .map(|(_vis, _name, ty)| bound_tt(ty, normalize_bounds(trt, traits))) + .collect(), (FieldBounds::None, _) | (FieldBounds::Trailing(..), []) => vec![], (FieldBounds::Trailing(traits), [.., last]) => { - vec![bound_tt(last.1, normalize_bounds(trt, traits))] + vec![bound_tt(last.2, normalize_bounds(trt, traits))] } (FieldBounds::Explicit(bounds), _) => bounds, }; @@ -1436,7 +1480,7 @@ fn impl_block( let padding_check_bound = padding_check.and_then(|check| (!fields.is_empty()).then_some(check)).map(|check| { let variant_types = variants.iter().map(|var| { - let types = var.iter().map(|(_name, ty)| ty); + let types = var.iter().map(|(_vis, _name, ty)| ty); quote!([#(#types),*]) }); let validator_context = check.validator_macro_context(); diff --git a/zerocopy-derive/src/output_tests.rs b/zerocopy-derive/src/output_tests.rs index 7226d600a9..1297997408 100644 --- a/zerocopy-derive/src/output_tests.rs +++ b/zerocopy-derive/src/output_tests.rs @@ -176,20 +176,47 @@ fn test_known_layout() { ::pointer_to_metadata(ptr as *mut _) } } + #[allow(non_camel_case_types)] + struct __Zerocopy_Field_0; + #[allow(non_camel_case_types)] + struct __Zerocopy_Field_1; + unsafe impl ::zerocopy::util::macro_util::Field<__Zerocopy_Field_0> + for Foo { + type Type = T; + } + unsafe impl ::zerocopy::util::macro_util::Field<__Zerocopy_Field_1> + for Foo { + type Type = U; + } #[repr(C)] #[repr(align(2))] #[doc(hidden)] struct __ZerocopyKnownLayoutMaybeUninit( - ::zerocopy::util::macro_util::core_reexport::mem::MaybeUninit, - ::MaybeUninit, + ::zerocopy::util::macro_util::core_reexport::mem::MaybeUninit< + as ::zerocopy::util::macro_util::Field<__Zerocopy_Field_0>>::Type, + >, + < as ::zerocopy::util::macro_util::Field< + __Zerocopy_Field_1, + >>::Type as ::zerocopy::KnownLayout>::MaybeUninit, ) where - U: ::zerocopy::KnownLayout; - unsafe impl ::zerocopy::KnownLayout - for __ZerocopyKnownLayoutMaybeUninit + as ::zerocopy::util::macro_util::Field< + __Zerocopy_Field_1, + >>::Type: ::zerocopy::KnownLayout; + unsafe impl ::zerocopy::KnownLayout for __ZerocopyKnownLayoutMaybeUninit where - U: ::zerocopy::KnownLayout, - ::MaybeUninit: ::zerocopy::KnownLayout, + as ::zerocopy::util::macro_util::Field< + __Zerocopy_Field_1, + >>::Type: ::zerocopy::KnownLayout, { #[allow(clippy::missing_inline_in_public_items)] fn only_derive_is_allowed_to_implement_this_trait() {} @@ -205,7 +232,12 @@ fn test_known_layout() { meta: Self::PointerMetadata, ) -> ::zerocopy::util::macro_util::core_reexport::ptr::NonNull { use ::zerocopy::KnownLayout; - let trailing = <::MaybeUninit as KnownLayout>::raw_from_ptr_len( + let trailing = << as ::zerocopy::util::macro_util::Field< + __Zerocopy_Field_1, + >>::Type as ::zerocopy::KnownLayout>::MaybeUninit as KnownLayout>::raw_from_ptr_len( bytes, meta, ); @@ -218,7 +250,12 @@ fn test_known_layout() { } #[inline(always)] fn pointer_to_metadata(ptr: *mut Self) -> Self::PointerMetadata { - <::MaybeUninit>::pointer_to_metadata( + << as ::zerocopy::util::macro_util::Field< + __Zerocopy_Field_1, + >>::Type as ::zerocopy::KnownLayout>::MaybeUninit>::pointer_to_metadata( ptr as *mut _, ) } diff --git a/zerocopy-derive/tests/struct_known_layout.rs b/zerocopy-derive/tests/struct_known_layout.rs index 1cfc584099..e34843f644 100644 --- a/zerocopy-derive/tests/struct_known_layout.rs +++ b/zerocopy-derive/tests/struct_known_layout.rs @@ -10,6 +10,8 @@ #![no_implicit_prelude] #![allow(warnings)] +extern crate rustversion; + include!("include.rs"); #[derive(imp::KnownLayout)] @@ -46,16 +48,56 @@ util_assert_impl_all!(TypeParams<'static, (), imp::IntoIter<()>>: imp::KnownLayo util_assert_impl_all!(TypeParams<'static, util::AU16, imp::IntoIter<()>>: imp::KnownLayout); // Deriving `KnownLayout` should work if the struct has bounded parameters. +// +// N.B. We limit this test to rustc >= 1.62, since earlier versions of rustc ICE +// when `KnownLayout` is derived on a `repr(C)` struct whose trailing field +// contains non-static lifetimes. +#[rustversion::since(1.62)] +const _: () = { + #[derive(imp::KnownLayout)] + #[repr(C)] + struct WithParams<'a: 'b, 'b: 'a, T: 'a + 'b + imp::KnownLayout, const N: usize>( + [T; N], + imp::PhantomData<&'a &'b ()>, + ) + where + 'a: 'b, + 'b: 'a, + T: 'a + 'b + imp::KnownLayout; + + util_assert_impl_all!(WithParams<'static, 'static, u8, 42>: imp::KnownLayout); +}; + +const _: () = { + // Similar to the previous test, except that the trailing field contains + // only static lifetimes. This is exercisable on all supported toolchains. + + #[derive(imp::KnownLayout)] + #[repr(C)] + struct WithParams<'a: 'b, 'b: 'a, T: 'a + 'b + imp::KnownLayout, const N: usize>( + &'a &'b [T; N], + imp::PhantomData<&'static ()>, + ) + where + 'a: 'b, + 'b: 'a, + T: 'a + 'b + imp::KnownLayout; + + util_assert_impl_all!(WithParams<'static, 'static, u8, 42>: imp::KnownLayout); +}; + +// Deriving `KnownLayout` should work if the struct references `Self`. See +// #2116. #[derive(imp::KnownLayout)] #[repr(C)] -struct WithParams<'a: 'b, 'b: 'a, T: 'a + 'b + imp::KnownLayout, const N: usize>( - [T; N], - imp::PhantomData<&'a &'b ()>, -) -where - 'a: 'b, - 'b: 'a, - T: 'a + 'b + imp::KnownLayout; - -util_assert_impl_all!(WithParams<'static, 'static, u8, 42>: imp::KnownLayout); +struct WithSelfReference { + leading: [u8; Self::N], + trailing: [[u8; Self::N]], +} + +impl WithSelfReference { + const N: usize = 42; +} + +util_assert_impl_all!(WithSelfReference: imp::KnownLayout);