Skip to content

Commit

Permalink
Interface: add cyl_bessel_j0 and cyl_bessel_j1 as mathemtical functio…
Browse files Browse the repository at this point in the history
…ns inside kernels (only accept real values for now) (#283)
  • Loading branch information
NaderAlAwar authored Aug 5, 2024
1 parent 359b08a commit cff567c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pykokkos/core/visitors/workunit_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,17 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:

return rand_call

if name in {"cyl_bessel_j0", "cyl_bessel_j1"}:
if len(args) != 1:
self.error(node, "pk.cyl_bessel_j0/j1 accepts only one argument")

s = cppast.Serializer()
arg_str = s.serialize(args[0])
math_call = cppast.CallExpr(cppast.DeclRefExpr(f"Kokkos::Experimental::{name}<Kokkos::complex<decltype({arg_str})>, double, int>"), args)
real_number_call = cppast.MemberCallExpr(math_call, cppast.DeclRefExpr("real"), [])

return real_number_call

return super().visit_Call(node)

def is_nested_call(self, node: ast.FunctionDef) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from .hierarchical import (
AUTO, TeamMember, PerTeam, PerThread, single
)
from .mathematical_special_functions import (
cyl_bessel_j0, cyl_bessel_j1
)
from .memory_space import MemorySpace, get_default_memory_space
from .parallel_dispatch import (
execute, flush,
Expand Down
6 changes: 6 additions & 0 deletions pykokkos/interface/mathematical_special_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

def cyl_bessel_j0(input: float) -> float:
pass

def cyl_bessel_j1(input: float) -> float:
pass

0 comments on commit cff567c

Please sign in to comment.