From fdeee7bb6778fec582847e02de8bac7ff2899dbd Mon Sep 17 00:00:00 2001 From: Mikael Brudfors Date: Tue, 10 Aug 2021 11:25:44 +0000 Subject: [PATCH] fix: asuper-resolution ow works when given multiple repeats --- unires/_core.py | 5 ++++- unires/_project.py | 31 +++++++++++++++++-------------- unires/_update.py | 8 ++++---- unires/struct.py | 2 +- 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/unires/_core.py b/unires/_core.py index b95d5c5..b675f9c 100644 --- a/unires/_core.py +++ b/unires/_core.py @@ -463,7 +463,10 @@ def _resample_inplane(x, sett): # make grid D = I.clone() for i in range(3): - D[i, i] = sett.vx[i] / vx_x[i] + if isinstance(sett.vx, (list, tuple)): + D[i, i] = sett.vx[i] / vx_x[i] + else: + D[i, i] = sett.vx / vx_x[i] if D[i, i] < 1.0: D[i, i] = 1 if float((I - D).abs().sum()) < 1e-4: diff --git a/unires/_project.py b/unires/_project.py index d1aaf99..e77f958 100644 --- a/unires/_project.py +++ b/unires/_project.py @@ -67,30 +67,33 @@ def _proj(operator, dat, x, y, method='super-resolution', do=True, diff (str, optional): Gradient difference operator, defaults to 'forward'. Returns: - dat (torch.tensor()): Projected image data (dim_y|dim_x). + dat_p (torch.tensor()): Projected image data (dim_y|dim_x). """ if operator == 'AtA': + # AtA + # (dat = y) if not do: # return dat - operator = 'none' - dat1 = rho * y.lam ** 2 * _DtD(dat, vx_y=vx_y, bound=bound, diff=diff) + operator = 'none' dat = dat[None, None, ...] - dat = x[n].tau * _proj_apply(operator, dat, x[n].po, method=method, + # sum likelihood terms + dat_p = x[n].tau * _proj_apply(operator, dat, x[n].po, method=method, bound=bound, interpolation=interpolation) for n1 in range(1, len(x)): - dat = dat + x[n1].tau * _proj_apply(operator, dat, x[n1].po, method=method, - bound=bound, interpolation=interpolation) - dat = dat[0, 0, ...] - dat += dat1 - else: # A, At + dat_p += x[n1].tau * _proj_apply(operator, dat, x[n1].po, method=method, + bound=bound, interpolation=interpolation) + dat_p = dat_p[0, 0, ...] + # add prior term + dat_p += rho * y.lam ** 2 * _DtD(dat[0, 0, ...], vx_y=vx_y, bound=bound, diff=diff) + else: + # A, At + # (dat = x or y) if not do: # return dat operator = 'none' - dat = dat[None, None, ...] - dat = _proj_apply(operator, dat, x[n].po, method=method, - bound=bound, interpolation=interpolation) - dat = dat[0, 0, ...] + dat_p = _proj_apply(operator, dat[None, None, ...], x[n].po, method=method, + bound=bound, interpolation=interpolation)[0, 0, ...] - return dat + return dat_p def _proj_apply(operator, dat, po, method='super-resolution', bound='zero', interpolation='linear'): diff --git a/unires/_update.py b/unires/_update.py index fe3ed19..8474751 100644 --- a/unires/_update.py +++ b/unires/_update.py @@ -121,7 +121,7 @@ def _update_admm(x, y, z, w, rho, tmp, obj, n_iter, sett): t0 = _print_info('fit-update', sett, 'y', n_iter) # PRINT for c in range(len(x)): # Loop over channels # RHS - tmp[:] = 0 + tmp[:] = 0.0 for n in range(len(x[c])): # Loop over observations of channel 'c' # _ = _print_info('int', sett, n) # PRINT tmp += x[c][n].tau * _proj('At', x[c][n].dat, x[c], y[c], method=sett.method, do=sett.do_proj, @@ -409,19 +409,19 @@ def _compute_nll(x, y, sett, rho, sum_dtype=torch.float64): vx_y = voxel_size(y[0].mat).float() nll_xy = torch.tensor(0, device=sett.device, dtype=torch.float64) for c in range(len(x)): - # Neg. log-likelihood term + # Sum neg. log-likelihood term for n in range(len(x[c])): msk = x[c][n].dat != 0 Ay = _proj('A', y[c].dat, x[c], y[c], n=n, method=sett.method, do=sett.do_proj, bound=sett.bound, interpolation=sett.interpolation) nll_xy += 0.5 * x[c][n].tau * torch.sum((x[c][n].dat[msk] - Ay[msk]) ** 2, dtype=sum_dtype) - # Neg. log-prior term + # Sum neg. log-prior term Dy = y[c].lam * im_gradient(y[c].dat, vx=vx_y, bound=sett.bound, which=sett.diff) if c > 0: nll_y += torch.sum(Dy ** 2, dim=0) else: nll_y = torch.sum(Dy ** 2, dim=0) - + # Neg. log-prior term nll_y = torch.sum(torch.sqrt(nll_y), dtype=sum_dtype) return nll_xy + nll_y, nll_xy, nll_y diff --git a/unires/struct.py b/unires/struct.py index e1b7cff..94759de 100644 --- a/unires/struct.py +++ b/unires/struct.py @@ -78,7 +78,7 @@ def __init__(self): self.do_print: int = 1 # Print progress to terminal (0, 1, 2, 3) self.do_proj: bool = None # Use projection matrices, defined in format_output() self.do_res_origin: bool = False # Resets origin, if CT data - self.force_inplane_res: bool = True # Force in-plane resolution of observed data to be greater or equal to recon vx + self.force_inplane_res: bool = False # Force in-plane resolution of observed data to be greater or equal to recon vx self.fov: str = 'brain' # If crop=True, uses this field-of-view ('brain'|'head'). self.gap: float = 0.0 # Slice gap, between 0 and 1 self.interpolation: str = 'linear' # Interpolation order (see nitorch.spatial)