Skip to content

Commit

Permalink
Visitors: add support for translating complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar committed Oct 14, 2024
1 parent 002584c commit 98bcedf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
8 changes: 7 additions & 1 deletion pykokkos/core/visitors/pykokkos_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,17 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
)
elif name in ["PerTeam", "PerThread", "fence"]:
name = "Kokkos::" + name
elif name in {"complex32", "complex64"}:
name = "Kokkos::complex"
if "32" in name:
name += "<float>"
else:
name += "<double>"

function = cppast.DeclRefExpr(name)
args: List[cppast.Expr] = [self.visit(a) for a in node.args]

if visitors_util.is_math_function(name) or name in ["printf", "abs", "Kokkos::PerTeam", "Kokkos::PerThread", "Kokkos::fence"]:
if visitors_util.is_math_function(name) or name in ["printf", "abs", "Kokkos::PerTeam", "Kokkos::PerThread", "Kokkos::fence", "Kokkos::complex<float>", "Kokkos::complex<double>"]:
return cppast.CallExpr(function, args)

if function in self.kokkos_functions:
Expand Down
7 changes: 5 additions & 2 deletions pykokkos/core/visitors/visitors_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def pretty_print(node):
"double": "double",
"bool": "bool",
"TeamMember": f"Kokkos::TeamPolicy<{Keywords.DefaultExecSpace.value}>::member_type",
"cpp_auto": "auto"
"cpp_auto": "auto",
"complex32": "Kokkos::complex<float>",
"complex64": "Kokkos::complex<double>"
}

# Maps from the DataType enum to cppast
Expand Down Expand Up @@ -307,7 +309,8 @@ def parse_view_template_params(

if parameter in ("int", "double", "float",
"int8_t", "int16_t", "int32_t", "int64_t",
"uint8_t", "uint16_t", "uint32_t", "uint64_t"):
"uint8_t", "uint16_t", "uint32_t", "uint64_t",
"Kokkos::complex<float>", "Kokkos::complex<double>"):
datatype: str = parameter + "*" * rank
params["dtype"] = datatype

Expand Down

0 comments on commit 98bcedf

Please sign in to comment.