Skip to content

Commit

Permalink
ENH: Add an FMM based on SVD field translations operating on an under…
Browse files Browse the repository at this point in the history
…lying linear data structure. (#139)

* Add something, not converging

* Fix up tests for linear svd based fmm

* Style checks and tests

* Add style checks
  • Loading branch information
skailasa authored Dec 5, 2023
1 parent cd06e06 commit faaa457
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 10 deletions.
1 change: 1 addition & 0 deletions fmm/src/field_translation/hashmap/source_to_target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ where
.fmm
.dc2e_inv_1
.dot(&self.fmm.dc2e_inv_2.dot(&check_potential_owned));

tmp.data_mut()
.iter_mut()
.for_each(|d| *d *= self.fmm.kernel.scale(level) * self.m2l_scale(level));
Expand Down
99 changes: 92 additions & 7 deletions fmm/src/field_translation/linear/source_to_target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use bempp_tools::Array3D;
use itertools::Itertools;
use num::{Complex, Float};
use rayon::prelude::*;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use bempp_field::{
array::pad3,
Expand Down Expand Up @@ -387,9 +387,7 @@ where
let scale = Complex::from(self.m2l_scale(level));

let kernel_data_halo = &self.fmm.m2l.operator_data.kernel_data_rearranged;
// println!("level {:?} pre processing time {:?} ", level, s.elapsed());

// let s = Instant::now();
m2l_cplx_chunked(
self.fmm.order,
level as usize,
Expand All @@ -401,15 +399,13 @@ where
chunksize,
scale,
);
// println!("level {:?} kernel time {:?} ", level, s.elapsed());

U::irfft_fftw_par_vec(
&mut global_check_potentials_hat,
&mut global_check_potentials,
&[p, q, r],
);

// let s = Instant::now();
// Compute local expansion coefficients and save to data tree
let (_, multi_indices) = MortonKey::surface_grid::<U>(self.fmm.order);

Expand Down Expand Up @@ -464,7 +460,6 @@ where
ptr = ptr.add(1)
}
});
// println!("level {:?} post processing time {:?} ", level, s.elapsed());
}

fn m2l_scale(&self, level: u64) -> U {
Expand Down Expand Up @@ -506,9 +501,99 @@ where
U: std::marker::Send + std::marker::Sync + Default,
{
fn m2l<'a>(&self, level: u64) {
let Some(_targets) = self.fmm.tree().get_keys(level) else {
let Some(sources) = self.fmm.tree().get_keys(level) else {
return;
};

let nsources = sources.len();

let mut source_map = HashMap::new();

for (i, t) in sources.iter().enumerate() {
source_map.insert(t, i);
}

let mut target_indices = vec![vec![-1i64; nsources]; 316];

// Need to identify all save locations in a pre-processing step.
for (j, source) in sources.iter().enumerate() {
let v_list = source
.parent()
.neighbors()
.iter()
.flat_map(|pn| pn.children())
.filter(|pnc| !source.is_adjacent(pnc))
.collect_vec();

let transfer_vectors = v_list
.iter()
.map(|target| target.find_transfer_vector(source))
.collect_vec();

let mut transfer_vectors_map = HashMap::new();
for (i, v) in transfer_vectors.iter().enumerate() {
transfer_vectors_map.insert(v, i);
}

let transfer_vectors_set: HashSet<_> = transfer_vectors.iter().collect();

for (i, tv) in self.fmm.m2l.transfer_vectors.iter().enumerate() {
if transfer_vectors_set.contains(&tv.hash) {
let target = &v_list[*transfer_vectors_map.get(&tv.hash).unwrap()];
let target_index = source_map.get(target).unwrap();
target_indices[i][j] = *target_index as i64;
}
}
}

// Interpret multipoles as a matrix
let ncoeffs = self.fmm.m2l.ncoeffs(self.fmm.order);
let multipoles = unsafe {
rlst_pointer_mat!['a, U, self.level_multipoles[level as usize][0].raw, (ncoeffs, nsources), (1, ncoeffs)]
};

let (nrows, _) = self.fmm.m2l.operator_data.c.shape();
let dim = (nrows, self.fmm.m2l.k);

let mut compressed_multipoles = self.fmm.m2l.operator_data.st_block.dot(&multipoles);

compressed_multipoles
.data_mut()
.iter_mut()
.for_each(|d| *d *= self.fmm.kernel.scale(level) * self.m2l_scale(level));

(0..316).into_par_iter().for_each(|c_idx| {
let top_left = (0, c_idx * self.fmm.m2l.k);
let c_sub = self.fmm.m2l.operator_data.c.block(top_left, dim);

let locals = self.fmm.dc2e_inv_1.dot(
&self.fmm.dc2e_inv_2.dot(
&self
.fmm
.m2l
.operator_data
.u
.dot(&c_sub.dot(&compressed_multipoles)),
),
);

let displacements = &target_indices[c_idx];

for (result_idx, &save_idx) in displacements.iter().enumerate() {
if save_idx > -1 {
let save_idx = save_idx as usize;
let mut local_ptr = self.level_locals[(level) as usize][save_idx].raw;
let res = &locals.data()[result_idx * ncoeffs..(result_idx + 1) * ncoeffs];

unsafe {
for &r in res.iter() {
*local_ptr += r;
local_ptr = local_ptr.add(1);
}
}
}
}
})
}

fn m2l_scale(&self, level: u64) -> U {
Expand Down
2 changes: 1 addition & 1 deletion fmm/src/fmm/hashmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ mod test {
datatree.run(false);

// Test that direct computation is close to the FMM.
let leaf = &datatree.fmm.tree.get_keys(depth).unwrap()[0];
let leaf = &datatree.fmm.tree.get_all_leaves().unwrap()[0];

let potentials = datatree.potentials.get(leaf).unwrap().lock().unwrap();
let pts = datatree.fmm.tree().get_points(leaf).unwrap();
Expand Down
90 changes: 88 additions & 2 deletions fmm/src/fmm/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ mod test {

use super::*;

use bempp_field::types::FftFieldTranslationKiFmm;
use bempp_field::types::{FftFieldTranslationKiFmm, SvdFieldTranslationKiFmm};
use bempp_kernel::laplace_3d::Laplace3dKernel;
use bempp_tree::implementations::helpers::points_fixture;

Expand Down Expand Up @@ -678,7 +678,7 @@ mod test {
}

#[test]
fn test_fmm_linear() {
fn test_fmm_linear_fft_f64() {
let npoints = 10000;
let points = points_fixture::<f64>(npoints, None, None);
let global_idxs = (0..npoints).collect_vec();
Expand Down Expand Up @@ -755,4 +755,90 @@ mod test {

assert!(rel_error <= 1e-6);
}

#[test]
fn test_fmm_linear_svd_f64() {
let npoints = 10000;
let points = points_fixture::<f64>(npoints, None, None);
let global_idxs = (0..npoints).collect_vec();
let charges = vec![1.0; npoints];

let order = 6;
let alpha_inner = 1.05;
let alpha_outer = 2.95;
let adaptive = false;
let ncrit = 150;

let depth = 3;
let kernel = Laplace3dKernel::default();

let tree = SingleNodeTree::new(
points.data(),
adaptive,
Some(ncrit),
Some(depth),
&global_idxs[..],
);

let m2l_data_svd = SvdFieldTranslationKiFmm::new(
kernel.clone(),
Some(1000),
order,
*tree.get_domain(),
alpha_inner,
);

let fmm = KiFmmLinear::new(order, alpha_inner, alpha_outer, kernel, tree, m2l_data_svd);

// Form charge dict, matching charges with their associated global indices
let charge_dict = build_charge_dict(&global_idxs[..], &charges[..]);

let datatree = FmmDataLinear::new(fmm, &charge_dict).unwrap();

datatree.run(false);

// Test that direct computation is close to the FMM.
let leaf = &datatree.fmm.tree.get_all_leaves().unwrap()[0];
let leaf_idx = datatree.fmm.tree().get_leaf_index(leaf).unwrap();

let (l, r) = datatree.charge_index_pointer[*leaf_idx];

let potentials = &datatree.potentials[l..r];

let coordinates = datatree.fmm.tree().get_all_coordinates().unwrap();
let (l, r) = datatree.charge_index_pointer[*leaf_idx];
let leaf_coordinates = &coordinates[l * 3..r * 3];

let ntargets = leaf_coordinates.len() / datatree.fmm.kernel.space_dimension();

let leaf_coordinates = unsafe {
rlst_pointer_mat!['static, f64, leaf_coordinates.as_ptr(), (ntargets, datatree.fmm.kernel.space_dimension()), (datatree.fmm.kernel.space_dimension(), 1)]
}.eval();

let mut direct = vec![0f64; ntargets];
let all_point_coordinates = points_fixture::<f64>(npoints, None, None);

let all_charges = charge_dict.into_values().collect_vec();

let kernel = Laplace3dKernel::default();

kernel.evaluate_st(
EvalType::Value,
all_point_coordinates.data(),
leaf_coordinates.data(),
&all_charges[..],
&mut direct[..],
);

let abs_error: f64 = potentials
.iter()
.zip(direct.iter())
.map(|(a, b)| (a - b).abs())
.sum();
let rel_error: f64 = abs_error / (direct.iter().sum::<f64>());

println!("rel_error {:?}", rel_error);

assert!(rel_error <= 1e-6);
}
}

0 comments on commit faaa457

Please sign in to comment.