Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New callback to save a list of iterates every set number of iterations #1913

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
47 changes: 47 additions & 0 deletions Wrappers/Python/cil/optimisation/utilities/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm.auto import tqdm as tqdm_auto
from tqdm.std import tqdm as tqdm_std
import numpy as np
from cil.processors import Slicer


class Callback(ABC):
Expand Down Expand Up @@ -135,6 +136,7 @@ class LogfileCallback(TextProgressCallback):
def __init__(self, log_file, mode='a', **kwargs):
self.fd = open(log_file, mode=mode)
super().__init__(file=self.fd, **kwargs)


class EarlyStoppingObjectiveValue(Callback):
'''Callback that stops iterations if the change in the objective value is less than a provided threshold value.
Expand Down Expand Up @@ -187,3 +189,48 @@ def __call__(self, algorithm):
raise StopIteration


class SaveIterates(Callback):
'''Callback to save iterates every set number of iterations.
Parameters
----------
interval: integer,
The iterates will be saved every `interval` number of iterations e.g. if `interval =4` the 0, 4, 8, 12,... iterates will be saved.

writer: TiffWriter, optional, default is None
If a writer is passed, it will be used to save the iterates. If not the iterates will be saved as a list in the class object `iterates`

roi: dict, optional default is None and no slicing will be applied
The region-of-interest to slice {'axis_name1':(start,stop,step), 'axis_name2':(start,stop,step)}
The `key` being the axis name to apply the processor to, the `value` holding a tuple containing the ROI description

Start: Starting index of input data. Must be an integer, or `None` defaults to index 0.
Stop: Stopping index of input data. Must be an integer, or `None` defaults to index N.
Step: Number of pixels to average together. Must be an integer or `None` defaults to 1.

'''
def __init__(self, interval=1, writer=None, roi=None): #TODO: optionally pass an ROI

self.writer=writer
if self.writer is None:
self.iterates=[]

self.interval=interval
self.roi=roi
if self.roi is not None:
self.slicer= Slicer(roi=self.roi)

super(SaveIterates, self).__init__()

def __call__(self, algo):
if algo.iteration % self.interval ==0:
if self.roi is None:
if self.writer is None:
self.iterates.append( algo.solution.copy())
else:
self.writer.write(algo.solution)
else:
self.slicer.set_input(algo.solution)
if self.writer is None:
self.iterates.append( self.slicer.get_output())
else:
self.writer.write( self.slicer.get_output())
75 changes: 75 additions & 0 deletions Wrappers/Python/test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,81 @@ def test_EarlyStoppingObjectiveValue(self):
callbacks.EarlyStoppingObjectiveValue(0.1)(alg)



class TestSaveIteratesCallback(unittest.TestCase):

class MockAlgo(Algorithm):
def __init__(self, initial, update_objective_interval=10, **kwargs):
super().__init__(update_objective_interval=update_objective_interval, **kwargs)
self.configured = True
self.x=initial

def update(self):
self.x -= 1

def update_objective(self):
self.loss.append(2 ** getattr(self, 'x', np.nan))


def setUp(self):
# Mock the algorithm object

self.image_geometry = ImageGeometry(10, 2)
self.data = self.image_geometry.allocate(10)
self.mock_algorithm = self.MockAlgo(self.data)

def test_save_iterates_no_writer_no_roi(self):
# Test saving iterates to a list with no writer and no ROI
callback = callbacks.SaveIterates(interval=1)

# Call the callback multiple times and increment iteration
self.mock_algorithm.run(5, callbacks=[callback])

# Check if iterates are saved correctly
self.assertEqual(len(callback.iterates), 6)
for i in range(5):
np.testing.assert_array_equal(callback.iterates[i].array, (10-i)*np.ones((2,10)))

def test_save_iterates_with_writer(self):
# Test saving iterates using a writer
mock_writer = MagicMock()
callback = callbacks.SaveIterates(interval=1, writer=mock_writer)

# Call the callback and check if writer.write() was called
callback(self.mock_algorithm)
mock_writer.write.assert_called_once_with(self.mock_algorithm.x)

def test_save_iterates_with_roi(self):
# Test saving iterates with an ROI applied
roi = {'horizontal_x': (0, 2, 1)}
from cil.processors import Slicer


callback = callbacks.SaveIterates(interval=1, roi=roi)

# Call the callback and check if slicer was used
callback(self.mock_algorithm)

np.testing.assert_array_equal(callback.iterates[0].array, 10*np.ones([2, 2]))

def test_save_iterates_with_interval(self):
# Test saving iterates with a specified interval
callback = callbacks.SaveIterates(interval=2)

# Call the callback multiple times and increment iteration
self.mock_algorithm.run(5, callbacks=[callback])

# Check if iterates were saved only at the correct intervals
self.assertEqual(len(callback.iterates), 3)
np.testing.assert_array_equal(callback.iterates[0].array, (10-0)*np.ones((2,10)))
np.testing.assert_array_equal(callback.iterates[1].array, (10-2)*np.ones((2,10)))
np.testing.assert_array_equal(callback.iterates[2].array, (10-4)*np.ones((2,10)))






class TestADMM(unittest.TestCase):
def setUp(self):
ig = ImageGeometry(2, 3, 2)
Expand Down
Loading