Skip to content

Commit

Permalink
chore(gpu): remove remaining par_iter over gpu_indexes
Browse files Browse the repository at this point in the history
Rename some variables to try and make the code clearer
  • Loading branch information
agnesLeroy committed Jul 31, 2024
1 parent eba4f6a commit 0a2ad8c
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 149 deletions.
5 changes: 2 additions & 3 deletions tfhe/src/core_crypto/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::core_crypto::prelude::{
};
pub use algorithms::*;
pub use entities::*;
use rayon::prelude::*;
use std::ffi::c_void;
pub(crate) use tfhe_cuda_backend::cuda_bind::*;

Expand Down Expand Up @@ -284,7 +283,7 @@ pub unsafe fn convert_lwe_programmable_bootstrap_key_async<T: UnsignedInteger>(
polynomial_size: PolynomialSize,
) {
let size = std::mem::size_of_val(src);
streams.gpu_indexes.par_iter().for_each(|&gpu_index| {
for &gpu_index in streams.gpu_indexes.iter() {
assert_eq!(dest.len() * std::mem::size_of::<T>(), size);
cuda_convert_lwe_programmable_bootstrap_key_64(
streams.ptr[gpu_index as usize],
Expand All @@ -296,7 +295,7 @@ pub unsafe fn convert_lwe_programmable_bootstrap_key_async<T: UnsignedInteger>(
l_gadget.0 as u32,
polynomial_size.0 as u32,
);
});
}
}

/// Convert multi-bit programmable bootstrap key
Expand Down
78 changes: 40 additions & 38 deletions tfhe/src/core_crypto/gpu/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,24 @@ where
///
/// - [CudaStreams::synchronize] __must__ be called after the copy as soon as synchronization is
/// required.
pub unsafe fn copy_from_gpu_async(&mut self, src: &Self, streams: &CudaStreams, gpu_index: u32)
where
pub unsafe fn copy_from_gpu_async(
&mut self,
src: &Self,
streams: &CudaStreams,
stream_index: u32,
) where
T: Numeric,
{
assert_eq!(self.len(gpu_index), src.len(gpu_index));
let size = src.len(gpu_index) * std::mem::size_of::<T>();
assert_eq!(self.len(stream_index), src.len(stream_index));
let size = src.len(stream_index) * std::mem::size_of::<T>();
// We check that src is not empty to avoid invalid pointers
if size > 0 {
cuda_memcpy_async_gpu_to_gpu(
self.as_mut_c_ptr(gpu_index),
src.as_c_ptr(gpu_index),
self.as_mut_c_ptr(stream_index),
src.as_c_ptr(stream_index),
size as u64,
streams.ptr[gpu_index as usize],
streams.gpu_indexes[gpu_index as usize],
streams.ptr[stream_index as usize],
streams.gpu_indexes[stream_index as usize],
);
}
}
Expand All @@ -118,105 +122,103 @@ where
///
/// - [CudaStreams::synchronize] __must__ be called after the copy as soon as synchronization is
/// required.
pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], streams: &CudaStreams, gpu_index: u32)
pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], streams: &CudaStreams, stream_index: u32)
where
T: Numeric,
{
assert_eq!(self.len(gpu_index), dest.len());
let size = self.len(gpu_index) * std::mem::size_of::<T>();
assert_eq!(self.len(stream_index), dest.len());
let size = self.len(stream_index) * std::mem::size_of::<T>();
// We check that src is not empty to avoid invalid pointers
if size > 0 {
cuda_memcpy_async_to_cpu(
dest.as_mut_ptr().cast::<c_void>(),
self.as_c_ptr(gpu_index),
self.as_c_ptr(stream_index),
size as u64,
streams.ptr[gpu_index as usize],
streams.gpu_indexes[gpu_index as usize],
streams.ptr[stream_index as usize],
streams.gpu_indexes[stream_index as usize],
);
}
}

/// Returns the number of elements in the vector, also referred to as its ‘length’.
pub fn len(&self, gpu_index: u32) -> usize {
self.lengths[gpu_index as usize]
pub fn len(&self, index: u32) -> usize {
self.lengths[index as usize]
}

/// Returns true if the ptr is empty
pub fn is_empty(&self, gpu_index: u32) -> bool {
self.lengths[gpu_index as usize] == 0
pub fn is_empty(&self, index: u32) -> bool {
self.lengths[index as usize] == 0
}

pub(crate) fn get_mut<R>(&mut self, range: R, gpu_index: u32) -> Option<CudaSliceMut<T>>
pub(crate) fn get_mut<R>(&mut self, range: R, index: u32) -> Option<CudaSliceMut<T>>
where
R: std::ops::RangeBounds<usize>,
T: Numeric,
{
let (start, end) = range_bounds_to_start_end(self.len(gpu_index), range).into_inner();
let (start, end) = range_bounds_to_start_end(self.len(index), range).into_inner();

// Check the range is compatible with the vec
if end <= start || end > self.lengths[gpu_index as usize] - 1 {
if end <= start || end > self.lengths[index as usize] - 1 {
None
} else {
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptrs[gpu_index as usize].wrapping_byte_add(start * std::mem::size_of::<T>());
self.ptrs[index as usize].wrapping_byte_add(start * std::mem::size_of::<T>());

// Compute the length
let new_len = end - start + 1;

// Create the slice
Some(unsafe {
CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[gpu_index as usize])
CudaSliceMut::new(shifted_ptr, new_len, self.gpu_indexes[index as usize])
})
}
}

pub(crate) fn split_at_mut(
&mut self,
mid: usize,
gpu_index: u32,
index: u32,
) -> (Option<CudaSliceMut<T>>, Option<CudaSliceMut<T>>)
where
T: Numeric,
{
// Check the index is compatible with the vec
if mid > self.lengths[gpu_index as usize] - 1 {
if mid > self.lengths[index as usize] - 1 {
(None, None)
} else if mid == 0 {
(
None,
Some(unsafe {
CudaSliceMut::new(
self.ptrs[gpu_index as usize],
self.lengths[gpu_index as usize],
gpu_index,
self.ptrs[index as usize],
self.lengths[index as usize],
index,
)
}),
)
} else if mid == self.lengths[gpu_index as usize] - 1 {
} else if mid == self.lengths[index as usize] - 1 {
(
Some(unsafe {
CudaSliceMut::new(
self.ptrs[gpu_index as usize],
self.lengths[gpu_index as usize],
gpu_index,
self.ptrs[index as usize],
self.lengths[index as usize],
index,
)
}),
None,
)
} else {
let new_len_1 = mid;
let new_len_2 = self.lengths[gpu_index as usize] - mid;
let new_len_2 = self.lengths[index as usize] - mid;
// Shift ptr
let shifted_ptr: *mut c_void =
self.ptrs[gpu_index as usize].wrapping_byte_add(mid * std::mem::size_of::<T>());
self.ptrs[index as usize].wrapping_byte_add(mid * std::mem::size_of::<T>());

// Create the slice
(
Some(unsafe {
CudaSliceMut::new(self.ptrs[gpu_index as usize], new_len_1, gpu_index)
}),
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len_2, gpu_index) }),
Some(unsafe { CudaSliceMut::new(self.ptrs[index as usize], new_len_1, index) }),
Some(unsafe { CudaSliceMut::new(shifted_ptr, new_len_2, index) }),
)
}
}
Expand Down
Loading

0 comments on commit 0a2ad8c

Please sign in to comment.