From 7848c11e314ad1322a6f3f767d5538186125c07a Mon Sep 17 00:00:00 2001 From: Andrea Lattuada Date: Wed, 26 Jul 2023 15:27:07 +0200 Subject: [PATCH] allow #![trigger f(x)] on ensures (of broadcast_forall) functions --- dependencies/syn/src/gen/clone.rs | 1 + dependencies/syn/src/gen/debug.rs | 1 + dependencies/syn/src/gen/eq.rs | 2 +- dependencies/syn/src/gen/fold.rs | 1 + dependencies/syn/src/gen/hash.rs | 1 + dependencies/syn/src/gen/visit.rs | 3 + dependencies/syn/src/gen/visit_mut.rs | 3 + dependencies/syn/src/verus.rs | 7 +- dependencies/syn/syn.json | 5 + dependencies/syn/tests/debug/gen.rs | 3 + source/builtin_macros/src/syntax.rs | 214 ++++++++++++------- source/rust_verify_test/tests/quantifiers.rs | 27 +++ 12 files changed, 187 insertions(+), 81 deletions(-) diff --git a/dependencies/syn/src/gen/clone.rs b/dependencies/syn/src/gen/clone.rs index 7d109adf22..a040c6ea43 100644 --- a/dependencies/syn/src/gen/clone.rs +++ b/dependencies/syn/src/gen/clone.rs @@ -288,6 +288,7 @@ impl Clone for DeriveInput { impl Clone for Ensures { fn clone(&self) -> Self { Ensures { + attrs: self.attrs.clone(), token: self.token.clone(), exprs: self.exprs.clone(), } diff --git a/dependencies/syn/src/gen/debug.rs b/dependencies/syn/src/gen/debug.rs index d930c54182..b6e709e598 100644 --- a/dependencies/syn/src/gen/debug.rs +++ b/dependencies/syn/src/gen/debug.rs @@ -501,6 +501,7 @@ impl Debug for DeriveInput { impl Debug for Ensures { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { let mut formatter = formatter.debug_struct("Ensures"); + formatter.field("attrs", &self.attrs); formatter.field("token", &self.token); formatter.field("exprs", &self.exprs); formatter.finish() diff --git a/dependencies/syn/src/gen/eq.rs b/dependencies/syn/src/gen/eq.rs index 4df5e5ba04..a6da6f1e5f 100644 --- a/dependencies/syn/src/gen/eq.rs +++ b/dependencies/syn/src/gen/eq.rs @@ -306,7 +306,7 @@ impl Eq for Ensures {} #[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] impl PartialEq for Ensures { fn eq(&self, other: &Self) -> bool { - self.exprs == other.exprs + self.attrs == other.attrs && self.exprs == other.exprs } } #[cfg(any(feature = "derive", feature = "full"))] diff --git a/dependencies/syn/src/gen/fold.rs b/dependencies/syn/src/gen/fold.rs index 7f475b059d..3a92da2353 100644 --- a/dependencies/syn/src/gen/fold.rs +++ b/dependencies/syn/src/gen/fold.rs @@ -1318,6 +1318,7 @@ where F: Fold + ?Sized, { Ensures { + attrs: FoldHelper::lift(node.attrs, |it| f.fold_attribute(it)), token: Token![ensures](tokens_helper(f, &node.token.span)), exprs: f.fold_specification(node.exprs), } diff --git a/dependencies/syn/src/gen/hash.rs b/dependencies/syn/src/gen/hash.rs index ea29277120..b712e6c0f2 100644 --- a/dependencies/syn/src/gen/hash.rs +++ b/dependencies/syn/src/gen/hash.rs @@ -430,6 +430,7 @@ impl Hash for Ensures { where H: Hasher, { + self.attrs.hash(state); self.exprs.hash(state); } } diff --git a/dependencies/syn/src/gen/visit.rs b/dependencies/syn/src/gen/visit.rs index 01af9f49ab..1d1e9b7f88 100644 --- a/dependencies/syn/src/gen/visit.rs +++ b/dependencies/syn/src/gen/visit.rs @@ -1328,6 +1328,9 @@ pub fn visit_ensures<'ast, V>(v: &mut V, node: &'ast Ensures) where V: Visit<'ast> + ?Sized, { + for it in &node.attrs { + v.visit_attribute(it); + } tokens_helper(v, &node.token.span); v.visit_specification(&node.exprs); } diff --git a/dependencies/syn/src/gen/visit_mut.rs b/dependencies/syn/src/gen/visit_mut.rs index cf2c985215..8245c1cf05 100644 --- a/dependencies/syn/src/gen/visit_mut.rs +++ b/dependencies/syn/src/gen/visit_mut.rs @@ -1329,6 +1329,9 @@ pub fn visit_ensures_mut(v: &mut V, node: &mut Ensures) where V: VisitMut + ?Sized, { + for it in &mut node.attrs { + v.visit_attribute_mut(it); + } tokens_helper(v, &mut node.token.span); v.visit_specification_mut(&mut node.exprs); } diff --git a/dependencies/syn/src/verus.rs b/dependencies/syn/src/verus.rs index 2e4365d1c3..baf8ac0826 100644 --- a/dependencies/syn/src/verus.rs +++ b/dependencies/syn/src/verus.rs @@ -120,6 +120,7 @@ ast_struct! { ast_struct! { pub struct Ensures { + pub attrs: Vec, pub token: Token![ensures], pub exprs: Specification, } @@ -428,8 +429,12 @@ pub mod parsing { #[cfg_attr(doc_cfg, doc(cfg(feature = "parsing")))] impl Parse for Ensures { fn parse(input: ParseStream) -> Result { + let mut attrs = Vec::new(); + let token = input.parse()?; + attr::parsing::parse_inner(input, &mut attrs)?; Ok(Ensures { - token: input.parse()?, + attrs, + token, exprs: input.parse()?, }) } diff --git a/dependencies/syn/syn.json b/dependencies/syn/syn.json index af92924e87..e77e2f9b71 100644 --- a/dependencies/syn/syn.json +++ b/dependencies/syn/syn.json @@ -846,6 +846,11 @@ "any": [] }, "fields": { + "attrs": { + "vec": { + "syn": "Attribute" + } + }, "token": { "token": "Ensures" }, diff --git a/dependencies/syn/tests/debug/gen.rs b/dependencies/syn/tests/debug/gen.rs index dea852c693..ef681945cf 100644 --- a/dependencies/syn/tests/debug/gen.rs +++ b/dependencies/syn/tests/debug/gen.rs @@ -642,6 +642,9 @@ impl Debug for Lite { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { let _val = &self.value; let mut formatter = formatter.debug_struct("Ensures"); + if !_val.attrs.is_empty() { + formatter.field("attrs", Lite(&_val.attrs)); + } formatter.field("exprs", Lite(&_val.exprs)); formatter.finish() } diff --git a/source/builtin_macros/src/syntax.rs b/source/builtin_macros/src/syntax.rs index 9f352f6cb7..ceacbcb8c0 100644 --- a/source/builtin_macros/src/syntax.rs +++ b/source/builtin_macros/src/syntax.rs @@ -315,23 +315,63 @@ impl Visitor { )); } } - if let Some(Ensures { token, mut exprs }) = ensures { + if let Some(Ensures { attrs, token, mut exprs }) = ensures { if exprs.exprs.len() > 0 { for expr in exprs.exprs.iter_mut() { self.visit_expr_mut(expr); } - if let Some((p, ty)) = ret_pat { - stmts.push(Stmt::Semi( - Expr::Verbatim( - quote_spanned!(token.span => ::builtin::ensures(|#p: #ty| [#exprs])), - ), - Semi { spans: [token.span] }, - )); - } else { - stmts.push(Stmt::Semi( - Expr::Verbatim(quote_spanned!(token.span => ::builtin::ensures([#exprs]))), - Semi { spans: [token.span] }, - )); + let cont = match self.extract_quant_triggers(attrs, token.span) { + Ok( + found @ (ExtractQuantTriggersFound::Auto + | ExtractQuantTriggersFound::Triggers(..)), + ) => { + if exprs.exprs.len() == 0 { + let err = + "when using #![trigger f(x)], at least one ensures is required"; + let expr = + Expr::Verbatim(quote_spanned!(token.span => compile_error!(#err))); + stmts.push(Stmt::Semi(expr, Semi { spans: [token.span] })); + false + } else { + let e = take_expr(&mut exprs.exprs[0]); + match found { + ExtractQuantTriggersFound::Auto => { + exprs.exprs[0] = Expr::Verbatim( + quote_spanned!(exprs.exprs[0].span() => #[verus::internal(auto_trigger)] (#e)), + ); + } + ExtractQuantTriggersFound::Triggers(tuple) => { + exprs.exprs[0] = Expr::Verbatim( + quote_spanned!(exprs.exprs[0].span() => ::builtin::with_triggers(#tuple, #e)), + ); + } + ExtractQuantTriggersFound::None => unreachable!(), + } + true + } + } + Ok(ExtractQuantTriggersFound::None) => true, + Err(err_expr) => { + exprs.exprs[0] = err_expr; + false + } + }; + if cont { + if let Some((p, ty)) = ret_pat { + stmts.push(Stmt::Semi( + Expr::Verbatim( + quote_spanned!(token.span => ::builtin::ensures(|#p: #ty| [#exprs])), + ), + Semi { spans: [token.span] }, + )); + } else { + stmts.push(Stmt::Semi( + Expr::Verbatim( + quote_spanned!(token.span => ::builtin::ensures([#exprs])), + ), + Semi { spans: [token.span] }, + )); + } } } } @@ -956,7 +996,7 @@ impl Visitor { _ => panic!("expected closure for quantifier"), }; - match extract_quant_triggers(inner_attrs, span) { + match self.extract_quant_triggers(inner_attrs, span) { Ok(ExtractQuantTriggersFound::Auto) => match &mut *arg { Expr::Closure(closure) => { let body = take_expr(&mut closure.body); @@ -1028,8 +1068,12 @@ impl Visitor { stmts.push(stmt_with_semi!(token.span => ::builtin::invariant_ensures([#exprs]))); } } - if let Some(Ensures { token, mut exprs }) = ensures { - if exprs.exprs.len() > 0 { + if let Some(Ensures { token, mut exprs, attrs }) = ensures { + if attrs.len() > 0 { + let err = "outer attributes only allowed on function's ensures"; + let expr = Expr::Verbatim(quote_spanned!(token.span => compile_error!(#err))); + stmts.push(Stmt::Semi(expr, Semi { spans: [token.span] })); + } else if exprs.exprs.len() > 0 { for expr in exprs.exprs.iter_mut() { self.visit_expr_mut(expr); } @@ -1049,66 +1093,70 @@ impl Visitor { } self.inside_ghost -= 1; } -} -enum ExtractQuantTriggersFound { - Auto, - Triggers(ExprTuple), - None, -} + fn extract_quant_triggers( + &mut self, + inner_attrs: Vec, + span: Span, + ) -> Result { + let mut triggers: Vec = Vec::new(); + for attr in inner_attrs { + let trigger: syn_verus::Result = + syn_verus::parse2(attr.tokens.clone()); + let path_segments_str = + attr.path.segments.iter().map(|x| x.ident.to_string()).collect::>(); + let ident_str = match &path_segments_str[..] { + [attr_name] => Some(attr_name), + _ => None, + }; + match (trigger, ident_str) { + (Ok(trigger), Some(id)) if id == &"auto" && trigger.exprs.len() == 0 => { + return Ok(ExtractQuantTriggersFound::Auto); + } + (Ok(trigger), Some(id)) if id == &"trigger" => { + let mut exprs = trigger.exprs; + for expr in exprs.iter_mut() { + self.visit_expr_mut(expr); + } + let tuple = ExprTuple { attrs: vec![], paren_token: Paren(span), elems: exprs }; + triggers.push(Expr::Tuple(tuple)); + } + (Err(err), _) => { + let span = attr.span(); + let err = err.to_string(); -fn extract_quant_triggers( - inner_attrs: Vec, - span: Span, -) -> Result { - let mut triggers: Vec = Vec::new(); - for attr in inner_attrs { - let trigger: syn_verus::Result = - syn_verus::parse2(attr.tokens.clone()); - let path_segments_str = - attr.path.segments.iter().map(|x| x.ident.to_string()).collect::>(); - let ident_str = match &path_segments_str[..] { - [attr_name] => Some(attr_name), - _ => None, - }; - match (trigger, ident_str) { - (Ok(trigger), Some(id)) if id == &"auto" && trigger.exprs.len() == 0 => { - return Ok(ExtractQuantTriggersFound::Auto); - } - (Ok(trigger), Some(id)) if id == &"trigger" => { - let tuple = - ExprTuple { attrs: vec![], paren_token: Paren(span), elems: trigger.exprs }; - triggers.push(Expr::Tuple(tuple)); + return Err(Expr::Verbatim(quote_spanned!(span => compile_error!(#err)))); + } + _ => { + let span = attr.span(); + return Err(Expr::Verbatim( + quote_spanned!(span => compile_error!("expected trigger")), + )); + } } - (Err(err), _) => { - let span = attr.span(); - let err = err.to_string(); + } - return Err(Expr::Verbatim(quote_spanned!(span => compile_error!(#err)))); - } - _ => { - let span = attr.span(); - return Err(Expr::Verbatim( - quote_spanned!(span => compile_error!("expected trigger")), - )); + Ok(if triggers.len() > 0 { + let mut elems = Punctuated::new(); + for elem in triggers { + elems.push(elem); + elems.push_punct(Token![,](span)); } - } + ExtractQuantTriggersFound::Triggers(ExprTuple { + attrs: vec![], + paren_token: Paren(span), + elems, + }) + } else { + ExtractQuantTriggersFound::None + }) } +} - Ok(if triggers.len() > 0 { - let mut elems = Punctuated::new(); - for elem in triggers { - elems.push(elem); - elems.push_punct(Token![,](span)); - } - ExtractQuantTriggersFound::Triggers(ExprTuple { - attrs: vec![], - paren_token: Paren(span), - elems, - }) - } else { - ExtractQuantTriggersFound::None - }) +enum ExtractQuantTriggersFound { + Auto, + Triggers(ExprTuple), + None, } impl VisitMut for Visitor { @@ -1661,7 +1709,7 @@ impl VisitMut for Visitor { Expr::AssertForall(assert) => { let span = assert.assert_token.span; let mut arg = assert.expr; - match extract_quant_triggers(assert.attrs, span) { + match self.extract_quant_triggers(assert.attrs, span) { Ok(ExtractQuantTriggersFound::Auto) => { arg = Box::new(Expr::Verbatim( quote_spanned!(arg.span() => #[verus::internal(auto_trigger)] #arg), @@ -1731,16 +1779,24 @@ impl VisitMut for Visitor { stmts.push(stmt_with_semi!( token.span => ::builtin::requires([#exprs]))); } - if let Some(Ensures { token, mut exprs }) = ensures { - for expr in exprs.exprs.iter_mut() { - self.visit_expr_mut(expr); - } - if let Some((p, ty)) = ret_pat { - stmts.push(stmt_with_semi!( - token.span => ::builtin::ensures(|#p: #ty| [#exprs]))); + if let Some(Ensures { token, mut exprs, attrs }) = ensures { + if attrs.len() > 0 { + let err = "outer attributes only allowed on function's ensures"; + let expr = Expr::Verbatim( + quote_spanned!(token.span => compile_error!(#err)), + ); + stmts.push(Stmt::Semi(expr, Semi { spans: [token.span] })); } else { - stmts.push(stmt_with_semi!( - token.span => ::builtin::ensures([#exprs]))); + for expr in exprs.exprs.iter_mut() { + self.visit_expr_mut(expr); + } + if let Some((p, ty)) = ret_pat { + stmts.push(stmt_with_semi!( + token.span => ::builtin::ensures(|#p: #ty| [#exprs]))); + } else { + stmts.push(stmt_with_semi!( + token.span => ::builtin::ensures([#exprs]))); + } } } self.inside_ghost -= 1; diff --git a/source/rust_verify_test/tests/quantifiers.rs b/source/rust_verify_test/tests/quantifiers.rs index 6a79206ba5..137971c8e7 100644 --- a/source/rust_verify_test/tests/quantifiers.rs +++ b/source/rust_verify_test/tests/quantifiers.rs @@ -277,3 +277,30 @@ test_verify_one_file! { } } => Err(err) => assert_vir_error_msg(err, "forall, choose, and exists do not allow parentheses") } + +test_verify_one_file! { + #[test] test_inner_triggers_broadcast_forall verus_code! { + mod M { + pub struct A {} + impl A { + pub spec fn f1(&self) -> bool; + pub spec fn f2(&self) -> bool; + pub spec fn f3(&self) -> bool; + } + + #[verifier::external_body] + #[verifier::broadcast_forall] + pub proof fn ab(a: A) + ensures #![trigger a.f1()] (a.f1() ==> a.f2()) && a.f3() + { + } + } + + use M::*; + proof fn p(a: A) + requires a.f1(), + ensures a.f2(), + { + } + } => Ok(()) +}