Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New callback to save a list of iterates every set number of iterations #1913

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
50 changes: 48 additions & 2 deletions Wrappers/Python/cil/optimisation/utilities/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from tqdm.auto import tqdm as tqdm_auto
from tqdm.std import tqdm as tqdm_std
import numpy as np

from cil.processors import Slicer
import os
from cil.io import TIFFWriter

class Callback(ABC):
'''Base Callback to inherit from for use in :code:`Algorithm.run(callbacks: list[Callback])`.
Expand Down Expand Up @@ -135,6 +137,7 @@ class LogfileCallback(TextProgressCallback):
def __init__(self, log_file, mode='a', **kwargs):
self.fd = open(log_file, mode=mode)
super().__init__(file=self.fd, **kwargs)


class EarlyStoppingObjectiveValue(Callback):
'''Callback that stops iterations if the change in the objective value is less than a provided threshold value.
Expand All @@ -158,8 +161,9 @@ def __call__(self, algorithm):
raise StopIteration

class CGLSEarlyStopping(Callback):
'''Callback to work with CGLS. It causes the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2` where `epsilon` is set to default as '1e-6', :math:`x` is the current iterate and :math:`x_0` is the initial value.
r'''Callback to work with CGLS. It causes the algorithm to terminate if :math:`||A^T(Ax-b)||_2 < \epsilon||A^T(Ax_0-b)||_2` where `epsilon` is set to default as '1e-6', :math:`x` is the current iterate and :math:`x_0` is the initial value.
It will also terminate if the algorithm begins to diverge i.e. if :math:`||x||_2> \omega`, where `omega` is set to default as 1e6.

Parameters
----------
epsilon: float, default 1e-6
Expand Down Expand Up @@ -187,3 +191,45 @@ def __call__(self, algorithm):
raise StopIteration


class SaveIterates(Callback):
r'''Callback to save iterates as tiff files every set number of iterations.

Parameters
----------
interval: integer,
The iterates will be saved every `interval` number of iterations e.g. if `interval =4` the 0, 4, 8, 12,... iterates will be saved.
file_name : string
This defines the file name prefix, i.e. the file name without the extension.
dir_path : string
The place to store the images
roi: dict, optional default is None and no slicing will be applied
The region-of-interest to slice {'axis_name1':(start,stop,step), 'axis_name2':(start,stop,step)}
The `key` being the axis name to apply the processor to, the `value` holding a tuple containing the ROI description
Start: Starting index of input data. Must be an integer, or `None` defaults to index 0.
Stop: Stopping index of input data. Must be an integer, or `None` defaults to index N.
Step: Number of pixels to average together. Must be an integer or `None` defaults to 1.
compression : str, default None. Accepted values None, 'uint8', 'uint16'
The lossy compression to apply. The default None will not compress data.
uint8' or 'unit16' will compress to unsigned int 8 and 16 bit respectively.
'''
def __init__(self, interval=1, file_name='iter', dir_path='./', roi=None, compression=None):

self.file_path= os.path.join(dir_path, file_name)

self.interval=interval
self.roi=roi
if self.roi is not None:
self.slicer= Slicer(roi=self.roi)
self.compression=compression
super(SaveIterates, self).__init__()

def __call__(self, algo):

if algo.iteration % self.interval ==0:
if self.roi is None:
TIFFWriter(data=algo.solution, file_name=self.file_path+f'_{algo.iteration:04d}.tif', counter_offset=-1,compression=self.compression ).write()
else:
self.slicer.set_input(algo.solution)
TIFFWriter(self.slicer.get_output(), file_name=self.file_path+f'_{algo.iteration:04d}.tif', counter_offset=-1,compression=self.compression ).write()
casperdcl marked this conversation as resolved.
Show resolved Hide resolved


90 changes: 89 additions & 1 deletion Wrappers/Python/test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import unittest
from os import unlink
from tempfile import NamedTemporaryFile

import os, glob
import numpy as np
import logging

Expand Down Expand Up @@ -63,6 +63,8 @@

from unittest.mock import MagicMock

from cil.io import TIFFStackReader

log = logging.getLogger(__name__)
initialise_tests()

Expand Down Expand Up @@ -1428,6 +1430,92 @@ def test_EarlyStoppingObjectiveValue(self):
callbacks.EarlyStoppingObjectiveValue(0.1)(alg)



class TestSaveIteratesCallback(unittest.TestCase):

class MockAlgo(Algorithm):
def __init__(self, initial, update_objective_interval=10, **kwargs):
super().__init__(update_objective_interval=update_objective_interval, **kwargs)
self.configured = True
self.x=initial

def update(self):
self.x -= 1

def update_objective(self):
self.loss.append(2 ** getattr(self, 'x', np.nan))


def setUp(self):
# Mock the algorithm object

self.image_geometry = ImageGeometry(10, 2)
self.data = self.image_geometry.allocate(10)
self.mock_algorithm = self.MockAlgo(self.data)
self.file_name= 'myfile'
self.cwd = os.getcwd()
self.dir_path=os.path.join(self.cwd, 'test_tiff' )

def test_save_iterates_no_writer_no_roi(self):
# Test saving iterates to a list with no writer and no ROI
callback = callbacks.SaveIterates(interval=1, file_name= self.file_name, dir_path=self.dir_path)

# Call the callback multiple times and increment iteration
self.mock_algorithm.run(5, callbacks=[callback])

# Check if iterates are saved correctly
files = glob.glob(os.path.join(glob.escape(self.dir_path), '*'))
assert len(files) == 6
reader = TIFFStackReader(file_name = self.dir_path)
read = reader.read()
for i in range(6):
np.testing.assert_array_equal(read[i], (10-i)*np.ones((2,10)))
[os.remove(file) for file in files]
os.rmdir(self.dir_path)


def test_save_iterates_with_roi(self):
# Test saving iterates with an ROI applied
roi = {'horizontal_x': (0, 2, 1)}

callback = callbacks.SaveIterates(interval=1, file_name= self.file_name, dir_path=self.dir_path, roi=roi)

# Call the callback and check if slicer was used
callback(self.mock_algorithm)
# Check if iterates are saved correctly
files = glob.glob(os.path.join(glob.escape(self.dir_path), '*'))
assert len(files) == 1
reader = TIFFStackReader(file_name = self.dir_path)
read = reader.read()
np.testing.assert_array_equal(read, 10*np.ones([2, 2]))
[os.remove(file) for file in files]
os.rmdir(self.dir_path)

def test_save_iterates_with_interval(self):
# Test saving iterates with a specified interval
callback = callbacks.SaveIterates(interval=2, file_name= self.file_name, dir_path=self.dir_path)

# Call the callback multiple times and increment iteration
self.mock_algorithm.run(5, callbacks=[callback])

# Check if iterates are saved correctly
files = glob.glob(os.path.join(glob.escape(self.dir_path), '*'))
print(files)
self.assertEqual( len(files), 3)
reader = TIFFStackReader(file_name = self.dir_path)
read = reader.read()
np.testing.assert_array_equal(read[0], (10-0)*np.ones((2,10)))
np.testing.assert_array_equal(read[1], (10-2)*np.ones((2,10)))
np.testing.assert_array_equal(read[2], (10-4)*np.ones((2,10)))
[os.remove(file) for file in files]
os.rmdir(self.dir_path)







class TestADMM(unittest.TestCase):
def setUp(self):
ig = ImageGeometry(2, 3, 2)
Expand Down
10 changes: 10 additions & 0 deletions docs/source/optimisation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,16 @@ A list of :code:`Callback` s to be executed each iteration can be passed to `Alg

Built-in callbacks include:

.. autoclass:: cil.optimisation.utilities.callbacks.SaveIterates
:members:

.. autoclass:: cil.optimisation.utilities.callbacks.EarlyStoppingObjectiveValue
:members:

.. autoclass:: cil.optimisation.utilities.callbacks.CGLSEarlyStopping
:members:


.. autoclass:: cil.optimisation.utilities.callbacks.ProgressCallback
:members:

Expand Down
Loading