Skip to content

Commit

Permalink
Add one more explicit displacement function
Browse files Browse the repository at this point in the history
  • Loading branch information
skailasa committed Jan 19, 2024
1 parent bf2b27c commit d5ab055
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 42 deletions.
103 changes: 62 additions & 41 deletions fmm/src/field_translation/source_to_target/svd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,67 @@ pub mod matrix {
pub mod adaptive {
use super::*;

impl<T, U> FmmDataAdaptive<KiFmmLinear<SingleNodeTree<U>, T, SvdFieldTranslationKiFmm<U, T>, U>, U>
where
T: Kernel<T = U>
+ ScaleInvariantKernel<T = U>
+ std::marker::Send
+ std::marker::Sync
+ Default,
U: Scalar<Real = U> + rlst_blis::interface::gemm::Gemm,
U: Float + Default,
U: std::marker::Send + std::marker::Sync + Default,
Array<U, BaseArray<U, VectorContainer<U>, 2>, 2>: MatrixSvd<Item = U>,
{
fn displacements(&self, level: u64) -> Vec<Mutex<Vec<i64>>> {
let sources = self.fmm.tree().get_keys(level).unwrap();
let nsources = sources.len();

let mut source_map = HashMap::new();

for (i, t) in sources.iter().enumerate() {
source_map.insert(t, i);
}

let all_displacements = vec![vec![-1i64; nsources]; 316];
let all_displacements = all_displacements.into_iter().map(Mutex::new).collect_vec();

sources.into_par_iter().enumerate().for_each(|(j, source)| {
let v_list = source
.parent()
.neighbors()
.iter()
.flat_map(|pn| pn.children())
.filter(|pnc| {
!source.is_adjacent(pnc) && self.fmm.tree().get_all_keys_set().contains(pnc)
})
.collect_vec();

let transfer_vectors = v_list
.iter()
.map(|target| target.find_transfer_vector(source))
.collect_vec();

let mut transfer_vectors_map = HashMap::new();
for (i, v) in transfer_vectors.iter().enumerate() {
transfer_vectors_map.insert(v, i);
}

let transfer_vectors_set: HashSet<_> = transfer_vectors.iter().cloned().collect();
for (i, tv) in self.fmm.m2l.transfer_vectors.iter().enumerate() {
let mut all_displacements_lock = all_displacements[i].lock().unwrap();
if transfer_vectors_set.contains(&tv.hash) {
let target = &v_list[*transfer_vectors_map.get(&tv.hash).unwrap()];
let target_index = source_map.get(target).unwrap();
all_displacements_lock[j] = *target_index as i64;
}
}
});

all_displacements
}
}

/// Implement the multipole to local translation operator for an SVD accelerated KiFMM on a single node.
impl<T, U> FieldTranslation<U>
for FmmDataAdaptive<KiFmmLinear<SingleNodeTree<U>, T, SvdFieldTranslationKiFmm<U, T>, U>, U>
Expand Down Expand Up @@ -391,47 +452,7 @@ pub mod adaptive {
};

let nsources = sources.len();

let mut source_map = HashMap::new();

for (i, t) in sources.iter().enumerate() {
source_map.insert(t, i);
}

let all_displacements = vec![vec![-1i64; nsources]; 316];
let all_displacements = all_displacements.into_iter().map(Mutex::new).collect_vec();

sources.into_par_iter().enumerate().for_each(|(j, source)| {
let v_list = source
.parent()
.neighbors()
.iter()
.flat_map(|pn| pn.children())
.filter(|pnc| {
!source.is_adjacent(pnc) && self.fmm.tree().get_all_keys_set().contains(pnc)
})
.collect_vec();

let transfer_vectors = v_list
.iter()
.map(|target| target.find_transfer_vector(source))
.collect_vec();

let mut transfer_vectors_map = HashMap::new();
for (i, v) in transfer_vectors.iter().enumerate() {
transfer_vectors_map.insert(v, i);
}

let transfer_vectors_set: HashSet<_> = transfer_vectors.iter().cloned().collect();
for (i, tv) in self.fmm.m2l.transfer_vectors.iter().enumerate() {
let mut all_displacements_lock = all_displacements[i].lock().unwrap();
if transfer_vectors_set.contains(&tv.hash) {
let target = &v_list[*transfer_vectors_map.get(&tv.hash).unwrap()];
let target_index = source_map.get(target).unwrap();
all_displacements_lock[j] = *target_index as i64;
}
}
});
let all_displacements = self.displacements(level);

// Interpret multipoles as a matrix
let ncoeffs = self.fmm.m2l.ncoeffs(self.fmm.order);
Expand Down
2 changes: 1 addition & 1 deletion fmm/src/fmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ mod test {

// Test matrix input
let points = points_fixture::<f64>(npoints, None, None);
let ncharge_vecs = 6;
let ncharge_vecs = 3;

let mut charge_mat = vec![vec![0.0; npoints]; ncharge_vecs];
charge_mat
Expand Down

0 comments on commit d5ab055

Please sign in to comment.