Skip to content

Commit

Permalink
Updated for RlstScalar (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke authored Mar 9, 2024
1 parent 7aa0b65 commit 0cda3d5
Show file tree
Hide file tree
Showing 47 changed files with 289 additions and 345 deletions.
1 change: 0 additions & 1 deletion bem/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ rlst = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-blis-src = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-dense = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-sparse = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-common = { git = "https://github.com/linalg-rs/rlst.git" }
rand = "0.8.5"

[dev-dependencies]
Expand Down
12 changes: 6 additions & 6 deletions bem/src/assembly/common.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use rlst_common::types::Scalar;
use rlst_dense::types::RlstScalar;

pub struct RawData2D<T: Scalar> {
pub struct RawData2D<T: RlstScalar> {
pub data: *mut T,
pub shape: [usize; 2],
}

unsafe impl<T: Scalar> Sync for RawData2D<T> {}
unsafe impl<T: RlstScalar> Sync for RawData2D<T> {}

pub struct SparseMatrixData<T: Scalar> {
pub struct SparseMatrixData<T: RlstScalar> {
pub data: Vec<T>,
pub rows: Vec<usize>,
pub cols: Vec<usize>,
pub shape: [usize; 2],
}

impl<T: Scalar> SparseMatrixData<T> {
impl<T: RlstScalar> SparseMatrixData<T> {
pub fn new(shape: [usize; 2]) -> Self {
Self {
data: vec![],
Expand Down Expand Up @@ -52,4 +52,4 @@ impl<T: Scalar> SparseMatrixData<T> {
}
}

unsafe impl<T: Scalar> Sync for SparseMatrixData<T> {}
unsafe impl<T: RlstScalar> Sync for SparseMatrixData<T> {}
1 change: 0 additions & 1 deletion element/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,3 @@ approx = "0.5"
rlst = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-blis-src = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-dense = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-common = { git = "https://github.com/linalg-rs/rlst.git" }
10 changes: 5 additions & 5 deletions element/src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use bempp_tools::arrays::AdjacencyList;
use bempp_traits::arrays::AdjacencyListAccess;
use bempp_traits::cell::ReferenceCellType;
use bempp_traits::element::{Continuity, FiniteElement, MapType};
use rlst_common::types::Scalar;
use rlst_dense::linalg::inverse::MatrixInverse;
use rlst_dense::types::RlstScalar;
use rlst_dense::{
array::views::ArrayViewMut,
array::Array,
Expand All @@ -31,7 +31,7 @@ pub enum ElementFamily {
RaviartThomas = 1,
}

pub struct CiarletElement<T: Scalar> {
pub struct CiarletElement<T: RlstScalar> {
cell_type: ReferenceCellType,
family: ElementFamily,
degree: usize,
Expand All @@ -47,7 +47,7 @@ pub struct CiarletElement<T: Scalar> {
// interpolation_weights: EntityWeights,
}

impl<T: Scalar> CiarletElement<T>
impl<T: RlstScalar> CiarletElement<T>
where
for<'a> Array<T, ArrayViewMut<'a, T, BaseArray<T, VectorContainer<T>, 2>, 2>, 2>: MatrixInverse,
{
Expand Down Expand Up @@ -255,7 +255,7 @@ where
}
}

impl<T: Scalar> FiniteElement for CiarletElement<T> {
impl<T: RlstScalar> FiniteElement for CiarletElement<T> {
type T = T;
fn value_shape(&self) -> &[usize] {
&self.value_shape
Expand Down Expand Up @@ -323,7 +323,7 @@ impl<T: Scalar> FiniteElement for CiarletElement<T> {
}
}

pub fn create_element<T: Scalar>(
pub fn create_element<T: RlstScalar>(
family: ElementFamily,
cell_type: ReferenceCellType,
degree: usize,
Expand Down
4 changes: 2 additions & 2 deletions element/src/element/lagrange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ use crate::element::{create_cell, CiarletElement, ElementFamily};
use crate::polynomials::polynomial_count;
use bempp_traits::cell::ReferenceCellType;
use bempp_traits::element::{Continuity, MapType};
use rlst_common::types::Scalar;
use rlst_dense::linalg::inverse::MatrixInverse;
use rlst_dense::types::RlstScalar;
use rlst_dense::{
array::views::ArrayViewMut, array::Array, base_array::BaseArray,
data_container::VectorContainer, rlst_dynamic_array2, rlst_dynamic_array3,
traits::RandomAccessMut,
};

/// Create a Lagrange element
pub fn create<T: Scalar>(
pub fn create<T: RlstScalar>(
cell_type: ReferenceCellType,
degree: usize,
continuity: Continuity,
Expand Down
4 changes: 2 additions & 2 deletions element/src/element/raviart_thomas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ use crate::element::{create_cell, CiarletElement, ElementFamily};
use crate::polynomials::polynomial_count;
use bempp_traits::cell::ReferenceCellType;
use bempp_traits::element::{Continuity, MapType};
use rlst_common::types::Scalar;
use rlst_dense::linalg::inverse::MatrixInverse;
use rlst_dense::types::RlstScalar;
use rlst_dense::{
array::views::ArrayViewMut, array::Array, base_array::BaseArray,
data_container::VectorContainer, rlst_dynamic_array2, rlst_dynamic_array3,
traits::RandomAccessMut,
};

/// Create a Raviart-Thomas element
pub fn create<T: Scalar>(
pub fn create<T: RlstScalar>(
cell_type: ReferenceCellType,
degree: usize,
continuity: Continuity,
Expand Down
12 changes: 6 additions & 6 deletions element/src/polynomials.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//! Orthonormal polynomials

use bempp_traits::cell::ReferenceCellType;
use rlst_common::types::Scalar;
use rlst_dense::traits::{RandomAccessByRef, RandomAccessMut, Shape};
use rlst_dense::types::RlstScalar;

/// Tabulate orthonormal polynomials on a interval
fn tabulate_legendre_polynomials_interval<
T: Scalar,
T: RlstScalar,
Array2: RandomAccessByRef<2, Item = T> + Shape<2>,
Array3Mut: RandomAccessMut<3, Item = T> + RandomAccessByRef<3, Item = T> + Shape<3>,
>(
Expand Down Expand Up @@ -74,7 +74,7 @@ fn quad_index(i: usize, j: usize, n: usize) -> usize {

/// Tabulate orthonormal polynomials on a quadrilateral
fn tabulate_legendre_polynomials_quadrilateral<
T: Scalar,
T: RlstScalar,
Array2: RandomAccessByRef<2, Item = T> + Shape<2>,
Array3Mut: RandomAccessMut<3, Item = T> + RandomAccessByRef<3, Item = T> + Shape<3>,
>(
Expand Down Expand Up @@ -225,7 +225,7 @@ fn tabulate_legendre_polynomials_quadrilateral<
}
/// Tabulate orthonormal polynomials on a triangle
fn tabulate_legendre_polynomials_triangle<
T: Scalar,
T: RlstScalar,
Array2: RandomAccessByRef<2, Item = T> + Shape<2>,
Array3Mut: RandomAccessMut<3, Item = T> + RandomAccessByRef<3, Item = T> + Shape<3>,
>(
Expand Down Expand Up @@ -447,7 +447,7 @@ pub fn derivative_count(cell_type: ReferenceCellType, derivatives: usize) -> usi
}
}

pub fn legendre_shape<T: Scalar, Array2: RandomAccessByRef<2, Item = T> + Shape<2>>(
pub fn legendre_shape<T: RlstScalar, Array2: RandomAccessByRef<2, Item = T> + Shape<2>>(
cell_type: ReferenceCellType,
points: &Array2,
degree: usize,
Expand All @@ -462,7 +462,7 @@ pub fn legendre_shape<T: Scalar, Array2: RandomAccessByRef<2, Item = T> + Shape<

/// Tabulate orthonormal polynomials
pub fn tabulate_legendre_polynomials<
T: Scalar,
T: RlstScalar,
Array2: RandomAccessByRef<2, Item = T> + Shape<2>,
Array3Mut: RandomAccessMut<3, Item = T> + RandomAccessByRef<3, Item = T> + Shape<3>,
>(
Expand Down
1 change: 0 additions & 1 deletion field/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ num = "0.4"
rlst = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-dense = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-blis = { git = "https://github.com/linalg-rs/rlst.git" }
rlst-common = { git = "https://github.com/linalg-rs/rlst.git" }
fftw = {git = "https://github.com/skailasa/fftw.git" }
cauchy = "0.4.*"
approx = "0.5"
Expand Down
6 changes: 3 additions & 3 deletions field/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use itertools::Itertools;
use num::traits::Num;

use rlst_common::types::Scalar;
use rlst_dense::types::RlstScalar;
use rlst_dense::{
array::Array,
base_array::BaseArray,
Expand All @@ -31,7 +31,7 @@ pub fn argsort<T: Ord>(arr: &[T]) -> Vec<usize> {
/// * `arr` - An array to be padded.
/// * `pad_size` - The amount of padding to be added along each axis.
/// * `pad_index` - The position in the array to start the padding from.
pub fn pad3<T: Scalar>(
pub fn pad3<T: RlstScalar>(
arr: &Array<T, BaseArray<T, VectorContainer<T>, 3>, 3>,
pad_size: (usize, usize, usize),
pad_index: (usize, usize, usize),
Expand Down Expand Up @@ -64,7 +64,7 @@ where
///
/// # Arguments
/// * `arr` - An array to be flipped.
pub fn flip3<T: Scalar>(
pub fn flip3<T: RlstScalar>(
arr: &Array<T, BaseArray<T, VectorContainer<T>, 3>, 3>,
) -> Array<T, BaseArray<T, VectorContainer<T>, 3>, 3>
where
Expand Down
29 changes: 14 additions & 15 deletions field/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use bempp_traits::kernel::ScaleInvariantKernel;
use itertools::Itertools;
use num::Zero;
use num::{Complex, Float};
use rlst_blis::interface::gemm::Gemm;
use rlst_common::types::Scalar;
use rlst_dense::rlst_array_from_slice2;
use rlst_dense::types::RlstScalar;
use rlst_dense::{
array::{empty_array, Array},
base_array::BaseArray,
Expand Down Expand Up @@ -41,7 +40,7 @@ use crate::{
impl<T, U> FieldTranslationData<U> for SvdFieldTranslationKiFmm<T, U>
where
T: Float + Default,
T: Scalar<Real = T> + Gemm,
T: RlstScalar<Real = T>,
U: Kernel<T = T> + Default,
Array<T, BaseArray<T, VectorContainer<T>, 2>, 2>: MatrixSvd<Item = T>,
{
Expand Down Expand Up @@ -178,7 +177,7 @@ where
impl<T, U> FieldTranslationData<U> for SvdFieldTranslationKiFmmRcmp<T, U>
where
T: Float + Default,
T: Scalar<Real = T> + Gemm,
T: RlstScalar<Real = T>,
U: Kernel<T = T> + Default,
Array<T, BaseArray<T, VectorContainer<T>, 2>, 2>: MatrixSvd<Item = T>,
{
Expand Down Expand Up @@ -346,7 +345,7 @@ where
impl<T, U> FieldTranslationData<U> for SvdFieldTranslationKiFmmIA<T, U>
where
T: Float + Default,
T: Scalar<Real = T> + Gemm,
T: RlstScalar<Real = T>,
U: Kernel<T = T> + ScaleInvariantKernel<T = T> + Default,
Array<T, BaseArray<T, VectorContainer<T>, 2>, 2>: MatrixSvd<Item = T>,
{
Expand Down Expand Up @@ -452,7 +451,7 @@ where
fn m2l_scale<T>(level: u64) -> T
where
T: Float + Default,
T: Scalar<Real = T> + Gemm,
T: RlstScalar<Real = T>,
{
if level < 2 {
panic!("M2L only perfomed on level 2 and below")
Expand All @@ -462,11 +461,11 @@ where
T::from(1. / 2.).unwrap()
} else {
let two = T::from(2.0).unwrap();
Scalar::powf(two, T::from(level - 3).unwrap())
RlstScalar::powf(two, T::from(level - 3).unwrap())
}
}

fn retain_energy<T: Float + Default + Scalar<Real = T> + Gemm>(
fn retain_energy<T: Float + Default + RlstScalar<Real = T>>(
singular_values: &[T],
percentage: T,
) -> usize {
Expand Down Expand Up @@ -494,7 +493,7 @@ fn retain_energy<T: Float + Default + Scalar<Real = T> + Gemm>(
impl<T, U> SvdFieldTranslationKiFmm<T, U>
where
T: Float + Default,
T: Scalar<Real = T> + rlst_blis::interface::gemm::Gemm,
T: RlstScalar<Real = T>,
U: Kernel<T = T> + Default,
Array<T, BaseArray<T, VectorContainer<T>, 2>, 2>: MatrixSvd<Item = T>,
{
Expand Down Expand Up @@ -535,7 +534,7 @@ where
impl<T, U> SvdFieldTranslationKiFmmRcmp<T, U>
where
T: Float + Default,
T: Scalar<Real = T> + rlst_blis::interface::gemm::Gemm,
T: RlstScalar<Real = T>,
U: Kernel<T = T> + Default,
Array<T, BaseArray<T, VectorContainer<T>, 2>, 2>: MatrixSvd<Item = T>,
{
Expand Down Expand Up @@ -585,7 +584,7 @@ where
impl<T, U> SvdFieldTranslationKiFmmIA<T, U>
where
T: Float + Default,
T: Scalar<Real = T> + rlst_blis::interface::gemm::Gemm,
T: RlstScalar<Real = T>,
U: Kernel<T = T> + Default + ScaleInvariantKernel<T = T>,
Array<T, BaseArray<T, VectorContainer<T>, 2>, 2>: MatrixSvd<Item = T>,
{
Expand Down Expand Up @@ -625,8 +624,8 @@ where

impl<T, U> FieldTranslationData<U> for FftFieldTranslationKiFmm<T, U>
where
T: Scalar<Real = T> + Float + Default + Fft,
Complex<T>: Scalar,
T: RlstScalar<Real = T> + Float + Default + Fft,
Complex<T>: RlstScalar,
U: Kernel<T = T> + Default,
{
type Domain = Domain<T>;
Expand Down Expand Up @@ -835,8 +834,8 @@ where

impl<T, U> FftFieldTranslationKiFmm<T, U>
where
T: Float + Scalar<Real = T> + Default + Fft,
Complex<T>: Scalar,
T: Float + RlstScalar<Real = T> + Default + Fft,
Complex<T>: RlstScalar,
U: Kernel<T = T> + Default,
{
/// Constructor for FFT field translation struct for the kernel independent FMM (KiFMM).
Expand Down
Loading

0 comments on commit 0cda3d5

Please sign in to comment.