Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

reduce temporary memory requirement for generate_permutation_matrix #24

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions src/ops_complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use cudart::result::CudaResult;
use cudart::slice::{DeviceSlice, DeviceVariable};
use cudart::stream::CudaStream;
use cudart::{cuda_kernel, cuda_kernel_declaration, cuda_kernel_signature_arguments_and_function};
use itertools::max;
use std::mem;

type BF = BaseField;
type EF = VectorizedExtensionField;
Expand Down Expand Up @@ -551,14 +551,17 @@ fn generate_permutation_matrix_raw(
}

pub fn get_generate_permutation_matrix_temp_storage_bytes(num_cells: usize) -> CudaResult<usize> {
let end_bit: i32 = (usize::BITS - num_cells.leading_zeros()) as i32;
let cell_size = mem::size_of::<u32>();
let num_bytes = num_cells * cell_size;
let sort_pairs_tsb =
get_sort_pairs_temp_storage_bytes::<u32, u32>(false, num_cells as u32, 0, end_bit)?;
get_sort_pairs_temp_storage_bytes::<u32, u32>(false, num_cells as u32, 0, 32)?;
assert!(sort_pairs_tsb <= 4 * num_bytes);
let encode_tsb = get_encode_temp_storage_bytes::<u32>(num_cells as i32)?;
assert!(encode_tsb <= 2 * num_bytes - cell_size);
let scan_tsb =
get_scan_temp_storage_bytes::<u32>(ScanOperation::Sum, false, false, num_cells as i32)?;
let cub_tsb = max([sort_pairs_tsb, encode_tsb, scan_tsb]).unwrap();
Ok((7 * num_cells + 1) * 4 + cub_tsb)
assert!(scan_tsb <= 2 * num_bytes);
Ok(5 * num_bytes)
robik75 marked this conversation as resolved.
Show resolved Hide resolved
}

pub fn generate_permutation_matrix(
Expand All @@ -569,22 +572,12 @@ pub fn generate_permutation_matrix(
stream: &CudaStream,
) -> CudaResult<()> {
let num_cells = variable_indexes.len();
let (sorted_variable_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (unsorted_cell_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (sorted_cell_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (unique_variable_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (run_lengths, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (run_offsets, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let (num_runs_out, temp_storage) = temp_storage.split_at_mut(4);
let (run_indexes, temp_storage) = temp_storage.split_at_mut(num_cells * 4);
let sorted_variable_indexes = unsafe { sorted_variable_indexes.transmute_mut() };
let unsorted_cell_indexes = unsafe { unsorted_cell_indexes.transmute_mut() };
let sorted_cell_indexes = unsafe { sorted_cell_indexes.transmute_mut() };
let unique_variable_indexes = unsafe { unique_variable_indexes.transmute_mut() };
let run_lengths = unsafe { run_lengths.transmute_mut() };
let run_offsets = unsafe { run_offsets.transmute_mut() };
let num_runs_out = unsafe { &mut num_runs_out.transmute_mut()[0] };
let run_indexes = unsafe { run_indexes.transmute_mut() };
let cell_size = mem::size_of::<u32>();
let num_bytes = num_cells * cell_size;
assert_eq!(result.len(), num_cells);
assert_eq!(mem::size_of::<BF>(), 2 * cell_size);
let result_as_u32 = unsafe { result.transmute_mut() };
let (sorted_variable_indexes, unsorted_cell_indexes) = result_as_u32.split_at_mut(num_cells);
set_by_val(1u32, unsorted_cell_indexes, stream)?;
scan_in_place(
ScanOperation::Sum,
Expand All @@ -594,6 +587,8 @@ pub fn generate_permutation_matrix(
unsorted_cell_indexes,
stream,
)?;
let (sorted_cell_indexes, temp_storage) = temp_storage.split_at_mut(num_bytes);
let sorted_cell_indexes = unsafe { sorted_cell_indexes.transmute_mut() };
sort_pairs(
false,
temp_storage,
Expand All @@ -605,15 +600,25 @@ pub fn generate_permutation_matrix(
32,
stream,
)?;
let (unique_variable_indexes, temp_storage) = temp_storage.split_at_mut(num_bytes);
let unique_variable_indexes = unsafe { unique_variable_indexes.transmute_mut() };
let (run_lengths, temp_storage) = temp_storage.split_at_mut(num_bytes);
let run_lengths = unsafe { run_lengths.transmute_mut() };
let (num_runs_out, encode_temp_storage) = temp_storage.split_at_mut(cell_size);
let num_runs_out = unsafe { &mut num_runs_out.transmute_mut()[0] };
set_to_zero(run_lengths, stream)?;
encode(
robik75 marked this conversation as resolved.
Show resolved Hide resolved
temp_storage,
encode_temp_storage,
sorted_variable_indexes,
unique_variable_indexes,
run_lengths,
num_runs_out,
stream,
)?;
let (run_offsets, run_indexes) = temp_storage.split_at_mut(num_cells * 4);
robik75 marked this conversation as resolved.
Show resolved Hide resolved
let run_offsets = unsafe { run_offsets.transmute_mut() };
let run_indexes = unsafe { run_indexes.transmute_mut() };
robik75 marked this conversation as resolved.
Show resolved Hide resolved
let temp_storage = unsafe { result.transmute_mut() };
memory_copy_async(run_offsets, run_lengths, stream)?;
scan_in_place(
ScanOperation::Sum,
Expand Down