From fd0af10119cbaf0284f48a3b3042afbc1cf05c33 Mon Sep 17 00:00:00 2001 From: Yuekai Jia Date: Sun, 28 Jul 2024 23:20:56 +0800 Subject: [PATCH] Add support for access paths to call_interface --- README.md | 46 ++++++++++++++++++++++++++++++++++++++ src/lib.rs | 65 +++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 93 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 54b4aee..bd3bb6c 100644 --- a/README.md +++ b/README.md @@ -38,3 +38,49 @@ assert_eq!( "Hello, rust 456!" ); ``` + +## Implementation + +The procedural macro in the above example will generate the following code: + +* `#[def_interface]`: + + ```rust,ignore + pub trait HelloIf { + fn hello(&self, name: &str, id: usize) -> String; + } + + #[allow(non_snake_case)] + pub mod __HelloIf_mod { + use super::*; + extern "Rust" { + pub fn __HelloIf_hello(name: &str, id: usize) -> String; + } + } + ``` +* `#[impl_interface]`: + + ```rust,ignore + struct HelloIfImpl; + + impl HelloIf for HelloIfImpl { + fn hello(&self, name: &str, id: usize) -> String { + { + #[export_name = "__HelloIf_hello"] + extern "Rust" fn __HelloIf_hello(name: &str, id: usize) -> String { + let _impl: HelloIfImpl = HelloIfImpl; + _impl.hello(name, id) + } + } + { + format!("Hello, {} {}!", name, id) + } + } + } + ``` + +* `call_interface!`: + + ```rust,ignore + unsafe { __HelloIf_mod::__HelloIf_hello("world", 123) } + ``` diff --git a/src/lib.rs b/src/lib.rs index adc2621..c73282e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ #![doc = include_str!("../README.md")] -#![feature(iter_next_chunk)] use proc_macro::TokenStream; use proc_macro2::Span; @@ -30,6 +29,7 @@ pub fn def_interface(attr: TokenStream, item: TokenStream) -> TokenStream { let ast = syn::parse_macro_input!(item as ItemTrait); let trait_name = &ast.ident; + let vis = &ast.vis; let mut extern_fn_list = vec![]; for item in &ast.items { @@ -46,16 +46,22 @@ pub fn def_interface(attr: TokenStream, item: TokenStream) -> TokenStream { } let extern_fn = quote! { - #sig; + pub #sig; }; extern_fn_list.push(extern_fn); } } + let mod_name = format_ident!("__{}_mod", trait_name); quote! { #ast - extern "Rust" { - #(#extern_fn_list)* + + #[allow(non_snake_case)] + #vis mod #mod_name { + use super::*; + extern "Rust" { + #(#extern_fn_list)* + } } } .into() @@ -117,8 +123,8 @@ pub fn impl_interface(attr: TokenStream, item: TokenStream) -> TokenStream { let call_impl = if has_self { quote! { - let IMPL: #impl_name = #impl_name; - IMPL.#fn_name( #(#args),* ) + let _impl: #impl_name = #impl_name; + _impl.#fn_name( #(#args),* ) } } else { quote! { #impl_name::#fn_name( #(#args),* ) } @@ -159,20 +165,43 @@ pub fn call_interface(item: TokenStream) -> TokenStream { } fn parse_call_interface(item: TokenStream) -> Result { - let mut iter = item.into_iter(); - let tt = iter - .next_chunk::<4>() - .or(Err("expect `Trait::func`"))? - .map(|t| t.to_string()); - - let trait_name = &tt[0]; - if tt[1] != ":" || tt[2] != ":" { - return Err("missing `::`".into()); + use proc_macro::TokenTree; + + let mut colon_cnt = 0; + let mut path = Vec::new(); + let mut args = Vec::new(); + for token in item.into_iter() { + if args.len() > 0 { + args.push(token.to_string()); + continue; + } + if let TokenTree::Ident(ident) = token { + colon_cnt = 0; + path.push(ident.to_string()); + } else if token.to_string() == ":" { + colon_cnt += 1; + if path.len() == 0 { + return Err("expect `Trait::func`".into()); + } + if colon_cnt > 2 { + return Err("expect `::`".into()); + } + } else { + args.push(token.to_string()); + } + } + + if path.len() < 2 { + return Err("expect `Trait::func`".into()); } - let fn_name = &tt[3]; + let fn_name = path.pop().unwrap(); + let trait_name = path.pop().unwrap(); let extern_fn_name = format!("__{}_{}", trait_name, fn_name); - let mut args = iter.map(|x| x.to_string()).collect::>().join(""); + path.push(format!("__{}_mod", trait_name)); + let mod_path = path.join("::"); + + let mut args = args.join(" "); if args.starts_with(',') { args.remove(0); } else if args.starts_with('(') && args.ends_with(')') { @@ -180,7 +209,7 @@ fn parse_call_interface(item: TokenStream) -> Result { args.pop(); } - let call = format!("unsafe {{ {}( {} ) }}", extern_fn_name, args); + let call = format!("unsafe {{ {}::{}( {} ) }}", mod_path, extern_fn_name, args); Ok(call .parse::() .or(Err("expect a correct argument list"))?)