Skip to content

Commit

Permalink
Nits
Browse files Browse the repository at this point in the history
  • Loading branch information
skailasa committed Nov 13, 2023
1 parent d08e676 commit 956e54d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 53 deletions.
90 changes: 45 additions & 45 deletions fmm/src/fmm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,28 @@ use crate::pinv::{pinv, SvdScalar};
use crate::types::{C2EType, ChargeDict, FmmData, KiFmm};

/// Implementation of constructor for single node KiFMM
impl<'a, T, U, W> KiFmm<SingleNodeTree<W>, T, U, W>
impl<'a, T, U, V> KiFmm<SingleNodeTree<V>, T, U, V>
where
T: Kernel<T = W> + KernelScale<T = W>,
T: Kernel<T = V> + KernelScale<T = V>,
U: FieldTranslationData<T>,
W: Scalar<Real = W> + Default + Float,
SvdScalar<W>: PartialOrd,
SvdScalar<W>: Scalar + Float + ToPrimitive,
DenseMatrixLinAlgBuilder<W>: Svd,
W: MultiplyAdd<
W,
VectorContainer<W>,
VectorContainer<W>,
VectorContainer<W>,
V: Scalar<Real = V> + Default + Float,
SvdScalar<V>: PartialOrd,
SvdScalar<V>: Scalar + Float + ToPrimitive,
DenseMatrixLinAlgBuilder<V>: Svd,
V: MultiplyAdd<
V,
VectorContainer<V>,
VectorContainer<V>,
VectorContainer<V>,
Dynamic,
Dynamic,
Dynamic,
>,
SvdScalar<W>: MultiplyAdd<
SvdScalar<W>,
VectorContainer<SvdScalar<W>>,
VectorContainer<SvdScalar<W>>,
VectorContainer<SvdScalar<W>>,
SvdScalar<V>: MultiplyAdd<
SvdScalar<V>,
VectorContainer<SvdScalar<V>>,
VectorContainer<SvdScalar<V>>,
VectorContainer<SvdScalar<V>>,
Dynamic,
Dynamic,
Dynamic,
Expand All @@ -69,10 +69,10 @@ where
/// * `m2l` - The M2L operator matrices, as well as metadata associated with this FMM.
pub fn new(
order: usize,
alpha_inner: W,
alpha_outer: W,
alpha_inner: V,
alpha_outer: V,
kernel: T,
tree: SingleNodeTree<W>,
tree: SingleNodeTree<V>,
m2l: U,
) -> Self {
let upward_equivalent_surface = ROOT.compute_surface(tree.get_domain(), order, alpha_inner);
Expand All @@ -86,21 +86,21 @@ where

// Store in RLST matrices
let upward_equivalent_surface = unsafe {
rlst_pointer_mat!['a, <W as cauchy::Scalar>::Real, upward_equivalent_surface.as_ptr(), (nequiv_surface, kernel.space_dimension()), (1, nequiv_surface)]
rlst_pointer_mat!['a, <V as cauchy::Scalar>::Real, upward_equivalent_surface.as_ptr(), (nequiv_surface, kernel.space_dimension()), (1, nequiv_surface)]
};
let upward_check_surface = unsafe {
rlst_pointer_mat!['a, <W as cauchy::Scalar>::Real, upward_check_surface.as_ptr(), (ncheck_surface, kernel.space_dimension()), (1, ncheck_surface)]
rlst_pointer_mat!['a, <V as cauchy::Scalar>::Real, upward_check_surface.as_ptr(), (ncheck_surface, kernel.space_dimension()), (1, ncheck_surface)]
};
let downward_equivalent_surface = unsafe {
rlst_pointer_mat!['a, <W as cauchy::Scalar>::Real, downward_equivalent_surface.as_ptr(), (nequiv_surface, kernel.space_dimension()), (1, nequiv_surface)]
rlst_pointer_mat!['a, <V as cauchy::Scalar>::Real, downward_equivalent_surface.as_ptr(), (nequiv_surface, kernel.space_dimension()), (1, nequiv_surface)]
};
let downward_check_surface = unsafe {
rlst_pointer_mat!['a, <W as cauchy::Scalar>::Real, downward_check_surface.as_ptr(), (ncheck_surface, kernel.space_dimension()), (1, ncheck_surface)]
rlst_pointer_mat!['a, <V as cauchy::Scalar>::Real, downward_check_surface.as_ptr(), (ncheck_surface, kernel.space_dimension()), (1, ncheck_surface)]
};

// Compute upward check to equivalent, and downward check to equivalent Gram matrices
// as well as their inverses using DGESVD.
let mut uc2e = rlst_dynamic_mat![W, (ncheck_surface, nequiv_surface)];
let mut uc2e = rlst_dynamic_mat![V, (ncheck_surface, nequiv_surface)];
kernel.assemble_st(
EvalType::Value,
upward_equivalent_surface.data(),
Expand All @@ -111,7 +111,7 @@ where
// Need to tranapose so that rows correspond to targets and columns to sources
let uc2e = uc2e.transpose().eval();

let mut dc2e = rlst_dynamic_mat![W, (ncheck_surface, nequiv_surface)];
let mut dc2e = rlst_dynamic_mat![V, (ncheck_surface, nequiv_surface)];
kernel.assemble_st(
EvalType::Value,
downward_equivalent_surface.data(),
Expand All @@ -122,11 +122,11 @@ where
// Need to tranapose so that rows correspond to targets and columns to sources
let dc2e = dc2e.transpose().eval();

let (s, ut, v) = pinv::<W>(&uc2e, None, None).unwrap();
let (s, ut, v) = pinv::<V>(&uc2e, None, None).unwrap();

let mut mat_s = rlst_dynamic_mat![SvdScalar<W>, (s.len(), s.len())];
let mut mat_s = rlst_dynamic_mat![SvdScalar<V>, (s.len(), s.len())];
for i in 0..s.len() {
mat_s[[i, i]] = SvdScalar::<W>::from_real(s[i]);
mat_s[[i, i]] = SvdScalar::<V>::from_real(s[i]);
}
let uc2e_inv_1 = v.dot(&mat_s);
let uc2e_inv_2 = ut;
Expand All @@ -137,27 +137,27 @@ where
let uc2e_inv_1 = uc2e_inv_1
.data()
.iter()
.map(|x| W::from(*x).unwrap())
.map(|x| V::from(*x).unwrap())
.collect_vec();
let uc2e_inv_1 = unsafe {
rlst_pointer_mat!['a, W, uc2e_inv_1.as_ptr(), uc2e_inv_1_shape, (1, uc2e_inv_1_shape.0)]
rlst_pointer_mat!['a, V, uc2e_inv_1.as_ptr(), uc2e_inv_1_shape, (1, uc2e_inv_1_shape.0)]
}
.eval();
let uc2e_inv_2 = uc2e_inv_2
.data()
.iter()
.map(|x| W::from(*x).unwrap())
.map(|x| V::from(*x).unwrap())
.collect_vec();
let uc2e_inv_2 = unsafe {
rlst_pointer_mat!['a, W, uc2e_inv_2.as_ptr(), uc2e_inv_2_shape, (1, uc2e_inv_2_shape.0)]
rlst_pointer_mat!['a, V, uc2e_inv_2.as_ptr(), uc2e_inv_2_shape, (1, uc2e_inv_2_shape.0)]
}
.eval();

let (s, ut, v) = pinv::<W>(&dc2e, None, None).unwrap();
let (s, ut, v) = pinv::<V>(&dc2e, None, None).unwrap();

let mut mat_s = rlst_dynamic_mat![SvdScalar<W>, (s.len(), s.len())];
let mut mat_s = rlst_dynamic_mat![SvdScalar<V>, (s.len(), s.len())];
for i in 0..s.len() {
mat_s[[i, i]] = SvdScalar::<W>::from_real(s[i]);
mat_s[[i, i]] = SvdScalar::<V>::from_real(s[i]);
}

let dc2e_inv_1 = v.dot(&mat_s);
Expand All @@ -169,40 +169,40 @@ where
let dc2e_inv_1 = dc2e_inv_1
.data()
.iter()
.map(|x| W::from(*x).unwrap())
.map(|x| V::from(*x).unwrap())
.collect_vec();
let dc2e_inv_1 = unsafe {
rlst_pointer_mat!['a, W, dc2e_inv_1.as_ptr(), dc2e_inv_1_shape, (1, dc2e_inv_1_shape.0)]
rlst_pointer_mat!['a, V, dc2e_inv_1.as_ptr(), dc2e_inv_1_shape, (1, dc2e_inv_1_shape.0)]
}
.eval();
let dc2e_inv_2 = dc2e_inv_2
.data()
.iter()
.map(|x| W::from(*x).unwrap())
.map(|x| V::from(*x).unwrap())
.collect_vec();
let dc2e_inv_2 = unsafe {
rlst_pointer_mat!['a, W, dc2e_inv_2.as_ptr(), dc2e_inv_2_shape, (1, dc2e_inv_2_shape.0)]
rlst_pointer_mat!['a, V, dc2e_inv_2.as_ptr(), dc2e_inv_2_shape, (1, dc2e_inv_2_shape.0)]
}
.eval();

// Calculate M2M/L2L matrices
let children = ROOT.children();
let mut m2m: Vec<C2EType<W>> = Vec::new();
let mut l2l: Vec<C2EType<W>> = Vec::new();
let mut m2m: Vec<C2EType<V>> = Vec::new();
let mut l2l: Vec<C2EType<V>> = Vec::new();

for child in children.iter() {
let child_upward_equivalent_surface =
child.compute_surface(tree.get_domain(), order, alpha_inner);
let child_downward_check_surface =
child.compute_surface(tree.get_domain(), order, alpha_inner);
let child_upward_equivalent_surface = unsafe {
rlst_pointer_mat!['a, <W as cauchy::Scalar>::Real, child_upward_equivalent_surface.as_ptr(), (nequiv_surface, kernel.space_dimension()), (1, nequiv_surface)]
rlst_pointer_mat!['a, <V as cauchy::Scalar>::Real, child_upward_equivalent_surface.as_ptr(), (nequiv_surface, kernel.space_dimension()), (1, nequiv_surface)]
};
let child_downward_check_surface = unsafe {
rlst_pointer_mat!['a, <W as cauchy::Scalar>::Real, child_downward_check_surface.as_ptr(), (ncheck_surface, kernel.space_dimension()), (1, ncheck_surface)]
rlst_pointer_mat!['a, <V as cauchy::Scalar>::Real, child_downward_check_surface.as_ptr(), (ncheck_surface, kernel.space_dimension()), (1, ncheck_surface)]
};

let mut pc2ce = rlst_dynamic_mat![W, (ncheck_surface, nequiv_surface)];
let mut pc2ce = rlst_dynamic_mat![V, (ncheck_surface, nequiv_surface)];

kernel.assemble_st(
EvalType::Value,
Expand All @@ -216,7 +216,7 @@ where

m2m.push(uc2e_inv_1.dot(&uc2e_inv_2.dot(&pc2ce)).eval());

let mut cc2pe = rlst_dynamic_mat![W, (ncheck_surface, nequiv_surface)];
let mut cc2pe = rlst_dynamic_mat![V, (ncheck_surface, nequiv_surface)];

kernel.assemble_st(
EvalType::Value,
Expand Down
8 changes: 0 additions & 8 deletions fmm/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,3 @@ where
/// The M2L operator matrices, as well as metadata associated with this FMM.
pub m2l: V,
}

pub trait SameType {
type Other;
}

impl<T> SameType for T {
type Other = T;
}

0 comments on commit 956e54d

Please sign in to comment.