Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[next]: Allow np.bool scalar in gtfn backend #1870

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ requires-python = '>=3.10, <3.12'
[project.optional-dependencies]
# bundles
all = ['gt4py[dace,formatting,jax,performance,testing]']
all-next = ['gt4py[dace-next,formatting,jax,performance,testing]']
# device-specific extras
cuda11 = ['cupy-cuda11x>=12.0']
cuda12 = ['cupy-cuda12x>=12.0']
Expand Down Expand Up @@ -443,9 +444,17 @@ conflicts = [
{extra = 'dace'},
{extra = 'dace-next'}
],
[
{extra = 'all'},
{extra = 'all-next'}
],
[
{extra = 'all'},
{extra = 'dace-next'}
],
[
{extra = 'all-next'},
{extra = 'dace'}
]
]

Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import diskcache
import factory
import filelock
import numpy as np

import gt4py._core.definitions as core_defs
import gt4py.next.allocators as next_allocators
Expand All @@ -34,6 +35,9 @@ def convert_arg(arg: Any) -> Any:
arr = arg.ndarray
origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain)))
return arr, origin
if isinstance(arg, np.bool_):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(arg, np.bool_):
if isinstance(arg, np.bool_):

Could you leave a comment why this is needed? Does numpy.float64 and python float work? If so, why does bool not work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree about adding a comment here and would also suggest to use a ternary operator:

    return bool(arg) if isinstance(arg, np.bool_) else arg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a switch-like if therefore I wouldn't do the ternary, like in the other cases.

# nanobind does not support implicit conversion of `np.bool` to `bool`# nanobind does not support implicit conversion of `np.bool` to `bool`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# nanobind does not support implicit conversion of `np.bool` to `bool`# nanobind does not support implicit conversion of `np.bool` to `bool`
# nanobind does not support implicit conversion of `np.bool` to `bool`

return bool(arg)
else:
return arg

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
# SPDX-License-Identifier: BSD-3-Clause

from functools import reduce

import numpy as np
import pytest

import gt4py.next as gtx
from gt4py.next import (
astype,
Expand All @@ -21,24 +23,24 @@
int64,
minimum,
neighbor_sum,
utils as gt_utils,
)
from gt4py.next.ffront.experimental import as_offset
from gt4py.next import utils as gt_utils

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
C2E,
E2V,
V2E,
E2VDim,
Edge,
IDim,
Ioff,
JDim,
KDim,
Koff,
V2EDim,
Vertex,
Edge,
cartesian_case,
unstructured_case,
unstructured_case_3d,
Expand Down Expand Up @@ -196,6 +198,21 @@ def testee(a: int32) -> cases.VField:
)


def test_np_bool_scalar_arg(unstructured_case):
"""Test scalar argument being turned into 0-dim field."""

@gtx.field_operator
def testee(a: gtx.bool) -> cases.VBoolField:
return broadcast(not a, (Vertex,))

a = np.bool_(True) # explicitly using a np.bool

ref = np.full([unstructured_case.default_sizes[Vertex]], not a, dtype=np.bool_)
out = cases.allocate(unstructured_case, testee, cases.RETURN)()

cases.verify(unstructured_case, testee, a, out=out, ref=ref)


def test_nested_scalar_arg(unstructured_case):
@gtx.field_operator
def testee_inner(a: int32) -> cases.VField:
Expand Down
Loading