Skip to content

Commit

Permalink
Stop depending on avrt.dll statically on Windows
Browse files Browse the repository at this point in the history
* Load `avrt.dll` dynamically with `LoadLibraryW`

* Fail with an `AudioThreadPriorityError`

See also https://bugzilla.mozilla.org/show_bug.cgi?id=1884214
  • Loading branch information
yjugl committed Mar 12, 2024
1 parent 10c8fc3 commit 7252f2f
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 10 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ libc = "0.2"
version = "0.52"
features = [
"Win32_Foundation",
"Win32_System_Threading",
"Win32_System_LibraryLoader",
]

[target.'cfg(target_os = "windows")'.dependencies.once_cell]
version = "1.19"

[target.'cfg(target_os = "linux")'.dependencies]
libc = "0.2"

Expand Down
116 changes: 107 additions & 9 deletions src/rt_win.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use crate::AudioThreadPriorityError;
use once_cell::sync;
use windows_sys::core::PCSTR;
use windows_sys::s;
use windows_sys::Win32::Foundation::FreeLibrary;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::Foundation::BOOL;
use windows_sys::Win32::Foundation::FALSE;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::System::Threading::{
AvRevertMmThreadCharacteristics, AvSetMmThreadCharacteristicsA,
};

use crate::AudioThreadPriorityError;
use windows_sys::Win32::Foundation::HMODULE;
use windows_sys::Win32::Foundation::WIN32_ERROR;
use windows_sys::Win32::System::LibraryLoader::GetProcAddress;
use windows_sys::Win32::System::LibraryLoader::LoadLibraryA;

use log::info;

Expand All @@ -28,7 +32,9 @@ impl RtPriorityHandleInternal {
pub fn demote_current_thread_from_real_time_internal(
rt_priority_handle: RtPriorityHandleInternal,
) -> Result<(), AudioThreadPriorityError> {
let rv = unsafe { AvRevertMmThreadCharacteristics(rt_priority_handle.task_handle) };
let rv = unsafe {
(av_rt()?.av_revert_mm_thread_characteristics)(rt_priority_handle.task_handle)
};
if rv == FALSE {
return Err(AudioThreadPriorityError::new(&format!(
"Unable to restore the thread priority ({:?})",
Expand All @@ -49,13 +55,14 @@ pub fn promote_current_thread_to_real_time_internal(
_audio_samplerate_hz: u32,
) -> Result<RtPriorityHandleInternal, AudioThreadPriorityError> {
let mut task_index = 0u32;

let handle = unsafe { AvSetMmThreadCharacteristicsA(s!("Audio"), &mut task_index) };
let handle = unsafe {
(av_rt()?.av_set_mm_thread_characteristics_a)(s!("Audio"), &mut task_index)
};
let handle = RtPriorityHandleInternal::new(task_index, handle);

if handle.task_handle == 0 {
return Err(AudioThreadPriorityError::new(&format!(
"Unable to restore the thread priority ({:?})",
"Unable to bump the thread priority ({:?})",
unsafe { GetLastError() }
)));
}
Expand All @@ -67,3 +74,94 @@ pub fn promote_current_thread_to_real_time_internal(

Ok(handle)
}

// We don't expect to see API failures on test machines
#[test]
fn test_successful_api_use() {
let handle = promote_current_thread_to_real_time_internal(0, 0);
println!("handle: {handle:?}");
assert!(handle.is_ok());
let result = demote_current_thread_from_real_time_internal(handle.unwrap());
println!("result: {result:?}");
assert!(result.is_ok());
}

fn av_rt() -> Result<&'static AvRtLibrary, AudioThreadPriorityError> {
static AV_RT_LIBRARY: sync::OnceCell<Result<AvRtLibrary, WIN32_ERROR>> = sync::OnceCell::new();
AV_RT_LIBRARY
.get_or_init(AvRtLibrary::try_new)
.as_ref()
.map_err(|win32_error| {
AudioThreadPriorityError::new(&format!("Unable to load avrt.dll ({win32_error})"))
})
}

// We don't expect to fail to load the library on test machines
#[test]
fn test_successful_avrt_library_load_as_static_ref() {
assert!(av_rt().is_ok())
}

type AvSetMmThreadCharacteristicsAFn = unsafe extern "system" fn(PCSTR, *mut u32) -> HANDLE;
type AvRevertMmThreadCharacteristicsFn = unsafe extern "system" fn(HANDLE) -> BOOL;

#[derive(Debug)]
struct AvRtLibrary {
// This field is used for its Drop behavior
#[allow(dead_code)]
module: OwnedLibrary,
av_set_mm_thread_characteristics_a: AvSetMmThreadCharacteristicsAFn,
av_revert_mm_thread_characteristics: AvRevertMmThreadCharacteristicsFn,
}

impl AvRtLibrary {
fn try_new() -> Result<Self, WIN32_ERROR> {
let module = unsafe { LoadLibraryA(s!("avrt.dll")) };
if module != 0 {
let module = OwnedLibrary(module);
let set_fn =
unsafe { GetProcAddress(module.raw(), s!("AvSetMmThreadCharacteristicsA")) };
if let Some(set_fn) = set_fn {
let revert_fn =
unsafe { GetProcAddress(module.raw(), s!("AvRevertMmThreadCharacteristics")) };
if let Some(revert_fn) = revert_fn {
let av_set_mm_thread_characteristics_a = unsafe {
std::mem::transmute::<_, AvSetMmThreadCharacteristicsAFn>(set_fn)
};
let av_revert_mm_thread_characteristics = unsafe {
std::mem::transmute::<_, AvRevertMmThreadCharacteristicsFn>(revert_fn)
};
return Ok(AvRtLibrary {
module,
av_set_mm_thread_characteristics_a,
av_revert_mm_thread_characteristics,
});
}
}
}
Err(unsafe { GetLastError() })
}
}

// We don't expect to fail to load the library on test machines
#[test]
fn test_successful_temporary_avrt_library_load() {
assert!(AvRtLibrary::try_new().is_ok())
}

#[derive(Debug)]
struct OwnedLibrary(HMODULE);

impl OwnedLibrary {
fn raw(&self) -> HMODULE {
self.0
}
}

impl Drop for OwnedLibrary {
fn drop(&mut self) {
unsafe {
FreeLibrary(self.raw());
}
}
}

0 comments on commit 7252f2f

Please sign in to comment.