Skip to content

Commit

Permalink
Add work on using index pointers to avoid re-allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
skailasa committed Aug 11, 2023
1 parent c00c627 commit 236892c
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 43 deletions.
31 changes: 30 additions & 1 deletion field/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashSet, usize, sync::{Arc, RwLock}};
use std::{collections::HashSet, usize, sync::{Arc, RwLock, Mutex}, ops::{Deref, DerefMut}};

use dashmap::DashMap;
use itertools::Itertools;
Expand Down Expand Up @@ -209,6 +209,35 @@ pub fn rfft3_fftw_par_vec(
});
}


pub fn rfft3_fftw_par_vec_arc_mutex(
mut input: &mut Vec<Arc<Mutex<Vec<f64>>>>,
mut output: &mut Vec<Arc<Mutex<Vec<c64>>>>,
shape: &[usize],
) {
assert!(shape.len() == 3);

let size: usize = shape.iter().product();
let size_d = shape.last().unwrap();
let size_real = (size / size_d) * (size_d / 2 + 1);

let mut plan: R2CPlan64 = R2CPlan::aligned(shape, Flag::MEASURE).unwrap();

let n = input.len();

(0..n).into_par_iter().for_each(|i| {
let input_arc = Arc::clone(&input[i]);
let output_arc = Arc::clone(&output[i]);

let mut input_data = input_arc.lock().unwrap();
let mut input_data_slice = input_data.as_mut_slice();
let mut output_data = output_arc.lock().unwrap();
let mut output_data_slice = output_data.as_mut_slice();

plan.r2c(input_data_slice, output_data_slice);
});
}

pub fn irfft3_fftw(mut input: &mut [c64], mut output: &mut[f64], shape: &[usize]) {
let size: usize = shape.iter().product();
let mut plan: C2RPlan64 = C2RPlan::aligned(shape, Flag::MEASURE).unwrap();
Expand Down
245 changes: 204 additions & 41 deletions fmm/src/field_translation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use std::{
};

use bempp_tools::Array3D;
use num::Complex;
use num::{Complex, FromPrimitive};
use itertools::Itertools;
use rayon::prelude::*;
use fftw::types::*;
use num::Zero;

use bempp_field::{types::{SvdFieldTranslationKiFmm, FftFieldTranslationNaiveKiFmm, FftFieldTranslationKiFmm}, helpers::{pad3, rfft3, irfft3, rfft3_fftw, irfft3_fftw, rfft3_fftw_par_dm, rfft3_fftw_par_vec}};
use bempp_field::{types::{SvdFieldTranslationKiFmm, FftFieldTranslationNaiveKiFmm, FftFieldTranslationKiFmm}, helpers::{pad3, rfft3, irfft3, rfft3_fftw, irfft3_fftw, rfft3_fftw_par_dm, rfft3_fftw_par_vec, rfft3_fftw_par_vec_arc_mutex}};
use bempp_traits::{
field::{FieldTranslation, FieldTranslationData},
fmm::{Fmm, InteractionLists, SourceTranslation, TargetTranslation},
Expand All @@ -24,7 +25,7 @@ use bempp_tree::types::{morton::MortonKey, single_node::SingleNodeTree};
use rlst::{
common::traits::*,
common::tools::PrettyPrint,
dense::{rlst_col_vec, rlst_mat, rlst_pointer_mat, traits::*, Dot, Shape, rlst_rand_col_vec},
dense::{rlst_col_vec, rlst_mat, rlst_pointer_mat, traits::*, Dot, Shape, rlst_rand_col_vec, global},
};

use crate::types::{FmmData, KiFmm};
Expand Down Expand Up @@ -627,6 +628,7 @@ where
use dashmap::DashMap;


type FftMatrixc64 = rlst::dense::Matrix<c64, rlst::dense::base_matrix::BaseMatrix<c64, rlst::dense::VectorContainer<c64>, Dynamic, Dynamic>, Dynamic, Dynamic>;


impl<T> FieldTranslation for FmmData<KiFmm<SingleNodeTree, T, FftFieldTranslationKiFmm<T>>>
Expand Down Expand Up @@ -726,64 +728,225 @@ where
let q = n + 1;
let r = o + 1;
let size = p*q*r;
let size_real = p*q*(r/2+1);
let pad_size = (p-m, q-n, r-o);
let pad_index = (p-m, q-n, r-o);
let real_dim = q;

let mut padded_signals = rlst_col_vec![f64, (size*ntargets)];
let mut padded_signals = vec![Arc::new(Mutex::new(vec![0f64; size])); ntargets];

let mut chunks = padded_signals.data_mut().par_chunks_exact_mut(size);
let range = (0..chunks.len()).into_par_iter();
range.zip(chunks).for_each(|(i, chunk)| {
(0..ntargets).into_par_iter().for_each(|i| {
let fmm_arc = Arc::clone(&self.fmm);
let target = targets[i];
let target = &targets[i];
let source_multipole_arc = Arc::clone(self.multipoles.get(&target).unwrap());
let source_multipole_lock = source_multipole_arc.lock().unwrap();
let signal = fmm_arc.m2l.compute_signal(fmm_arc.order, source_multipole_lock.data());

let mut padded_signal = pad3(&signal, pad_size, pad_index);

chunk.copy_from_slice(padded_signal.get_data());
let mut padded_signal_arc = Arc::clone(&padded_signals[i]);

padded_signal_arc.lock().unwrap().deref_mut().copy_from_slice(padded_signal.get_data());
});
println!("data organisation time {:?}", start.elapsed().as_millis());

let size_real = p*q*(r/2+1);
let mut padded_signals_hat = rlst_col_vec![c64, (size_real*ntargets)];
// Each index maps to a target (sorted) from targets
let mut padded_signals_hat = vec![Arc::new(Mutex::new(vec![Complex::<f64>::zero(); size_real])); ntargets];

let start = Instant::now();
rfft3_fftw_par_vec(&mut padded_signals, &mut padded_signals_hat, &[p, q, r]);
rfft3_fftw_par_vec_arc_mutex(&mut padded_signals, &mut padded_signals_hat, &[p, q, r]);

println!("fft time {:?}", start.elapsed().as_millis());
println!("size real {:?} size {:?}", size_real, size);

let ncoeffs = self.fmm.m2l.ncoeffs(self.fmm.order);
// Compute hadamard product with kernels
let range = (0..self.fmm.m2l.transfer_vectors.len()).into_par_iter();
self.fmm.m2l.transfer_vectors.iter().take(16).par_bridge().for_each(|tv| {
// Locate correct precomputed FFT of kernel
let k_idx = self.fmm
.m2l
.transfer_vectors
.iter()
.position(|x| x.vector == tv.vector)
.unwrap();
let padded_kernel_hat = &self.fmm.m2l.m2l[k_idx];
let &(m_, n_, o_) = padded_kernel_hat.shape();
let len_padded_kernel_hat= m_*n_*o_;
let padded_kernel_hat= unsafe {
rlst_pointer_mat!['a, Complex<f64>, padded_kernel_hat.get_data().as_ptr(), (len_padded_kernel_hat, 1), (1,1)]
};

let padded_kernel_hat_arc = Arc::new(padded_kernel_hat);

padded_signals_hat.data().chunks_exact(len_padded_kernel_hat).enumerate().for_each(|(i, padded_signal_hat)| {
let padded_signal_hat = unsafe {
rlst_pointer_mat!['a, Complex<f64>, padded_signal_hat.as_ptr(), (len_padded_kernel_hat, 1), (1,1)]
};

let padded_kernel_hat_ref = Arc::clone(&padded_kernel_hat_arc);
let start = Instant::now();

// Map between keys and index locations in targets at this level
let mut target_index_map = Arc::new(RwLock::new(HashMap::new()));

for (i, target) in targets.iter().enumerate() {

let mut map = target_index_map.write().unwrap();
map.insert(*target, i);
}

// Each index corresponds to a target, and contains a vector of pointers to the padded signals in the targets interactions list
let mut source_index_pointer: Vec<Arc<Mutex<Vec<Arc<Mutex<Vec<Complex<f64>>>>>>>> =
(0..ntargets).map(|_| Arc::new(Mutex::new(Vec::<Arc<Mutex<Vec<Complex<f64>>>>>::new()))).collect();

targets
.into_par_iter()
.zip(source_index_pointer.par_iter_mut())
.enumerate()
.for_each(|(i, (target, arc_mutex_vec))| {

let fmm_arc = Arc::clone(&self.fmm);
let v_list = target
.parent()
.neighbors()
.iter()
.flat_map(|pn| pn.children())
.filter(|pnc| !target.is_adjacent_same_level(pnc))
.collect_vec();

// Lookup indices for each element of v_list and add the pointers to the underlying data to the index pointer
let mut indices = Vec::new();
let target_index_map_arc = Arc::clone(&target_index_map);
let map = target_index_map.read().unwrap();
for source in v_list.iter() {
let idx = map.get(source).unwrap();
indices.push(*idx);
}

let mut outer_vec: MutexGuard<'_, Vec<Arc<Mutex<Vec<Complex<f64>>>>>> = arc_mutex_vec.lock().unwrap();
for &idx in indices.iter() {
let tmp: Arc<Mutex<Vec<Complex<f64>>>> = Arc::clone(&padded_signals_hat[idx]);
outer_vec.push(tmp);
}
});

println!("index pointer time {:?}", start.elapsed().as_millis());


// Compute Hadamard product with elements of V List, now stored in source_index_pointer

let start = Instant::now();
// let mut global_check_potentials_hat = vec![Arc::new(Mutex::new(vec![Complex::<f64>::zero(); size_real])); ntargets];

let mut global_check_potentials_hat = (0..ntargets)
.map(|_| Arc::new(Mutex::new(vec![Complex::<f32>::zero(); size_real]))).collect_vec();
// let mut global_check_potentials_hat = (0..ntargets)
// .map(|_| Arc::new(Mutex::new(vec![0f64; size_real]))).collect_vec();

global_check_potentials_hat
.par_iter_mut()
.zip(
source_index_pointer
.into_par_iter()
)
.zip(
targets.into_par_iter()
).for_each(|((check_potential_hat, sources), target)| {

// Find the corresponding Kernel matrices for each signal
let fmm_arc = Arc::clone(&self.fmm);
let v_list = target
.parent()
.neighbors()
.iter()
.flat_map(|pn| pn.children())
.filter(|pnc| !target.is_adjacent_same_level(pnc))
.collect_vec();


let k_idxs = v_list
.iter()
.map(|source| target.find_transfer_vector(source))
.map(|tv| {
fmm_arc
.m2l
.transfer_vectors
.iter()
.position(|x| x.vector == tv)
.unwrap()
}).collect_vec();


// Compute convolutions
let check_potential_hat_arc = Arc::clone(check_potential_hat);
let mut check_potential_hat_data = check_potential_hat_arc.lock().unwrap();

let tmp = sources.lock().unwrap();
let mut result = vec![Complex::<f64>::zero(); size_real];

let check_potential = padded_signal_hat.cmp_wise_product(padded_kernel_hat_ref.deref()).eval();
// for i in 0..result.len() {
// for _ in 0..189 {
// result[i] += Complex::<f64>::from(1.0);
// }
// }

for i in 0..1 {

let psh = tmp[i].lock().unwrap();
let pkh = &fmm_arc.m2l.m2l[k_idxs[i]].get_data();

let hadamard: Vec<c64> = psh.iter().zip(pkh.iter()).map(|(s, k)| {*s * *k}).collect_vec();
for j in 0..result.len() {
result[j] += Complex::<f64>::from(1.0);
}
}


// for ((i, source), &k_idx) in tmp.iter().enumerate().zip(k_idxs.iter()) {

// // let psh = source.lock().unwrap();
// // let pkh = &fmm_arc.m2l.m2l[k_idx];

// // let psh = unsafe {
// // rlst_pointer_mat!['a, c64, psh.as_ptr(), (size_real, 1), (1,1)]
// // };

// // let pkh = unsafe {
// // rlst_pointer_mat!['a, c64, pkh.get_data().as_ptr(), (size_real, 1), (1,1)]
// // };

// // let hadamard = psh.cmp_wise_product(&pkh).eval();
// // result.iter_mut().zip(hadamard.data().iter()).for_each(|(r, h)| *r += h);

// let psh = source.lock().unwrap();
// let pkh = &fmm_arc.m2l.m2l[k_idx].get_data();

// let hadamard: Vec<c64> = psh.iter().zip(pkh.iter()).map(|(s, k)| {*s * *k}).collect_vec();

// for j in 0..result.len() {
// result[j] += Complex::<f64>::from(1.0);
// }

// // result.iter_mut().zip(hadamard.iter()).for_each(|(r, h)| *r += Complex::<f32>::zero())
// // result.iter_mut().for_each(|(r)| *r += Complex::<f32>::zero())
// // check_potential_hat_data.deref_mut().iter_mut()
// // .zip(hadamard.iter())
// // .for_each(|(r, h)| *r += h);
// // check_potential_hat_data.deref_mut().iter_mut()
// // // .zip(hadamard.iter())
// // .for_each(|(r)| *r += Complex::<f64>::from(1.0));

// }

// check_potential_hat_data.deref_mut().iter_mut().for_each(|x| *x += Complex::zero());

});
});

println!("Hadamard time {:?}", start.elapsed().as_millis());
// let ncoeffs = self.fmm.m2l.ncoeffs(self.fmm.order);
// // Compute hadamard product with kernels
// let range = (0..self.fmm.m2l.transfer_vectors.len()).into_par_iter();
// self.fmm.m2l.transfer_vectors.iter().take(16).par_bridge().for_each(|tv| {
// // Locate correct precomputed FFT of kernel
// let k_idx = self.fmm
// .m2l
// .transfer_vectors
// .iter()
// .position(|x| x.vector == tv.vector)
// .unwrap();
// let padded_kernel_hat = &self.fmm.m2l.m2l[k_idx];
// let &(m_, n_, o_) = padded_kernel_hat.shape();
// let len_padded_kernel_hat= m_*n_*o_;
// let padded_kernel_hat= unsafe {
// rlst_pointer_mat!['a, Complex<f64>, padded_kernel_hat.get_data().as_ptr(), (len_padded_kernel_hat, 1), (1,1)]
// };

// let padded_kernel_hat_arc = Arc::new(padded_kernel_hat);

// padded_signals_hat.data().chunks_exact(len_padded_kernel_hat).enumerate().for_each(|(i, padded_signal_hat)| {
// let padded_signal_hat = unsafe {
// rlst_pointer_mat!['a, Complex<f64>, padded_signal_hat.as_ptr(), (len_padded_kernel_hat, 1), (1,1)]
// };

// let padded_kernel_hat_ref = Arc::clone(&padded_kernel_hat_arc);

// let check_potential = padded_signal_hat.cmp_wise_product(padded_kernel_hat_ref.deref()).eval();
// });
// });


//////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion fmm/src/fmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ mod test {
let global_idxs = (0..npoints).collect_vec();
let charges = vec![1.0; npoints];

let order = 10;
let order = 9;
let alpha_inner = 1.05;
let alpha_outer = 2.9;
let adaptive = false;
Expand Down
26 changes: 26 additions & 0 deletions tree/src/implementations/impl_morton.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,32 @@ impl MortonKey {
.collect()
}

pub fn is_adjacent_same_level(&self, other: &MortonKey) -> bool {
// Calculate distance between centres of each node
let da = 1 << (DEEPEST_LEVEL - self.level());
let db = 1 << (DEEPEST_LEVEL - other.level());
let ra = (da as f64) * 0.5;
let rb = (db as f64) * 0.5;

let ca: Vec<f64> = self.anchor.iter().map(|&x| (x as f64) + ra).collect();
let cb: Vec<f64> = other.anchor.iter().map(|&x| (x as f64) + rb).collect();

let distance: Vec<f64> = ca.iter().zip(cb.iter()).map(|(a, b)| b - a).collect();

let min = -ra - rb;
let max = ra + rb;
let mut result = true;

for &d in distance.iter() {
if d > max || d < min {
result = false
}
}

result

}

/// Check if two keys are adjacent with respect to each other
pub fn is_adjacent(&self, other: &MortonKey) -> bool {
let ancestors = self.ancestors();
Expand Down

0 comments on commit 236892c

Please sign in to comment.