Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BlockOperator direct and adjoint methods: can pass out as a DataContainer instead of a (1,1) BlockDataContainer where geometry permits #1926

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions Wrappers/Python/cil/optimisation/operators/BlockOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,25 @@ def direct(self, x, out=None):
prod += self.get_item(row,
col).direct(x_b.get_item(col))
res.append(prod)
if 1 == shape[0] == shape[1]:
if (1 == shape[0]) and (1 == shape[1]):
# the output is a single DataContainer, so we can take it out
return res[0]
else:
return BlockDataContainer(*res, shape=shape)


else:
if not isinstance(out, BlockDataContainer):
if (1 == shape[0]) and (1 == shape[1]):
out = BlockDataContainer(out)
flag_single_element = True
else:
raise ValueError(f'You passed to `out` a `DataContainer`. You needed to pass a `BlockDataContainer` of shape {shape}')
else:
flag_single_element = False
tmp = self.range_geometry().allocate()
if not isinstance(tmp, BlockDataContainer):
tmp = BlockDataContainer(tmp)
for row in range(self.shape[0]):
for col in range(self.shape[1]):
if col == 0:
Expand All @@ -225,7 +235,10 @@ def direct(self, x, out=None):
x_b.get_item(col),
out=tmp.get_item(row))
temp_out_row += tmp.get_item(row)
return out
if flag_single_element:
return out.get_item(0)
else:
return out

def adjoint(self, x, out=None):
'''Adjoint operation for the BlockOperator
Expand Down Expand Up @@ -265,6 +278,14 @@ def adjoint(self, x, out=None):
else:
return BlockDataContainer(*res, shape=shape)
else:
if not isinstance(out, BlockDataContainer):
if (1 == shape[0]) and (1 == shape[1]):
out = BlockDataContainer(out)
flag_single_element = True
else:
raise ValueError(f'You passed to `out` a `DataContainer`. You needed to pass a `BlockDataContainer` of shape {shape}')
else:
flag_single_element = False
for col in range(self.shape[1]):
for row in range(self.shape[0]):
if row == 0:
Expand All @@ -289,7 +310,10 @@ def adjoint(self, x, out=None):
temp_out_col += self.get_item(row,col).adjoint(
x_b.get_item(row),
)
return out
if flag_single_element:
return out.get_item(0)
else:
return out

def is_linear(self):
'''Returns whether all the elements of the BlockOperator are linear'''
Expand Down
59 changes: 59 additions & 0 deletions Wrappers/Python/test/test_BlockOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,66 @@ def test_block_operator_1_1(self):

self.assertEqual(K.domain_geometry(), ig)

def test_blockoperator_out_datacontainer(self):

#test direct
M, N ,W = 3, 4, 5
ig = ImageGeometry(M, N, W)
operator0=IdentityOperator(ig)
operator1=-IdentityOperator(ig)
K = BlockOperator(operator0, operator1, shape = (1,2))
bg=BlockGeometry(ig, ig)
data=bg.allocate('random', seed=2)
out=K.range.allocate(0)
assert not isinstance(out, BlockDataContainer)
ans = K.direct(data)
K.direct(data, out)
self.assertNumpyArrayEqual(ans.array, out.array)

#test direct out is BlockDataContainer
out = BlockDataContainer(out)
assert isinstance(out, BlockDataContainer)
ans = K.direct(data)
K.direct(data, out)
self.assertNumpyArrayEqual(ans.array, out.get_item(0).array)

#test adjoint wrong dimension
out=ig.allocate(0)
data = ig.allocate('random')
print(K.range_geometry)
with self.assertRaises(ValueError):
K.adjoint(data, out)


#test adjoint out not BlockDataContainer
M, N ,W = 3, 4, 5
operator0=IdentityOperator(ig)
operator1=-IdentityOperator(ig)
K = BlockOperator(operator0, operator1, shape = (2,1))
bg=BlockGeometry(ig, ig)
data=bg.allocate('random', seed=2)
out=K.domain.allocate(0)
assert not isinstance(out, BlockDataContainer)
ans = K.adjoint(data)
K.adjoint(data, out)
self.assertNumpyArrayEqual(ans.array, out.array)

#test adjoint out is BlockDataContainer
out = BlockDataContainer(out)
assert isinstance(out, BlockDataContainer)
ans = K.adjoint(data)
K.adjoint(data, out)
self.assertNumpyArrayEqual(ans.array, out.get_item(0).array)

#test direct wrong dimension
out=ig.allocate(0)
data = ig.allocate('random')
print(K.range_geometry)
with self.assertRaises(ValueError):
K.direct(data, out)




@unittest.skipIf(True, 'Skipping time tests')
def test_timedifference(self):
Expand Down
Loading