diff --git a/primedev/client/clientauthhooks.cpp b/primedev/client/clientauthhooks.cpp index 35ae3aa77..ceb648a53 100644 --- a/primedev/client/clientauthhooks.cpp +++ b/primedev/client/clientauthhooks.cpp @@ -3,8 +3,6 @@ #include "client/r2client.h" #include "core/vanilla.h" -AUTOHOOK_INIT() - ConVar* Cvar_ns_has_agreed_to_send_token; // mirrored in script @@ -12,16 +10,14 @@ const int NOT_DECIDED_TO_SEND_TOKEN = 0; const int AGREED_TO_SEND_TOKEN = 1; const int DISAGREED_TO_SEND_TOKEN = 2; -// clang-format off -AUTOHOOK(AuthWithStryder, engine.dll + 0x1843A0, -void, __fastcall, (void* a1)) -// clang-format on +static void (*__fastcall o_pAuthWithStryder)(void* a1) = nullptr; +static void __fastcall h_AuthWithStryder(void* a1) { // don't attempt to do Atlas auth if we are in vanilla compatibility mode // this prevents users from joining untrustworthy servers (unless they use a concommand or something) if (g_pVanillaCompatibility->GetVanillaCompatibility()) { - AuthWithStryder(a1); + o_pAuthWithStryder(a1); return; } @@ -38,15 +34,13 @@ void, __fastcall, (void* a1)) *g_pLocalPlayerOriginToken = 0; } - AuthWithStryder(a1); + o_pAuthWithStryder(a1); } char* p3PToken; -// clang-format off -AUTOHOOK(Auth3PToken, engine.dll + 0x183760, -char*, __fastcall, ()) -// clang-format on +static char* (*__fastcall o_pAuth3PToken)() = nullptr; +static char* __fastcall h_Auth3PToken() { if (!g_pVanillaCompatibility->GetVanillaCompatibility() && g_pMasterServerManager->m_sOwnClientAuthToken[0]) { @@ -54,12 +48,16 @@ char*, __fastcall, ()) strcpy(p3PToken, "Protocol 3: Protect the Pilot"); } - return Auth3PToken(); + return o_pAuth3PToken(); } ON_DLL_LOAD_CLIENT_RELIESON("engine.dll", ClientAuthHooks, ConVar, (CModule module)) { - AUTOHOOK_DISPATCH() + o_pAuthWithStryder = module.Offset(0x1843A0).RCast(); + HookAttach(&(PVOID&)o_pAuthWithStryder, (PVOID)h_AuthWithStryder); + + o_pAuth3PToken = module.Offset(0x183760).RCast(); + HookAttach(&(PVOID&)o_pAuth3PToken, (PVOID)h_Auth3PToken); p3PToken = module.Offset(0x13979D80).RCast(); diff --git a/primedev/client/clientruihooks.cpp b/primedev/client/clientruihooks.cpp index ad50d11a8..e49e6f63a 100644 --- a/primedev/client/clientruihooks.cpp +++ b/primedev/client/clientruihooks.cpp @@ -1,23 +1,20 @@ #include "core/convar/convar.h" -AUTOHOOK_INIT() - ConVar* Cvar_rui_drawEnable; -// clang-format off -AUTOHOOK(DrawRUIFunc, engine.dll + 0xFC500, -bool, __fastcall, (void* a1, float* a2)) -// clang-format on +static bool (*__fastcall o_pDrawRUIFunc)(void* a1, float* a2) = nullptr; +static bool __fastcall h_DrawRUIFunc(void* a1, float* a2) { if (!Cvar_rui_drawEnable->GetBool()) return 0; - return DrawRUIFunc(a1, a2); + return o_pDrawRUIFunc(a1, a2); } ON_DLL_LOAD_CLIENT_RELIESON("engine.dll", RUI, ConVar, (CModule module)) { - AUTOHOOK_DISPATCH() + o_pDrawRUIFunc = module.Offset(0xFC500).RCast(); + HookAttach(&(PVOID&)o_pDrawRUIFunc, (PVOID)h_DrawRUIFunc); Cvar_rui_drawEnable = new ConVar("rui_drawEnable", "1", FCVAR_CLIENTDLL, "Controls whether RUI should be drawn"); } diff --git a/primedev/client/clientvideooverrides.cpp b/primedev/client/clientvideooverrides.cpp index d8aa27541..54bb64697 100644 --- a/primedev/client/clientvideooverrides.cpp +++ b/primedev/client/clientvideooverrides.cpp @@ -1,11 +1,7 @@ #include "mods/modmanager.h" -AUTOHOOK_INIT() - -// clang-format off -AUTOHOOK_PROCADDRESS(BinkOpen, bink2w64.dll, BinkOpen, -void*, __fastcall, (const char* path, uint32_t flags)) -// clang-format on +static void* (*__fastcall o_pBinkOpen)(const char* path, uint32_t flags) = nullptr; +static void* __fastcall h_BinkOpen(const char* path, uint32_t flags) { std::string filename(fs::path(path).filename().string()); spdlog::info("BinkOpen {}", filename); @@ -25,16 +21,20 @@ void*, __fastcall, (const char* path, uint32_t flags)) { // create new path fs::path binkPath(fileOwner->m_ModDirectory / "media" / filename); - return BinkOpen(binkPath.string().c_str(), flags); + return o_pBinkOpen(binkPath.string().c_str(), flags); } else - return BinkOpen(path, flags); + return o_pBinkOpen(path, flags); } -ON_DLL_LOAD_CLIENT("engine.dll", BinkVideo, (CModule module)) +ON_DLL_LOAD_CLIENT("bink2w64.dll", BinkRead, (CModule module)) { - AUTOHOOK_DISPATCH() + o_pBinkOpen = module.GetExportedFunction("BinkOpen").RCast(); + HookAttach(&(PVOID&)o_pBinkOpen, (PVOID)h_BinkOpen); +} +ON_DLL_LOAD_CLIENT("engine.dll", BinkVideo, (CModule module)) +{ // remove engine check for whether the bik we're trying to load exists in r2/media, as this will fail for biks in mods // note: the check in engine is actually unnecessary, so it's just useless in practice and we lose nothing by removing it module.Offset(0x459AD).NOP(6); diff --git a/primedev/core/hooks.cpp b/primedev/core/hooks.cpp index 52fb4bf16..0a274805f 100644 --- a/primedev/core/hooks.cpp +++ b/primedev/core/hooks.cpp @@ -325,8 +325,6 @@ void CallLoadLibraryACallbacks(LPCSTR lpLibFileName, HMODULE moduleAddress) void CallLoadLibraryWCallbacks(LPCWSTR lpLibFileName, HMODULE moduleAddress) { - CModule cModule(moduleAddress); - while (true) { bool bDoneCalling = true; diff --git a/primedev/thirdparty/silver-bun/module.cpp b/primedev/thirdparty/silver-bun/module.cpp index 84f4da9e2..dceb602a4 100644 --- a/primedev/thirdparty/silver-bun/module.cpp +++ b/primedev/thirdparty/silver-bun/module.cpp @@ -66,6 +66,21 @@ void CModule::Init() m_ModuleSections.push_back(ModuleSections_t(reinterpret_cast(hCurrentSection.Name), static_cast(m_pModuleBase + hCurrentSection.VirtualAddress), hCurrentSection.SizeOfRawData)); // Push back a struct with the section data. } + + // Get the location of IMAGE_IMPORT_DESCRIPTOR for this module by adding the IMAGE_DIRECTORY_ENTRY_IMPORT relative virtual address onto our + // module base address. + IMAGE_IMPORT_DESCRIPTOR* pImageImportDescriptors = reinterpret_cast( + m_pModuleBase + m_pNTHeaders->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress); + if (!pImageImportDescriptors) + return; + + for (IMAGE_IMPORT_DESCRIPTOR* pIID = pImageImportDescriptors; pIID->Name != 0; pIID++) + { + // Get virtual relative Address of the imported module name. Then add module base Address to get the actual location. + const char* szImportedModuleName = reinterpret_cast(reinterpret_cast(m_pModuleBase + pIID->Name)); + + m_vImportedModules.push_back(szImportedModuleName); + } } //----------------------------------------------------------------------------- diff --git a/primedev/thirdparty/silver-bun/module.h b/primedev/thirdparty/silver-bun/module.h index 5683ee146..cc5130866 100644 --- a/primedev/thirdparty/silver-bun/module.h +++ b/primedev/thirdparty/silver-bun/module.h @@ -52,6 +52,7 @@ class CModule ModuleSections_t GetSectionByName(const char* szSectionName) const; inline const std::vector& GetSections() const { return m_ModuleSections; } + inline const std::vector& GetImportedModules() const { return m_vImportedModules; } inline uintptr_t GetModuleBase(void) const { return m_pModuleBase; } inline DWORD GetModuleSize(void) const { return m_nModuleSize; } inline const std::string& GetModuleName(void) const { return m_ModuleName; } @@ -73,4 +74,5 @@ class CModule uintptr_t m_pModuleBase; DWORD m_nModuleSize; std::vector m_ModuleSections; + std::vector m_vImportedModules; }; diff --git a/primedev/windows/libsys.cpp b/primedev/windows/libsys.cpp index 501eae687..0aff820b7 100644 --- a/primedev/windows/libsys.cpp +++ b/primedev/windows/libsys.cpp @@ -18,15 +18,31 @@ ILoadLibraryExW o_LoadLibraryExW = nullptr; //----------------------------------------------------------------------------- void LibSys_RunModuleCallbacks(HMODULE hModule) { + // Modules that we have already ran callbacks for. + // Note: If we ever hook unloading modules, then this will need updating to handle removal etc. + static std::vector vCalledModules; + if (!hModule) { return; } + // If we have already ran callbacks for this module, don't run them again. + if (std::find(vCalledModules.begin(), vCalledModules.end(), hModule) != vCalledModules.end()) + { + return; + } + vCalledModules.push_back(hModule); + // Get module base name in ASCII as noone wants to deal with unicode CHAR szModuleName[MAX_PATH]; GetModuleBaseNameA(GetCurrentProcess(), hModule, szModuleName, MAX_PATH); + // Run calllbacks for all imported modules + CModule cModule(hModule); + for (const std::string& svImport : cModule.GetImportedModules()) + LibSys_RunModuleCallbacks(GetModuleHandleA(svImport.c_str())); + // DevMsg(eLog::NONE, "%s\n", szModuleName); // Call callbacks