Skip to content

Commit

Permalink
Merge pull request kokkos#124 from tylerjereddy/treddy_ufunc_templs
Browse files Browse the repository at this point in the history
ENH: more ufuncs to API std
  • Loading branch information
NaderAlAwar authored Jan 3, 2023
2 parents cccbffd + c3a0929 commit ffd607c
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ jobs:
# for hypothesis-driven test case generation
pytest $GITHUB_WORKSPACE/pre_compile_tools/pre_compile_ufuncs.py -s
# only run a subset of the conformance tests to get started
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos
pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor
2 changes: 1 addition & 1 deletion pre_compile_tools/pre_compile_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
# level kernels/workunits directly
filtered_function_list = []
for f in function_list:
if not "impl" in f[0] and not "dispatcher" in f[0]:
if not "impl" in f[0] and not f[0].startswith("_"):
filtered_function_list.append(f)
# TODO: expand types and view dimensions for
# ufunc pre-compilation as the support
Expand Down
6 changes: 5 additions & 1 deletion pykokkos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
isinf,
isnan,
equal,
isfinite)
isfinite,
round,
trunc,
ceil,
floor)
from pykokkos.lib.info import iinfo, finfo
from pykokkos.lib.create import (zeros,
ones,
Expand Down
1 change: 1 addition & 0 deletions pykokkos/core/visitors/visitors_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def pretty_print(node):
"pow",
"radians",
"remainder",
"round",
"sin",
"sinh",
"sqrt",
Expand Down
136 changes: 136 additions & 0 deletions pykokkos/lib/ufunc_workunits.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,142 @@
import pykokkos as pk


@pk.workunit
def floor_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.double]):
out[tid] = floor(view[tid])


@pk.workunit
def floor_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.View2D[pk.double]):
for i in range(view.extent(1)):
out[tid][i] = floor(view[tid][i])


@pk.workunit
def floor_impl_3d_double(tid: int, view: pk.View3D[pk.double], out: pk.View3D[pk.double]):
for i in range(view.extent(1)):
for j in range(view.extent(2)):
out[tid][i][j] = floor(view[tid][i][j])

@pk.workunit
def floor_impl_1d_float(tid: int, view: pk.View1D[pk.float], out: pk.View1D[pk.float]):
out[tid] = floor(view[tid])


@pk.workunit
def floor_impl_2d_float(tid: int, view: pk.View2D[pk.float], out: pk.View2D[pk.float]):
for i in range(view.extent(1)):
out[tid][i] = floor(view[tid][i])


@pk.workunit
def floor_impl_3d_float(tid: int, view: pk.View3D[pk.float], out: pk.View3D[pk.float]):
for i in range(view.extent(1)):
for j in range(view.extent(2)):
out[tid][i][j] = floor(view[tid][i][j])

@pk.workunit
def ceil_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.double]):
out[tid] = ceil(view[tid])


@pk.workunit
def ceil_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.View2D[pk.double]):
for i in range(view.extent(1)):
out[tid][i] = ceil(view[tid][i])


@pk.workunit
def ceil_impl_3d_double(tid: int, view: pk.View3D[pk.double], out: pk.View3D[pk.double]):
for i in range(view.extent(1)):
for j in range(view.extent(2)):
out[tid][i][j] = ceil(view[tid][i][j])

@pk.workunit
def ceil_impl_1d_float(tid: int, view: pk.View1D[pk.float], out: pk.View1D[pk.float]):
out[tid] = ceil(view[tid])


@pk.workunit
def ceil_impl_2d_float(tid: int, view: pk.View2D[pk.float], out: pk.View2D[pk.float]):
for i in range(view.extent(1)):
out[tid][i] = ceil(view[tid][i])


@pk.workunit
def ceil_impl_3d_float(tid: int, view: pk.View3D[pk.float], out: pk.View3D[pk.float]):
for i in range(view.extent(1)):
for j in range(view.extent(2)):
out[tid][i][j] = ceil(view[tid][i][j])

@pk.workunit
def trunc_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.double]):
out[tid] = trunc(view[tid])


@pk.workunit
def trunc_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.View2D[pk.double]):
for i in range(view.extent(1)):
out[tid][i] = trunc(view[tid][i])


@pk.workunit
def trunc_impl_3d_double(tid: int, view: pk.View3D[pk.double], out: pk.View3D[pk.double]):
for i in range(view.extent(1)):
for j in range(view.extent(2)):
out[tid][i][j] = trunc(view[tid][i][j])

@pk.workunit
def trunc_impl_1d_float(tid: int, view: pk.View1D[pk.float], out: pk.View1D[pk.float]):
out[tid] = trunc(view[tid])


@pk.workunit
def trunc_impl_2d_float(tid: int, view: pk.View2D[pk.float], out: pk.View2D[pk.float]):
for i in range(view.extent(1)):
out[tid][i] = trunc(view[tid][i])


@pk.workunit
def trunc_impl_3d_float(tid: int, view: pk.View3D[pk.float], out: pk.View3D[pk.float]):
for i in range(view.extent(1)):
for j in range(view.extent(2)):
out[tid][i][j] = trunc(view[tid][i][j])

@pk.workunit
def round_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.double]):
out[tid] = round(view[tid])


@pk.workunit
def round_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.View2D[pk.double]):
for i in range(view.extent(1)):
out[tid][i] = round(view[tid][i])


@pk.workunit
def round_impl_3d_double(tid: int, view: pk.View3D[pk.double], out: pk.View3D[pk.double]):
for i in range(view.extent(1)):
for j in range(view.extent(2)):
out[tid][i][j] = round(view[tid][i][j])

@pk.workunit
def round_impl_1d_float(tid: int, view: pk.View1D[pk.float], out: pk.View1D[pk.float]):
out[tid] = round(view[tid])


@pk.workunit
def round_impl_2d_float(tid: int, view: pk.View2D[pk.float], out: pk.View2D[pk.float]):
for i in range(view.extent(1)):
out[tid][i] = round(view[tid][i])


@pk.workunit
def round_impl_3d_float(tid: int, view: pk.View3D[pk.float], out: pk.View3D[pk.float]):
for i in range(view.extent(1)):
for j in range(view.extent(2)):
out[tid][i][j] = round(view[tid][i][j])

@pk.workunit
def isfinite_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.uint8]):
out[tid] = isfinite(view[tid])
Expand Down
Loading

0 comments on commit ffd607c

Please sign in to comment.