Skip to content

Commit

Permalink
Closes #3337: Fix can_cast with NumPy 2.0 breaking changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Mar 3, 2025
1 parent 7b69a9a commit 6e85630
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 44 deletions.
9 changes: 8 additions & 1 deletion arkouda/numpy/_numeric.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
import json
from enum import Enum
from typing import TYPE_CHECKING, List, Sequence, Tuple, TypeVar, Union
Expand All @@ -9,13 +10,19 @@

from arkouda.client import generic_msg
from arkouda.groupbyclass import GroupBy, groupable
from arkouda.numpy.dtypes import _datatype_check, bigint
from arkouda.numpy.dtypes import (
_datatype_check,
_is_dtype_in_union,
bigint,
)
from arkouda.numpy.dtypes import bool_ as ak_bool
from arkouda.numpy.dtypes import dtype as ak_dtype
from arkouda.numpy.dtypes import dtype as akdtype
from arkouda.numpy.dtypes import float64 as ak_float64
from arkouda.numpy.dtypes import int64 as ak_int64
from arkouda.numpy.dtypes import (
int_scalars,
isSupportedInt,
isSupportedNumber,
numeric_scalars,
resolve_scalar_dtype,
Expand Down
31 changes: 31 additions & 0 deletions arkouda/numpy/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"bitType",
"bool_",
"bool_scalars",
"can_cast",
"complex128",
"complex64",
"dtype",
Expand Down Expand Up @@ -121,6 +122,36 @@ def dtype(dtype):
return np.dtype(dtype)


def can_cast(from_, to) -> builtins.bool:
"""
Returns True if cast between data types can occur according to the casting rule.
Parameters
__________
from_: dtype, dtype specifier, NumPy scalar, or pdarray
Data type, NumPy scalar, or array to cast from.
to: dtype or dtype specifier
Data type to cast to.
Return
------
bool
True if cast can occur according to the casting rule.
"""
if isSupportedInt(from_):
if (from_ < 2**64) and (from_ >= 0) and (to == dtype(uint64)):
return True

if (
np.isscalar(from_) or _is_dtype_in_union(from_, numeric_scalars)
) and not isinstance(from_, (int, float, complex)):
return np.can_cast(from_, to)

return False


def _is_dtype_in_union(dtype, union_type) -> builtins.bool:
"""
Check if a given type is in a typing.Union.
Expand Down
41 changes: 25 additions & 16 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@
bigint,
)
from arkouda.numpy.dtypes import bool_ as akbool
from arkouda.numpy.dtypes import bool_scalars, dtype
from arkouda.numpy.dtypes import (
bool_scalars,
dtype,
)
from arkouda.numpy.dtypes import float64 as akfloat64
from arkouda.numpy.dtypes import get_byteorder, get_server_byteorder
from arkouda.numpy.dtypes import (
get_byteorder,
get_server_byteorder,
)
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.dtypes import (
int_scalars,
Expand Down Expand Up @@ -571,7 +577,9 @@ def _binop(self, other: pdarray, op: str) -> pdarray:
# pdarray binop scalar
# If scalar cannot be safely cast, server will infer the return dtype
dt = resolve_scalar_dtype(other)
if self.dtype != bigint and np.can_cast(other, self.dtype):
from arkouda.numpy.dtypes import can_cast as ak_can_cast

if self.dtype != bigint and ak_can_cast(other, self.dtype):
# If scalar can be losslessly cast to array dtype,
# do the cast so that return array will have same dtype
dt = self.dtype.name
Expand Down Expand Up @@ -616,7 +624,9 @@ def _r_binop(self, other: pdarray, op: str) -> pdarray:
# pdarray binop scalar
# If scalar cannot be safely cast, server will infer the return dtype
dt = resolve_scalar_dtype(other)
if self.dtype != bigint and np.can_cast(other, self.dtype):
from arkouda.numpy.dtypes import can_cast as ak_can_cast

if self.dtype != bigint and ak_can_cast(other, self.dtype):
# If scalar can be losslessly cast to array dtype,
# do the cast so that return array will have same dtype
dt = self.dtype.name
Expand Down Expand Up @@ -4131,31 +4141,30 @@ def fmod(dividend: Union[pdarray, numeric_scalars], divisor: Union[pdarray, nume
)
# TODO: handle shape broadcasting for multidimensional arrays

# The code below creates a command string for fmod2vv, fmod2vs or fmod2sv.

# The code below creates a command string for fmod2vv, fmod2vs or fmod2sv.

if isinstance(dividend, pdarray) and isinstance(divisor, pdarray) :
if isinstance(dividend, pdarray) and isinstance(divisor, pdarray):
cmdstring = f"fmod2vv<{dividend.dtype},{dividend.ndim},{divisor.dtype}>"

elif isinstance(dividend, pdarray) and not (isinstance(divisor, pdarray)) :
if resolve_scalar_dtype(divisor) in ['float64', 'int64', 'uint64', 'bool'] :
acmd = 'fmod2vs_'+resolve_scalar_dtype(divisor)
else : # this condition *should* be impossible because of the isSupportedNumber check
elif isinstance(dividend, pdarray) and not (isinstance(divisor, pdarray)):
if resolve_scalar_dtype(divisor) in ["float64", "int64", "uint64", "bool"]:
acmd = "fmod2vs_" + resolve_scalar_dtype(divisor)
else: # this condition *should* be impossible because of the isSupportedNumber check
raise TypeError(f"Scalar divisor type {resolve_scalar_dtype(divisor)} not allowed in fmod")
cmdstring = f"{acmd}<{dividend.dtype},{dividend.ndim}>"

elif not (isinstance(dividend, pdarray) and isinstance(divisor, pdarray)) :
if resolve_scalar_dtype(dividend) in ['float64', 'int64', 'uint64', 'bool'] :
acmd = 'fmod2sv_'+resolve_scalar_dtype(dividend)
else : # this condition *should* be impossible because of the isSupportedNumber check
elif not (isinstance(dividend, pdarray) and isinstance(divisor, pdarray)):
if resolve_scalar_dtype(dividend) in ["float64", "int64", "uint64", "bool"]:
acmd = "fmod2sv_" + resolve_scalar_dtype(dividend)
else: # this condition *should* be impossible because of the isSupportedNumber check
raise TypeError(f"Scalar dividend type {resolve_scalar_dtype(dividend)} not allowed in fmod")
cmdstring = f"{acmd}<{divisor.dtype},{divisor.ndim}>" # type: ignore[union-attr]

else:
m = mod(dividend, divisor)
return _create_scalar_array(m)

# We reach here if this was any case other than scalar & scalar
# We reach here if this was any case other than scalar & scalar

return create_pdarray(
cast(
Expand Down
Loading

0 comments on commit 6e85630

Please sign in to comment.