Skip to content

Commit

Permalink
updated call functions to return NaN (#809)
Browse files Browse the repository at this point in the history
closes #758
* updated non TV call functions to return NaN
* simplification of wrapper calls
* check input arr is c-ordered
  • Loading branch information
gfardell authored Mar 31, 2021
1 parent 43e6373 commit 0f03a37
Showing 1 changed file with 80 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,32 @@
import warnings


class ROF_TV(Function):
class TV_Base(Function):
def __call__(self,x):
in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
EnergyValTV = TV_ENERGY(in_arr, in_arr, self.alpha, 2)
return 0.5*EnergyValTV[0]

def convex_conjugate(self,x):
return 0.0

class ROF_TV(TV_Base):
def __init__(self,lambdaReg,iterationsTV,tolerance,time_marchstep,device):
# set parameters
self.lambdaReg = lambdaReg
self.iterationsTV = iterationsTV
self.alpha = lambdaReg
self.max_iteration = iterationsTV
self.time_marchstep = time_marchstep
self.device = device # string for 'cpu' or 'gpu'
self.tolerance = tolerance

def __call__(self,x):
# evaluate objective function of TV gradient
EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.lambdaReg, 2)
return 0.5*EnergyValTV[0]

def proximal(self,x,tau, out = None):
pars = {'algorithm' : ROF_TV, \
'input' : np.asarray(x.as_array(), dtype=np.float32),\
'regularization_parameter':self.lambdaReg*tau, \
'number_of_iterations' :self.iterationsTV ,\
'time_marching_parameter':self.time_marchstep,\
'tolerance':self.tolerance}

res , info = regularisers.ROF_TV(pars['input'],
pars['regularization_parameter'],
pars['number_of_iterations'],
pars['time_marching_parameter'], pars['tolerance'], self.device)
in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
res , info = regularisers.ROF_TV(in_arr,
self.alpha,
self.max_iteration,
self.time_marchstep,
self.tolerance,
self.device)

self.info = info

Expand All @@ -67,7 +67,7 @@ def proximal(self,x,tau, out = None):
out.fill(res)
return out

class FGP_TV(Function):
class FGP_TV(TV_Base):
def __init__(self, alpha=1, max_iteration=100, tolerance=1e-6, isotropic=True, nonnegativity=True, printing=False, device='cpu'):

if isotropic == True:
Expand All @@ -86,15 +86,10 @@ def __init__(self, alpha=1, max_iteration=100, tolerance=1e-6, isotropic=True, n
self.nonnegativity = nonnegativity
self.device = device # string for 'cpu' or 'gpu'

def __call__(self,x):
# evaluate objective function of TV gradient
EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.alpha, 2)
return 0.5*EnergyValTV[0]

def proximal(self,x,tau, out=None):

in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
res , info = regularisers.FGP_TV(\
np.asarray(x.as_array(), dtype=np.float32),\
in_arr,\
self.alpha*tau,\
self.max_iteration,\
self.tolerance,\
Expand All @@ -107,15 +102,10 @@ def proximal(self,x,tau, out=None):
out = x.copy()
out.fill(res)
return out

def convex_conjugate(self,x):
return 0.0



class TGV(Function):

def __init__(self, regularisation_parameter, alpha1, alpha2, iter_TGV, LipshitzConstant, torelance, device ):

self.regularisation_parameter = regularisation_parameter
self.alpha1 = alpha1
self.alpha2 = alpha2
Expand All @@ -124,31 +114,20 @@ def __init__(self, regularisation_parameter, alpha1, alpha2, iter_TGV, LipshitzC
self.torelance = torelance
self.device = device


def __call__(self,x):
warnings.warn("{}: the __call__ method is not currently implemented. Returning 0.".format(self.__class__.__name__))

# TODO this is not correct, need a TGV energy same as TV
return 0.0
warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))
return np.nan

def proximal(self, x, tau, out=None):

pars = {'algorithm' : TGV, \
'input' : np.asarray(x.as_array(), dtype=np.float32),\
'regularisation_parameter':self.regularisation_parameter, \
'alpha1':self.alpha1,\
'alpha0':self.alpha2,\
'number_of_iterations' :self.iter_TGV ,\
'LipshitzConstant' :self.LipshitzConstant ,\
'tolerance_constant':self.torelance}

res , info = regularisers.TGV(pars['input'],
pars['regularisation_parameter'],
pars['alpha1'],
pars['alpha0'],
pars['number_of_iterations'],
pars['LipshitzConstant'],
pars['tolerance_constant'],self.device)
def proximal(self, x, tau, out=None):
in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
res , info = regularisers.TGV(in_arr,
self.regularisation_parameter,
self.alpha1,
self.alpha2,
self.iter_TGV,
self.LipshitzConstant,
self.torelance,
self.device)

# info: return number of iteration and reached tolerance
# https://github.com/vais-ral/CCPi-Regularisation-Toolkit/blob/master/src/Core/regularisers_CPU/TGV_core.c#L168
Expand All @@ -164,14 +143,11 @@ def proximal(self, x, tau, out=None):
return out

def convex_conjugate(self, x):
# TODO this is not correct
return 0.0
warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))
return np.nan


class LLT_ROF(Function):



def __init__(self, regularisation_parameterROF,
regularisation_parameterLLT,
iter_LLT_ROF, time_marching_parameter, torelance, device ):
Expand All @@ -184,36 +160,29 @@ def __init__(self, regularisation_parameterROF,
self.device = device

def __call__(self,x):
warnings.warn("{}: the __call__ method is not currently implemented. Returning 0.".format(self.__class__.__name__))

# TODO this is not correct, need a TGV energy same as TV
return 0.0
warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))
return np.nan

def proximal(self, x, tau, out=None):

pars = {'algorithm' : LLT_ROF, \
'input' : np.asarray(x.as_array(), dtype=np.float32),\
'regularisation_parameterROF':self.regularisation_parameterROF, \
'regularisation_parameterLLT':self.regularisation_parameterLLT,
'number_of_iterations' :self.iter_LLT_ROF ,\
'time_marching_parameter': self.time_marching_parameter,\
'tolerance_constant':self.torelance}



res , info = regularisers.LLT_ROF(pars['input'],
pars['regularisation_parameterROF'],
pars['regularisation_parameterLLT'],
pars['number_of_iterations'],
pars['time_marching_parameter'],
pars['tolerance_constant'],self.device)
in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
res , info = regularisers.LLT_ROF(in_arr,
self.regularisation_parameterROF,
self.regularisation_parameterLLT,
self.iter_LLT_ROF,
self.time_marching_parameter,
self.torelance,
self.device)

# info: return number of iteration and reached tolerance
# https://github.com/vais-ral/CCPi-Regularisation-Toolkit/blob/master/src/Core/regularisers_CPU/TGV_core.c#L168
# Stopping Criteria || u^k - u^(k-1) ||_{2} / || u^{k} ||_{2}

self.info = info


def convex_conjugate(self, x):
warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))
return np.nan

class FGP_dTV(Function):
def __init__(self, reference, alpha=1, max_iteration=100,
tolerance=1e-6, eta=0.01, isotropic=True, nonnegativity=True, device='cpu'):
Expand All @@ -236,13 +205,13 @@ def __init__(self, reference, alpha=1, max_iteration=100,
self.eta = eta

def __call__(self,x):
# evaluate objective function of TV gradient
EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.alpha, 2)
return 0.5*EnergyValTV[0]
warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))
return np.nan

def proximal(self,x,tau, out=None):
in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
res , info = regularisers.FGP_dTV(\
np.asarray(x.as_array(), dtype=np.float32),\
in_arr,\
self.reference,\
self.alpha*tau,\
self.max_iteration,\
Expand All @@ -259,40 +228,27 @@ def proximal(self,x,tau, out=None):
return out

def convex_conjugate(self, x):
# TODO this is not correct
return 0.0


warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))
return np.nan

class SB_TV(Function):
class SB_TV(TV_Base):
def __init__(self,lambdaReg,iterationsTV,tolerance,methodTV,printing,device):
# set parameters
self.lambdaReg = lambdaReg
self.iterationsTV = iterationsTV
self.alpha = lambdaReg
self.max_iteration = iterationsTV
self.tolerance = tolerance
self.methodTV = methodTV
self.printing = printing
self.device = device # string for 'cpu' or 'gpu'

def __call__(self,x):

# evaluate objective function of TV gradient
EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.lambdaReg, 2)
return 0.5*EnergyValTV[0]


def proximal(self,x,tau, out=None):
pars = {'algorithm' : SB_TV, \
'input' : np.asarray(x.as_array(), dtype=np.float32),\
'regularization_parameter':self.lambdaReg*tau, \
'number_of_iterations' :self.iterationsTV ,\
'tolerance_constant':self.tolerance,\
'methodTV': self.methodTV}

res , info = regularisers.SB_TV(pars['input'],
pars['regularization_parameter'],
pars['number_of_iterations'],
pars['tolerance_constant'],
pars['methodTV'], self.device)
in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
res , info = regularisers.SB_TV(in_arr,
self.alpha*tau,
self.max_iteration,
self.tolerance,
self.methodTV,
self.device)

self.info = info

Expand All @@ -303,8 +259,6 @@ def proximal(self,x,tau, out=None):
out.fill(res)
return out



class TNV(Function):

def __init__(self,regularisation_parameter,iterationsTNV,tolerance):
Expand All @@ -313,31 +267,25 @@ def __init__(self,regularisation_parameter,iterationsTNV,tolerance):
self.regularisation_parameter = regularisation_parameter
self.iterationsTNV = iterationsTNV
self.tolerance = tolerance


def __call__(self,x):
warnings.warn("{}: the __call__ method is not currently implemented. Returning 0.".format(self.__class__.__name__))
# evaluate objective function of TV gradient
return 0.0
warnings.warn("{}: the __call__ method is not implemented. Returning NaN.".format(self.__class__.__name__))
return np.nan

def proximal(self,x,tau, out=None):
pars = {'algorithm' : TNV, \
'input' : np.asarray(x.as_array(), dtype=np.float32),\
'regularisation_parameter':self.regularisation_parameter, \
'number_of_iterations' :self.iterationsTNV,\
'tolerance_constant':self.tolerance}

res = regularisers.TNV(pars['input'],
pars['regularisation_parameter'],
pars['number_of_iterations'],
pars['tolerance_constant'])

#self.info = info

in_arr = np.asarray(x.as_array(), dtype=np.float32, order='C')
res = regularisers.TNV(in_arr,
self.regularisation_parameter,
self.iterationsTNV,
self.tolerance)

if out is not None:
out.fill(res)
else:
out = x.copy()
out.fill(res)
return out

def convex_conjugate(self, x):
warnings.warn("{}: the convex_conjugate method is not implemented. Returning NaN.".format(self.__class__.__name__))
return np.nan

0 comments on commit 0f03a37

Please sign in to comment.