From 413f162b3fd4423df9f23d4cbee12f6d502fc75a Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 17 Sep 2024 08:24:09 +0100 Subject: [PATCH] Add python wrapper (#299) --- .github/workflows/run-tests.yml | 27 ++ .github/workflows/run-weekly-tests.yml | 2 +- .gitignore | 4 +- Cargo.toml | 5 +- build.rs | 31 ++ cbindgen.toml | 7 + find_examples.py | 7 +- pyproject.toml | 50 ++++ python/bempp/__init__.py | 3 + python/bempp/function_space.py | 33 +++ python/test/test_function_space.py | 40 +++ src/bindings.rs | 274 ++++++++++++++++++ src/lib.rs | 1 + .../generate_rust_simplex_rules.py | 22 +- 14 files changed, 482 insertions(+), 24 deletions(-) create mode 100644 build.rs create mode 100644 cbindgen.toml create mode 100644 pyproject.toml create mode 100644 python/bempp/__init__.py create mode 100644 python/bempp/function_space.py create mode 100644 python/test/test_function_space.py create mode 100644 src/bindings.rs diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index bccf35a6..13769c6e 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -49,6 +49,33 @@ jobs: chmod +x examples.sh ./examples.sh + run-tests-python: + name: Run Python tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + steps: + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - uses: actions/checkout@v3 + - name: Install uv + run: pip install uv "maturin>=1.7" + - name: Make virtual environment + run: | + uv venv .venv + uv pip install pip pytest + - name: Install python package + run: | + source .venv/bin/activate + maturin develop + - name: Run Python tests + run: | + source .venv/bin/activate + python -m pytest python/test + check-dependencies: name: Check dependencies runs-on: ubuntu-latest diff --git a/.github/workflows/run-weekly-tests.yml b/.github/workflows/run-weekly-tests.yml index a2bc2288..f9a26445 100644 --- a/.github/workflows/run-weekly-tests.yml +++ b/.github/workflows/run-weekly-tests.yml @@ -1,4 +1,4 @@ -name: ๐Ÿงช +name: ๐Ÿงช๐Ÿ“… on: schedule: diff --git a/.gitignore b/.gitignore index 5de27325..73550fc1 100644 --- a/.gitignore +++ b/.gitignore @@ -9,10 +9,8 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk -# nano swp files -.*.swp - *.pyc +python/bempp/_bempprs examples.sh diff --git a/Cargo.toml b/Cargo.toml index efb22ce7..b2ab6188 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ lazy_static = "1.4" ndelement = { git = "https://github.com/bempp/ndelement.git" } ndgrid = { git = "https://github.com/bempp/ndgrid.git" } rayon = "1.9" -rlst = "0.2" +rlst = { version = "0.2.0", default-features = false } green-kernels = "0.2.0" [dev-dependencies] @@ -37,6 +37,9 @@ cauchy = "0.4.*" criterion = { version = "0.5.*", features = ["html_reports"]} # kifmm = { version = "1.0" } +[build-dependencies] +cbindgen = "0.27.0" + [[bench]] name = "assembly_benchmark" harness = false diff --git a/build.rs b/build.rs new file mode 100644 index 00000000..e9ab3129 --- /dev/null +++ b/build.rs @@ -0,0 +1,31 @@ +use std::env; +use std::fs; +use std::path::{Path, PathBuf}; + +fn main() { + let crate_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + + // Determine the target directory within the workspace root + let target_dir = env::var("CARGO_TARGET_DIR") + .map(PathBuf::from) + .unwrap_or_else(|_| crate_dir.join("target")); + + // Ensure the target directory exists + fs::create_dir_all(&target_dir).expect("Unable to create target directory"); + + // Create the header file path + let header_path = Path::new(&target_dir).join("include").join("_bempprs.h"); + + let config_path = Path::new(&crate_dir).join("cbindgen.toml"); + let config = cbindgen::Config::from_file(config_path).expect("Unable to load cbindgen config"); + + // Generate the bindings + let bindings = cbindgen::Builder::new() + .with_crate(crate_dir) + .with_config(config) + .generate() + .expect("Unable to generate bindings"); + + // Write the bindings to the header file + bindings.write_to_file(header_path); +} diff --git a/cbindgen.toml b/cbindgen.toml new file mode 100644 index 00000000..90ba9295 --- /dev/null +++ b/cbindgen.toml @@ -0,0 +1,7 @@ +language = "C" + +[export] +exclude = [] + +[enum] +prefix_with_name = true diff --git a/find_examples.py b/find_examples.py index 23926432..54952440 100644 --- a/find_examples.py +++ b/find_examples.py @@ -15,9 +15,8 @@ import os import argparse -parser = argparse.ArgumentParser(description='Parse inputs.') -parser.add_argument('--features', default=None, - help='feature flags to pass to the examples') +parser = argparse.ArgumentParser(description="Parse inputs.") +parser.add_argument("--features", default=None, help="feature flags to pass to the examples") raw_features = parser.parse_args().features @@ -61,7 +60,7 @@ if options is None: options = "" if "--features" in options: - a, b = options.split("--features \"") + a, b = options.split('--features "') options = f"{a}--features \"{','.join(features)},{b}" else: options += f" --features \"{','.join(features)}\"" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..1498aeeb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +requires = ["maturin>=1,<2"] +build-backend = "maturin" + +[project] +name = "bempp-rs" +version = "0.1.0-dev" +description = "Boundary element method library" +readme = "README.md" +requires-python = ">=3.8" +license = { file = "LICENSE" } +authors = [ + {name = "Timo Betcke", email = "timo.betcke@gmail.com"}, + {name = "Srinath Kailasa", email = "srinathkailasa@gmail.com"}, + {name = "Matthew Scroggs", email = "rust@mscroggs.co.uk"} +] +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", +] +dependencies = [ + "maturin>=1.7", + "numpy", + "cffi", + 'patchelf; platform_system == "Linux"', + "ndelement", + "ndgrid" +] +packages = ["bempp"] + +[project.urls] +homepage = "https://github.com/bempp/bempp-rs" +repository = "https://github.com/bempp/bempp-rs" + +[tool.maturin] +python-source = "python" +module-name = "bempp._bempprs" + +[tool.ruff] +line-length = 100 +indent-width = 4 + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.mypy] +ignore_missing_imports = true diff --git a/python/bempp/__init__.py b/python/bempp/__init__.py new file mode 100644 index 00000000..1cffbb27 --- /dev/null +++ b/python/bempp/__init__.py @@ -0,0 +1,3 @@ +"""Bempp.""" + +from bempp import function_space diff --git a/python/bempp/function_space.py b/python/bempp/function_space.py new file mode 100644 index 00000000..ada31e22 --- /dev/null +++ b/python/bempp/function_space.py @@ -0,0 +1,33 @@ +"""Function space.""" + +from bempp._bempprs import lib as _lib, ffi as _ffi +from ndgrid.grid import Grid +from ndelement.ciarlet import ElementFamily + + +class FunctionSpace(object): + """Function space.""" + + def __init__(self, rs_space): + """Initialise.""" + self._rs_space = rs_space + + def __del__(self): + """Delete.""" + _lib.free_space(self._rs_space) + + @property + def local_size(self) -> int: + """Number of DOFs on current process.""" + return _lib.space_local_size(self._rs_space) + + @property + def global_size(self) -> int: + """Number of DOFs on all processes.""" + return _lib.space_global_size(self._rs_space) + + +def function_space(grid: Grid, family: ElementFamily) -> FunctionSpace: + return FunctionSpace( + _lib.space_new(_ffi.cast("void*", grid._rs_grid), _ffi.cast("void*", family._rs_family)) + ) diff --git a/python/test/test_function_space.py b/python/test/test_function_space.py new file mode 100644 index 00000000..45c0e029 --- /dev/null +++ b/python/test/test_function_space.py @@ -0,0 +1,40 @@ +import pytest +from bempp.function_space import function_space +from ndgrid.shapes import regular_sphere +from ndelement.ciarlet import create_family, Family, Continuity +from ndelement.reference_cell import ReferenceCellType + + +@pytest.mark.parametrize("level", range(4)) +def test_create_space_dp0(level): + grid = regular_sphere(level) + element = create_family(Family.Lagrange, 0, Continuity.Discontinuous) + + space = function_space(grid, element) + + assert space.local_size == grid.entity_count(ReferenceCellType.Triangle) + assert space.local_size == space.global_size + + +@pytest.mark.parametrize("level", range(4)) +def test_create_space_p1(level): + grid = regular_sphere(level) + element = create_family(Family.Lagrange, 1) + + space = function_space(grid, element) + + assert space.local_size == grid.entity_count(ReferenceCellType.Point) + assert space.local_size == space.global_size + + +@pytest.mark.parametrize("level", range(4)) +def test_create_space_p2(level): + grid = regular_sphere(level) + element = create_family(Family.Lagrange, 2) + + space = function_space(grid, element) + + assert space.local_size == grid.entity_count(ReferenceCellType.Point) + grid.entity_count( + ReferenceCellType.Interval + ) + assert space.local_size == space.global_size diff --git a/src/bindings.rs b/src/bindings.rs new file mode 100644 index 00000000..e39fc41a --- /dev/null +++ b/src/bindings.rs @@ -0,0 +1,274 @@ +//! Bindings for C + +#![allow(missing_docs)] +#![allow(clippy::missing_safety_doc)] + +#[derive(Debug, PartialEq, Clone, Copy)] +#[repr(u8)] +pub enum DType { + F32 = 0, + F64 = 1, + C32 = 2, + C64 = 3, +} + +#[derive(Debug, PartialEq, Clone, Copy)] +#[repr(u8)] +pub enum RealDType { + F32 = 0, + F64 = 1, +} + +mod function { + use super::DType; + use crate::{function::SerialFunctionSpace, traits::FunctionSpace}; + use ndelement::{ + bindings as ndelement_b, ciarlet, ciarlet::CiarletElement, traits::ElementFamily, + types::ReferenceCellType, + }; + use ndgrid::{bindings as ndgrid_b, traits::Grid, SingleElementGrid}; + use rlst::{c32, c64, MatrixInverse, RlstScalar}; + use std::ffi::c_void; + + #[derive(Debug, PartialEq, Clone, Copy)] + #[repr(u8)] + pub enum SpaceType { + SerialFunctionSpace = 0, + } + + #[derive(Debug, PartialEq, Clone, Copy)] + #[repr(u8)] + pub enum GridType { + SerialSingleElementGrid = 0, + } + + #[repr(C)] + pub struct FunctionSpaceWrapper { + pub space: *const c_void, + pub dtype: DType, + pub stype: SpaceType, + pub gtype: GridType, + } + + impl Drop for FunctionSpaceWrapper { + fn drop(&mut self) { + let Self { + space, + dtype, + stype, + gtype, + } = self; + match stype { + SpaceType::SerialFunctionSpace => match gtype { + GridType::SerialSingleElementGrid => match dtype { + DType::F32 => drop(unsafe { + Box::from_raw( + *space + as *mut SerialFunctionSpace< + f32, + SingleElementGrid>, + >, + ) + }), + DType::F64 => drop(unsafe { + Box::from_raw( + *space + as *mut SerialFunctionSpace< + f64, + SingleElementGrid>, + >, + ) + }), + DType::C32 => drop(unsafe { + Box::from_raw( + *space + as *mut SerialFunctionSpace< + c32, + SingleElementGrid>, + >, + ) + }), + DType::C64 => drop(unsafe { + Box::from_raw( + *space + as *mut SerialFunctionSpace< + c64, + SingleElementGrid>, + >, + ) + }), + }, + }, + } + } + } + + #[no_mangle] + pub unsafe extern "C" fn free_space(s: *mut FunctionSpaceWrapper) { + assert!(!s.is_null()); + unsafe { drop(Box::from_raw(s)) } + } + + pub(crate) unsafe fn extract_space( + space: *const FunctionSpaceWrapper, + ) -> *const S { + (*space).space as *const S + } + + #[no_mangle] + pub unsafe extern "C" fn space_local_size(space: *mut FunctionSpaceWrapper) -> usize { + match (*space).stype { + SpaceType::SerialFunctionSpace => match (*space).gtype { + GridType::SerialSingleElementGrid => match (*space).dtype { + DType::F32 => (*extract_space::< + SerialFunctionSpace>>, + >(space)) + .local_size(), + DType::F64 => (*extract_space::< + SerialFunctionSpace>>, + >(space)) + .local_size(), + DType::C32 => (*extract_space::< + SerialFunctionSpace>>, + >(space)) + .local_size(), + DType::C64 => (*extract_space::< + SerialFunctionSpace>>, + >(space)) + .local_size(), + }, + }, + } + } + + #[no_mangle] + pub unsafe extern "C" fn space_global_size(space: *mut FunctionSpaceWrapper) -> usize { + match (*space).stype { + SpaceType::SerialFunctionSpace => match (*space).gtype { + GridType::SerialSingleElementGrid => match (*space).dtype { + DType::F32 => (*extract_space::< + SerialFunctionSpace>>, + >(space)) + .global_size(), + DType::F64 => (*extract_space::< + SerialFunctionSpace>>, + >(space)) + .global_size(), + DType::C32 => (*extract_space::< + SerialFunctionSpace>>, + >(space)) + .global_size(), + DType::C64 => (*extract_space::< + SerialFunctionSpace>>, + >(space)) + .global_size(), + }, + }, + } + } + + pub unsafe extern "C" fn space_new_internal< + T: RlstScalar + MatrixInverse, + G: Grid + Sync, + E: ElementFamily, CellType = ReferenceCellType>, + >( + g: *const ndgrid_b::grid::GridWrapper, + f: *const ndelement_b::ciarlet::ElementFamilyWrapper, + ) -> *const FunctionSpaceWrapper { + Box::into_raw(Box::new(FunctionSpaceWrapper { + space: Box::into_raw(Box::new(SerialFunctionSpace::new( + &*((*g).grid as *const G), + &*((*f).family as *const E), + ))) as *const c_void, + dtype: match (*f).dtype { + ndelement_b::ciarlet::DType::F32 => DType::F32, + ndelement_b::ciarlet::DType::F64 => DType::F64, + ndelement_b::ciarlet::DType::C32 => DType::C32, + ndelement_b::ciarlet::DType::C64 => DType::C64, + }, + stype: SpaceType::SerialFunctionSpace, + gtype: match (*g).gtype { + ndgrid_b::grid::GridType::SerialSingleElementGrid => { + GridType::SerialSingleElementGrid + } + }, + })) + } + + #[no_mangle] + pub unsafe extern "C" fn space_new( + g: *const c_void, + f: *const c_void, + ) -> *const FunctionSpaceWrapper { + let g = g as *const ndgrid_b::grid::GridWrapper; + let f = f as *const ndelement_b::ciarlet::ElementFamilyWrapper; + match (*g).gtype { + ndgrid_b::grid::GridType::SerialSingleElementGrid => match (*f).etype { + ndelement_b::ciarlet::ElementType::Lagrange => match (*g).dtype { + ndgrid_b::DType::F32 => match (*f).dtype { + ndelement_b::ciarlet::DType::F32 => space_new_internal::< + f32, + SingleElementGrid>, + ciarlet::LagrangeElementFamily, + >(g, f), + ndelement_b::ciarlet::DType::C32 => space_new_internal::< + c32, + SingleElementGrid>, + ciarlet::LagrangeElementFamily, + >(g, f), + _ => { + panic!("Incompatible data types."); + } + }, + ndgrid_b::DType::F64 => match (*f).dtype { + ndelement_b::ciarlet::DType::F64 => space_new_internal::< + f64, + SingleElementGrid>, + ciarlet::LagrangeElementFamily, + >(g, f), + ndelement_b::ciarlet::DType::C64 => space_new_internal::< + c64, + SingleElementGrid>, + ciarlet::LagrangeElementFamily, + >(g, f), + _ => { + panic!("Incompatible data types."); + } + }, + }, + ndelement_b::ciarlet::ElementType::RaviartThomas => match (*g).dtype { + ndgrid_b::DType::F32 => match (*f).dtype { + ndelement_b::ciarlet::DType::F32 => space_new_internal::< + f32, + SingleElementGrid>, + ciarlet::RaviartThomasElementFamily, + >(g, f), + ndelement_b::ciarlet::DType::C32 => space_new_internal::< + c32, + SingleElementGrid>, + ciarlet::RaviartThomasElementFamily, + >(g, f), + _ => { + panic!("Incompatible data types."); + } + }, + ndgrid_b::DType::F64 => match (*f).dtype { + ndelement_b::ciarlet::DType::F64 => space_new_internal::< + f64, + SingleElementGrid>, + ciarlet::RaviartThomasElementFamily, + >(g, f), + ndelement_b::ciarlet::DType::C64 => space_new_internal::< + c64, + SingleElementGrid>, + ciarlet::RaviartThomasElementFamily, + >(g, f), + _ => { + panic!("Incompatible data types."); + } + }, + }, + }, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 807c610a..cbaa38bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ extern crate lazy_static; pub mod assembly; +pub mod bindings; pub mod function; pub mod quadrature; pub mod traits; diff --git a/src/quadrature/simplex_rules/generate_rust_simplex_rules.py b/src/quadrature/simplex_rules/generate_rust_simplex_rules.py index 9f9e4c7f..91b9de57 100644 --- a/src/quadrature/simplex_rules/generate_rust_simplex_rules.py +++ b/src/quadrature/simplex_rules/generate_rust_simplex_rules.py @@ -10,8 +10,8 @@ orders = [] npoints = [] -for (dirpath, dirnames, filenames) in os.walk("."): - all_rule_files += [os.path.join(dirpath, file) for file in filenames if file.endswith(".txt")] +for dirpath, dirnames, filenames in os.walk("."): + all_rule_files += [os.path.join(dirpath, file) for file in filenames if file.endswith(".txt")] for rule_file in all_rule_files: base = os.path.basename(rule_file) @@ -19,7 +19,7 @@ orders += [int(order_str)] npoints += [int(points_str)] -with open("simplex_rule_definitions.rs", 'w') as f: +with open("simplex_rule_definitions.rs", "w") as f: f.write("//! Definition of simplex rules.\n") f.write("\n") f.write("use std::collections::HashMap;\n") @@ -38,10 +38,7 @@ f.write("m.insert(ReferenceCellType::Pyramid, HM::new());\n") f.write("m.insert(ReferenceCellType::Interval, HM::new());\n") - - - for (index, rule_file) in enumerate(all_rule_files): - + for index, rule_file in enumerate(all_rule_files): arr = np.atleast_2d(np.loadtxt(rule_file)) points = arr[:, :-1] weights = arr[:, -1] @@ -74,7 +71,6 @@ points = 0.5 * (1.0 + points) weights = weights / 8.0 - elif rule_file.startswith("./tet"): identifier = "ReferenceCellType::Tetrahedron" @@ -84,13 +80,12 @@ elif rule_file.startswith("./pyr"): identifier = "ReferenceCellType::Pyramid" - points = (1.0 + points) @ np.array([[0.5, 0, 0], - [0, 0.5, 0], - [-0.25, -0.25, 0.5]],dtype='float64') + points = (1.0 + points) @ np.array( + [[0.5, 0, 0], [0, 0.5, 0], [-0.25, -0.25, 0.5]], dtype="float64" + ) weights = weights / 8.0 - else: raise ValueError("Unknown simplex type.") @@ -131,11 +126,8 @@ f.write(f"{weight},") f.write("]));\n") - f.write("m };\n}") os.system("rustfmt ./simplex_rule_definitions.rs") os.system("cp ./simplex_rule_definitions.rs ../") os.system("rm ./simplex_rule_definitions.rs") - -