Skip to content

Commit

Permalink
Add svd field translation tests too
Browse files Browse the repository at this point in the history
  • Loading branch information
skailasa committed Oct 12, 2023
1 parent 186abe8 commit 114ba7a
Showing 1 changed file with 83 additions and 12 deletions.
95 changes: 83 additions & 12 deletions field/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,18 +617,92 @@ mod test {
assert_eq!(m2l.kernel_data.keys().len(), 16);
}

#[test]
fn test_svd_field_translation() {
let kernel = Laplace3dKernel::new();
let order: usize = 2;

let domain = Domain {
origin: [0., 0., 0.],
diameter: [1., 1., 1.],
};
let alpha = 1.05;

// Some expansion data
let ncoeffs = 6 * (order - 1).pow(2) + 2;
let mut multipole = rlst_mat![f64, (ncoeffs, 1)];

for i in 0..ncoeffs {
*multipole.get_mut(i, 0).unwrap() = i as f64;
}

// Create field translation object
let svd = SvdFieldTranslationKiFmm::new(kernel, Some(1000), order, domain, alpha);

// Pick a random source/target pair
let idx = 153;
let (all_transfer_vectors, _) = compute_transfer_vectors();

let transfer_vector = &all_transfer_vectors[idx];

// Lookup correct components of SVD compressed M2L operator matrix
let c_idx = svd
.transfer_vectors
.iter()
.position(|x| x.hash == transfer_vector.hash)
.unwrap();

let (nrows, _) = svd.operator_data.c.shape();
let top_left = (0, c_idx * svd.k);
let dim = (nrows, svd.k);

let c_sub = svd.operator_data.c.block(top_left, dim);

let compressed_multipole = svd.operator_data.st_block.dot(&multipole).eval();

let compressed_check_potential = c_sub.dot(&compressed_multipole);

// Post process to find check potential
let check_potential = svd.operator_data.u.dot(&compressed_check_potential).eval();

let sources = transfer_vector
.source
.compute_surface(&domain, order, alpha);
let targets = transfer_vector
.target
.compute_surface(&domain, order, alpha);
let mut direct = vec![0f64; ncoeffs];
svd.kernel.evaluate_st(
EvalType::Value,
&sources[..],
&targets[..],
multipole.data(),
&mut direct[..],
);

let abs_error: f64 = check_potential
.data()
.iter()
.zip(direct.iter())
.map(|(a, b)| (a - b).abs())
.sum();
let rel_error: f64 = abs_error / (direct.iter().sum::<f64>());

assert!(rel_error < 1e-14);
}

#[test]
fn test_fft_field_translation() {
let kernel = Laplace3dKernel::new();
let order: usize = 3;
let order: usize = 5;

let domain = Domain {
origin: [0., 0., 0.],
diameter: [1., 1., 1.],
};
let alpha = 1.05;

// Some random expansion data
// Some expansion data
let ncoeffs = 6 * (order - 1).pow(2) + 2;
let mut multipole = rlst_mat![f64, (ncoeffs, 1)];

Expand All @@ -643,23 +717,20 @@ mod test {
let m2l = fft.compute_m2l_operators(order, domain);

// Pick a random source/target pair
// let idx = 29;
let idx = 153;
let (all_transfer_vectors, _) = compute_transfer_vectors();

let transfer_vector = &all_transfer_vectors[idx];
let unique_transfer_vector = fft.transfer_vector_map.get(&transfer_vector.hash).unwrap();

// Place charges on the convolution grid
let surface_map = fft
let permutation_matrix = fft
.operator_data
.permutation_matrices
.get(&transfer_vector.hash)
.unwrap();

let r_multipole = surface_map.dot(&multipole).eval();

// println!("HERE {:?} {:?}", multipole.data(), r_multipole.data());
let r_multipole = permutation_matrix.dot(&multipole).eval();

// Compute FFT of the representative signal
let r_signal = fft.compute_signal(order, r_multipole.data());
Expand Down Expand Up @@ -698,17 +769,17 @@ mod test {
);

// Unpermute the coefficients
let surface_multi_index_axial_diag = fft
let permuted_multi_indices = fft
.operator_data
.permuted_multi_indices
.get(&transfer_vector.hash)
.unwrap();

let mut tmp = Vec::new();
let ntargets = surface_multi_index_axial_diag.len() / 3;
let xs = &surface_multi_index_axial_diag[0..ntargets];
let ys = &surface_multi_index_axial_diag[ntargets..2 * ntargets];
let zs = &surface_multi_index_axial_diag[2 * ntargets..];
let ntargets = permuted_multi_indices.len() / 3;
let xs = &permuted_multi_indices[0..ntargets];
let ys = &permuted_multi_indices[ntargets..2 * ntargets];
let zs = &permuted_multi_indices[2 * ntargets..];

for i in 0..ntargets {
let val = r_potentials.get(zs[i], ys[i], xs[i]).unwrap();
Expand Down

0 comments on commit 114ba7a

Please sign in to comment.