From 96f8289752318bea1040a04e082c92b65ece6950 Mon Sep 17 00:00:00 2001 From: Moulberry Date: Thu, 21 Nov 2024 02:17:41 +0800 Subject: [PATCH] Implement #[derive(Query)] for enum types --- macros/src/lib.rs | 7 +- macros/src/query.rs | 294 ++++++++++++++++++++++++++- tests/derive.rs | 3 +- tests/derive/enum.rs | 9 - tests/derive/enum.stderr | 11 - tests/derive/enum_query.rs | 24 +++ tests/derive/enum_unsupported.rs | 6 + tests/derive/enum_unsupported.stderr | 5 + tests/derive/nested_query.rs | 6 + tests/derive/no_prelude.rs | 9 + tests/derive/union.stderr | 2 +- tests/derive/wrong_lifetime.rs | 6 + tests/derive/wrong_lifetime.stderr | 11 + tests/tests.rs | 61 ++++++ 14 files changed, 419 insertions(+), 35 deletions(-) delete mode 100644 tests/derive/enum.rs delete mode 100644 tests/derive/enum.stderr create mode 100644 tests/derive/enum_query.rs create mode 100644 tests/derive/enum_unsupported.rs create mode 100644 tests/derive/enum_unsupported.stderr diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 4f6c5f12..634054b5 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -63,12 +63,15 @@ pub fn derive_dynamic_bundle_clone(input: TokenStream) -> TokenStream { .into() } -/// Implement `Query` for a struct +/// Implement `Query` for a struct or enum. /// -/// Queries structs can be passed to the type parameter of `World::query`. They must have exactly +/// Queries can be passed to the type parameter of `World::query`. They must have exactly /// one lifetime parameter, and all of their fields must be queries (e.g. references) using that /// lifetime. /// +/// For enum queries, the result will always be the first variant that matches the entity. +/// Unit variants and variants without any fields will always match an entity. +/// /// # Example /// ``` /// # use hecs::*; diff --git a/macros/src/query.rs b/macros/src/query.rs index ddc8e72a..ee0a65a9 100644 --- a/macros/src/query.rs +++ b/macros/src/query.rs @@ -1,20 +1,22 @@ use proc_macro2::Span; use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::{DeriveInput, Error, Ident, Lifetime, Result, Type}; +use syn::{DataEnum, DataStruct, DeriveInput, Error, Ident, Lifetime, Result, Type, Visibility}; pub fn derive(input: DeriveInput) -> Result { let ident = input.ident; - let vis = input.vis; - let data = match input.data { - syn::Data::Struct(s) => s, + + match input.data { + syn::Data::Struct(_) | syn::Data::Enum(_) => {} _ => { return Err(Error::new_spanned( ident, - "derive(Query) may only be applied to structs", + "derive(Query) may only be applied to structs and enums", )) } - }; + } + + let vis = input.vis; let lifetime = input .generics .lifetimes() @@ -36,6 +38,19 @@ pub fn derive(input: DeriveInput) -> Result { )); } + match input.data { + syn::Data::Struct(data_struct) => derive_struct(ident, vis, data_struct, lifetime), + syn::Data::Enum(data_enum) => derive_enum(ident, vis, data_enum, lifetime), + _ => unreachable!(), + } +} + +fn derive_struct( + ident: Ident, + vis: Visibility, + data: DataStruct, + lifetime: Lifetime, +) -> Result { let (fields, queries) = match data.fields { syn::Fields::Named(ref fields) => fields .named @@ -55,7 +70,7 @@ pub fn derive(input: DeriveInput) -> Result { ( syn::Member::Unnamed(syn::Index { index: i as u32, - span: Span::call_site(), + span: Span::mixed_site(), }), query_ty(&lifetime, &f.ty), ) @@ -67,7 +82,7 @@ pub fn derive(input: DeriveInput) -> Result { .iter() .map(|ty| quote! { <#ty as ::hecs::Query>::Fetch }) .collect::>(); - let fetch_ident = Ident::new(&format!("{}Fetch", ident), Span::call_site()); + let fetch_ident = Ident::new(&format!("{}Fetch", ident), Span::mixed_site()); let fetch = match data.fields { syn::Fields::Named(_) => quote! { #vis struct #fetch_ident { @@ -83,7 +98,7 @@ pub fn derive(input: DeriveInput) -> Result { #vis struct #fetch_ident; }, }; - let state_ident = Ident::new(&format!("{}State", ident), Span::call_site()); + let state_ident = Ident::new(&format!("{}State", ident), Span::mixed_site()); let state = match data.fields { syn::Fields::Named(_) => quote! { #[derive(Clone, Copy)] @@ -108,7 +123,7 @@ pub fn derive(input: DeriveInput) -> Result { .map(|x| match x { syn::Member::Named(ref ident) => ident.clone(), syn::Member::Unnamed(ref index) => { - Ident::new(&format!("field_{}", index.index), Span::call_site()) + Ident::new(&format!("field_{}", index.index), Span::mixed_site()) } }) .collect::>(); @@ -193,6 +208,263 @@ pub fn derive(input: DeriveInput) -> Result { }) } +fn derive_enum( + enum_ident: Ident, + vis: Visibility, + data: DataEnum, + lifetime: Lifetime, +) -> Result { + let mut dangling_constructor = None; + let mut fetch_variants = TokenStream2::new(); + let mut state_variants = TokenStream2::new(); + let mut query_get_variants = TokenStream2::new(); + let mut fetch_access_variants = TokenStream2::new(); + let mut fetch_borrow_variants = TokenStream2::new(); + let mut fetch_prepare_variants = TokenStream2::new(); + let mut fetch_execute_variants = TokenStream2::new(); + let mut fetch_release_variants = TokenStream2::new(); + let mut fetch_for_each_borrow = TokenStream2::new(); + + for variant in &data.variants { + let (fields, queries) = match variant.fields { + syn::Fields::Named(ref fields) => fields + .named + .iter() + .map(|f| { + ( + syn::Member::Named(f.ident.clone().unwrap()), + query_ty(&lifetime, &f.ty), + ) + }) + .unzip(), + syn::Fields::Unnamed(ref fields) => fields + .unnamed + .iter() + .enumerate() + .map(|(i, f)| { + ( + syn::Member::Unnamed(syn::Index { + index: i as u32, + span: Span::mixed_site(), + }), + query_ty(&lifetime, &f.ty), + ) + }) + .unzip(), + syn::Fields::Unit => (Vec::new(), Vec::new()), + }; + + let ident = variant.ident.clone(); + + if ident == "__HecsDanglingFetch__" { + return Err(Error::new_spanned( + ident, + "derive(Query) reserves this identifier for internal use", + )); + } + + let named_fields = fields + .iter() + .map(|x| match x { + syn::Member::Named(ref ident) => ident.clone(), + syn::Member::Unnamed(ref index) => { + Ident::new(&format!("field_{}", index.index), Span::mixed_site()) + } + }) + .collect::>(); + + let fetches = queries + .iter() + .map(|ty| quote! { <#ty as ::hecs::Query>::Fetch }) + .collect::>(); + + if dangling_constructor.is_none() && fields.is_empty() { + dangling_constructor = Some(quote! { + Self::#ident {} + }); + } + + fetch_variants.extend(quote! { + #ident { + #( + #named_fields: #fetches, + )* + }, + }); + + state_variants.extend(quote! { + #ident { + #( + #named_fields: <#fetches as ::hecs::Fetch>::State, + )* + }, + }); + + query_get_variants.extend(quote! { + Self::Fetch::#ident { #(#named_fields),* } => { + #( + let #named_fields: <#queries as ::hecs::Query>::Item<'q> = <#queries as ::hecs::Query>::get(#named_fields, n); + )* + Self::Item::#ident { #( #fields: #named_fields,)* } + }, + }); + + fetch_access_variants.extend(quote! { + 'block: { + let mut access = ::hecs::Access::Iterate; + #( + if let ::core::option::Option::Some(new_access) = #fetches::access(archetype) { + access = ::core::cmp::max(access, new_access); + } else { + break 'block; + } + )* + return ::core::option::Option::Some(access) + } + }); + + fetch_borrow_variants.extend(quote! { + Self::State::#ident { #(#named_fields),* } => { + #( + #fetches::borrow(archetype, #named_fields); + )* + }, + }); + + fetch_prepare_variants.extend(quote! { + 'block: { + #( + let ::core::option::Option::Some(#named_fields) = #fetches::prepare(archetype) else { + break 'block; + }; + )* + return ::core::option::Option::Some(Self::State::#ident { #(#named_fields,)* }); + } + }); + + fetch_execute_variants.extend(quote! { + Self::State::#ident { #(#named_fields),* } => { + return Self::#ident { + #( + #named_fields: #fetches::execute(archetype, #named_fields), + )* + }; + }, + }); + + fetch_release_variants.extend(quote! { + Self::State::#ident { #(#named_fields),* } => { + #( + #fetches::release(archetype, #named_fields); + )* + }, + }); + + fetch_for_each_borrow.extend(quote! { + #( + <#fetches as ::hecs::Fetch>::for_each_borrow(&mut f); + )* + }); + } + + let dangling_constructor = if let Some(dangling_constructor) = dangling_constructor { + dangling_constructor + } else { + fetch_variants.extend(quote! { + __HecsDanglingFetch__, + }); + query_get_variants.extend(quote! { + Self::Fetch::__HecsDanglingFetch__ => panic!("Called get() with dangling fetch"), + }); + quote! { + Self::__HecsDanglingFetch__ + } + }; + + let fetch_ident = Ident::new(&format!("{}Fetch", enum_ident), Span::mixed_site()); + let fetch = quote! { + #vis enum #fetch_ident { + #fetch_variants + } + }; + + let state_ident = Ident::new(&format!("{}State", enum_ident), Span::mixed_site()); + let state = quote! { + #vis enum #state_ident { + #state_variants + } + }; + + Ok(quote! { + const _: () = { + #[derive(Clone)] + #fetch + + impl<'a> ::hecs::Query for #enum_ident<'a> { + type Item<'q> = #enum_ident<'q>; + + type Fetch = #fetch_ident; + + #[allow(unused_variables)] + unsafe fn get<'q>(fetch: &Self::Fetch, n: usize) -> Self::Item<'q> { + match fetch { + #query_get_variants + } + } + } + + #[derive(Clone, Copy)] + #state + + unsafe impl ::hecs::Fetch for #fetch_ident { + type State = #state_ident; + + fn dangling() -> Self { + #dangling_constructor + } + + #[allow(unused_variables, unused_mut, unreachable_code)] + fn access(archetype: &::hecs::Archetype) -> ::core::option::Option<::hecs::Access> { + #fetch_access_variants + ::core::option::Option::None + } + + #[allow(unused_variables)] + fn borrow(archetype: &::hecs::Archetype, state: Self::State) { + match state { + #fetch_borrow_variants + } + } + + #[allow(unused_variables, unreachable_code)] + fn prepare(archetype: &::hecs::Archetype) -> ::core::option::Option { + #fetch_prepare_variants + ::core::option::Option::None + } + + #[allow(unused_variables)] + fn execute(archetype: &::hecs::Archetype, state: Self::State) -> Self { + match state { + #fetch_execute_variants + } + } + + #[allow(unused_variables)] + fn release(archetype: &::hecs::Archetype, state: Self::State) { + match state { + #fetch_release_variants + } + } + + #[allow(unused_variables, unused_mut)] + fn for_each_borrow(mut f: impl ::core::ops::FnMut(::core::any::TypeId, bool)) { + #fetch_for_each_borrow + } + } + }; + }) +} + fn query_ty(lifetime: &Lifetime, ty: &Type) -> TokenStream2 { struct Visitor<'a> { replace: &'a Lifetime, @@ -200,7 +472,7 @@ fn query_ty(lifetime: &Lifetime, ty: &Type) -> TokenStream2 { impl syn::visit_mut::VisitMut for Visitor<'_> { fn visit_lifetime_mut(&mut self, l: &mut Lifetime) { if l == self.replace { - *l = Lifetime::new("'static", Span::call_site()); + *l = Lifetime::new("'static", Span::mixed_site()); } } } diff --git a/tests/derive.rs b/tests/derive.rs index 49c94a2d..65388656 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -6,8 +6,9 @@ fn derive() { const TEST_DIR: &str = "tests/derive"; let t = trybuild::TestCases::new(); - let failures = &["enum.rs", "union.rs", "wrong_lifetime.rs"]; + let failures = &["enum_unsupported.rs", "union.rs", "wrong_lifetime.rs"]; let successes = &[ + "enum_query.rs", "unit_structs.rs", "tuple_structs.rs", "named_structs.rs", diff --git a/tests/derive/enum.rs b/tests/derive/enum.rs deleted file mode 100644 index a4cd4dd0..00000000 --- a/tests/derive/enum.rs +++ /dev/null @@ -1,9 +0,0 @@ -use hecs::{Bundle, Query}; - -#[derive(Query)] -enum Foo {} - -#[derive(Bundle)] -enum Bar {} - -fn main() {} diff --git a/tests/derive/enum.stderr b/tests/derive/enum.stderr deleted file mode 100644 index 620694b8..00000000 --- a/tests/derive/enum.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: derive(Query) may only be applied to structs - --> $DIR/enum.rs:4:6 - | -4 | enum Foo {} - | ^^^ - -error: derive(Bundle) does not support enums or unions - --> $DIR/enum.rs:7:6 - | -7 | enum Bar {} - | ^^^ diff --git a/tests/derive/enum_query.rs b/tests/derive/enum_query.rs new file mode 100644 index 00000000..5da2df8f --- /dev/null +++ b/tests/derive/enum_query.rs @@ -0,0 +1,24 @@ +use hecs::Query; + +#[derive(Query)] +enum Foo<'a> { + Foo(&'a i32) +} + +#[derive(Query)] +enum Bar<'a> { + Bar { + bar: &'a bool + }, +} + +#[derive(Query)] +enum All<'a> { + Foo(&'a i32), + Bar { + bar: &'a bool + }, + Baz +} + +fn main() {} diff --git a/tests/derive/enum_unsupported.rs b/tests/derive/enum_unsupported.rs new file mode 100644 index 00000000..569f1d77 --- /dev/null +++ b/tests/derive/enum_unsupported.rs @@ -0,0 +1,6 @@ +use hecs::Bundle; + +#[derive(Bundle)] +enum Foo {} + +fn main() {} diff --git a/tests/derive/enum_unsupported.stderr b/tests/derive/enum_unsupported.stderr new file mode 100644 index 00000000..f3c416d8 --- /dev/null +++ b/tests/derive/enum_unsupported.stderr @@ -0,0 +1,5 @@ +error: derive(Bundle) does not support enums or unions + --> $DIR/enum_unsupported.rs:4:6 + | +4 | enum Foo {} + | ^^^ diff --git a/tests/derive/nested_query.rs b/tests/derive/nested_query.rs index 9b4f4e89..c4ffdbca 100644 --- a/tests/derive/nested_query.rs +++ b/tests/derive/nested_query.rs @@ -11,4 +11,10 @@ struct Bar<'a> { baz: &'a mut bool, } +#[derive(Query)] +enum Baz<'a> { + Foo(Foo<'a>), + Bar(Bar<'a>), +} + fn main() {} diff --git a/tests/derive/no_prelude.rs b/tests/derive/no_prelude.rs index acd672c9..ef154f88 100644 --- a/tests/derive/no_prelude.rs +++ b/tests/derive/no_prelude.rs @@ -18,4 +18,13 @@ struct Quux<'a> { foo: &'a (), } +#[derive(::hecs::Query)] +enum Corge<'a> { + Foo (&'a i32), + Bar { + bar: &'a bool + }, + Baz +} + fn main() {} diff --git a/tests/derive/union.stderr b/tests/derive/union.stderr index 26749665..d6265355 100644 --- a/tests/derive/union.stderr +++ b/tests/derive/union.stderr @@ -1,4 +1,4 @@ -error: derive(Query) may only be applied to structs +error: derive(Query) may only be applied to structs and enums --> $DIR/union.rs:4:7 | 4 | union Foo { diff --git a/tests/derive/wrong_lifetime.rs b/tests/derive/wrong_lifetime.rs index 6c806732..0b55abfe 100644 --- a/tests/derive/wrong_lifetime.rs +++ b/tests/derive/wrong_lifetime.rs @@ -6,4 +6,10 @@ struct Foo<'a> { bar: &'static mut bool, } +#[derive(Query)] +enum Bar<'a> { + Foo(&'a i32), + Bar(&'static mut bool) +} + fn main() {} diff --git a/tests/derive/wrong_lifetime.stderr b/tests/derive/wrong_lifetime.stderr index 8c7deda2..6af1f9a0 100644 --- a/tests/derive/wrong_lifetime.stderr +++ b/tests/derive/wrong_lifetime.stderr @@ -8,3 +8,14 @@ error: lifetime may not live long enough | type annotation requires that `'q` must outlive `'static` | = note: this error originates in the derive macro `Query` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: lifetime may not live long enough + --> $DIR/wrong_lifetime.rs:9:10 + | +9 | #[derive(Query)] + | ^^^^^ + | | + | lifetime `'q` defined here + | type annotation requires that `'q` must outlive `'static` + | + = note: this error originates in the derive macro `Query` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/tests.rs b/tests/tests.rs index 9183a746..d36eb49c 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -78,6 +78,67 @@ fn derived_query() { ); } +#[test] +#[cfg(feature = "macros")] +fn derived_enum_query() { + #[derive(Query, Debug, PartialEq)] + enum Foo<'a> { + NumberAndString(&'a i32, &'a String), + Number(&'a i32), + Boolean(&'a mut bool), + } + + let mut world = World::new(); + let e1 = world.spawn((42, false)); + + assert_eq!(world.query_one_mut::(e1).unwrap(), Foo::Number(&42)); + + let e2 = world.spawn((String::from("Hello"), false)); + + assert_eq!( + world.query_one_mut::(e2).unwrap(), + Foo::Boolean(&mut false) + ); + + let e3 = world.spawn((String::from("Hello"), 42)); + + assert_eq!( + world.query_one_mut::(e3).unwrap(), + Foo::NumberAndString(&42, &String::from("Hello")) + ); + + let e4 = world.spawn((String::from("Hello"), 0_usize)); + + assert_eq!( + world.query_one_mut::(e4), + Err(QueryOneError::Unsatisfied) + ); +} + +#[test] +#[cfg(feature = "macros")] +fn derived_enum_query_with_empty() { + #[derive(Query, Debug, PartialEq)] + enum Foo<'a> { + Number(&'a i32), + Empty, + Impossible(&'a String), + } + + let mut world = World::new(); + let e1 = world.spawn((42, false)); + + assert_eq!(world.query_one_mut::(e1).unwrap(), Foo::Number(&42)); + + let e2 = world.spawn((false, 0_usize)); + + assert_eq!(world.query_one_mut::(e2).unwrap(), Foo::Empty); + + let e3 = world.spawn((String::from("Hello"), false)); + + assert_eq!(world.query_one_mut::(e3).unwrap(), Foo::Empty); +} + #[test] #[cfg(feature = "macros")] fn derived_bundle_clone() {