Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

reduce temporary memory requirement for generate_permutation_matrix #24

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions native/ops_complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,18 @@ EXTERN __global__ void pack_variable_indexes_kernel(const uint64_t *src, uint32_
memory::store_cs(dst + gid, u32);
}

EXTERN __global__ void mark_ends_of_runs_kernel(const unsigned *run_lengths, const unsigned *run_offsets, unsigned *result, const unsigned count) {
EXTERN __global__ void mark_ends_of_runs_kernel(const unsigned *num_runs_out, const unsigned *run_offsets, unsigned *result, const unsigned count) {
const unsigned gid = blockIdx.x * blockDim.x + threadIdx.x;
if (gid >= count)
return;
const unsigned run_length = run_lengths[gid];
if (run_length == 0)
const unsigned num_runs = *num_runs_out;
if (gid >= num_runs)
return;
const unsigned run_offset = run_offsets[gid];
const unsigned run_offset_next = gid == num_runs - 1 ? count : run_offsets[gid + 1];
const unsigned run_length = run_offset_next - run_offset;
result[run_offset + run_length - 1] = 1;
}

EXTERN __global__ void generate_permutation_matrix_kernel(const unsigned *unique_variable_indexes, const unsigned *run_indexes, const unsigned *run_lengths,
EXTERN __global__ void generate_permutation_matrix_kernel(const unsigned *unique_variable_indexes, const unsigned *run_indexes, const unsigned *num_runs_out,
const unsigned *run_offsets, const unsigned *cell_indexes, const base_field *scalars,
base_field *result, const unsigned columns_count, const unsigned log_rows_count) {
const unsigned gid = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -253,8 +253,10 @@ EXTERN __global__ void generate_permutation_matrix_kernel(const unsigned *unique
const unsigned last_run_index = run_indexes[count - 1];
const unsigned last_run_variable_index = unique_variable_indexes[last_run_index];
const unsigned run_index = run_indexes[gid];
const unsigned run_length = run_lengths[run_index];
const unsigned run_offset = run_offsets[run_index];
const unsigned num_runs = *num_runs_out;
const unsigned run_offset_next = run_index == num_runs - 1 ? count : run_offsets[run_index + 1];
const unsigned run_length = run_offset_next - run_offset;
const unsigned src_in_run_index = gid - run_offset;
const bool is_placeholder = run_index == last_run_index && last_run_variable_index == (1U << 31);
const unsigned dst_in_run_index = is_placeholder ? src_in_run_index : (src_in_run_index + 1) % run_length;
Expand Down
87 changes: 42 additions & 45 deletions src/ops_complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ use crate::ops_simple::{set_by_val, set_to_zero};
use crate::utils::{get_grid_block_dims_for_threads_count, WARP_SIZE};
use crate::BaseField;
use cudart::execution::{CudaLaunchConfig, Dim3, KernelFunction};
use cudart::memory::memory_copy_async;
use cudart::paste::paste;
use cudart::result::CudaResult;
use cudart::slice::{DeviceSlice, DeviceVariable};
use cudart::stream::CudaStream;
use cudart::{cuda_kernel, cuda_kernel_declaration, cuda_kernel_signature_arguments_and_function};
use itertools::max;
use std::mem;

type BF = BaseField;
type EF = VectorizedExtensionField;
Expand Down Expand Up @@ -462,29 +461,28 @@ pub fn select(
cuda_kernel!(
MarkEndsOfRuns,
mark_ends_of_runs_kernel(
run_lengths: *const u32,
num_runs_out: *const u32,
run_offsets: *const u32,
result: *mut u32,
count: u32,
)
);

pub fn mark_ends_of_runs(
run_lengths: &DeviceSlice<u32>,
num_runs_out: &DeviceVariable<u32>,
run_offsets: &DeviceSlice<u32>,
result: &mut DeviceSlice<u32>,
stream: &CudaStream,
) -> CudaResult<()> {
assert_eq!(run_lengths.len(), run_offsets.len());
assert!(run_lengths.len() <= u32::MAX as usize);
assert!(result.len() <= u32::MAX as usize);
let count = run_lengths.len() as u32;
assert!(run_offsets.len() <= u32::MAX as usize);
assert_eq!(run_offsets.len(), result.len());
let count = run_offsets.len() as u32;
let (grid_dim, block_dim) = get_launch_dims(count);
let run_lengths = run_lengths.as_ptr();
let num_runs_out = num_runs_out.as_ptr();
let run_offsets = run_offsets.as_ptr();
let result = result.as_mut_ptr();
let config = CudaLaunchConfig::basic(grid_dim, block_dim, stream);
let args = MarkEndsOfRunsArguments::new(run_lengths, run_offsets, result, count);
let args = MarkEndsOfRunsArguments::new(num_runs_out, run_offsets, result, count);
MarkEndsOfRunsFunction::default().launch(&config, &args)
}

Expand All @@ -493,7 +491,7 @@ cuda_kernel!(
generate_permutation_matrix_kernel(
unique_variable_indexes: *const u32,
run_indexes: *const u32,
run_lengths: *const u32,
num_runs_out: *const u32,
run_offsets: *const u32,
cell_indexes: *const u32,
scalars: *const BF,
Expand All @@ -507,17 +505,16 @@ cuda_kernel!(
fn generate_permutation_matrix_raw(
unique_variable_indexes: &DeviceSlice<u32>,
run_indexes: &DeviceSlice<u32>,
run_lengths: &DeviceSlice<u32>,
num_runs_out: &DeviceVariable<u32>,
run_offsets: &DeviceSlice<u32>,
cell_indexes: &DeviceSlice<u32>,
scalars: &DeviceSlice<BF>,
result: &mut DeviceSlice<BF>,
stream: &CudaStream,
) -> CudaResult<()> {
assert!(run_indexes.len() <= u32::MAX as usize);
assert_eq!(run_lengths.len(), run_offsets.len());
assert_eq!(run_lengths.len(), unique_variable_indexes.len());
assert!(run_lengths.len() <= u32::MAX as usize);
assert_eq!(run_indexes.len(), run_offsets.len());
assert_eq!(run_indexes.len(), unique_variable_indexes.len());
assert!(cell_indexes.len() <= u32::MAX as usize);
assert!(scalars.len() <= u32::MAX as usize);
let columns_count = scalars.len() as u32;
Expand All @@ -530,7 +527,7 @@ fn generate_permutation_matrix_raw(
let log_rows_count = rows_count.ilog2();
let unique_variable_indexes = unique_variable_indexes.as_ptr();
let run_indexes = run_indexes.as_ptr();
let run_lengths = run_lengths.as_ptr();
let num_runs_out = num_runs_out.as_ptr();
let run_offsets = run_offsets.as_ptr();
let cell_indexes = cell_indexes.as_ptr();
let scalars = scalars.as_ptr();
Expand All @@ -539,7 +536,7 @@ fn generate_permutation_matrix_raw(
let args = GeneratePermutationMatrixArguments::new(
unique_variable_indexes,
run_indexes,
run_lengths,
num_runs_out,
run_offsets,
cell_indexes,
scalars,
Expand All @@ -551,14 +548,17 @@ fn generate_permutation_matrix_raw(
}

pub fn get_generate_permutation_matrix_temp_storage_bytes(num_cells: usize) -> CudaResult<usize> {
let end_bit: i32 = (usize::BITS - num_cells.leading_zeros()) as i32;
let cell_size = mem::size_of::<u32>();
let num_bytes = num_cells * cell_size;
let sort_pairs_tsb =
get_sort_pairs_temp_storage_bytes::<u32, u32>(false, num_cells as u32, 0, end_bit)?;
get_sort_pairs_temp_storage_bytes::<u32, u32>(false, num_cells as u32, 0, 32)?;
assert!(sort_pairs_tsb <= 3 * num_bytes + cell_size);
let encode_tsb = get_encode_temp_storage_bytes::<u32>(num_cells as i32)?;
assert!(encode_tsb <= num_bytes);
let scan_tsb =
get_scan_temp_storage_bytes::<u32>(ScanOperation::Sum, false, false, num_cells as i32)?;
let cub_tsb = max([sort_pairs_tsb, encode_tsb, scan_tsb]).unwrap();
Ok((7 * num_cells + 1) * 4 + cub_tsb)
assert!(scan_tsb <= 2 * num_bytes);
Ok(4 * num_bytes + cell_size)
}

pub fn generate_permutation_matrix(
Expand All @@ -569,22 +569,12 @@ pub fn generate_permutation_matrix(
stream: &CudaStream,
) -> CudaResult<()> {
let num_cells = variable_indexes.len();
let (sorted_variable_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (unsorted_cell_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (sorted_cell_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (unique_variable_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (run_lengths, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (run_offsets, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (num_runs_out, temp_storage) = temp_storage.split_at_mut(4);
let (run_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let sorted_variable_indexes = unsafe { sorted_variable_indexes.transmute_mut() };
let unsorted_cell_indexes = unsafe { unsorted_cell_indexes.transmute_mut() };
let sorted_cell_indexes = unsafe { sorted_cell_indexes.transmute_mut() };
let unique_variable_indexes = unsafe { unique_variable_indexes.transmute_mut() };
let run_lengths = unsafe { run_lengths.transmute_mut() };
let run_offsets = unsafe { run_offsets.transmute_mut() };
let num_runs_out = unsafe { &mut num_runs_out.transmute_mut()[0] };
let run_indexes = unsafe { run_indexes.transmute_mut() };
let cell_size = mem::size_of::<u32>();
let num_bytes = num_cells * cell_size;
assert_eq!(result.len(), num_cells);
assert_eq!(mem::size_of::<BF>(), 2 * cell_size);
let result_as_u32 = unsafe { result.transmute_mut() };
let (sorted_variable_indexes, unsorted_cell_indexes) = result_as_u32.split_at_mut(num_cells);
set_by_val(1u32, unsorted_cell_indexes, stream)?;
scan_in_place(
ScanOperation::Sum,
Expand All @@ -594,6 +584,8 @@ pub fn generate_permutation_matrix(
unsorted_cell_indexes,
stream,
)?;
let (sorted_cell_indexes, temp_storage) = temp_storage.split_at_mut(num_bytes);
let sorted_cell_indexes = unsafe { sorted_cell_indexes.transmute_mut() };
sort_pairs(
false,
temp_storage,
Expand All @@ -605,7 +597,12 @@ pub fn generate_permutation_matrix(
32,
stream,
)?;
set_to_zero(run_lengths, stream)?;
let (unique_variable_indexes, temp_storage) = temp_storage.split_at_mut(num_bytes);
let unique_variable_indexes = unsafe { unique_variable_indexes.transmute_mut() };
let (run_lengths, temp_storage) = temp_storage.split_at_mut(num_bytes);
let run_lengths = unsafe { run_lengths.transmute_mut() };
let (temp_storage, num_runs_out) = temp_storage.split_at_mut(num_bytes);
let num_runs_out = unsafe { &mut num_runs_out.transmute_mut()[0] };
encode(
robik75 marked this conversation as resolved.
Show resolved Hide resolved
temp_storage,
sorted_variable_indexes,
Expand All @@ -614,7 +611,9 @@ pub fn generate_permutation_matrix(
num_runs_out,
stream,
)?;
memory_copy_async(run_offsets, run_lengths, stream)?;
let run_offsets = run_lengths;
let run_indexes = unsafe { temp_storage.transmute_mut() };
let temp_storage = unsafe { result.transmute_mut() };
scan_in_place(
ScanOperation::Sum,
false,
Expand All @@ -624,7 +623,7 @@ pub fn generate_permutation_matrix(
stream,
)?;
set_to_zero(run_indexes, stream)?;
mark_ends_of_runs(run_lengths, run_offsets, run_indexes, stream)?;
mark_ends_of_runs(num_runs_out, run_offsets, run_indexes, stream)?;
scan_in_place(
ScanOperation::Sum,
false,
Expand All @@ -636,7 +635,7 @@ pub fn generate_permutation_matrix(
generate_permutation_matrix_raw(
unique_variable_indexes,
run_indexes,
run_lengths,
num_runs_out,
run_offsets,
sorted_cell_indexes,
scalars,
Expand Down Expand Up @@ -1845,11 +1844,9 @@ mod tests {
let mut d_unique_out = DeviceAllocation::alloc(N).unwrap();
let mut d_counts_out = DeviceAllocation::alloc(N).unwrap();
let mut d_num_runs_out = DeviceAllocation::alloc(1).unwrap();
let mut d_offsets = DeviceAllocation::alloc(N).unwrap();
let mut d_result = DeviceAllocation::alloc(N).unwrap();
let stream = CudaStream::default();
memory_copy_async(&mut d_in, &h_in, &stream).unwrap();
set_to_zero(&mut d_counts_out, &stream).unwrap();
encode(
&mut d_temp_storage,
&d_in,
Expand All @@ -1859,7 +1856,7 @@ mod tests {
&stream,
)
.unwrap();
memory_copy_async(&mut d_offsets, &d_counts_out, &stream).unwrap();
let mut d_offsets = d_counts_out;
scan_in_place(
ScanOperation::Sum,
false,
Expand All @@ -1870,7 +1867,7 @@ mod tests {
)
.unwrap();
set_to_zero(&mut d_result, &stream).unwrap();
super::mark_ends_of_runs(&d_counts_out, &d_offsets, &mut d_result, &stream).unwrap();
super::mark_ends_of_runs(&d_num_runs_out[0], &d_offsets, &mut d_result, &stream).unwrap();
memory_copy_async(&mut h_result, &d_result, &stream).unwrap();
stream.synchronize().unwrap();
for i in 0..N {
Expand Down
Loading