Skip to content

Commit

Permalink
fix MixedL21Norm (#417)
Browse files Browse the repository at this point in the history
* fix MixedL21Norm

* merged master

* re added future imports
  • Loading branch information
epapoutsellis authored and paskino committed Oct 30, 2019
1 parent a4d3aa5 commit 9bd69b7
Showing 1 changed file with 87 additions and 75 deletions.
162 changes: 87 additions & 75 deletions Wrappers/Python/ccpi/optimisation/functions/MixedL21Norm.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
# -*- coding: utf-8 -*-
# Copyright 2019 Science Technology Facilities Council
# Copyright 2019 University of Manchester
#
# This work is part of the Core Imaging Library developed by Science Technology
# Facilities Council and University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# CCP in Tomographic Imaging (CCPi) Core Imaging Library (CIL).

# Copyright 2017 UKRI-STFC
# Copyright 2017 University of Manchester

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
Expand All @@ -25,17 +23,15 @@

from ccpi.optimisation.functions import Function, ScaledFunction
from ccpi.framework import BlockDataContainer
import numpy as np

import functools
import numpy

class MixedL21Norm(Function):


r'''MixedL21Norm: .. math:: f(x) = ||x||_{2,1} = \int \|x\|_{2} dx
where x is a vector/tensor vield
'''
f(x) = ||x||_{2,1} = \sum |x|_{2}
'''

def __init__(self, **kwargs):
Expand All @@ -45,13 +41,15 @@ def __init__(self, **kwargs):

def __call__(self, x):

'''Evaluates MixedL21Norm at point x
''' Evaluates L2,1Norm at point x
:param: x: is a BlockDataContainer
:param: x is a BlockDataContainer
'''
if not isinstance(x, BlockDataContainer):
raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x)))
tmp = x.get_item(0) * 0

tmp = x.get_item(0) * 0.
for el in x.containers:
tmp += el.power(2.)
return tmp.sqrt().sum()
Expand All @@ -62,80 +60,68 @@ def gradient(self, x, out=None):

def convex_conjugate(self,x):

r'''Convex conjugate of of MixedL21Norm:
Indicator function of .. math:: ||\cdot||_{2, \infty}
which is either 0 if .. math:: ||x||_{2, \infty}<1 or \infty
''' This is the Indicator function of ||\cdot||_{2, \infty}
which is either 0 if ||x||_{2, \infty} or \infty
'''

return 0.0


def proximal(self, x, tau, out=None):

r'''Proximal operator of MixedL21Norm at x:
.. math:: prox_{\tau * f(x)
'''
pass

def proximal_conjugate(self, x, tau, out=None):

r'''Proximal operator of the convex conjugate of MixedL21Norm at x:
.. math:: prox_{\tau * f^{*}}(x)
if out is None:

tmp = sum([ el*el for el in x.containers]).sqrt()
res = (tmp - tau).maximum(0.0) * x/tmp
return res

else:

tmp = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 ).sqrt()
res = (tmp - tau).maximum(0.0) * x/tmp

'''
for el in res.containers:
el.as_array()[np.isnan(el.as_array())]=0

out.fill(res)


def proximal_conjugate(self, x, tau, out=None):


if out is None:
# tmp = [ el*el for el in x.containers]
# res = sum(tmp).sqrt().maximum(1.0)
# frac = [el/res for el in x.containers]
# return BlockDataContainer(*frac)
tmp = x.get_item(0) * 0
for el in x.containers:
tmp += el.power(2.)
tmp.sqrt(out=tmp)
tmp.maximum(1.0, out=tmp)
frac = [ el.divide(tmp) for el in x.containers ]
tmp = x.get_item(0) * 0
for el in x.containers:
tmp += el.power(2.)
tmp.sqrt(out=tmp)
tmp.maximum(1.0, out=tmp)
frac = [ el.divide(tmp) for el in x.containers ]
return BlockDataContainer(*frac)


else:

res1 = functools.reduce(lambda a,b: a + b*b, x.containers, x.get_item(0) * 0 )
if False:
res = res1.sqrt().maximum(1.0)
x.divide(res, out=out)
else:
res1.sqrt(out=res1)
res1.maximum(1.0, out=res1)
x.divide(res1, out=out)

res1.sqrt(out=res1)
res1.maximum(1.0, out=res1)
x.divide(res1, out=out)


def __rmul__(self, scalar):

'''Multiplication of MixedL21Norm with a scalar
Returns: ScaledFunction
''' Multiplication of MixedL21Norm with a scalar
Returns: ScaledFunction
'''
return ScaledFunction(self, scalar)


def sqrt_maximum(x, a):
y = numpy.sqrt(x)
if y >= a:
return y
else:
return a
#
if __name__ == '__main__':

M, N, K = 2,3,5
from ccpi.framework import BlockGeometry
M, N, K = 2,3,50
from ccpi.framework import BlockGeometry, ImageGeometry
import numpy

ig = ImageGeometry(M, N)
Expand All @@ -145,8 +131,9 @@ def sqrt_maximum(x, a):
U = BG.allocate('random_int')

# Define no scale and scaled
alpha = 0.5
f_no_scaled = MixedL21Norm()
f_scaled = 0.5 * MixedL21Norm()
f_scaled = alpha * MixedL21Norm()

# call

Expand Down Expand Up @@ -174,11 +161,36 @@ def sqrt_maximum(x, a):

numpy.testing.assert_array_almost_equal(res_no_out[1].as_array(), \
res_out[1].as_array(), decimal=4)
#


tau = 0.4
d1 = f_scaled.proximal(U, tau)

tmp = (U.get_item(0)**2 + U.get_item(1)**2).sqrt()

d2 = (tmp - alpha*tau).maximum(0) * U/tmp

numpy.testing.assert_array_almost_equal(d1.get_item(0).as_array(), \
d2.get_item(0).as_array(), decimal=4)

numpy.testing.assert_array_almost_equal(d1.get_item(1).as_array(), \
d2.get_item(1).as_array(), decimal=4)

out1 = BG.allocate('random_int')


f_scaled.proximal(U, tau, out = out1)

numpy.testing.assert_array_almost_equal(out1.get_item(0).as_array(), \
d1.get_item(0).as_array(), decimal=4)

numpy.testing.assert_array_almost_equal(out1.get_item(1).as_array(), \
d1.get_item(1).as_array(), decimal=4)
#







0 comments on commit 9bd69b7

Please sign in to comment.