From d5ab055a10205d3bee308203708b9ce068152ca0 Mon Sep 17 00:00:00 2001 From: Srinath Kailasa Date: Fri, 19 Jan 2024 11:10:34 +0000 Subject: [PATCH] Add one more explicit displacement function --- .../field_translation/source_to_target/svd.rs | 103 +++++++++++------- fmm/src/fmm.rs | 2 +- 2 files changed, 63 insertions(+), 42 deletions(-) diff --git a/fmm/src/field_translation/source_to_target/svd.rs b/fmm/src/field_translation/source_to_target/svd.rs index 664e9972..93dcbb1d 100644 --- a/fmm/src/field_translation/source_to_target/svd.rs +++ b/fmm/src/field_translation/source_to_target/svd.rs @@ -278,6 +278,67 @@ pub mod matrix { pub mod adaptive { use super::*; + impl FmmDataAdaptive, T, SvdFieldTranslationKiFmm, U>, U> + where + T: Kernel + + ScaleInvariantKernel + + std::marker::Send + + std::marker::Sync + + Default, + U: Scalar + rlst_blis::interface::gemm::Gemm, + U: Float + Default, + U: std::marker::Send + std::marker::Sync + Default, + Array, 2>, 2>: MatrixSvd, + { + fn displacements(&self, level: u64) -> Vec>> { + let sources = self.fmm.tree().get_keys(level).unwrap(); + let nsources = sources.len(); + + let mut source_map = HashMap::new(); + + for (i, t) in sources.iter().enumerate() { + source_map.insert(t, i); + } + + let all_displacements = vec![vec![-1i64; nsources]; 316]; + let all_displacements = all_displacements.into_iter().map(Mutex::new).collect_vec(); + + sources.into_par_iter().enumerate().for_each(|(j, source)| { + let v_list = source + .parent() + .neighbors() + .iter() + .flat_map(|pn| pn.children()) + .filter(|pnc| { + !source.is_adjacent(pnc) && self.fmm.tree().get_all_keys_set().contains(pnc) + }) + .collect_vec(); + + let transfer_vectors = v_list + .iter() + .map(|target| target.find_transfer_vector(source)) + .collect_vec(); + + let mut transfer_vectors_map = HashMap::new(); + for (i, v) in transfer_vectors.iter().enumerate() { + transfer_vectors_map.insert(v, i); + } + + let transfer_vectors_set: HashSet<_> = transfer_vectors.iter().cloned().collect(); + for (i, tv) in self.fmm.m2l.transfer_vectors.iter().enumerate() { + let mut all_displacements_lock = all_displacements[i].lock().unwrap(); + if transfer_vectors_set.contains(&tv.hash) { + let target = &v_list[*transfer_vectors_map.get(&tv.hash).unwrap()]; + let target_index = source_map.get(target).unwrap(); + all_displacements_lock[j] = *target_index as i64; + } + } + }); + + all_displacements + } + } + /// Implement the multipole to local translation operator for an SVD accelerated KiFMM on a single node. impl FieldTranslation for FmmDataAdaptive, T, SvdFieldTranslationKiFmm, U>, U> @@ -391,47 +452,7 @@ pub mod adaptive { }; let nsources = sources.len(); - - let mut source_map = HashMap::new(); - - for (i, t) in sources.iter().enumerate() { - source_map.insert(t, i); - } - - let all_displacements = vec![vec![-1i64; nsources]; 316]; - let all_displacements = all_displacements.into_iter().map(Mutex::new).collect_vec(); - - sources.into_par_iter().enumerate().for_each(|(j, source)| { - let v_list = source - .parent() - .neighbors() - .iter() - .flat_map(|pn| pn.children()) - .filter(|pnc| { - !source.is_adjacent(pnc) && self.fmm.tree().get_all_keys_set().contains(pnc) - }) - .collect_vec(); - - let transfer_vectors = v_list - .iter() - .map(|target| target.find_transfer_vector(source)) - .collect_vec(); - - let mut transfer_vectors_map = HashMap::new(); - for (i, v) in transfer_vectors.iter().enumerate() { - transfer_vectors_map.insert(v, i); - } - - let transfer_vectors_set: HashSet<_> = transfer_vectors.iter().cloned().collect(); - for (i, tv) in self.fmm.m2l.transfer_vectors.iter().enumerate() { - let mut all_displacements_lock = all_displacements[i].lock().unwrap(); - if transfer_vectors_set.contains(&tv.hash) { - let target = &v_list[*transfer_vectors_map.get(&tv.hash).unwrap()]; - let target_index = source_map.get(target).unwrap(); - all_displacements_lock[j] = *target_index as i64; - } - } - }); + let all_displacements = self.displacements(level); // Interpret multipoles as a matrix let ncoeffs = self.fmm.m2l.ncoeffs(self.fmm.order); diff --git a/fmm/src/fmm.rs b/fmm/src/fmm.rs index a07c95ab..d943d0c3 100644 --- a/fmm/src/fmm.rs +++ b/fmm/src/fmm.rs @@ -1369,7 +1369,7 @@ mod test { // Test matrix input let points = points_fixture::(npoints, None, None); - let ncharge_vecs = 6; + let ncharge_vecs = 3; let mut charge_mat = vec![vec![0.0; npoints]; ncharge_vecs]; charge_mat