Skip to content

Commit

Permalink
Refactor _blocks._find_common_type to use np.result_type
Browse files Browse the repository at this point in the history
  • Loading branch information
SpacemanPaul committed Jul 11, 2024
1 parent 26671fd commit 8166143
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions odc/geo/_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
from .types import Chunks2d


def _find_common_type(array_types, scalar_types):
# TODO: don't use find_common_type as it's being removed from numpy
return np.find_common_type(array_types, scalar_types)
def _find_common_type(array_types, scalar_type=None):
if scalar_type is None:
return np.result_type(*array_types)
else:
if np.issubdtype(scalar_type, np.floating):
array_types = array_types + [0.0]
elif np.issubdtype(scalar_type, np.complexfloating):
array_types = array_types + [0j]
return np.result_type(*array_types)


class BlockAssembler:
Expand All @@ -33,7 +39,7 @@ def __init__(
self._dtype = (
np.dtype("float32")
if len(blocks) == 0
else _find_common_type([b.dtype for b in blocks.values()], [])
else _find_common_type([b.dtype for b in blocks.values()])
)
self._axis = axis
self._blocks = blocks
Expand Down Expand Up @@ -126,7 +132,8 @@ def extract(
dtype = self._dtype
if fill_value is not None:
# possibly upgrade to float based on fill_value
dtype = _find_common_type([dtype], [np.min_scalar_type(fill_value)])
fill_dtype = np.min_scalar_type(fill_value)
dtype = _find_common_type([dtype], np.min_scalar_type(fill_value))
else:
dtype = np.dtype(dtype)

Expand Down

0 comments on commit 8166143

Please sign in to comment.