Skip to content

Commit

Permalink
ENH: Update FMM to use implicit padding in FFT based M2L (#142)
Browse files Browse the repository at this point in the history
* Add implicit padding

* Start working on kernel

* Add unchunked kernel

* Change way buffers are allocated

* Add a converging implementation without chunking

* Begin cleanup

* Add a chunked implementation

* Start experimenting with removal of if statement in kernel

* Working with no indirection in kernel

* Add chunking to pre processing

* Fix temperamental chunk size

* Respond to clippy, delete redundant code

* Cleanup

* Fix tests

* Remove sub dir for fmms
  • Loading branch information
skailasa authored Dec 15, 2023
1 parent faaa457 commit d53cb1b
Show file tree
Hide file tree
Showing 18 changed files with 830 additions and 2,789 deletions.
220 changes: 93 additions & 127 deletions field/src/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,55 @@ use fftw::{plan::*, types::*};
use num::Complex;
use rayon::prelude::*;

use crate::types::{FftMatrixc32, FftMatrixc64, FftMatrixf32, FftMatrixf64};

pub trait Fft<DtypeReal, DtypeCplx>
pub trait Fft
where
Self: Sized,
{
/// Compute a Real FFT over a rlst matrix which stores data corresponding to multiple 3 dimensional arrays of shape `shape`, stored in column major order.
/// Compute a parallel real to complex FFT over a slice which stores data corresponding to multiple 3 dimensional arrays of shape `shape`, stored in column major order.
/// This function is multithreaded, and uses the FFTW library.
///
/// # Arguments
/// * `input` - Input slice of real data, corresponding to a 3D array stored in column major order.
/// * `output` - Output slice.
/// * `shape` - Shape of input data.
fn rfft3_fftw_par_vec(input: &mut DtypeReal, output: &mut DtypeCplx, shape: &[usize]);
fn rfft3_fftw_par_slice(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]);

/// Compute an inverse Real FFT over a rlst matrix which stores data corresponding to multiple 3 dimensional arrays of shape `shape`, stored in column major order.
/// Compute a real to complex FFT over a slice which stores data corresponding to multiple 3 dimensional arrays of shape `shape`, stored in column major order.
/// This function is multithreaded, and uses the FFTW library.
///
/// # Arguments
/// * `input` - Input slice of real data, corresponding to a 3D array stored in column major order.
/// * `output` - Output slice.
/// * `shape` - Shape of input data.
fn rfft3_fftw_slice(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]);

/// Compute an parallel complex to real inverse FFT over a slice which stores data corresponding to multiple 3 dimensional arrays of shape `shape`, stored in column major order.
/// This function is multithreaded, and uses the FFTW library.
///
/// # Arguments
/// * `input` - Input slice of complex data, corresponding to an FFT of a 3D array stored in column major order.
/// * `output` - Output slice.
/// * `shape` - Shape of output data.
fn irfft_fftw_par_vec(input: &mut DtypeCplx, output: &mut DtypeReal, shape: &[usize]);
fn irfft3_fftw_par_slice(input: &mut [Complex<Self>], output: &mut [Self], shape: &[usize]);

/// Compute a Real FFT of an input slice corresponding to a 3D array stored in column major format, specified by `shape` using the FFTW library.
/// Compute an complex to real inverse FFT over a rlst matrix which stores data corresponding to multiple 3 dimensional arrays of shape `shape`, stored in column major order.
/// This function is multithreaded, and uses the FFTW library.
///
/// # Arguments
/// * `input` - Input slice of complex data, corresponding to an FFT of a 3D array stored in column major order.
/// * `output` - Output slice.
/// * `shape` - Shape of output data.
fn irfft3_fftw_slice(input: &mut [Complex<Self>], output: &mut [Self], shape: &[usize]);

/// Compute a real to complex FFT of an input slice corresponding to a 3D array stored in column major format, specified by `shape` using the FFTW library.
///
/// # Arguments
/// * `input` - Input slice of real data, corresponding to a 3D array stored in column major order.
/// * `output` - Output slice.
/// * `shape` - Shape of input data.
fn rfft3_fftw(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]);

/// Compute an inverse Real FFT of an input slice corresponding to the FFT of a 3D array stored in column major format, specified by `shape` using the FFTW library.
/// Compute an complex to real inverse FFT of an input slice corresponding to the FFT of a 3D array stored in column major format, specified by `shape` using the FFTW library.
/// This function normalises the output.
///
/// # Arguments
Expand All @@ -45,120 +61,8 @@ where
fn irfft3_fftw(input: &mut [Complex<Self>], output: &mut [Self], shape: &[usize]);
}

// impl Fft<FftMatrixf32, FftMatrixc32> for f32 {
// fn rfft3_fftw_par_vec(input: &mut FftMatrixf32, output: &mut FftMatrixc32, shape: &[usize]) {
// let size: usize = shape.iter().product();
// let size_d = shape.last().unwrap();
// let size_real = (size / size_d) * (size_d / 2 + 1);
// let plan: R2CPlan32 = R2CPlan::aligned(shape, Flag::MEASURE).unwrap();

// let it_inp = input.data_mut().par_chunks_exact_mut(size).into_par_iter();
// let it_out = output
// .data_mut()
// .par_chunks_exact_mut(size_real)
// .into_par_iter();

// it_inp.zip(it_out).for_each(|(inp, out)| {
// let _ = plan.r2c(inp, out);
// });
// }

// fn irfft_fftw_par_vec(input: &mut FftMatrixc32, output: &mut FftMatrixf32, shape: &[usize]) {
// let size: usize = shape.iter().product();
// let size_d = shape.last().unwrap();
// let size_real = (size / size_d) * (size_d / 2 + 1);
// let plan: C2RPlan32 = C2RPlan::aligned(shape, Flag::MEASURE).unwrap();

// let it_inp = input
// .data_mut()
// .par_chunks_exact_mut(size_real)
// .into_par_iter();
// let it_out = output.data_mut().par_chunks_exact_mut(size).into_par_iter();

// it_inp.zip(it_out).for_each(|(inp, out)| {
// let _ = plan.c2r(inp, out);
// // Normalise output
// out.iter_mut()
// .for_each(|value| *value *= 1.0 / (size as f32));
// })
// }

// fn rfft3_fftw(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]) {
// assert!(shape.len() == 3);
// let plan: R2CPlan32 = R2CPlan::aligned(shape, Flag::MEASURE).unwrap();
// let _ = plan.r2c(input, output);
// }

// fn irfft3_fftw(input: &mut [Complex<Self>], output: &mut [Self], shape: &[usize]) {
// assert!(shape.len() == 3);
// let size: usize = shape.iter().product();
// let plan: C2RPlan32 = C2RPlan::aligned(shape, Flag::MEASURE).unwrap();
// let _ = plan.c2r(input, output);
// // Normalise
// output
// .iter_mut()
// .for_each(|value| *value *= 1.0 / (size as f32));
// }
// }

// impl Fft<FftMatrixf64, FftMatrixc64> for f64 {
// fn rfft3_fftw_par_vec(input: &mut FftMatrixf64, output: &mut FftMatrixc64, shape: &[usize]) {
// let size: usize = shape.iter().product();
// let size_d = shape.last().unwrap();
// let size_real = (size / size_d) * (size_d / 2 + 1);
// let plan: R2CPlan64 = R2CPlan::aligned(shape, Flag::MEASURE).unwrap();

// let it_inp = input.data_mut().par_chunks_exact_mut(size).into_par_iter();
// let it_out = output
// .data_mut()
// .par_chunks_exact_mut(size_real)
// .into_par_iter();

// it_inp.zip(it_out).for_each(|(inp, out)| {
// let _ = plan.r2c(inp, out);
// });
// }

// fn irfft_fftw_par_vec(input: &mut FftMatrixc64, output: &mut FftMatrixf64, shape: &[usize]) {
// let size: usize = shape.iter().product();
// let size_d = shape.last().unwrap();
// let size_real = (size / size_d) * (size_d / 2 + 1);
// let plan: C2RPlan64 = C2RPlan::aligned(shape, Flag::MEASURE).unwrap();

// let it_inp = input
// .data_mut()
// .par_chunks_exact_mut(size_real)
// .into_par_iter();
// let it_out = output.data_mut().par_chunks_exact_mut(size).into_par_iter();

// it_inp.zip(it_out).for_each(|(inp, out)| {
// let _ = plan.c2r(inp, out);
// // Normalise output
// out.iter_mut()
// .for_each(|value| *value *= 1.0 / (size as f64));
// })
// }

// fn rfft3_fftw(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]) {
// assert!(shape.len() == 3);
// let plan: R2CPlan64 = R2CPlan::aligned(shape, Flag::MEASURE).unwrap();
// let _ = plan.r2c(input, output);
// }

// fn irfft3_fftw(input: &mut [Complex<Self>], output: &mut [Self], shape: &[usize]) {
// assert!(shape.len() == 3);
// let size: usize = shape.iter().product();
// let plan: C2RPlan64 = C2RPlan::aligned(shape, Flag::MEASURE).unwrap();
// let _ = plan.c2r(input, output);
// // Normalise
// output
// .iter_mut()
// .for_each(|value| *value *= 1.0 / (size as f64));
// }
// }

impl Fft<FftMatrixf32, FftMatrixc32> for f32 {
fn rfft3_fftw_par_vec(input: &mut FftMatrixf32, output: &mut FftMatrixc32, shape: &[usize]) {
impl Fft for f32 {
fn rfft3_fftw_par_slice(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]) {
let size: usize = shape.iter().product();
let size_d = shape.last().unwrap();
let size_real = (size / size_d) * (size_d / 2 + 1);
Expand All @@ -172,7 +76,7 @@ impl Fft<FftMatrixf32, FftMatrixc32> for f32 {
});
}

fn irfft_fftw_par_vec(input: &mut FftMatrixc32, output: &mut FftMatrixf32, shape: &[usize]) {
fn irfft3_fftw_par_slice(input: &mut [Complex<Self>], output: &mut [Self], shape: &[usize]) {
let size: usize = shape.iter().product();
let size_d = shape.last().unwrap();
let size_real = (size / size_d) * (size_d / 2 + 1);
Expand All @@ -189,6 +93,37 @@ impl Fft<FftMatrixf32, FftMatrixc32> for f32 {
})
}

fn rfft3_fftw_slice(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]) {
let size: usize = shape.iter().product();
let size_d = shape.last().unwrap();
let size_real = (size / size_d) * (size_d / 2 + 1);
let plan: R2CPlan32 = R2CPlan::aligned(shape, Flag::MEASURE).unwrap();

let it_inp = input.chunks_exact_mut(size);
let it_out = output.chunks_exact_mut(size_real);

it_inp.zip(it_out).for_each(|(inp, out)| {
let _ = plan.r2c(inp, out);
});
}

fn irfft3_fftw_slice(input: &mut [Complex<Self>], output: &mut [Self], shape: &[usize]) {
let size: usize = shape.iter().product();
let size_d = shape.last().unwrap();
let size_real = (size / size_d) * (size_d / 2 + 1);
let plan: C2RPlan32 = C2RPlan::aligned(shape, Flag::MEASURE).unwrap();

let it_inp = input.chunks_exact_mut(size_real);
let it_out = output.chunks_exact_mut(size);

it_inp.zip(it_out).for_each(|(inp, out)| {
let _ = plan.c2r(inp, out);
// Normalise output
out.iter_mut()
.for_each(|value| *value *= 1.0 / (size as f32));
})
}

fn rfft3_fftw(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]) {
assert!(shape.len() == 3);
let plan: R2CPlan32 = R2CPlan::aligned(shape, Flag::MEASURE).unwrap();
Expand All @@ -207,8 +142,8 @@ impl Fft<FftMatrixf32, FftMatrixc32> for f32 {
}
}

impl Fft<FftMatrixf64, FftMatrixc64> for f64 {
fn rfft3_fftw_par_vec(input: &mut FftMatrixf64, output: &mut FftMatrixc64, shape: &[usize]) {
impl Fft for f64 {
fn rfft3_fftw_par_slice(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]) {
let size: usize = shape.iter().product();
let size_d = shape.last().unwrap();
let size_real = (size / size_d) * (size_d / 2 + 1);
Expand All @@ -222,7 +157,7 @@ impl Fft<FftMatrixf64, FftMatrixc64> for f64 {
});
}

fn irfft_fftw_par_vec(input: &mut FftMatrixc64, output: &mut FftMatrixf64, shape: &[usize]) {
fn irfft3_fftw_par_slice(input: &mut [Complex<Self>], output: &mut [Self], shape: &[usize]) {
let size: usize = shape.iter().product();
let size_d = shape.last().unwrap();
let size_real = (size / size_d) * (size_d / 2 + 1);
Expand All @@ -239,6 +174,37 @@ impl Fft<FftMatrixf64, FftMatrixc64> for f64 {
})
}

fn rfft3_fftw_slice(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]) {
let size: usize = shape.iter().product();
let size_d = shape.last().unwrap();
let size_real = (size / size_d) * (size_d / 2 + 1);
let plan: R2CPlan64 = R2CPlan::aligned(shape, Flag::MEASURE).unwrap();

let it_inp = input.chunks_exact_mut(size);
let it_out = output.chunks_exact_mut(size_real);

it_inp.zip(it_out).for_each(|(inp, out)| {
let _ = plan.r2c(inp, out);
});
}

fn irfft3_fftw_slice(input: &mut [Complex<Self>], output: &mut [Self], shape: &[usize]) {
let size: usize = shape.iter().product();
let size_d = shape.last().unwrap();
let size_real = (size / size_d) * (size_d / 2 + 1);
let plan: C2RPlan64 = C2RPlan::aligned(shape, Flag::MEASURE).unwrap();

let it_inp = input.chunks_exact_mut(size_real);
let it_out = output.chunks_exact_mut(size);

it_inp.zip(it_out).for_each(|(inp, out)| {
let _ = plan.c2r(inp, out);
// Normalise output
out.iter_mut()
.for_each(|value| *value *= 1.0 / (size as f64));
})
}

fn rfft3_fftw(input: &mut [Self], output: &mut [Complex<Self>], shape: &[usize]) {
assert!(shape.len() == 3);
let plan: R2CPlan64 = R2CPlan::aligned(shape, Flag::MEASURE).unwrap();
Expand Down
Loading

0 comments on commit d53cb1b

Please sign in to comment.