diff --git a/fmm/examples/single_node.rs b/fmm/examples/single_node.rs index 1e62d371..42a55a06 100644 --- a/fmm/examples/single_node.rs +++ b/fmm/examples/single_node.rs @@ -4,13 +4,16 @@ use itertools::Itertools; use rlst_dense::traits::RawAccess; -use bempp_field::types::FftFieldTranslationKiFmm; +use bempp_field::types::{FftFieldTranslationKiFmm, SvdFieldTranslationKiFmm}; use bempp_fmm::{ charge::build_charge_dict, types::{FmmDataUniform, KiFmmLinear}, }; use bempp_kernel::laplace_3d::Laplace3dKernel; -use bempp_traits::{fmm::FmmLoop, tree::Tree}; +use bempp_traits::{ + fmm::{FmmLoop, M2LSetup}, + tree::Tree, +}; use bempp_tree::implementations::helpers::points_fixture; use bempp_tree::types::single_node::SingleNodeTree; @@ -25,27 +28,28 @@ fn main() { let order = 6; let alpha_inner = 1.05; let alpha_outer = 2.95; - let depth = 4; + let depth = 5; let tree = SingleNodeTree::new(points.data(), false, None, Some(depth), &global_idxs, true); let kernel = Laplace3dKernel::default(); let m2l_data: FftFieldTranslationKiFmm> = FftFieldTranslationKiFmm::new(kernel.clone(), order, *tree.get_domain(), alpha_inner); - // let m2l_data = SvdFieldTranslationKiFmm::new( - // kernel.clone(), - // Some(80), - // order, - // *tree.get_domain(), - // alpha_inner, - // ); + let m2l_data = SvdFieldTranslationKiFmm::new( + kernel.clone(), + Some(80), + order, + *tree.get_domain(), + alpha_inner, + ); let fmm = KiFmmLinear::new(order, alpha_inner, alpha_outer, kernel, tree, m2l_data); // Form charge dict, matching charges with their associated global indices let charge_dict = build_charge_dict(&global_idxs, &charges); - let datatree = FmmDataUniform::new(fmm, &charge_dict).unwrap(); + let mut datatree = FmmDataUniform::new(fmm, &charge_dict).unwrap(); + datatree.setup(); let s = Instant::now(); let times = datatree.run(true); diff --git a/fmm/examples/single_node_matrix.rs b/fmm/examples/single_node_matrix.rs index e025c227..e9c18fc1 100644 --- a/fmm/examples/single_node_matrix.rs +++ b/fmm/examples/single_node_matrix.rs @@ -1,6 +1,7 @@ use std::time::Instant; use bempp_fmm::types::FmmDataUniformMatrix; +use bempp_traits::fmm::M2LSetup; use itertools::Itertools; use rlst_dense::traits::RawAccess; @@ -23,8 +24,8 @@ fn main() { // Test matrix input let points = points_fixture::(npoints, None, None); - let ncharge_vecs = 3; - let depth = 4; + let ncharge_vecs = 8; + let depth = 5; let mut charge_mat = vec![vec![0.0; npoints]; ncharge_vecs]; charge_mat @@ -62,7 +63,8 @@ fn main() { .collect(); // Associate data with the FMM - let datatree = FmmDataUniformMatrix::new(fmm, &charge_dicts).unwrap(); + let mut datatree = FmmDataUniformMatrix::new(fmm, &charge_dicts).unwrap(); + datatree.setup(); let s = Instant::now(); let times = datatree.run(true); diff --git a/fmm/src/field_translation/source_to_target/fft.rs b/fmm/src/field_translation/source_to_target/fft.rs index 58d634d7..9cc6fce5 100644 --- a/fmm/src/field_translation/source_to_target/fft.rs +++ b/fmm/src/field_translation/source_to_target/fft.rs @@ -32,8 +32,34 @@ use crate::field_translation::hadamard::matmul8x8; /// Field translations defined on uniformly refined trees. pub mod uniform { + use bempp_traits::fmm::M2LSetup; + use super::*; + impl M2LSetup + for FmmDataUniform, T, FftFieldTranslationKiFmm, U>, U> + where + T: Kernel + + ScaleInvariantKernel + + std::marker::Send + + std::marker::Sync + + Default, + U: Scalar + Float + Default + std::marker::Send + std::marker::Sync + Fft, + Complex: Scalar, + Array, 2>, 2>: MatrixSvd, + { + fn setup(&mut self) -> &mut Self { + // let mut level_displacements = Vec::new(); + // for level in 2..=self.fmm.tree().get_depth() { + // level_displacements.push(self.displacements(level)); + // } + + // self.level_displacements = level_displacements; + self.is_setup = true; + self + } + } + impl FmmDataUniform, T, FftFieldTranslationKiFmm, U>, U> where T: Kernel @@ -373,9 +399,33 @@ pub mod uniform { /// Field translations defined on adaptively refined pub mod adaptive { + use bempp_traits::fmm::M2LSetup; use rlst_dense::rlst_array_from_slice2; use super::*; + impl M2LSetup + for FmmDataAdaptive, T, FftFieldTranslationKiFmm, U>, U> + where + T: Kernel + + ScaleInvariantKernel + + std::marker::Send + + std::marker::Sync + + Default, + U: Scalar + Float + Default + std::marker::Send + std::marker::Sync + Fft, + Complex: Scalar, + Array, 2>, 2>: MatrixSvd, + { + fn setup(&mut self) -> &mut Self { + // let mut level_displacements = Vec::new(); + // for level in 2..=self.fmm.tree().get_depth() { + // level_displacements.push(self.displacements(level)); + // } + + // self.level_displacements = level_displacements; + self.is_setup = true; + self + } + } impl FmmDataAdaptive, T, FftFieldTranslationKiFmm, U>, U> where diff --git a/fmm/src/field_translation/source_to_target/svd.rs b/fmm/src/field_translation/source_to_target/svd.rs index b32790ae..a5973ee8 100644 --- a/fmm/src/field_translation/source_to_target/svd.rs +++ b/fmm/src/field_translation/source_to_target/svd.rs @@ -28,10 +28,40 @@ use rlst_dense::{ /// Field translations for uniformly refined trees that take matrix input for charges. pub mod matrix { + use bempp_traits::fmm::M2LSetup; + use crate::types::{FmmDataUniformMatrix, KiFmmLinearMatrix}; use super::*; + impl M2LSetup + for FmmDataUniformMatrix< + KiFmmLinearMatrix, T, SvdFieldTranslationKiFmm, U>, + U, + > + where + T: Kernel + + ScaleInvariantKernel + + std::marker::Send + + std::marker::Sync + + Default, + U: Scalar + rlst_blis::interface::gemm::Gemm, + U: Float + Default, + U: std::marker::Send + std::marker::Sync + Default, + Array, 2>, 2>: MatrixSvd, + { + fn setup(&mut self) -> &mut Self { + let mut level_displacements = Vec::new(); + for level in 2..=self.fmm.tree().get_depth() { + level_displacements.push(self.displacements(level)); + } + + self.level_displacements = level_displacements; + self.is_setup = true; + self + } + } + impl FmmDataUniformMatrix< KiFmmLinearMatrix, T, SvdFieldTranslationKiFmm, U>, @@ -121,7 +151,8 @@ pub mod matrix { let nsources = sources.len(); - let all_displacements = self.displacements(level); + // let all_displacements = self.displacements(level); + let all_displacements = &self.level_displacements[(level - 2) as usize]; // Interpret multipoles as a matrix let multipoles = rlst_array_from_slice2!( @@ -276,8 +307,35 @@ pub mod matrix { } pub mod adaptive { + use bempp_traits::fmm::M2LSetup; + use super::*; + impl M2LSetup + for FmmDataAdaptive, T, SvdFieldTranslationKiFmm, U>, U> + where + T: Kernel + + ScaleInvariantKernel + + std::marker::Send + + std::marker::Sync + + Default, + U: Scalar + rlst_blis::interface::gemm::Gemm, + U: Float + Default, + U: std::marker::Send + std::marker::Sync + Default, + Array, 2>, 2>: MatrixSvd, + { + fn setup(&mut self) -> &mut Self { + let mut level_displacements = Vec::new(); + for level in 2..=self.fmm.tree().get_depth() { + level_displacements.push(self.displacements(level)); + } + + self.level_displacements = level_displacements; + self.is_setup = true; + self + } + } + impl FmmDataAdaptive, T, SvdFieldTranslationKiFmm, U>, U> where T: Kernel @@ -452,7 +510,8 @@ pub mod adaptive { }; let nsources = sources.len(); - let all_displacements = self.displacements(level); + // let all_displacements = self.displacements(level); + let all_displacements = &self.level_displacements[(level - 2) as usize]; // Interpret multipoles as a matrix let ncoeffs = self.fmm.m2l.ncoeffs(self.fmm.order); @@ -576,8 +635,35 @@ pub mod adaptive { } pub mod uniform { + use bempp_traits::fmm::M2LSetup; + use super::*; + impl M2LSetup + for FmmDataUniform, T, SvdFieldTranslationKiFmm, U>, U> + where + T: Kernel + + ScaleInvariantKernel + + std::marker::Send + + std::marker::Sync + + Default, + U: Scalar + rlst_blis::interface::gemm::Gemm, + U: Float + Default, + U: std::marker::Send + std::marker::Sync + Default, + Array, 2>, 2>: MatrixSvd, + { + fn setup(&mut self) -> &mut Self { + let mut level_displacements = Vec::new(); + for level in 2..=self.fmm.tree().get_depth() { + level_displacements.push(self.displacements(level)); + } + + self.level_displacements = level_displacements; + self.is_setup = true; + self + } + } + impl FmmDataUniform, T, SvdFieldTranslationKiFmm, U>, U> where T: Kernel @@ -622,6 +708,7 @@ pub mod uniform { for (i, tv) in self.fmm.m2l.transfer_vectors.iter().enumerate() { let mut all_displacements_lock = all_displacements[i].lock().unwrap(); + // let mut all_displacements_lock = all_displacements[i]; if transfer_vectors_set.contains(&tv.hash) { let target = &v_list[*transfer_vectors_map.get(&tv.hash).unwrap()]; @@ -660,7 +747,9 @@ pub mod uniform { let nsources = sources.len(); - let all_displacements = self.displacements(level); + // let all_displacements = self.displacements(level); + let all_displacements = &self.level_displacements[(level - 2) as usize]; + // let all_displacements = all_displacements.into_iter().map(Mutex::new).collect_vec(); // Interpret multipoles as a matrix let ncoeffs = self.fmm.m2l.ncoeffs(self.fmm.order); diff --git a/fmm/src/fmm.rs b/fmm/src/fmm.rs index 43c81fe1..95d73956 100644 --- a/fmm/src/fmm.rs +++ b/fmm/src/fmm.rs @@ -549,12 +549,17 @@ where } fn run(&self, time: bool) -> Option { - let t1 = self.upward_pass(time); - let t2 = self.downward_pass(time); - if let (Some(mut t1), Some(t2)) = (t1, t2) { - t1.extend(t2); - Some(t1) + if self.is_setup { + let t1 = self.upward_pass(time); + let t2 = self.downward_pass(time); + + if let (Some(mut t1), Some(t2)) = (t1, t2) { + t1.extend(t2); + Some(t1) + } else { + None + } } else { None } @@ -654,12 +659,17 @@ where } fn run(&self, time: bool) -> Option { - let t1 = self.upward_pass(time); - let t2 = self.downward_pass(time); - if let (Some(mut t1), Some(t2)) = (t1, t2) { - t1.extend(t2); - Some(t1) + if self.is_setup { + let t1 = self.upward_pass(time); + let t2 = self.downward_pass(time); + + if let (Some(mut t1), Some(t2)) = (t1, t2) { + t1.extend(t2); + Some(t1) + } else { + None + } } else { None } @@ -759,12 +769,16 @@ where } fn run(&self, time: bool) -> Option { - let t1 = self.upward_pass(time); - let t2 = self.downward_pass(time); - - if let (Some(mut t1), Some(t2)) = (t1, t2) { - t1.extend(t2); - Some(t1) + if self.is_setup { + let t1 = self.upward_pass(time); + let t2 = self.downward_pass(time); + + if let (Some(mut t1), Some(t2)) = (t1, t2) { + t1.extend(t2); + Some(t1) + } else { + None + } } else { None } diff --git a/fmm/src/types.rs b/fmm/src/types.rs index afbcfa8f..ecd8cb6a 100644 --- a/fmm/src/types.rs +++ b/fmm/src/types.rs @@ -1,5 +1,6 @@ //! Data structures FMM data and metadata. use std::collections::HashMap; +use std::sync::Mutex; use bempp_traits::fmm::KiFmm; use bempp_traits::kernel::ScaleInvariantKernel; @@ -27,6 +28,8 @@ where T: Fmm, U: Scalar + Float + Default, { + pub is_setup: bool, + /// The associated FMM object, which implements an FMM interface pub fmm: T, @@ -69,6 +72,8 @@ where /// Leaf downward surfaces pub leaf_downward_surfaces: Vec, + pub level_displacements: Vec>>>, + /// The charge data at each leaf box. pub charges: Vec, @@ -87,6 +92,8 @@ where T: Fmm, U: Scalar + Float + Default, { + pub is_setup: bool, + /// The associated FMM object, which implements an FMM interface pub fmm: T, @@ -152,6 +159,8 @@ where /// Global indices of each charge pub global_indices: Vec, + + pub level_displacements: Vec>>>, } pub struct FmmDataAdaptive @@ -159,6 +168,8 @@ where T: Fmm, U: Scalar + Float + Default, { + pub is_setup: bool, + /// The associated FMM object, which implements an FMM interface pub fmm: T, @@ -212,6 +223,8 @@ where /// Global indices of each charge pub global_indices: Vec, + + pub level_displacements: Vec>>>, } /// Type to store data associated with the kernel independent (KiFMM) in. @@ -494,6 +507,7 @@ where } return Ok(Self { + is_setup: false, fmm, multipoles, level_multipoles, @@ -512,6 +526,7 @@ where charge_index_pointer, scales, global_indices, + level_displacements: Vec::default(), }); } @@ -740,6 +755,7 @@ where } return Ok(Self { + is_setup: false, fmm, multipoles, level_multipoles, @@ -762,6 +778,7 @@ where ncoeffs, scales, global_indices, + level_displacements: Vec::default(), }); } else { return Err("Not a uniform tree".to_string()); @@ -934,6 +951,7 @@ where } return Ok(Self { + is_setup: false, fmm, multipoles, level_multipoles, @@ -952,6 +970,7 @@ where charge_index_pointer, scales, global_indices, + level_displacements: Vec::default(), }); } else { return Err("Not an adaptive tree".to_string()); diff --git a/traits/src/fmm.rs b/traits/src/fmm.rs index f6a6114c..49c5b014 100644 --- a/traits/src/fmm.rs +++ b/traits/src/fmm.rs @@ -66,6 +66,10 @@ where fn alpha_outer(&self) -> <::Kernel as Kernel>::T; } +pub trait M2LSetup { + fn setup(&mut self) -> &mut Self; +} + /// Dictionary containing timings pub type TimeDict = HashMap;