Skip to content

Commit

Permalink
Start working on kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
skailasa committed Dec 13, 2023
1 parent d2a041c commit 49c97ba
Showing 1 changed file with 103 additions and 68 deletions.
171 changes: 103 additions & 68 deletions fmm/src/field_translation/linear/source_to_target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
use bempp_tools::Array3D;

use itertools::Itertools;
use num::{Complex, Float};
use num::{Complex, Float, Zero};
use rayon::prelude::*;
use std::collections::{HashMap, HashSet};
use std::{collections::{HashMap, HashSet}, time::Instant};

use bempp_field::{
array::pad3,
Expand Down Expand Up @@ -301,10 +301,11 @@ where
let Some(targets) = self.fmm.tree().get_keys(level) else {
return;
};
// let s = Instant::now();
let s = Instant::now();
// Form signals to use for convolution first
let n = 2 * self.fmm.order - 1;
let ntargets = targets.len();
let nparents = nparents(level as usize);
let ncoeffs = self.fmm.m2l.ncoeffs(self.fmm.order);

// Pad the signal
Expand All @@ -315,11 +316,9 @@ where
let r = o + 1;
let size = p * q * r;
let size_real = p * q * (r / 2 + 1);
let pad_size = (p - m, q - n, r - o);
let pad_index = (p - m, q - n, r - o);
let mut padded_signals = vec![U::default(); size * ntargets];
let mut signals = vec![U::default(); size * ntargets];

let padded_signals_chunks = padded_signals.par_chunks_exact_mut(size);
let signals_chunks = signals.par_chunks_exact_mut(size);

let ntargets = targets.len();
let min = &targets[0];
Expand All @@ -330,81 +329,117 @@ where
let multipoles = &self.multipoles[min_idx * ncoeffs..(max_idx + 1) * ncoeffs];
let multipoles_chunks = multipoles.par_chunks_exact(ncoeffs);

padded_signals_chunks
signals_chunks
.zip(multipoles_chunks)
.for_each(|(padded_signal, multipole)| {
let signal = self.fmm.m2l.compute_signal(self.fmm.order, multipole);

let tmp = pad3(&signal, pad_size, pad_index);

padded_signal.copy_from_slice(tmp.get_data());
.for_each(|(signal, multipole)| {
for (i, &j) in self.fmm.m2l.surf_to_conv_map.iter().enumerate() {
signal[j] = multipole[i]
}
});

let mut padded_signals_hat = vec![Complex::<U>::default(); size_real * ntargets];
let mut signals_hat = vec![Complex::<U>::default(); size_real * ntargets];

U::rfft3_fftw_par_vec(&mut padded_signals, &mut padded_signals_hat, &[p, q, r]);
U::rfft3_fftw_par_vec(&mut signals, &mut signals_hat, &[p, q, r]);

let ntargets = targets.len();
let mut global_check_potentials_hat = vec![Complex::<U>::default(); size_real * ntargets];
let mut global_check_potentials = vec![U::default(); size * ntargets];

// Get check potentials in frequency order
let global_check_potentials_hat_freq = unsafe {
check_potentials_freq_cplx(
self.fmm.order,
level as usize,
&mut global_check_potentials_hat,
)
};
// Get signal FFTs into frequency order
signals_hat.par_chunks_exact_mut(8*size_real).for_each(|sibling_chunk| {

let padded_signals_hat_freq = signal_freq_order_cplx_optimized(
self.fmm.order,
level as usize,
&padded_signals_hat[..],
);
for i in 0..size_real {
for j in 0..8 {
sibling_chunk[8 * i + j] = sibling_chunk[size_real * j + i]
}
}
});

// Allocate check potentials (in frequency order at this point implicitly)
let mut check_potentials_hat = vec![Complex::<U>::default(); size_real * ntargets];

println!("pre processing time {:?}", s.elapsed());

let s = Instant::now();
let kernel_data_halo = &self.fmm.m2l.operator_data.kernel_data_rearranged;
(0..size_real)
.into_par_iter()
.zip(signals_hat.par_chunks_exact(ntargets))
.zip(check_potentials_hat.par_chunks_exact_mut(ntargets))
.for_each(|((freq, signal_freq), check_potentials_freq)| {

// Create a map between targets and index positions in vec of len 'ntargets'
let mut target_map = HashMap::new();
(0..nparents).for_each(|sibling_index| {
let save_locations = &mut check_potentials_freq[(sibling_index*8)..(sibling_index + 1)*8];

for (i, t) in targets.iter().enumerate() {
target_map.insert(t, i);
}
for (i, kernel_data) in kernel_data_halo.iter().enumerate() {
let frequency_offset = 64 * freq;
let kernel_data_freq = &kernel_data[frequency_offset..(frequency_offset + 64)];


let chunksize;
if level == 2 {
chunksize = 8;
} else if level == 3 {
chunksize = 64
} else {
chunksize = 128
}
}
})

let all_displacements = displacements(&self.fmm.tree, level, &target_map);
});

let (chunked_displacements, chunked_save_locations) =
chunked_displacements(level as usize, chunksize, &all_displacements);
println!("kernel time {:?}", s.elapsed());

let scale = Complex::from(self.m2l_scale(level));
// let ntargets = targets.len();
// let mut global_check_potentials_hat = vec![Complex::<U>::default(); size_real * ntargets];
// let mut global_check_potentials = vec![U::default(); size * ntargets];

let kernel_data_halo = &self.fmm.m2l.operator_data.kernel_data_rearranged;
// // Get check potentials in frequency order
// let global_check_potentials_hat_freq = unsafe {
// check_potentials_freq_cplx(
// self.fmm.order,
// level as usize,
// &mut global_check_potentials_hat,
// )
// };

m2l_cplx_chunked(
self.fmm.order,
level as usize,
&padded_signals_hat_freq,
&global_check_potentials_hat_freq,
kernel_data_halo,
&chunked_displacements,
&chunked_save_locations,
chunksize,
scale,
);

U::irfft_fftw_par_vec(
&mut global_check_potentials_hat,
&mut global_check_potentials,
&[p, q, r],
);
// let padded_signals_hat_freq = signal_freq_order_cplx_optimized(
// self.fmm.order,
// level as usize,
// &padded_signals_hat[..],
// );

// // Create a map between targets and index positions in vec of len 'ntargets'
// let mut target_map = HashMap::new();

// for (i, t) in targets.iter().enumerate() {
// target_map.insert(t, i);
// }

// let chunksize;
// if level == 2 {
// chunksize = 8;
// } else if level == 3 {
// chunksize = 64
// } else {
// chunksize = 128
// }

// let all_displacements = displacements(&self.fmm.tree, level, &target_map);

// let (chunked_displacements, chunked_save_locations) =
// chunked_displacements(level as usize, chunksize, &all_displacements);

// let scale = Complex::from(self.m2l_scale(level));

// let kernel_data_halo = &self.fmm.m2l.operator_data.kernel_data_rearranged;

// m2l_cplx_chunked(
// self.fmm.order,
// level as usize,
// &padded_signals_hat_freq,
// &global_check_potentials_hat_freq,
// kernel_data_halo,
// &chunked_displacements,
// &chunked_save_locations,
// chunksize,
// scale,
// );

// U::irfft_fftw_par_vec(
// &mut global_check_potentials_hat,
// &mut global_check_potentials,
// &[p, q, r],
// );

// // Compute local expansion coefficients and save to data tree
// let (_, multi_indices) = MortonKey::surface_grid::<U>(self.fmm.order);
Expand Down

0 comments on commit 49c97ba

Please sign in to comment.