diff --git a/rivets-macros/src/lib.rs b/rivets-macros/src/lib.rs index 1e1b5e3..f382da0 100644 --- a/rivets-macros/src/lib.rs +++ b/rivets-macros/src/lib.rs @@ -9,8 +9,13 @@ use lazy_regex::regex; use proc_macro::{self, Diagnostic, Level, Span, TokenStream}; use proc_macro2::TokenStream as TokenStream2; use quote::quote; +use std::sync::{atomic::AtomicBool, LazyLock, Mutex}; use syn::{parse_macro_input, Abi, DeriveInput, Error, Expr, FnArg, Ident, ItemFn, Variant}; +static IS_FINALIZED: AtomicBool = AtomicBool::new(false); +static MANGLED_NAMES: LazyLock>> = LazyLock::new(|| Mutex::new(vec![])); +static CPP_IMPORTS: LazyLock>> = LazyLock::new(|| Mutex::new(vec![])); + macro_rules! derive_error { ($string: tt) => { Error::new(proc_macro2::Span::call_site(), $string) @@ -19,6 +24,14 @@ macro_rules! derive_error { }; } +macro_rules! check_finalized { + () => { + if IS_FINALIZED.load(std::sync::atomic::Ordering::Relaxed) { + panic!("The rivets library has already been finalized!"); + } + }; +} + fn failure(callback: proc_macro2::TokenStream, error_message: &str) -> TokenStream { Diagnostic::spanned(Span::call_site(), Level::Error, error_message).emit(); callback.into() @@ -42,8 +55,6 @@ fn determine_calling_convention(input: &ItemFn, unmangled_name: &str) -> Result< } } -static mut MANGLED_NAMES: Vec<(String, String)> = vec![]; - /// A procedural macro for detouring a C++ compiled function. /// /// The argument to the macro is the mangled name of the C++ function to detour. @@ -58,6 +69,8 @@ static mut MANGLED_NAMES: Vec<(String, String)> = vec![]; /// /// This macro cannot hook into the middle of a C++ function. It can only hook into the beginning or end of a function. /// +/// This macro cannot hook into a function that has been inlined by the compiler. Prominent examples of this include `lua_gettop`. +/// /// Exposes an `unsafe` `back` function that can be called in order to resume control flow to the original C++ function. /// /// Internally uses the `retour` crate to create a static detour for the function and thus inherits the safety guarantees of that crate. @@ -88,6 +101,8 @@ static mut MANGLED_NAMES: Vec<(String, String)> = vec![]; /// See the `pdb2hpp` module for a tool that can generate the correct FFI types for C++ functions. #[proc_macro_attribute] pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream { + check_finalized!(); + let mangled_name = attr.to_string(); let unmangled_name = rivets_shared::demangle(&mangled_name).unwrap_or_else(|| mangled_name.clone()); @@ -144,61 +159,179 @@ pub fn detour(attr: TokenStream, item: TokenStream) -> TokenStream { #callback pub unsafe fn hook(address: u64) -> Result<(), rivets::retour::Error> { - let compiled_function: #cpp_function_header = std::mem::transmute(address); + let compiled_function: #cpp_function_header = std::mem::transmute(address); // todo: rust documentation recommends casting this to a raw function pointer. address as *const _ Detour.initialize(compiled_function, #name)?.enable()?; Ok(()) } } }; - unsafe { - MANGLED_NAMES.push((mangled_name.clone(), format!("{name}"))); - } + MANGLED_NAMES + .lock() + .expect("Failed to lock mangled names") + .push((mangled_name.clone(), name.to_string())); Diagnostic::spanned(Span::call_site(), Level::Note, unmangled_name.clone()).emit(); result.into() } +/// A procedural macro for importing a C++ compiled function into the rust scope. +/// This macro is useful in the case where you need to directly call any C++ function from rust. +/// +/// # Arguments +/// * `mangled_name` - The mangled name of the C++ function to import. +/// * `dll_name` (optional) - Argument for the name of the DLL to import the function from. If not provided, factorio.exe will be used. +/// +/// Note that most Factorio libraries (such as allegro and lua) are statically linked. In this case, the `dll_name` argument is not needed. +/// +/// # Examples +/// ``` +/// // Summons the lua_gettop function from the compiled lua library. +/// // lua_gettop is compiled without name mangling, so calling convention (in this case, extern "C") must be manually provided. +/// #[import(lua_gettop)] +/// extern "C" fn lua_gettop(lua_state: *mut luastate::lua_State) -> i64 {} +/// +/// // Calls the lua_gettop function with correct arguments. +/// fn my_func(*mut luastate::lua_State) { +/// let top = unsafe { lua_gettop(lua_state) }; +/// println!("Lua stack top: {top}"); +/// } +/// ``` +/// +/// # Safety +/// The arguments and return type of the imported function must be exactly matching FFI types. +/// All structs, classes, enums, and union arguments must have a corresponding `#[repr(C)]` attribute and must also have the correct offsets and sizes. +/// Alternatively, the user can use the `rivets::Opaque` type to represent any arbitrary FFI data if you do not intend to interact with the data. +/// See the `pdb2hpp` module for a tool that can generate the correct FFI types for C++ functions. +/// +/// The user must also ensure that the calling convention is correct. +/// Rivets attempts to automatically parse this information from the mangled name however +/// - If the calling convention is not one of cdecl, stdcall, fastcall, thiscall, or vectorcall, the user must specify the calling convention manually. +/// - If the calling convention is not present in the mangled name, the user must specify the calling convention manually. +/// - In rare cases the function may use a non-standard calling convention. In this case, the user must manually populate the required stack and registers via inline assembly. +/// +/// Calling any imported function repersents calling into the C++ compiled codebase and thus is inherently unsafe. +#[proc_macro_attribute] +pub fn import(attr: TokenStream, item: TokenStream) -> TokenStream { + check_finalized!(); + + let mangled_name = attr.to_string(); + let unmangled_name = + rivets_shared::demangle(&mangled_name).unwrap_or_else(|| mangled_name.clone()); + + let input = parse_macro_input!(item as ItemFn); + + let calling_convention = match determine_calling_convention(&input, &unmangled_name) { + Ok(calling_convention) => Some(calling_convention), + Err(e) => return failure(quote! { #input }, &e.to_string()), + }; + + let arg_types = input.sig.inputs.iter().map(|arg| match arg { + FnArg::Receiver(_) => { + quote! {compile_error!("Summoned functions cannot use the self parameter.")} + } + FnArg::Typed(pat) => { + let ty = &pat.ty; + quote! { #ty } + } + }); + + let return_type = &input.sig.output; + let vis = &input.vis; + let attr = &input.attrs; + let attr = quote! { #(#attr)* }; + + let name = &input.sig.ident; + let function_type = + quote! { #attr #vis unsafe #calling_convention fn(#(#arg_types),*) #return_type }; + + CPP_IMPORTS + .lock() + .expect("Failed to lock cpp imports") + .push((mangled_name.clone(), name.to_string())); + + Diagnostic::spanned(Span::call_site(), Level::Note, unmangled_name.clone()).emit(); + + quote! { + #[allow(non_upper_case_globals)] + static mut #name: rivets::UnsafeSummonedFunction<#function_type> = rivets::UnsafeSummonedFunction::Uninitialized; + }.into() +} + +fn get_hooks() -> Vec { + MANGLED_NAMES + .lock() + .expect("Failed to lock mangled names") + .iter() + .map(|(mangled_name, module_name)| { + let module_name = Ident::new(module_name, proc_macro2::Span::call_site()); + quote! { + hooks.push( + rivets::RivetsHook { + mangled_name: #mangled_name.into(), + hook: #module_name::hook + } + ); + } + }) + .collect() +} + +fn get_imports() -> Vec { + CPP_IMPORTS.lock().expect("Failed to lock cpp imports") + .iter() + .map(|(mangled_name, rust_name)| { + let rust_name = Ident::new(rust_name, proc_macro2::Span::call_site()); + quote! { + let Some(address) = symbol_cache.get_function_address(base_address, #mangled_name) + else { + panic!( + "Failed to find address for the following mangled function inside the PDB: {}", #mangled_name + ); + }; + let function = unsafe { + std::mem::transmute(address) // todo: rust documentation recommends casting this to a raw function pointer. address as *const _ + }; + unsafe { #rust_name = rivets::UnsafeSummonedFunction::Function(function); } + } + }) + .collect() +} + /// A procedural macro for finalizing the rivets library. /// This macro should be called once at the end of the `main.rs` file. /// It will finalize the rivets library and inject all of the detours. #[proc_macro] pub fn finalize(_: TokenStream) -> TokenStream { - let injects = unsafe { MANGLED_NAMES.clone() }; - let injects = injects.iter().map(|(mangled_name, name)| { - let name = Ident::new(name, proc_macro2::Span::call_site()); - quote! { - hooks.push( - rivets::RivetsHook { - mangled_name: #mangled_name.into(), - hook: #name::hook - } - ); - } - }); + check_finalized!(); + IS_FINALIZED.store(true, std::sync::atomic::Ordering::Relaxed); - quote! { - rivets::dll_syringe::payload_procedure! { - fn rivets_finalize(symbol_cache: rivets::SymbolCache) -> Option { - let base_address = match symbol_cache.get_module_base_address() { - Ok(base_address) => base_address, - Err(e) => return Some(format!("{e}")), - }; + let hooks = get_hooks(); + let imports = get_imports(); - let mut hooks: Vec = Vec::new(); - #(#injects)* - for hook in &hooks { - let inject_result = unsafe { symbol_cache.inject(base_address, hook) }; - if inject_result.is_err() { - return Some(format!("{inject_result:?}")); - } + let finalize = quote! { + fn rivets_finalize(symbol_cache: rivets::SymbolCache) -> Option { + let base_address = match symbol_cache.get_module_base_address() { + Ok(base_address) => base_address, + Err(e) => return Some(format!("{e}")), + }; + + #(#imports)* + + let mut hooks: Vec = Vec::new(); + #(#hooks)* + for hook in &hooks { + let inject_result = unsafe { symbol_cache.inject(base_address, hook) }; + if inject_result.is_err() { + return Some(format!("{inject_result:?}")); } - None } + None } - } - .into() + }; + + quote! { rivets::dll_syringe::payload_procedure! { #finalize } }.into() } #[derive(FromDeriveInput)] diff --git a/rivets-shared/src/lib.rs b/rivets-shared/src/lib.rs index c8703ea..2a7c069 100644 --- a/rivets-shared/src/lib.rs +++ b/rivets-shared/src/lib.rs @@ -5,6 +5,7 @@ use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize}; use std::collections::HashMap; use std::ffi::{CStr, CString}; use std::fs::File; +use std::ops::Deref; use std::path::Path; use undname::Flags; use windows::core::PCSTR; @@ -154,3 +155,27 @@ impl SymbolCache { Ok((hook.hook)(address)?) } } + +/// Represents a function that has been imported from a C++ compiled DLL. +/// Invariant: If the function is not initialized, it is UB to dereference it. +/// The rivets::finalize!() macro should be used to ensure that the function is initialized. +pub enum UnsafeSummonedFunction +where + T: 'static + Sized, +{ + Function(T), + Uninitialized, +} + +impl Deref for UnsafeSummonedFunction { + type Target = T; + + #[inline] + #[track_caller] + fn deref(&self) -> &Self::Target { + match self { + Self::Function(x) => x, + Self::Uninitialized => unsafe { std::hint::unreachable_unchecked() }, + } + } +}