diff --git a/src/darsia/restoration/tvd.py b/src/darsia/restoration/tvd.py index 7257de98..bee5caeb 100644 --- a/src/darsia/restoration/tvd.py +++ b/src/darsia/restoration/tvd.py @@ -1,16 +1,23 @@ """ Module wrapping TV denoising algoriths from skimage -into the realm of DarSIA, operating on darsia.Image. +into the realm of DarSIA. These can be directly integrated +in darsia.ConcentrationAnalysis, in particular as part of +the definition of a restoration object. """ import numpy as np import skimage +import darsia + +from .heterogeneous_tvd import heterogeneous_tv_denoising + class TVD: """ - Total variation denoising through skimage.restoration. + Total variation denoising interface to skimage.restoration + as well as darsia.restoration. """ @@ -23,9 +30,25 @@ def __init__(self, key: str = "", **kwargs) -> None: """ self.method = kwargs.get(key + "smoothing method", "chambolle") - self.weight = kwargs.get(key + "smoothing weight", 0.1) - self.eps = kwargs.get(key + "smoothing eps", 2e-4) - self.max_num_iter = kwargs.get(key + "smoothing max_num_iter", 200) + if self.method == "heterogeneous anisotropic bregman": + + # Internal TV method + self.weight = kwargs.get("weight", 0.1) + self.omega = kwargs.get("omega", 1) + self.penalty = kwargs.get("penalty", 1.0) + self.tvd_stopping_criterion = kwargs.get( + "tvd stopping criterion", darsia.StoppingCriterion(1e-4, 100) + ) + self.cg_stopping_criterion = kwargs.get( + "cg stopping criterion", darsia.StoppingCriterion(1e-2, 100) + ) + + else: + + # Skimage type methods + self.weight = kwargs.get(key + "smoothing weight", 0.1) + self.eps = kwargs.get(key + "smoothing eps", 2e-4) + self.max_num_iter = kwargs.get(key + "smoothing max_num_iter", 200) def __call__(self, img: np.ndarray) -> np.ndarray: """ @@ -66,6 +89,16 @@ def __call__(self, img: np.ndarray) -> np.ndarray: isotropic=True, ) + elif self.method == "heterogeneous anisotropic bregman": + img = heterogeneous_tv_denoising( + img, + weight=self.weight, + omega=self.omega, + penalty=self.penalty, + tvd_stopping_criterion=self.tvd_stopping_criterion, + cg_stopping_criterion=self.cg_stopping_criterion, + ) + else: raise ValueError(f"Method {self.method} not supported.")