diff --git a/improc/jigsaw.py b/improc/jigsaw.py new file mode 100644 index 00000000..889c45df --- /dev/null +++ b/improc/jigsaw.py @@ -0,0 +1,203 @@ +import numpy as np + + +def cut(im, cut_size, overlap=0, pad_value=None): + """Cut an image into small cutouts, with set intervals and some overlap. + + Parameters + ---------- + im: np.ndarray + A 2D image that needs to be cut into pieces. + cut_size: scalar int or 2-tuple of int + The size of the output cutouts. + Can be a scalar (in which case the cutouts are square) + or a 2-tuple so that the shape of each cutout is equal + to this parameter. + overlap: scalar float or integer or string + The fraction of overlap between the cutouts, in the range [0,1). + If given as an integer (or float larger or equal to 1.0) will be + interpreted as the number of pixels that overlap between adjacent cutouts. + The default is 0 (no overlap). + For fractional overlap, the number of pixels that overlap + between adjacent cutouts (in each dimension) is ceil(cut_size * overlap). + That means the next cutout will start inside the current cutout + at pixel number floor(cut_size * (1 - overlap)). + If given as a string, will be interpreted as a window function. + Currently only "hanning" is supported. + See the stitch() function for more details. + pad_value: scalar float or string, optional + If given a value, will use that as + a filler for any part of the cutouts + that lie outside the original image + (this can happen if the last pixels + reach out of the image). + Default is NaN. + If given as a string equal to "nan" will also + use it as NaN. If given as "zero" will use 0. + + Returns + ------- + cutouts: 3D np.ndarray + A 3D array where the first dimension is the number of cutouts, + and the other two dimensions are the cutout height and width. + corners: list of 2-tuples + Each item on the list contains the lower corner (x,y) pixel positions + of the relevant cutout in the coordinates of the full image. + """ + S = im.shape # size of the input + if np.isscalar(cut_size): + C = (cut_size, cut_size) + elif isinstance(cut_size, tuple) and len(cut_size) == 2: + C = cut_size + else: + raise TypeError("cut_size must be a scalar or 2-tuple") + + if not isinstance(overlap, (int, float, np.number, str)): + raise ValueError("Overlap must be a scalar, a 2D array, or a string") + + if isinstance(overlap, str): + if overlap == 'hanning': + overlap = 0.5 # we don't apply the window here, only set up the correct overlap for it + else: + raise ValueError(f'Unknown overlap type "{overlap}". Use "hanning" or a scalar or a 2D array.') + + if overlap < 0: + raise ValueError("Overlap must be a non-negative number") + if overlap < 1: + pix_overlap = tuple(int(np.ceil(c * overlap)) for c in C) + else: + pix_overlap = (int(overlap), int(overlap)) + + if isinstance(pad_value, str): + if pad_value == 'nan': + pad_value = np.nan + elif pad_value == 'zero': + pad_value = 0 + else: + raise ValueError('pad_value must be a scalar or "nan" or "zero"') + + # estimate the number of cutouts in each dimension + num_x = int(np.ceil(S[1] / (C[1] - pix_overlap[1]))) + 1 + num_y = int(np.ceil(S[0] / (C[0] - pix_overlap[0]))) + 1 + + corners = [] + + for i in range(num_x): + corner_x = i * (C[1] - pix_overlap[1]) - pix_overlap[1] + if corner_x >= S[1]: + break + + for j in range(num_y): + corner_y = j * (C[0] - pix_overlap[0]) - pix_overlap[0] + if corner_y >= S[0]: + break + + corners.append((corner_x, corner_y)) + + output = np.empty((len(corners), C[0], C[1])) + + for i, (x, y) in enumerate(corners): + low_x = max(x, 0) + high_x = min(x + C[1], S[1]) + low_y = max(y, 0) + high_y = min(y + C[0], S[0]) + + cutout = np.full(C, pad_value) + cutout[low_y - y:high_y - y, low_x - x:high_x - x] = im[low_y:high_y, low_x:high_x] + + output[i] = cutout + + return output, corners + + +def stitch(cutouts, im_shape, corners=None, overlap=None): + """Rebuild a full image from cutouts. + + Parameters + ---------- + cutouts: np.ndarray + A 3D array where the first dimension is the number of cutouts, + and the other two dimensions are the cutout height and width. + im_shape: scalar int or 2-tuple of int + The shape of the full image. + corners: list of 2-tuples (optional) + Each item on the list contains the lower corner (x,y) pixel positions + of the relevant cutout in the coordinates of the full image. + If not given, will try to guess this from the im_shape, + by assuming overlap=0. + overlap: float scalar or np.ndarray or string (optional) + The overlap fraction, or the overlap number of pixels. + Will zero out the top and right edges of the cutouts. + If given as a fraction, the number of pixels that are not + zeroed out, from the beginning of each cutout, + is floor(cut_size * (1 - overlap)). + + Another option is to give an array of the same size as one of the cutouts. + The cutouts are multiplied by this array before adding them + to make the full image. This is useful for overlapping images, + as it allows to blend the cutouts in a smooth way. + + Can also give a string, which will be interpreted as one of these options: + - 'hanning': a Hanning window, which is a cosine window, defined as + 0.5 * (1 - cos(2 * pi * x / (cut_size - 1))), where x is the pixel position. + + If not given, will not apply any overlap window, which will alter the pixel values + of the output image in any case where there is overlap! + + Returns + ------- + A 2D array with the full image. + """ + C = cutouts.shape[1:] # size of the cutouts + N = cutouts.shape[0] # number of cutouts + S = im_shape + if isinstance(S, (int, float, np.number)): + S = (S, S) + if S is not None and len(S) != 2: + raise ValueError("im_shape must be a scalar or 2-tuple") + + if corners is None: # infer corner_list from im_shape + corners = [] + for i in range(N): + corners.append((i % (S[1] // C[1]) * C[1], i // (S[1] // C[1]) * C[0])) + + # check if we can interpret the "overlap" input to make a window + window = None # by default there is no window array + if isinstance(overlap, str): + if overlap == 'hanning': + window = np.hanning(C[0])[:, None] * np.hanning(C[1])[None, :] + else: + raise ValueError(f'Unknown overlap type "{overlap}". Use "hanning" or a scalar or a 2D array.') + elif np.isscalar(overlap): + if overlap < 0: + raise ValueError("Overlap must be a non-negative number") + if overlap < 1: + pix_overlap = tuple(int(np.ceil(overlap * c)) for c in C) + else: + pix_overlap = (int(overlap), int(overlap)) + window = np.ones(C) + window[-pix_overlap[0]:, :] = 0 + window[:, -pix_overlap[1]:] = 0 + elif overlap is not None: + raise ValueError('overlap must be a scalar, a 2D array, or a string') + + if window is not None: + if window.shape != C: + raise ValueError('window must have the same shape as the cutouts') + + window = window[None, :, :] # add a dimension to broadcast + cutouts = cutouts * window + + im = np.zeros(S) + for i, (x, y) in enumerate(corners): + low_x = max(x, 0) + high_x = min(x + C[1], S[1]) + low_y = max(y, 0) + high_y = min(y + C[0], S[0]) + im[low_y:high_y, low_x:high_x] += cutouts[i, low_y - y:high_y - y, low_x - x:high_x - x] + + return im + + + + diff --git a/pipeline/subtraction.py b/pipeline/subtraction.py index a9b56975..8615e72e 100644 --- a/pipeline/subtraction.py +++ b/pipeline/subtraction.py @@ -71,10 +71,16 @@ def get_process_name(self): def __setattr__(self, key, value): # make sure the pad_value in this dictionary is always the same value (as a critical parameter it should be) if key == 'subtiling': + if value is None: + value = {} if 'pad_value' in value: pad_value = value['pad_value'] - if pad_value is None or pad_value == 'nan' or np.isnan(pad_value): - value['pad_value'] = np.nan + if ( + pad_value is None or + (isinstance(pad_value, str) and pad_value.lower() == 'nan') or + (isinstance(pad_value, np.number) and np.isnan(pad_value)) + ): + value['pad_value'] = 'nan' super().__setattr__(key, value) @@ -205,21 +211,28 @@ def _subtract_zogy(self, new_image, ref_image): ref_tile_data, _ = jigsaw.cut(ref_image_data, **self.pars.subtiling) # corners should be the same! tile_shape = ref_tile_data.shape[1:] - centers = [(c[0] + tile_shape[1] // 2, c[1] + tile_shape[2] // 2) for c in corners] # x,y in flipped order! + centers = [(c[0] + tile_shape[0] // 2, c[1] + tile_shape[1] // 2) for c in corners] # x,y in flipped order! # find the PSF clip in the center of each tile - new_tile_psfs = [new_image.psf.get_clip(c) for c in centers] - ref_tile_psfs = [ref_image.psf.get_clip(c) for c in centers] + new_tile_psfs = [new_image.psf.get_clip(*c) for c in centers] + ref_tile_psfs = [ref_image.psf.get_clip(*c) for c in centers] # get noise arrays, find their median internally in zogy_subtract - new_tile_noises, _ = jigsaw.cut(new_image_noise, **self.pars.subtiling) - ref_tile_noises, _ = jigsaw.cut(ref_image_noise, **self.pars.subtiling) + if isinstance(new_image_noise, np.ndarray): + new_tile_noises, _ = jigsaw.cut(new_image_noise, **self.pars.subtiling) + else: + new_tile_noises = [new_image_noise] * len(corners) + + if isinstance(ref_image_noise, np.ndarray): + ref_tile_noises, _ = jigsaw.cut(ref_image_noise, **self.pars.subtiling) + else: + ref_tile_noises = [ref_image_noise] * len(corners) # flux ZPs are assumed uniform across the image... # loop over the tiles and make the subtractions - outputs = [{}] * len(corners) - for i, out in enumerate(outputs): + outputs = [] + for i in range(len(corners)): out = zogy_subtract( ref_tile_data[i], new_tile_data[i], @@ -230,6 +243,7 @@ def _subtract_zogy(self, new_image, ref_image): ref_image_flux_zp, new_image_flux_zp, ) + outputs.append(out) output = {} tiled_output = {} @@ -246,7 +260,12 @@ def _subtract_zogy(self, new_image, ref_image): tiled_output[key][i] = value for key, value in tiled_output.items(): # for each type of array, e.g., sub_image, score, etc. - output[key] = jigsaw.stitch(value, corners, new_image.data.shape) + output[key] = jigsaw.stitch( + value, + new_image.data.shape, + corners, + overlap=self.pars.subtiling.get('overlap') + ) else: # do not use subtiling output = zogy_subtract(