diff --git a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs index 43779dda08..0b9c6d1bed 100644 --- a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs @@ -365,7 +365,6 @@ where // DivRem is a bit special as it returns a tuple of quotient and remainder macro_rules! generic_integer_impl_scalar_div_rem { ( - key_method: $key_method:ident, // A 'list' of tuple, where the first element is the concrete Fhe type // e.g (FheUint8 and the rest is scalar types (u8, u16, etc) fhe_and_scalar_type: $( @@ -393,15 +392,24 @@ macro_rules! generic_integer_impl_scalar_div_rem { InternalServerKey::Cpu(cpu_key) => { let (q, r) = cpu_key .pbs_key() - .$key_method(&*self.ciphertext.on_cpu(), rhs); + .signed_scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs); ( <$concrete_type>::new(q, cpu_key.tag.clone()), <$concrete_type>::new(r, cpu_key.tag.clone()) ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices does not support div rem yet") + InternalServerKey::Cuda(cuda_key) => { + let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { + cuda_key.key.signed_scalar_div_rem( + &*self.ciphertext.on_gpu(), rhs, streams + ) + }); + let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r)); + ( + <$concrete_type>::new(q, cuda_key.tag.clone()), + <$concrete_type>::new(r, cuda_key.tag.clone()) + ) } }) } @@ -410,8 +418,8 @@ macro_rules! generic_integer_impl_scalar_div_rem { )* // Closing first repeating pattern }; } + generic_integer_impl_scalar_div_rem!( - key_method: signed_scalar_div_rem_parallelized, fhe_and_scalar_type: (super::FheInt2, i8), (super::FheInt4, i8), @@ -826,8 +834,13 @@ generic_integer_impl_scalar_operation!( RadixCiphertext::Cpu(inner_result) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Div '/' with clear value is not yet supported by Cuda devices") + InternalServerKey::Cuda(cuda_key) => { + let inner_result = with_thread_local_cuda_streams(|streams| { + cuda_key.key.signed_scalar_div( + &lhs.ciphertext.on_gpu(), rhs, streams + ) + }); + RadixCiphertext::Cuda(inner_result) } }) } @@ -859,8 +872,13 @@ generic_integer_impl_scalar_operation!( RadixCiphertext::Cpu(inner_result) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Rem '%' with clear value is not yet supported by Cuda devices") + InternalServerKey::Cuda(cuda_key) => { + let inner_result = with_thread_local_cuda_streams(|streams| { + cuda_key.key.signed_scalar_rem( + &lhs.ciphertext.on_gpu(), rhs, streams + ) + }); + RadixCiphertext::Cuda(inner_result) } }) } diff --git a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs index 709c91cafa..6f46a0714d 100644 --- a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs @@ -285,9 +285,17 @@ where ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support overflowing_sub yet"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let inner_result = cuda_key.key.unsigned_overflowing_sub( + &self.ciphertext.on_gpu(), + &other.ciphertext.on_gpu(), + streams, + ); + ( + FheUint::::new(inner_result.0, cuda_key.tag.clone()), + FheBool::new(inner_result.1, cuda_key.tag.clone()), + ) + }), }) } } diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 451093555a..a5ddf6cf31 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -446,7 +446,6 @@ where // DivRem is a bit special as it returns a tuple of quotient and remainder macro_rules! generic_integer_impl_scalar_div_rem { ( - key_method: $key_method:ident, // A 'list' of tuple, where the first element is the concrete Fhe type // e.g (FheUint8 and the rest is scalar types (u8, u16, etc) fhe_and_scalar_type: $( @@ -473,15 +472,24 @@ macro_rules! generic_integer_impl_scalar_div_rem { global_state::with_internal_keys(|key| { match key { InternalServerKey::Cpu(cpu_key) => { - let (q, r) = cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs); + let (q, r) = cpu_key.pbs_key().scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs); ( <$concrete_type>::new(q, cpu_key.tag.clone()), <$concrete_type>::new(r, cpu_key.tag.clone()) ) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Cuda devices do not support div_rem yet"); + InternalServerKey::Cuda(cuda_key) => { + let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { + cuda_key.key.scalar_div_rem( + &*self.ciphertext.on_gpu(), rhs, streams + ) + }); + let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r)); + ( + <$concrete_type>::new(q, cuda_key.tag.clone()), + <$concrete_type>::new(r, cuda_key.tag.clone()) + ) } } }) @@ -492,7 +500,6 @@ macro_rules! generic_integer_impl_scalar_div_rem { }; } generic_integer_impl_scalar_div_rem!( - key_method: scalar_div_rem_parallelized, fhe_and_scalar_type: (super::FheUint2, u8), (super::FheUint4, u8), @@ -978,8 +985,13 @@ generic_integer_impl_scalar_operation!( RadixCiphertext::Cpu(inner_result) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Div '/' with clear value is not yet supported by Cuda devices") + InternalServerKey::Cuda(cuda_key) => { + let inner_result = with_thread_local_cuda_streams(|streams| { + cuda_key.key.scalar_div( + &lhs.ciphertext.on_gpu(), rhs, streams + ) + }); + RadixCiphertext::Cuda(inner_result) } }) } @@ -1014,8 +1026,13 @@ generic_integer_impl_scalar_operation!( RadixCiphertext::Cpu(inner_result) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - panic!("Rem '%' with clear value is not yet supported by Cuda devices") + InternalServerKey::Cuda(cuda_key) => { + let inner_result = with_thread_local_cuda_streams(|streams| { + cuda_key.key.scalar_rem( + &lhs.ciphertext.on_gpu(), rhs, streams + ) + }); + RadixCiphertext::Cuda(inner_result) } }) }