From 97540a458e361e35d8080e3cf2a61c95984cdacf Mon Sep 17 00:00:00 2001 From: Margaret Duff Date: Tue, 3 Sep 2024 16:07:52 +0000 Subject: [PATCH 1/2] I think a fix --- .../optimisation/operators/BlockOperator.py | 28 ++++++++- Wrappers/Python/test/test_BlockOperator.py | 59 +++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/Wrappers/Python/cil/optimisation/operators/BlockOperator.py b/Wrappers/Python/cil/optimisation/operators/BlockOperator.py index 66ede257bf..353a8f14bd 100644 --- a/Wrappers/Python/cil/optimisation/operators/BlockOperator.py +++ b/Wrappers/Python/cil/optimisation/operators/BlockOperator.py @@ -212,7 +212,17 @@ def direct(self, x, out=None): else: + if not isinstance(out, BlockDataContainer): + if 1 == shape[0] == shape[1]: + out = BlockDataContainer(out) + flag_return_data_container = True + else: + raise ValueError(f'You passed to `out` a `DataContainer`. You needed to pass a `BlockDataContainer` of shape {shape}') + else: + flag_return_data_container = False tmp = self.range_geometry().allocate() + if not isinstance(tmp, BlockDataContainer): + tmp = BlockDataContainer(tmp) for row in range(self.shape[0]): for col in range(self.shape[1]): if col == 0: @@ -225,7 +235,10 @@ def direct(self, x, out=None): x_b.get_item(col), out=tmp.get_item(row)) temp_out_row += tmp.get_item(row) - return out + if flag_return_data_container: + return out.get_item(0) + else: + return out def adjoint(self, x, out=None): '''Adjoint operation for the BlockOperator @@ -265,6 +278,14 @@ def adjoint(self, x, out=None): else: return BlockDataContainer(*res, shape=shape) else: + if not isinstance(out, BlockDataContainer): + if 1 == shape[0] == shape[1]: + out = BlockDataContainer(out) + flag_return_data_container = True + else: + raise ValueError(f'You passed to `out` a `DataContainer`. You needed to pass a `BlockDataContainer` of shape {shape}') + else: + flag_return_data_container = False for col in range(self.shape[1]): for row in range(self.shape[0]): if row == 0: @@ -289,7 +310,10 @@ def adjoint(self, x, out=None): temp_out_col += self.get_item(row,col).adjoint( x_b.get_item(row), ) - return out + if flag_return_data_container: + return out.get_item(0) + else: + return out def is_linear(self): '''Returns whether all the elements of the BlockOperator are linear''' diff --git a/Wrappers/Python/test/test_BlockOperator.py b/Wrappers/Python/test/test_BlockOperator.py index 910a711fda..0aae447992 100644 --- a/Wrappers/Python/test/test_BlockOperator.py +++ b/Wrappers/Python/test/test_BlockOperator.py @@ -209,7 +209,66 @@ def test_block_operator_1_1(self): self.assertEqual(K.domain_geometry(), ig) + def test_blockoperator_out_datacontainer(self): + + #test direct + M, N ,W = 3, 4, 5 + ig = ImageGeometry(M, N, W) + operator0=IdentityOperator(ig) + operator1=-IdentityOperator(ig) + K = BlockOperator(operator0, operator1, shape = (1,2)) + bg=BlockGeometry(ig, ig) + data=bg.allocate('random', seed=2) + out=K.range.allocate(0) + assert not isinstance(out, BlockDataContainer) + ans = K.direct(data) + K.direct(data, out) + self.assertNumpyArrayEqual(ans.array, out.array) + + #test direct out is BlockDataContainer + out = BlockDataContainer(out) + assert isinstance(out, BlockDataContainer) + ans = K.direct(data) + K.direct(data, out) + self.assertNumpyArrayEqual(ans.array, out.get_item(0).array) + + #test adjoint wrong dimension + out=ig.allocate(0) + data = ig.allocate('random') + print(K.range_geometry) + with self.assertRaises(ValueError): + K.adjoint(data, out) + + + #test adjoint out not BlockDataContainer + M, N ,W = 3, 4, 5 + operator0=IdentityOperator(ig) + operator1=-IdentityOperator(ig) + K = BlockOperator(operator0, operator1, shape = (2,1)) + bg=BlockGeometry(ig, ig) + data=bg.allocate('random', seed=2) + out=K.domain.allocate(0) + assert not isinstance(out, BlockDataContainer) + ans = K.adjoint(data) + K.adjoint(data, out) + self.assertNumpyArrayEqual(ans.array, out.array) + + #test adjoint out is BlockDataContainer + out = BlockDataContainer(out) + assert isinstance(out, BlockDataContainer) + ans = K.adjoint(data) + K.adjoint(data, out) + self.assertNumpyArrayEqual(ans.array, out.get_item(0).array) + + #test direct wrong dimension + out=ig.allocate(0) + data = ig.allocate('random') + print(K.range_geometry) + with self.assertRaises(ValueError): + K.direct(data, out) + + @unittest.skipIf(True, 'Skipping time tests') def test_timedifference(self): From 00ac333cf5201174a37f17f4204d2a4a8ebd671a Mon Sep 17 00:00:00 2001 From: Margaret Duff Date: Mon, 7 Oct 2024 11:58:04 +0000 Subject: [PATCH 2/2] Updates from discussion with Edo and Gemma --- .../optimisation/operators/BlockOperator.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/Wrappers/Python/cil/optimisation/operators/BlockOperator.py b/Wrappers/Python/cil/optimisation/operators/BlockOperator.py index 353a8f14bd..f458a4191f 100644 --- a/Wrappers/Python/cil/optimisation/operators/BlockOperator.py +++ b/Wrappers/Python/cil/optimisation/operators/BlockOperator.py @@ -204,7 +204,7 @@ def direct(self, x, out=None): prod += self.get_item(row, col).direct(x_b.get_item(col)) res.append(prod) - if 1 == shape[0] == shape[1]: + if (1 == shape[0]) and (1 == shape[1]): # the output is a single DataContainer, so we can take it out return res[0] else: @@ -213,13 +213,13 @@ def direct(self, x, out=None): else: if not isinstance(out, BlockDataContainer): - if 1 == shape[0] == shape[1]: + if (1 == shape[0]) and (1 == shape[1]): out = BlockDataContainer(out) - flag_return_data_container = True + flag_single_element = True else: raise ValueError(f'You passed to `out` a `DataContainer`. You needed to pass a `BlockDataContainer` of shape {shape}') else: - flag_return_data_container = False + flag_single_element = False tmp = self.range_geometry().allocate() if not isinstance(tmp, BlockDataContainer): tmp = BlockDataContainer(tmp) @@ -235,7 +235,7 @@ def direct(self, x, out=None): x_b.get_item(col), out=tmp.get_item(row)) temp_out_row += tmp.get_item(row) - if flag_return_data_container: + if flag_single_element: return out.get_item(0) else: return out @@ -279,13 +279,13 @@ def adjoint(self, x, out=None): return BlockDataContainer(*res, shape=shape) else: if not isinstance(out, BlockDataContainer): - if 1 == shape[0] == shape[1]: + if (1 == shape[0]) and (1 == shape[1]): out = BlockDataContainer(out) - flag_return_data_container = True + flag_single_element = True else: raise ValueError(f'You passed to `out` a `DataContainer`. You needed to pass a `BlockDataContainer` of shape {shape}') else: - flag_return_data_container = False + flag_single_element = False for col in range(self.shape[1]): for row in range(self.shape[0]): if row == 0: @@ -310,7 +310,7 @@ def adjoint(self, x, out=None): temp_out_col += self.get_item(row,col).adjoint( x_b.get_item(row), ) - if flag_return_data_container: + if flag_single_element: return out.get_item(0) else: return out