diff --git a/Wrappers/Python/cil/optimisation/functions/TotalVariation.py b/Wrappers/Python/cil/optimisation/functions/TotalVariation.py index 6c438c76a8..316b05d69c 100644 --- a/Wrappers/Python/cil/optimisation/functions/TotalVariation.py +++ b/Wrappers/Python/cil/optimisation/functions/TotalVariation.py @@ -254,10 +254,6 @@ def __call__(self, x): def proximal(self, x, tau, out=None): r""" Returns the proximal operator of the TotalVariation function at :code:`x` .""" - if id(x)==id(out): - raise InPlaceError(message="TotalVariation.proximal cannot be used in place") - - if self.strong_convexity_constant > 0: strongly_convex_factor = (1 + tau * self.strong_convexity_constant) @@ -266,7 +262,8 @@ def proximal(self, x, tau, out=None): solution = self._fista_on_dual_rof(x, tau, out=out) if self.strong_convexity_constant > 0: - x *= strongly_convex_factor + if id(x) != id(solution): + x *= strongly_convex_factor tau *= strongly_convex_factor return solution @@ -306,12 +303,16 @@ def _fista_on_dual_rof(self, x, tau, out=None): if out is None: out = self.gradient_operator.domain_geometry().allocate(0) + if id(x) == id(out): + x_eval= x.copy() + else: + x_eval = x for k in range(self.iterations): t0 = t self.gradient_operator.adjoint(tmp_q, out=out) - out.sapyb(tau_reg_neg, x, 1.0, out=out) + out.sapyb(tau_reg_neg, x_eval, 1.0, out=out) self.projection_C(out, tau=None, out=out) self.gradient_operator.direct(out, out=p1) diff --git a/Wrappers/Python/cil/optimisation/operators/Operator.py b/Wrappers/Python/cil/optimisation/operators/Operator.py index 1afec3c156..1e6f294b88 100644 --- a/Wrappers/Python/cil/optimisation/operators/Operator.py +++ b/Wrappers/Python/cil/optimisation/operators/Operator.py @@ -151,7 +151,8 @@ def domain(self): @property def range(self): return self.range_geometry() - + + def __rmul__(self, scalar): '''Defines the multiplication by a scalar on the left diff --git a/Wrappers/Python/test/test_out_in_place.py b/Wrappers/Python/test/test_out_in_place.py index d1395864a1..6dfcc63a4d 100644 --- a/Wrappers/Python/test/test_out_in_place.py +++ b/Wrappers/Python/test/test_out_in_place.py @@ -111,6 +111,7 @@ def setUp(self): (WeightedL2NormSquared(weight=b_ig), ig), (TotalVariation(backend='c', warm_start=False, max_iteration=100), ig), (TotalVariation(backend='numpy', warm_start=False, max_iteration=100), ig), + (TotalVariation(backend='numpy', warm_start=False, max_iteration=100, strong_convexity_constant=0.5), ig), (OperatorCompositionFunction(L2NormSquared(), A), ig), (MixedL21Norm(), bg), (SmoothMixedL21Norm(epsilon=0.3), bg),