Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make the simd_runtime_generate macro explicitly unsafe #52

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ use simdeez::*;
result
});
fn main() {
unsafe {
let _dist = distance_runtime_select(&0.0, &0.0, &10.0, &10.0);
}
}
```
This will generate 5 functions for you:
This will generate 5 unsafe functions for you:
* `distance<S:Simd>` the generic version of your function
* `distance_scalar` a scalar fallback
* `distance_sse2` SSE2 version
Expand Down
35 changes: 19 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,13 @@
//! result
//! });
//! # fn main() {
//! unsafe {
//! let _dist = distance_runtime_select(&0.0, &0.0, &10.0, &10.0);
//! }
//! # }
//! ```
//!
//! This will generate 5 functions for you:
//! This will generate 5 unsafe functions for you:
//! * `distance<S:Simd>` the generic version of your function
//! * `distance_scalar` a scalar fallback
//! * `distance_sse2` SSE2 version
Expand Down Expand Up @@ -631,7 +634,7 @@ pub trait Simd: Sync + Send {
/// * SSE41 (fn_name_sse41)
/// * SSE2 (fn_name_sse2)
/// * Scalar fallback (fn_name_scalar)
/// Finally, it also generates a function which will select at runtime the fastest version
/// Finally, it also generates a unsafe function which will select at runtime the fastest version
/// from above that the cpu supports. (fn_name_runtime_select)
#[macro_export]
macro_rules! simd_runtime_generate {
Expand All @@ -658,23 +661,23 @@ macro_rules! simd_runtime_generate {
$vis unsafe fn [<$fn_name _avx2>]($($arg:$typ,)*) $(-> $rt)? {
$fn_name::<Avx2>($($arg,)*)
}
$vis fn [<$fn_name _runtime_select>]($($arg:$typ,)*) $(-> $rt)? {
$vis unsafe fn [<$fn_name _runtime_select>]($($arg:$typ,)*) $(-> $rt)? {
if is_x86_feature_detected!("avx2") {
unsafe { [<$fn_name _avx2>]($($arg,)*) }
[<$fn_name _avx2>]($($arg,)*)
} else if is_x86_feature_detected!("sse4.1") {
unsafe { [<$fn_name _sse41>]($($arg,)*) }
[<$fn_name _sse41>]($($arg,)*)
} else if is_x86_feature_detected!("sse2") {
unsafe { [<$fn_name _sse2>]($($arg,)*) }
[<$fn_name _sse2>]($($arg,)*)
} else {
unsafe { [<$fn_name _scalar>]($($arg,)*) }
[<$fn_name _scalar>]($($arg,)*)
}
}
}
};

}

/// Generates a generic version of your function (fn_name)
/// Generates a generic unsafe version of your function (fn_name)
/// And the fastest version supported by your rust compilation settings
/// (fn_name_compiletime)
#[macro_export]
Expand All @@ -686,22 +689,22 @@ macro_rules! simd_compiletime_generate {

paste::item! {
#[cfg(target_feature = "avx2")]
$vis fn [<$fn_name _compiletime>]($($arg:$typ,)*) $(-> $rt)? {
unsafe { $fn_name::<Avx2>($($arg,)*) }
$vis unsafe fn [<$fn_name _compiletime>]($($arg:$typ,)*) $(-> $rt)? {
$fn_name::<Avx2>($($arg,)*)
}

#[cfg(all(target_feature = "sse4.1",not(target_feature = "avx2")))]
$vis fn [<$fn_name _compiletime>]($($arg:$typ,)*) $(-> $rt)? {
unsafe { $fn_name::<Sse41>($($arg,)*) }
$vis unsafe fn [<$fn_name _compiletime>]($($arg:$typ,)*) $(-> $rt)? {
$fn_name::<Sse41>($($arg,)*)
}
#[cfg(all(target_feature = "sse2",not(any(target_feature="sse4.1",target_feature = "avx2"))))]
$vis fn [<$fn_name _compiletime>]($($arg:$typ,)*) $(-> $rt)? {
unsafe { $fn_name::<Sse2>($($arg,)*) }
$vis unsafe fn [<$fn_name _compiletime>]($($arg:$typ,)*) $(-> $rt)? {
$fn_name::<Sse2>($($arg,)*)
}

#[cfg(not(any(target_feature="sse4.1",target_feature = "avx2",target_feature="sse2")))]
$vis fn [<$fn_name _compiletime>]($($arg:$typ,)*) $(-> $rt)? {
unsafe { $fn_name::<Scalar>($($arg,)*) }
$vis unsafe fn [<$fn_name _compiletime>]($($arg:$typ,)*) $(-> $rt)? {
$fn_name::<Scalar>($($arg,)*)
}


Expand Down