From 60f552494060138d9fd7f427ba6785fe2c4dbac7 Mon Sep 17 00:00:00 2001 From: Zachary S Date: Mon, 12 Jun 2023 03:47:50 -0500 Subject: [PATCH] Cleanup custom bounds code. --- derive/src/lib.rs | 81 ++++++++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 3a93ba9..e1acf05 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -366,69 +366,61 @@ fn derive_marker_trait(input: DeriveInput) -> TokenStream { .unwrap_or_else(|err| err.into_compile_error()) } -/// Find a `#[name(key = "value")]` attribute on the struct, and parse "value" -/// with `parser` and return it. +/// Find `#[name(key = "value")]` helper attributes on the struct, and return +/// their `"value"`s parsed with `parser`. /// -/// Returns an error if multiple attributes with `name` are found, or if the one -/// found does not match the expected format. Returns `Ok(None)` if no attribute -/// with `name` is found. -fn find_helper_attribute( +/// Returns an error any matching attributes do not match the expected format. +/// Returns `Ok([])` if no attributes with `name` are found. +fn find_helper_attributes( attributes: &[syn::Attribute], name: &str, key: &str, parser: P, - example_value: &str, -) -> Result> { + example_value: &str, invalid_value_msg: &str, +) -> Result> { + let invalid_format_msg = + format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",); let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta { - syn::Meta::Path(path) => path.is_ident(name).then(|| { - Err(syn::Error::new_spanned( - &path, - format!( - "{name} attribute must be `{name}({key} = \"{example_value}\")`", - ), - )) - }), + // If a `Path` matches our `name`, return an error, else ignore it. + // e.g. `#[zeroable]` + syn::Meta::Path(path) => path + .is_ident(name) + .then(|| Err(syn::Error::new_spanned(&path, &invalid_format_msg))), + // If a `NameValue` matches our `name`, return an error, else ignore it. + // e.g. `#[zeroable = "hello"]` syn::Meta::NameValue(namevalue) => { namevalue.path.is_ident(name).then(|| { - Err(syn::Error::new_spanned( - &namevalue.path, - format!( - "{name} attribute must be `{name}({key} = \"{example_value}\")`", - ), - )) + Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg)) }) } + // If a `List` matches our `name`, match its contents to our format, else + // ignore it. If its contents match our format, return the value, else + // return an error. syn::Meta::List(list) => list.path.is_ident(name).then(|| { let namevalue: MetaNameValue = syn::parse2(list.tokens.clone()).map_err(|_| { - syn::Error::new_spanned( - &list.tokens, - format!( - "{name} attribute must be `{name}({key} = \"{example_value}\")`", - ), - ) + syn::Error::new_spanned(&list.tokens, &invalid_format_msg) })?; - if namevalue.path.is_ident("bound") { + if namevalue.path.is_ident(key) { match namevalue.value { syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(strlit), .. }) => Ok(strlit), - _ => Err(syn::Error::new_spanned( - &namevalue.path, - format!( - "{name} attribute must be `{name}({key} = \"{example_value}\")`", - ), - )), + _ => { + Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg)) + } } } else { - Err(syn::Error::new_spanned( - &namevalue.path, - format!( - "{name} attribute must be `{name}({key} = \"{example_value}\")`", - ), - )) + Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg)) } }), }); + // Parse each value found with the given parser, and return them if no errors + // occur. values_to_check - .map(|r| r.and_then(|lit| Ok((lit.clone(), lit.parse_with(parser)?)))) + .map(|lit| { + let lit = lit?; + lit.parse_with(parser).map_err(|err| { + syn::Error::new_spanned(&lit, &format!("{invalid_value_msg}: {err}")) + }) + }) .collect() } @@ -436,12 +428,13 @@ fn derive_marker_trait_inner( mut input: DeriveInput, ) -> Result { let trait_ = Trait::ident(&input)?; - let explicit_bounds = find_helper_attribute( + let explicit_bounds = find_helper_attributes( &input.attrs, "zeroable", "bound", >::parse_terminated, "Type: Trait", + "invalid where predicate", )?; if explicit_bounds.is_empty() { // Enforce bound on all generic fields. @@ -451,7 +444,7 @@ fn derive_marker_trait_inner( // soundness) let explicit_bounds = explicit_bounds .into_iter() - .flat_map(|a| (a.1)) + .flatten() .collect::>(); input.generics.make_where_clause().predicates.extend(explicit_bounds);