Skip to content

Commit

Permalink
During declosurefy, make sure to include additional trait bounds (#606)
Browse files Browse the repository at this point in the history
See Discussion #604
  • Loading branch information
asomers authored Sep 2, 2024
1 parent 7dd21ad commit 6c5276e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ This project adheres to [Semantic Versioning](http://semver.org/).

### Fixed

- When mocking functions with a closure argument, but not using
`#[concretize]`, include any additional trait bounds in the trait object
argument passed to `.with` and `.returning.
([#606](https://github.com/asomers/mockall/pull/606))

- Fixed naming conflict when mocking multiple traits with same name but from
different modules.
([#601](https://github.com/asomers/mockall/pull/601))
Expand Down
55 changes: 37 additions & 18 deletions mockall_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,34 +319,33 @@ fn declosurefy(gen: &Generics, args: &Punctuated<FnArg, Token![,]>) ->
{
let mut hm = HashMap::default();

let mut save_fn_types = |ident: &Ident, tpb: &TypeParamBound| {
if let TypeParamBound::Trait(tb) = tpb {
let fident = &tb.path.segments.last().unwrap().ident;
if ["Fn", "FnMut", "FnOnce"].iter().any(|s| fident == *s) {
let newty: Type = parse2(quote!(Box<dyn #tb>)).unwrap();
let subst_ty: Type = parse2(quote!(#ident)).unwrap();
assert!(hm.insert(subst_ty, newty).is_none(),
"A generic parameter had two Fn bounds?");
let mut save_fn_types = |ident: &Ident, bounds: &Punctuated<TypeParamBound, Token![+]>|
{
for tpb in bounds.iter() {
if let TypeParamBound::Trait(tb) = tpb {
let fident = &tb.path.segments.last().unwrap().ident;
if ["Fn", "FnMut", "FnOnce"].iter().any(|s| fident == *s) {
let newty: Type = parse2(quote!(Box<dyn #bounds>)).unwrap();
let subst_ty: Type = parse2(quote!(#ident)).unwrap();
assert!(hm.insert(subst_ty, newty).is_none(),
"A generic parameter had two Fn bounds?");
}
}
}
};

// First, build a HashMap of all Fn generic types
for g in gen.params.iter() {
if let GenericParam::Type(tp) = g {
for tpb in tp.bounds.iter() {
save_fn_types(&tp.ident, tpb);
}
save_fn_types(&tp.ident, &tp.bounds);
}
}
if let Some(wc) = &gen.where_clause {
for pred in wc.predicates.iter() {
if let WherePredicate::Type(pt) = pred {
let bounded_ty = &pt.bounded_ty;
if let Ok(ident) = parse2::<Ident>(quote!(#bounded_ty)) {
for tpb in pt.bounds.iter() {
save_fn_types(&ident, tpb);
}
save_fn_types(&ident, &pt.bounds);
} else {
// We can't yet handle where clauses this complicated
}
Expand Down Expand Up @@ -1620,12 +1619,14 @@ mod concretize_args {
check_concretize(
quote!(fn foo<F1: Fn(u32) -> u32,
F2: FnMut(&mut u32) -> u32,
F3: FnOnce(u32) -> u32>(f1: F1, f2: F2, f3: F3)),
F3: FnOnce(u32) -> u32,
F4: Fn() + Send>(f1: F1, f2: F2, f3: F3, f4: F4)),
&[quote!(f1: &(dyn Fn(u32) -> u32)),
quote!(f2: &mut(dyn FnMut(&mut u32) -> u32)),
quote!(f3: &(dyn FnOnce(u32) -> u32))],
&[quote!(&f1), quote!(&mut f2), quote!(&f3)],
&[quote!(f1: F1), quote!(mut f2: F2), quote!(f3: F3)]
quote!(f3: &(dyn FnOnce(u32) -> u32)),
quote!(f4: &(dyn Fn() + Send))],
&[quote!(&f1), quote!(&mut f2), quote!(&f3), quote!(&f4)],
&[quote!(f1: F1), quote!(mut f2: F2), quote!(f3: F3), quote!(f4: F4)]
);
}

Expand Down Expand Up @@ -1736,6 +1737,15 @@ mod declosurefy {
}
}

#[test]
fn bounds() {
check_declosurefy(
quote!(fn foo<F: Fn(u32) -> u32 + Send>(f: F)),
&[quote!(f: Box<dyn Fn(u32) -> u32 + Send>)],
&[quote!(Box::new(f))]
);
}

#[test]
fn r#fn() {
check_declosurefy(
Expand Down Expand Up @@ -1780,6 +1790,15 @@ mod declosurefy {
&[quote!(Box::new(f))]
);
}

#[test]
fn where_clause_with_bounds() {
check_declosurefy(
quote!(fn foo<F>(f: F) where F: Fn(u32) -> u32 + Send),
&[quote!(f: Box<dyn Fn(u32) -> u32 + Send>)],
&[quote!(Box::new(f))]
);
}
}

mod deimplify {
Expand Down

0 comments on commit 6c5276e

Please sign in to comment.