Skip to content

Commit

Permalink
more wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
mscroggs committed Sep 17, 2024
1 parent 519bebb commit fa4bf48
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 2 deletions.
39 changes: 39 additions & 0 deletions python/bempp/function_space.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Function space."""

import typing
import numpy as np
from bempp._bempprs import lib as _lib, ffi as _ffi
from ndelement._ndelementrs import ffi as _elementffi
from ndgrid._ndgridrs import ffi as _gridffi
from ndgrid.grid import Grid
from ndgrid.ownership import Owned, Ghost
from ndelement.ciarlet import ElementFamily, CiarletElement
from ndelement.reference_cell import ReferenceCellType

Expand Down Expand Up @@ -38,6 +40,43 @@ def element(self, entity: ReferenceCellType) -> CiarletElement:
owned=False,
)

# TODO: test
def get_local_dof_numbers(self, entity_dim: int, entity_index: int) -> typing.List[int]:
"""Get the local DOF numbers associated with an entity."""
dofs = np.empty(
_lib.space_get_local_dof_numbers_size(self._rs_space, entity_dim, entity_index),
dtype=np.uintp,
)
_lib.space_get_local_dof_numbers(
self._rs_space, entity_dim, entity_index, _ffi.cast("uintptr_t*", dofs.ctypes.data)
)
return [int(i) for i in dofs]

# TODO: test
def cell_dofs(self, cell: int) -> typing.Optional[typing.List[int]]:
"""Get the local DOF numbers associated with a cell."""
if not _lib.space_has_cell_dofs(self._rs_space, cell):
return None
dofs = np.empty(_lib.space_cell_dofs_size(self._rs_space, cell), dtype=np.uintp)
_lib.space_cell_dofs(self._rs_space, cell, _ffi.cast("uintptr_t*", dofs.ctypes.data))
return [int(i) for i in dofs]

# TODO: test
def global_dof_index(self, local_dof_index: int) -> typing.Optional[typing.List[int]]:
"""Get the global DOF number for a local DOF."""
return _lib.space_global_dof_index(self._rs_space, local_dof_index)

# TODO: test
def ownership(self, local_dof_index) -> typing.Union[Owned, Ghost]:
"""The ownership of a local DOF."""
if _lib.space_is_owned(self._rs_space, local_dof_index):
return Owned()
else:
return Ghost(
_lib.space_ownership_process(self._rs_space, local_dof_index),
_lib.space_ownership_index(self._rs_space, local_dof_index),
)

@property
def dtype(self):
"""Data type."""
Expand Down
255 changes: 254 additions & 1 deletion src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mod function {
bindings as ndelement_b, ciarlet, ciarlet::CiarletElement, traits::ElementFamily,
types::ReferenceCellType,
};
use ndgrid::{bindings as ndgrid_b, traits::Grid, SingleElementGrid};
use ndgrid::{bindings as ndgrid_b, traits::Grid, types::Ownership, SingleElementGrid};
use rlst::{c32, c64, MatrixInverse, RlstScalar};
use std::ffi::c_void;

Expand Down Expand Up @@ -300,6 +300,259 @@ mod function {
}
}

#[no_mangle]
pub unsafe extern "C" fn space_get_local_dof_numbers_size(
space: *mut FunctionSpaceWrapper,
entity_dim: usize,
entity_number: usize,
) -> usize {
match (*space).stype {
SpaceType::SerialFunctionSpace => match (*space).gtype {
GridType::SerialSingleElementGrid => match (*space).dtype {
DType::F32 => (*extract_space::<
SerialFunctionSpace<f32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.get_local_dof_numbers(entity_dim, entity_number),
DType::F64 => (*extract_space::<
SerialFunctionSpace<f64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.get_local_dof_numbers(entity_dim, entity_number),
DType::C32 => (*extract_space::<
SerialFunctionSpace<c32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.get_local_dof_numbers(entity_dim, entity_number),
DType::C64 => (*extract_space::<
SerialFunctionSpace<c64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.get_local_dof_numbers(entity_dim, entity_number),
},
},
}
.len()
}

#[no_mangle]
pub unsafe extern "C" fn space_get_local_dof_numbers(
space: *mut FunctionSpaceWrapper,
entity_dim: usize,
entity_number: usize,
dofs: *mut usize,
) {
for (i, dof) in match (*space).stype {
SpaceType::SerialFunctionSpace => match (*space).gtype {
GridType::SerialSingleElementGrid => match (*space).dtype {
DType::F32 => (*extract_space::<
SerialFunctionSpace<f32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.get_local_dof_numbers(entity_dim, entity_number),
DType::F64 => (*extract_space::<
SerialFunctionSpace<f64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.get_local_dof_numbers(entity_dim, entity_number),
DType::C32 => (*extract_space::<
SerialFunctionSpace<c32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.get_local_dof_numbers(entity_dim, entity_number),
DType::C64 => (*extract_space::<
SerialFunctionSpace<c64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.get_local_dof_numbers(entity_dim, entity_number),
},
},
}
.iter()
.enumerate()
{
*dofs.add(i) = *dof;
}
}

#[no_mangle]
pub unsafe extern "C" fn space_has_cell_dofs(
space: *mut FunctionSpaceWrapper,
cell: usize,
) -> bool {
match (*space).stype {
SpaceType::SerialFunctionSpace => match (*space).gtype {
GridType::SerialSingleElementGrid => match (*space).dtype {
DType::F32 => (*extract_space::<
SerialFunctionSpace<f32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.cell_dofs(cell),
DType::F64 => (*extract_space::<
SerialFunctionSpace<f64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.cell_dofs(cell),
DType::C32 => (*extract_space::<
SerialFunctionSpace<c32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.cell_dofs(cell),
DType::C64 => (*extract_space::<
SerialFunctionSpace<c64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.cell_dofs(cell),
},
},
}
.is_some()
}

#[no_mangle]
pub unsafe extern "C" fn space_cell_dofs_size(
space: *mut FunctionSpaceWrapper,
cell: usize,
) -> usize {
match (*space).stype {
SpaceType::SerialFunctionSpace => match (*space).gtype {
GridType::SerialSingleElementGrid => match (*space).dtype {
DType::F32 => (*extract_space::<
SerialFunctionSpace<f32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.cell_dofs(cell),
DType::F64 => (*extract_space::<
SerialFunctionSpace<f64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.cell_dofs(cell),
DType::C32 => (*extract_space::<
SerialFunctionSpace<c32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.cell_dofs(cell),
DType::C64 => (*extract_space::<
SerialFunctionSpace<c64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.cell_dofs(cell),
},
},
}
.unwrap()
.len()
}

#[no_mangle]
pub unsafe extern "C" fn space_cell_dofs(
space: *mut FunctionSpaceWrapper,
cell: usize,
dofs: *mut usize,
) {
for (i, dof) in match (*space).stype {
SpaceType::SerialFunctionSpace => match (*space).gtype {
GridType::SerialSingleElementGrid => match (*space).dtype {
DType::F32 => (*extract_space::<
SerialFunctionSpace<f32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.cell_dofs(cell),
DType::F64 => (*extract_space::<
SerialFunctionSpace<f64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.cell_dofs(cell),
DType::C32 => (*extract_space::<
SerialFunctionSpace<c32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.cell_dofs(cell),
DType::C64 => (*extract_space::<
SerialFunctionSpace<c64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.cell_dofs(cell),
},
},
}
.unwrap()
.iter()
.enumerate()
{
*dofs.add(i) = *dof;
}
}

#[no_mangle]
pub unsafe extern "C" fn space_global_dof_index(
space: *mut FunctionSpaceWrapper,
local_dof_index: usize,
) -> usize {
match (*space).stype {
SpaceType::SerialFunctionSpace => match (*space).gtype {
GridType::SerialSingleElementGrid => match (*space).dtype {
DType::F32 => (*extract_space::<
SerialFunctionSpace<f32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.global_dof_index(local_dof_index),
DType::F64 => (*extract_space::<
SerialFunctionSpace<f64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.global_dof_index(local_dof_index),
DType::C32 => (*extract_space::<
SerialFunctionSpace<c32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.global_dof_index(local_dof_index),
DType::C64 => (*extract_space::<
SerialFunctionSpace<c64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.global_dof_index(local_dof_index),
},
},
}
}

unsafe fn space_ownership(
space: *mut FunctionSpaceWrapper,
local_dof_index: usize,
) -> Ownership {
match (*space).stype {
SpaceType::SerialFunctionSpace => match (*space).gtype {
GridType::SerialSingleElementGrid => match (*space).dtype {
DType::F32 => (*extract_space::<
SerialFunctionSpace<f32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.ownership(local_dof_index),
DType::F64 => (*extract_space::<
SerialFunctionSpace<f64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.ownership(local_dof_index),
DType::C32 => (*extract_space::<
SerialFunctionSpace<c32, SingleElementGrid<f32, CiarletElement<f32>>>,
>(space))
.ownership(local_dof_index),
DType::C64 => (*extract_space::<
SerialFunctionSpace<c64, SingleElementGrid<f64, CiarletElement<f64>>>,
>(space))
.ownership(local_dof_index),
},
},
}
}

#[no_mangle]
pub unsafe extern "C" fn space_is_owned(
space: *mut FunctionSpaceWrapper,
local_dof_index: usize,
) -> bool {
space_ownership(space, local_dof_index) == Ownership::Owned
}

#[no_mangle]
pub unsafe extern "C" fn space_ownership_process(
space: *mut FunctionSpaceWrapper,
local_dof_index: usize,
) -> usize {
if let Ownership::Ghost(process, _index) = space_ownership(space, local_dof_index) {
process
} else {
panic!("Cannot get process of owned DOF");
}
}

#[no_mangle]
pub unsafe extern "C" fn space_ownership_index(
space: *mut FunctionSpaceWrapper,
local_dof_index: usize,
) -> usize {
if let Ownership::Ghost(_process, index) = space_ownership(space, local_dof_index) {
index
} else {
panic!("Cannot get process of owned DOF");
}
}

#[no_mangle]
pub unsafe extern "C" fn space_dtype(space: *const FunctionSpaceWrapper) -> u8 {
(*space).dtype as u8
Expand Down
2 changes: 1 addition & 1 deletion src/traits/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub trait FunctionSpace {
/// Compute a colouring of the cells so that no two cells that share an entity with DOFs associated with it are assigned the same colour
fn cell_colouring(&self) -> HashMap<ReferenceCellType, Vec<Vec<usize>>>;

/// Get the global DOF indes associated with a local DOF indec
/// Get the global DOF index associated with a local DOF index
fn global_dof_index(&self, local_dof_index: usize) -> usize;

/// Get ownership of a local DOF
Expand Down

0 comments on commit fa4bf48

Please sign in to comment.