Skip to content

Commit

Permalink
Add converging adaptive fmm with svd
Browse files Browse the repository at this point in the history
  • Loading branch information
skailasa committed Jan 3, 2024
1 parent a52f017 commit 4423ae2
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 34 deletions.
141 changes: 121 additions & 20 deletions fmm/src/field_translation/source_to_target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,16 @@ where
.eval();

if nsources > 0 {
if target.level() < self.fmm.tree.get_depth() && self.fmm.tree.leaves_set.contains(target) {
let leaf_idx = self.fmm.tree().get_leaf_index(target).unwrap();

let (l, r) = self.charge_index_pointer[*leaf_idx];

if r - l > 0 {
println!("RUNNING P2L {:?}", self.fmm.tree.key_to_index.get(target));
}
}

let mut check_potential = rlst_col_vec![U, ncoeffs];
self.fmm.kernel.evaluate_st(
EvalType::Value,
Expand Down Expand Up @@ -605,25 +615,25 @@ where
let size_real = npad * npad * (npad / 2 + 1);
let all_displacements = self.displacements(level);

// let ntargets = targets.len();
// let min = &targets[0];
// let max = &targets[ntargets - 1];
let ntargets = targets.len();
let min = &targets[0];
let max = &targets[ntargets - 1];

// let min_idx = self.fmm.tree().key_to_index.get(min).unwrap();
// let max_idx = self.fmm.tree().key_to_index.get(max).unwrap();
let min_idx = self.fmm.tree().key_to_index.get(min).unwrap();
let max_idx = self.fmm.tree().key_to_index.get(max).unwrap();


// let multipoles = &self.multipoles[min_idx * ncoeffs..(max_idx + 1) * ncoeffs];
let mut targets = targets.iter().cloned().collect_vec();
let ntargets = targets.len();
targets.sort();
let multipoles = &self.multipoles[min_idx * ncoeffs..(max_idx + 1) * ncoeffs];
// let mut targets = targets.iter().cloned().collect_vec();
// let ntargets = targets.len();
// targets.sort();

let mut multipoles = Vec::new();
for target in targets.iter() {
let target_index_pointer = *self.level_index_pointer[level as usize].get(target).unwrap();
let multipole = self.level_multipoles[level as usize][target_index_pointer];
multipoles.push(multipole);
}
// let mut multipoles = Vec::new();
// for target in targets.iter() {
// let target_index_pointer = *self.level_index_pointer[level as usize].get(target).unwrap();
// let multipole = self.level_multipoles[level as usize][target_index_pointer];
// multipoles.push(multipole);
// }

////////////////////////////////////////////////////////////////////////////////////
// Pre-process to setup data structures for M2L kernel
Expand Down Expand Up @@ -655,16 +665,16 @@ where

// Pre-processing to find FFT
multipoles
// .par_chunks_exact(ncoeffs * nsiblings * chunk_size)
.par_chunks_exact(nsiblings * chunk_size)
.par_chunks_exact(ncoeffs * nsiblings * chunk_size)
// .par_chunks_exact(nsiblings * chunk_size)
.enumerate()
.for_each(|(i, multipole_chunk)| {
// Place Signal on convolution grid
let mut signal_chunk = vec![U::zero(); size * nsiblings * chunk_size];

for i in 0..nsiblings * chunk_size {
// let multipole = &multipole_chunk[i * ncoeffs..(i + 1) * ncoeffs];
let multipole = unsafe { std::slice::from_raw_parts(multipole_chunk[i].raw, ncoeffs) };
let multipole = &multipole_chunk[i * ncoeffs..(i + 1) * ncoeffs];
// let multipole = unsafe { std::slice::from_raw_parts(multipole_chunk[i].raw, ncoeffs) };
let signal = &mut signal_chunk[i * size..(i + 1) * size];
for (surf_idx, &conv_idx) in self.fmm.m2l.surf_to_conv_map.iter().enumerate() {
signal[conv_idx] = multipole[surf_idx]
Expand Down Expand Up @@ -878,7 +888,98 @@ where
>,
U: std::marker::Send + std::marker::Sync + Default,
{
fn p2l(&self, _level: u64) {}
fn p2l<'a>(&self, level: u64) {

let Some(targets) = self.fmm.tree().get_keys(level) else {
return;
};

let ncoeffs = self.fmm.m2l.ncoeffs(self.fmm.order);
let dim = self.fmm.kernel().space_dimension();
let surface_size = ncoeffs * dim;
let min_idx = self.fmm.tree().key_to_index.get(&targets[0]).unwrap();
let max_idx = self
.fmm
.tree()
.key_to_index
.get(targets.last().unwrap())
.unwrap();
let downward_surfaces =
&self.downward_surfaces[min_idx * surface_size..(max_idx + 1) * surface_size];
let coordinates = self.fmm.tree().get_all_coordinates().unwrap();

// assert_eq!(ntargets, downward_surfaces.len() / surface_size);
targets
.par_iter()
.zip(downward_surfaces.par_chunks_exact(surface_size))
.zip(self.level_locals[level as usize].par_iter())
.for_each(|((target, downward_surface), local_ptr)| {
// Find check potential
if let Some(x_list) = self.fmm.get_x_list(target) {
let x_list_indices = x_list
.iter()
.filter_map(|k| self.fmm.tree().get_leaf_index(k));
let charges = x_list_indices
.clone()
.map(|&idx| {
let index_pointer = &self.charge_index_pointer[idx];
&self.charges[index_pointer.0..index_pointer.1]
})
.collect_vec();

let sources_coordinates = x_list_indices
.into_iter()
.map(|&idx| {
let index_pointer = &self.charge_index_pointer[idx];
&coordinates[index_pointer.0 * dim..index_pointer.1 * dim]
})
.collect_vec();

let target_local =
unsafe { std::slice::from_raw_parts_mut(local_ptr.raw, ncoeffs) };

for (&charges, sources) in charges.iter().zip(sources_coordinates) {
let nsources = sources.len() / dim;
let sources = unsafe {
rlst_pointer_mat!['a, U, sources.as_ptr(), (nsources, dim), (dim, 1)]
}
.eval();

if nsources > 0 {
// if target.level() < self.fmm.tree.get_depth() && self.fmm.tree.leaves_set.contains(target) {
// let leaf_idx = self.fmm.tree().get_leaf_index(target).unwrap();

// let (l, r) = self.charge_index_pointer[*leaf_idx];

// // if r - l > 0 {
// // println!("RUNNING P2L {:?}", self.fmm.tree.key_to_index.get(target));
// // }
// }

let mut check_potential = rlst_col_vec![U, ncoeffs];
self.fmm.kernel.evaluate_st(
EvalType::Value,
sources.data(),
downward_surface,
charges,
check_potential.data_mut(),
);
let scale = self.fmm.kernel().scale(target.level());
let mut tmp = self
.fmm
.dc2e_inv_1
.dot(&self.fmm.dc2e_inv_2.dot(&check_potential));
tmp.data_mut().iter_mut().for_each(|val| *val *= scale);

target_local
.iter_mut()
.zip(tmp.data())
.for_each(|(r, &t)| *r += t);
}
}
}
});
}

fn m2l<'a>(&self, level: u64) {
let Some(sources) = self.fmm.tree().get_keys(level) else {
Expand Down
26 changes: 14 additions & 12 deletions fmm/src/fmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,13 +1069,13 @@ mod test {
let global_idxs = (0..npoints).collect_vec();
let charges = vec![1.0; npoints];

let order = 7;
let order = 6;
let alpha_inner = 1.05;
let alpha_outer = 2.95;
let adaptive = true;
let ncrit = 15;
let ncrit = 150;
let sphere = true;
let sparse = true;
let sparse = false;

let points;
if sphere {
Expand All @@ -1090,21 +1090,21 @@ mod test {
points.data(),
adaptive,
Some(ncrit),
Some(3),
None,
&global_idxs[..],
sparse,
);

let m2l_data =
FftFieldTranslationKiFmm::new(kernel.clone(), order, *tree.get_domain(), alpha_inner);

// let m2l_data = SvdFieldTranslationKiFmm::new(
// kernel.clone(),
// Some(1000),
// order,
// *tree.get_domain(),
// alpha_inner,
// );
let m2l_data = SvdFieldTranslationKiFmm::new(
kernel.clone(),
Some(1000),
order,
*tree.get_domain(),
alpha_inner,
);

let fmm = KiFmmLinear::new(order, alpha_inner, alpha_outer, kernel, tree, m2l_data);

Expand All @@ -1131,7 +1131,6 @@ mod test {
keys_l2.len(),
keys_l3.len()
);
println!("DEPTH {:?}={:?}", datatree.fmm.tree().get_depth(), depth);

// Test that direct computation is close to the FMM.
// let mut test_idx = 0;
Expand All @@ -1146,6 +1145,9 @@ mod test {
println!("test idx vec {:?}", test_idx_vec.len());
let test_idx = test_idx_vec[123];
let leaf = &datatree.fmm.tree().get_all_leaves().unwrap()[test_idx];
let leaf = &datatree.fmm.tree().get_all_keys().unwrap()[3316];
let leaf = &datatree.fmm.tree().get_all_keys().unwrap()[3305];
println!("DEPTH {:?}={:?}", datatree.fmm.tree().get_depth(), leaf.level());

// Test that all points are contained in some leaf
let mut total_points = 0;
Expand Down
4 changes: 2 additions & 2 deletions fmm/src/interaction_lists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ where
.neighbors()
.iter()
.flat_map(|n| n.children())
.filter(|nc| self.tree.get_all_leaves_set().contains(nc) && !leaf.is_adjacent(nc))
.filter(|nc| self.tree.get_all_keys_set().contains(nc) && !leaf.is_adjacent(nc))
.collect_vec();

if !w_list.is_empty() {
Expand All @@ -118,7 +118,7 @@ where
.parent()
.neighbors()
.into_iter()
.filter(|pn| self.tree.get_all_leaves_set().contains(pn) && !key.is_adjacent(pn))
.filter(|pn| self.tree.get_all_keys_set().contains(pn) && !key.is_adjacent(pn))
.collect_vec();

if !x_list.is_empty() {
Expand Down

0 comments on commit 4423ae2

Please sign in to comment.