diff --git a/crates/core_simd/src/ops.rs b/crates/core_simd/src/ops.rs index b007456cf2c..63a96106283 100644 --- a/crates/core_simd/src/ops.rs +++ b/crates/core_simd/src/ops.rs @@ -6,6 +6,7 @@ use core::ops::{Shl, Shr}; mod assign; mod deref; +mod shift_scalar; mod unary; impl core::ops::Index for Simd diff --git a/crates/core_simd/src/ops/shift_scalar.rs b/crates/core_simd/src/ops/shift_scalar.rs new file mode 100644 index 00000000000..77aac656395 --- /dev/null +++ b/crates/core_simd/src/ops/shift_scalar.rs @@ -0,0 +1,58 @@ +// Shift operations uniquely typically only have a scalar on the right-hand side. +// Here, we implement shifts for scalar RHS arguments. + +use crate::simd::{LaneCount, Simd, SupportedLaneCount}; + +macro_rules! impl_splatted_shifts { + { impl $trait:ident :: $trait_fn:ident for $ty:ty } => { + impl core::ops::$trait<$ty> for Simd<$ty, N> + where + LaneCount: SupportedLaneCount, + { + type Output = Self; + fn $trait_fn(self, rhs: $ty) -> Self::Output { + self.$trait_fn(Simd::splat(rhs)) + } + } + + impl core::ops::$trait<&$ty> for Simd<$ty, N> + where + LaneCount: SupportedLaneCount, + { + type Output = Self; + fn $trait_fn(self, rhs: &$ty) -> Self::Output { + self.$trait_fn(Simd::splat(*rhs)) + } + } + + impl<'lhs, const N: usize> core::ops::$trait<$ty> for &'lhs Simd<$ty, N> + where + LaneCount: SupportedLaneCount, + { + type Output = Simd<$ty, N>; + fn $trait_fn(self, rhs: $ty) -> Self::Output { + self.$trait_fn(Simd::splat(rhs)) + } + } + + impl<'lhs, const N: usize> core::ops::$trait<&$ty> for &'lhs Simd<$ty, N> + where + LaneCount: SupportedLaneCount, + { + type Output = Simd<$ty, N>; + fn $trait_fn(self, rhs: &$ty) -> Self::Output { + self.$trait_fn(Simd::splat(*rhs)) + } + } + }; + { $($ty:ty),* } => { + $( + impl_splatted_shifts! { impl Shl::shl for $ty } + impl_splatted_shifts! { impl Shr::shr for $ty } + )* + } +} + +// In the past there were inference issues when generically splatting arguments. +// Enumerate them instead. +impl_splatted_shifts! { i8, i16, i32, i64, isize, u8, u16, u32, u64, usize } diff --git a/crates/core_simd/tests/ops_macros.rs b/crates/core_simd/tests/ops_macros.rs index 3a02f3f01e1..dfc0e1a3708 100644 --- a/crates/core_simd/tests/ops_macros.rs +++ b/crates/core_simd/tests/ops_macros.rs @@ -94,6 +94,36 @@ macro_rules! impl_binary_checked_op_test { macro_rules! impl_common_integer_tests { { $vector:ident, $scalar:ident } => { test_helpers::test_lanes! { + fn shr() { + use core::ops::Shr; + let shr = |x: $scalar, y: $scalar| x.wrapping_shr(y as _); + test_helpers::test_binary_elementwise( + &<$vector:: as Shr<$vector::>>::shr, + &shr, + &|_, _| true, + ); + test_helpers::test_binary_scalar_rhs_elementwise( + &<$vector:: as Shr<$scalar>>::shr, + &shr, + &|_, _| true, + ); + } + + fn shl() { + use core::ops::Shl; + let shl = |x: $scalar, y: $scalar| x.wrapping_shl(y as _); + test_helpers::test_binary_elementwise( + &<$vector:: as Shl<$vector::>>::shl, + &shl, + &|_, _| true, + ); + test_helpers::test_binary_scalar_rhs_elementwise( + &<$vector:: as Shl<$scalar>>::shl, + &shl, + &|_, _| true, + ); + } + fn reduce_sum() { test_helpers::test_1(&|x| { test_helpers::prop_assert_biteq! (