Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 29, 2024
1 parent 6c3754a commit b9df0bc
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 22 deletions.
8 changes: 5 additions & 3 deletions src/pysimulators/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ def partition_init(self, *args, **keywords):
n = len(partitionin)
argss = tuple(
tuple(
a[i]
if class_args[j] in partition_args and not isscalarlike(a)
else a
(
a[i]
if class_args[j] in partition_args and not isscalarlike(a)
else a
)
for j, a in enumerate(args)
)
for i in range(n)
Expand Down
4 changes: 1 addition & 3 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def assert_partioning_chunk(cls, op, value, key):
nn = (
max(n1, n2)
if partitionin is None
else 1
if isscalarlike(partitionin)
else len(partitionin)
else 1 if isscalarlike(partitionin) else len(partitionin)
)
if n1 != n2 and not isscalarlike(value) and not isscalarlike(key):
# the partitioned arguments do not have the same length
Expand Down
34 changes: 18 additions & 16 deletions tests/test_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,24 @@ def test_function(array, func):

@pytest.mark.parametrize(
'value, expected_type',
[
(Quantity(1), float),
(Quantity(1, dtype='float32'), np.float32),
(Quantity(1.0), float),
(Quantity(complex(1, 0)), np.complex128),
(Quantity(1.0, dtype=np.complex64), np.complex64),
(Quantity(1.0, dtype=np.complex128), np.complex128),
(Quantity(np.array(complex(1, 0))), complex),
(Quantity(np.array(np.complex64(1.0))), np.complex64),
(Quantity(np.array(np.complex128(1.0))), complex),
]
+ [
(Quantity(1.0, dtype=np.complex256), np.complex256),
]
if hasattr(np, 'complex256')
else [],
(
[
(Quantity(1), float),
(Quantity(1, dtype='float32'), np.float32),
(Quantity(1.0), float),
(Quantity(complex(1, 0)), np.complex128),
(Quantity(1.0, dtype=np.complex64), np.complex64),
(Quantity(1.0, dtype=np.complex128), np.complex128),
(Quantity(np.array(complex(1, 0))), complex),
(Quantity(np.array(np.complex64(1.0))), np.complex64),
(Quantity(np.array(np.complex128(1.0))), complex),
]
+ [
(Quantity(1.0, dtype=np.complex256), np.complex256),
]
if hasattr(np, 'complex256')
else []
),
)
def test_dtype(value, expected_type):
assert value.dtype == expected_type
Expand Down

0 comments on commit b9df0bc

Please sign in to comment.