Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
further reduce temporary memory requirement for generate_permutation_…
Browse files Browse the repository at this point in the history
…matrix
  • Loading branch information
robik75 committed Jan 26, 2024
1 parent a7972b4 commit d530195
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 39 deletions.
16 changes: 9 additions & 7 deletions native/ops_complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,18 @@ EXTERN __global__ void pack_variable_indexes_kernel(const uint64_t *src, uint32_
memory::store_cs(dst + gid, u32);
}

EXTERN __global__ void mark_ends_of_runs_kernel(const unsigned *run_lengths, const unsigned *run_offsets, unsigned *result, const unsigned count) {
EXTERN __global__ void mark_ends_of_runs_kernel(const unsigned *num_runs_out, const unsigned *run_offsets, unsigned *result, const unsigned count) {
const unsigned gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid >= count)
return;
const unsigned run_length = run_lengths[gid];
if (run_length == 0)
const unsigned num_runs = *num_runs_out;
if (gid >= num_runs)
return;
const unsigned run_offset = run_offsets[gid];
const unsigned run_offset_next = gid == num_runs - 1 ? count : run_offsets[gid + 1];
const unsigned run_length = run_offset_next - run_offset;
result[run_offset + run_length - 1] = 1;
}

EXTERN __global__ void generate_permutation_matrix_kernel(const unsigned *unique_variable_indexes, const unsigned *run_indexes, const unsigned *run_lengths,
EXTERN __global__ void generate_permutation_matrix_kernel(const unsigned *unique_variable_indexes, const unsigned *run_indexes, const unsigned *num_runs_out,
const unsigned *run_offsets, const unsigned *cell_indexes, const base_field *scalars,
base_field *result, const unsigned columns_count, const unsigned log_rows_count) {
const unsigned gid = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -253,8 +253,10 @@ EXTERN __global__ void generate_permutation_matrix_kernel(const unsigned *unique
const unsigned last_run_index = run_indexes[count - 1];
const unsigned last_run_variable_index = unique_variable_indexes[last_run_index];
const unsigned run_index = run_indexes[gid];
const unsigned run_length = run_lengths[run_index];
const unsigned run_offset = run_offsets[run_index];
const unsigned num_runs = *num_runs_out;
const unsigned run_offset_next = run_index == num_runs - 1 ? count : run_offsets[run_index + 1];
const unsigned run_length = run_offset_next - run_offset;
const unsigned src_in_run_index = gid - run_offset;
const bool is_placeholder = run_index == last_run_index && last_run_variable_index == (1U << 31);
const unsigned dst_in_run_index = is_placeholder ? src_in_run_index : (src_in_run_index + 1) % run_length;
Expand Down
56 changes: 24 additions & 32 deletions src/ops_complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use crate::ops_simple::{set_by_val, set_to_zero};
use crate::utils::{get_grid_block_dims_for_threads_count, WARP_SIZE};
use crate::BaseField;
use cudart::execution::{CudaLaunchConfig, Dim3, KernelFunction};
use cudart::memory::memory_copy_async;
use cudart::paste::paste;
use cudart::result::CudaResult;
use cudart::slice::{DeviceSlice, DeviceVariable};
Expand Down Expand Up @@ -462,29 +461,28 @@ pub fn select(
cuda_kernel!(
MarkEndsOfRuns,
mark_ends_of_runs_kernel(
run_lengths: *const u32,
num_runs_out: *const u32,
run_offsets: *const u32,
result: *mut u32,
count: u32,
)
);

pub fn mark_ends_of_runs(
run_lengths: &DeviceSlice<u32>,
num_runs_out: &DeviceVariable<u32>,
run_offsets: &DeviceSlice<u32>,
result: &mut DeviceSlice<u32>,
stream: &CudaStream,
) -> CudaResult<()> {
assert_eq!(run_lengths.len(), run_offsets.len());
assert!(run_lengths.len() <= u32::MAX as usize);
assert!(result.len() <= u32::MAX as usize);
let count = run_lengths.len() as u32;
assert!(run_offsets.len() <= u32::MAX as usize);
assert_eq!(run_offsets.len(), result.len());
let count = run_offsets.len() as u32;
let (grid_dim, block_dim) = get_launch_dims(count);
let run_lengths = run_lengths.as_ptr();
let num_runs_out = num_runs_out.as_ptr();
let run_offsets = run_offsets.as_ptr();
let result = result.as_mut_ptr();
let config = CudaLaunchConfig::basic(grid_dim, block_dim, stream);
let args = MarkEndsOfRunsArguments::new(run_lengths, run_offsets, result, count);
let args = MarkEndsOfRunsArguments::new(num_runs_out, run_offsets, result, count);
MarkEndsOfRunsFunction::default().launch(&config, &args)
}

Expand All @@ -493,7 +491,7 @@ cuda_kernel!(
generate_permutation_matrix_kernel(
unique_variable_indexes: *const u32,
run_indexes: *const u32,
run_lengths: *const u32,
num_runs_out: *const u32,
run_offsets: *const u32,
cell_indexes: *const u32,
scalars: *const BF,
Expand All @@ -507,17 +505,16 @@ cuda_kernel!(
fn generate_permutation_matrix_raw(
unique_variable_indexes: &DeviceSlice<u32>,
run_indexes: &DeviceSlice<u32>,
run_lengths: &DeviceSlice<u32>,
num_runs_out: &DeviceVariable<u32>,
run_offsets: &DeviceSlice<u32>,
cell_indexes: &DeviceSlice<u32>,
scalars: &DeviceSlice<BF>,
result: &mut DeviceSlice<BF>,
stream: &CudaStream,
) -> CudaResult<()> {
assert!(run_indexes.len() <= u32::MAX as usize);
assert_eq!(run_lengths.len(), run_offsets.len());
assert_eq!(run_lengths.len(), unique_variable_indexes.len());
assert!(run_lengths.len() <= u32::MAX as usize);
assert_eq!(run_indexes.len(), run_offsets.len());
assert_eq!(run_indexes.len(), unique_variable_indexes.len());
assert!(cell_indexes.len() <= u32::MAX as usize);
assert!(scalars.len() <= u32::MAX as usize);
let columns_count = scalars.len() as u32;
Expand All @@ -530,7 +527,7 @@ fn generate_permutation_matrix_raw(
let log_rows_count = rows_count.ilog2();
let unique_variable_indexes = unique_variable_indexes.as_ptr();
let run_indexes = run_indexes.as_ptr();
let run_lengths = run_lengths.as_ptr();
let num_runs_out = num_runs_out.as_ptr();
let run_offsets = run_offsets.as_ptr();
let cell_indexes = cell_indexes.as_ptr();
let scalars = scalars.as_ptr();
Expand All @@ -539,7 +536,7 @@ fn generate_permutation_matrix_raw(
let args = GeneratePermutationMatrixArguments::new(
unique_variable_indexes,
run_indexes,
run_lengths,
num_runs_out,
run_offsets,
cell_indexes,
scalars,
Expand All @@ -555,13 +552,13 @@ pub fn get_generate_permutation_matrix_temp_storage_bytes(num_cells: usize) -> C
let num_bytes = num_cells * cell_size;
let sort_pairs_tsb =
get_sort_pairs_temp_storage_bytes::<u32, u32>(false, num_cells as u32, 0, 32)?;
assert!(sort_pairs_tsb <= 4 * num_bytes);
assert!(sort_pairs_tsb <= 3 * num_bytes + cell_size);
let encode_tsb = get_encode_temp_storage_bytes::<u32>(num_cells as i32)?;
assert!(encode_tsb <= 2 * num_bytes - cell_size);
assert!(encode_tsb <= num_bytes);
let scan_tsb =
get_scan_temp_storage_bytes::<u32>(ScanOperation::Sum, false, false, num_cells as i32)?;
assert!(scan_tsb <= 2 * num_bytes);
Ok(5 * num_bytes)
Ok(4 * num_bytes + cell_size)
}

pub fn generate_permutation_matrix(
Expand Down Expand Up @@ -604,22 +601,19 @@ pub fn generate_permutation_matrix(
let unique_variable_indexes = unsafe { unique_variable_indexes.transmute_mut() };
let (run_lengths, temp_storage) = temp_storage.split_at_mut(num_bytes);
let run_lengths = unsafe { run_lengths.transmute_mut() };
let (num_runs_out, encode_temp_storage) = temp_storage.split_at_mut(cell_size);
let (temp_storage, num_runs_out) = temp_storage.split_at_mut(num_bytes);
let num_runs_out = unsafe { &mut num_runs_out.transmute_mut()[0] };
set_to_zero(run_lengths, stream)?;
encode(
encode_temp_storage,
temp_storage,
sorted_variable_indexes,
unique_variable_indexes,
run_lengths,
num_runs_out,
stream,
)?;
let (run_offsets, run_indexes) = temp_storage.split_at_mut(num_cells * 4);
let run_offsets = unsafe { run_offsets.transmute_mut() };
let run_indexes = unsafe { run_indexes.transmute_mut() };
let run_offsets = run_lengths;
let run_indexes = unsafe { temp_storage.transmute_mut() };
let temp_storage = unsafe { result.transmute_mut() };
memory_copy_async(run_offsets, run_lengths, stream)?;
scan_in_place(
ScanOperation::Sum,
false,
Expand All @@ -629,7 +623,7 @@ pub fn generate_permutation_matrix(
stream,
)?;
set_to_zero(run_indexes, stream)?;
mark_ends_of_runs(run_lengths, run_offsets, run_indexes, stream)?;
mark_ends_of_runs(num_runs_out, run_offsets, run_indexes, stream)?;
scan_in_place(
ScanOperation::Sum,
false,
Expand All @@ -641,7 +635,7 @@ pub fn generate_permutation_matrix(
generate_permutation_matrix_raw(
unique_variable_indexes,
run_indexes,
run_lengths,
num_runs_out,
run_offsets,
sorted_cell_indexes,
scalars,
Expand Down Expand Up @@ -1850,11 +1844,9 @@ mod tests {
let mut d_unique_out = DeviceAllocation::alloc(N).unwrap();
let mut d_counts_out = DeviceAllocation::alloc(N).unwrap();
let mut d_num_runs_out = DeviceAllocation::alloc(1).unwrap();
let mut d_offsets = DeviceAllocation::alloc(N).unwrap();
let mut d_result = DeviceAllocation::alloc(N).unwrap();
let stream = CudaStream::default();
memory_copy_async(&mut d_in, &h_in, &stream).unwrap();
set_to_zero(&mut d_counts_out, &stream).unwrap();
encode(
&mut d_temp_storage,
&d_in,
Expand All @@ -1864,7 +1856,7 @@ mod tests {
&stream,
)
.unwrap();
memory_copy_async(&mut d_offsets, &d_counts_out, &stream).unwrap();
let mut d_offsets = d_counts_out;
scan_in_place(
ScanOperation::Sum,
false,
Expand All @@ -1875,7 +1867,7 @@ mod tests {
)
.unwrap();
set_to_zero(&mut d_result, &stream).unwrap();
super::mark_ends_of_runs(&d_counts_out, &d_offsets, &mut d_result, &stream).unwrap();
super::mark_ends_of_runs(&d_num_runs_out[0], &d_offsets, &mut d_result, &stream).unwrap();
memory_copy_async(&mut h_result, &d_result, &stream).unwrap();
stream.synchronize().unwrap();
for i in 0..N {
Expand Down

0 comments on commit d530195

Please sign in to comment.