Skip to content

Commit

Permalink
chore(gpu): add scalar div and signed scalar div to hl api
Browse files Browse the repository at this point in the history
Also add overflowing sub to hl
  • Loading branch information
agnesLeroy committed Sep 18, 2024
1 parent 7b497d3 commit 50019e0
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 21 deletions.
36 changes: 27 additions & 9 deletions tfhe/src/high_level_api/integers/signed/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ where
// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
key_method: $key_method:ident,
// A 'list' of tuple, where the first element is the concrete Fhe type
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
fhe_and_scalar_type: $(
Expand Down Expand Up @@ -393,15 +392,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
InternalServerKey::Cpu(cpu_key) => {
let (q, r) = cpu_key
.pbs_key()
.$key_method(&*self.ciphertext.on_cpu(), rhs);
.signed_scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
(
<$concrete_type>::new(q, cpu_key.tag.clone()),
<$concrete_type>::new(r, cpu_key.tag.clone())
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices does not support div rem yet")
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
cuda_key.key.signed_scalar_div_rem(
&*self.ciphertext.on_gpu(), rhs, streams
)
});
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
(
<$concrete_type>::new(q, cuda_key.tag.clone()),
<$concrete_type>::new(r, cuda_key.tag.clone())
)
}
})
}
Expand All @@ -410,8 +418,8 @@ macro_rules! generic_integer_impl_scalar_div_rem {
)* // Closing first repeating pattern
};
}

generic_integer_impl_scalar_div_rem!(
key_method: signed_scalar_div_rem_parallelized,
fhe_and_scalar_type:
(super::FheInt2, i8),
(super::FheInt4, i8),
Expand Down Expand Up @@ -826,8 +834,13 @@ generic_integer_impl_scalar_operation!(
RadixCiphertext::Cpu(inner_result)
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Div '/' with clear value is not yet supported by Cuda devices")
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
cuda_key.key.signed_scalar_div(
&lhs.ciphertext.on_gpu(), rhs, streams
)
});
RadixCiphertext::Cuda(inner_result)
}
})
}
Expand Down Expand Up @@ -859,8 +872,13 @@ generic_integer_impl_scalar_operation!(
RadixCiphertext::Cpu(inner_result)
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Rem '%' with clear value is not yet supported by Cuda devices")
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
cuda_key.key.signed_scalar_rem(
&lhs.ciphertext.on_gpu(), rhs, streams
)
});
RadixCiphertext::Cuda(inner_result)
}
})
}
Expand Down
14 changes: 11 additions & 3 deletions tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,17 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support overflowing_sub yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner_result = cuda_key.key.unsigned_overflowing_sub(
&self.ciphertext.on_gpu(),
&other.ciphertext.on_gpu(),
streams,
);
(
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
FheBool::new(inner_result.1, cuda_key.tag.clone()),
)
}),
})
}
}
Expand Down
35 changes: 26 additions & 9 deletions tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@ where
// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
key_method: $key_method:ident,
// A 'list' of tuple, where the first element is the concrete Fhe type
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
fhe_and_scalar_type: $(
Expand All @@ -473,15 +472,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
global_state::with_internal_keys(|key| {
match key {
InternalServerKey::Cpu(cpu_key) => {
let (q, r) = cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs);
let (q, r) = cpu_key.pbs_key().scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
(
<$concrete_type>::new(q, cpu_key.tag.clone()),
<$concrete_type>::new(r, cpu_key.tag.clone())
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support div_rem yet");
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
cuda_key.key.scalar_div_rem(
&*self.ciphertext.on_gpu(), rhs, streams
)
});
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
(
<$concrete_type>::new(q, cuda_key.tag.clone()),
<$concrete_type>::new(r, cuda_key.tag.clone())
)
}
}
})
Expand All @@ -492,7 +500,6 @@ macro_rules! generic_integer_impl_scalar_div_rem {
};
}
generic_integer_impl_scalar_div_rem!(
key_method: scalar_div_rem_parallelized,
fhe_and_scalar_type:
(super::FheUint2, u8),
(super::FheUint4, u8),
Expand Down Expand Up @@ -978,8 +985,13 @@ generic_integer_impl_scalar_operation!(
RadixCiphertext::Cpu(inner_result)
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Div '/' with clear value is not yet supported by Cuda devices")
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
cuda_key.key.scalar_div(
&lhs.ciphertext.on_gpu(), rhs, streams
)
});
RadixCiphertext::Cuda(inner_result)
}
})
}
Expand Down Expand Up @@ -1014,8 +1026,13 @@ generic_integer_impl_scalar_operation!(
RadixCiphertext::Cpu(inner_result)
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Rem '%' with clear value is not yet supported by Cuda devices")
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
cuda_key.key.scalar_rem(
&lhs.ciphertext.on_gpu(), rhs, streams
)
});
RadixCiphertext::Cuda(inner_result)
}
})
}
Expand Down

0 comments on commit 50019e0

Please sign in to comment.