Skip to content

Commit

Permalink
Merge pull request #60 from beckermr/iimage
Browse files Browse the repository at this point in the history
ENH add InterpolatedImage
  • Loading branch information
beckermr authored Nov 10, 2023
2 parents 0d547a0 + e0fd4ce commit 7e8cbc1
Show file tree
Hide file tree
Showing 22 changed files with 2,575 additions and 266 deletions.
6 changes: 1 addition & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* `Transformation`
* `Shear`
* `Convolve`
* `InterpolatedImage` and `Interpolant`
* Added implementation of fundamental operations:
* `drawImage`
* `drawReal`
Expand All @@ -24,10 +25,5 @@
* Added a `from_galsim` method to convert from GalSim objects to JAX-GalSim objects

* Caveats
* Currently the FFT convolution does not perform kwrapping of hermitian images,
so it will lead to erroneous results on underesolved images that need k-space wrapping.
Wrapping for real images is implemented. K-space images arise from doing convolutions
via FFTs and so one would expect that underresolved images with convolutions may not be
rendered as accurately.
* Real space convolution and photon shooting methods are not
yet implemented in drawImage.
8 changes: 5 additions & 3 deletions jax_galsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,14 @@
from .gaussian import Gaussian
from .box import Box, Pixel
from .gsobject import GSObject

# Interpolation
from .moffat import Moffat

from .sum import Add, Sum
from .transform import Transform, Transformation
from .convolve import Convolve, Convolution, Deconvolution, Deconvolve

# WCS
from .wcs import (
BaseWCS,
AffineTransform,
JacobianWCS,
OffsetWCS,
Expand All @@ -77,7 +75,11 @@
Quintic,
Lanczos,
)
from .interpolatedimage import InterpolatedImage, _InterpolatedImage

# packages kept separate
from . import bessel
from . import fits

# this one is specific to jax_galsim
from . import core
8 changes: 4 additions & 4 deletions jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, *args, **kwargs):
# Save the construction parameters (as they are at this point) as attributes so they
# can be inspected later if necessary.
if bool(real_space):
raise NotImplementedError("Real space convolutions are not implemented")
raise NotImplementedError("Real-space convolutions are not implemented")
self._real_space = bool(real_space)

# Figure out what gsparams to use
Expand Down Expand Up @@ -296,7 +296,7 @@ def _max_sb(self):
return self.flux / jnp.sum(jnp.array(area_list))

def _xValue(self, pos):
raise NotImplementedError("Not implemented")
raise NotImplementedError("Real-space convolutions are not implemented")

def _kValue(self, kpos):
kv_list = [
Expand All @@ -305,10 +305,10 @@ def _kValue(self, kpos):
return jnp.prod(jnp.array(kv_list))

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
raise NotImplementedError("Not implemented")
raise NotImplementedError("Real-space convolutions are not implemented")

def _shoot(self, photons, rng):
raise NotImplementedError("Not implemented")
raise NotImplementedError("Photon shooting convolutions are not implemented")

def _drawKImage(self, image, jac=None):
image = self.obj_list[0]._drawKImage(image, jac)
Expand Down
58 changes: 58 additions & 0 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from functools import partial

import jax
import jax.numpy as jnp


@jax.jit
def compute_major_minor_from_jacobian(jac):
h1 = jnp.hypot(jac[0, 0] + jac[1, 1], jac[0, 1] - jac[1, 0])
h2 = jnp.hypot(jac[0, 0] - jac[1, 1], jac[0, 1] + jac[1, 0])
major = 0.5 * jnp.abs(h1 + h2)
minor = 0.5 * jnp.abs(h1 - h2)
return major, minor


def convert_to_float(x):
Expand Down Expand Up @@ -43,6 +53,54 @@ def cast_scalar_to_int(x):
return x


def is_equal_with_arrays(x, y):
"""Return True if the data is equal, False otherwise. Handles jax.Array types."""
if isinstance(x, list):
if isinstance(y, list) and len(x) == len(y):
for vx, vy in zip(x, y):
if not is_equal_with_arrays(vx, vy):
return False
return True
else:
return False
elif isinstance(x, tuple):
if isinstance(y, tuple) and len(x) == len(y):
for vx, vy in zip(x, y):
if not is_equal_with_arrays(vx, vy):
return False
return True
else:
return False
elif isinstance(x, set):
if isinstance(y, set) and len(x) == len(y):
for vx, vy in zip(sorted(x), sorted(y)):
if not is_equal_with_arrays(vx, vy):
return False
return True
else:
return False
elif isinstance(x, dict):
if isinstance(y, dict) and len(x) == len(y):
for kx, vx in x.items():
if kx not in y or (not is_equal_with_arrays(vx, y[kx])):
return False
return True
else:
return False
elif isinstance(x, jax.Array) and jnp.ndim(x) > 0:
if isinstance(y, jax.Array) and y.shape == x.shape:
return jnp.array_equal(x, y)
else:
return False
elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or (
isinstance(y, jax.Array) and jnp.ndim(y) == 0
):
# this case covers comparing an array scalar to a python scalar or vice versa
return jnp.array_equal(x, y)
else:
return x == y


def _recurse_list_to_tuple(x):
if isinstance(x, list):
return tuple(_recurse_list_to_tuple(v) for v in x)
Expand Down
106 changes: 105 additions & 1 deletion jax_galsim/core/wrap_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@jax.jit
def wrap_nonhermition(im, xmin, ymin, nxwrap, nywrap):
def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap):
def _body_j(j, vals):
i, im = vals

Expand Down Expand Up @@ -33,3 +33,107 @@ def _body_i(i, vals):

im = jax.lax.fori_loop(0, im.shape[0], _body_i, im)
return im


@jax.jit
def expand_hermitian_x(im):
return jnp.concatenate([im[:, 1:][::-1, ::-1].conjugate(), im], axis=1)


@jax.jit
def contract_hermitian_x(im):
return im[:, im.shape[1] // 2 :]


@jax.jit
def wrap_hermitian_x(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny):
im_exp = expand_hermitian_x(im)
im_exp = wrap_nonhermitian(
im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny
)
return contract_hermitian_x(im_exp)


@jax.jit
def expand_hermitian_y(im):
return jnp.concatenate([im[1:, :][::-1, ::-1].conjugate(), im], axis=0)


@jax.jit
def contract_hermitian_y(im):
return im[im.shape[0] // 2 :, :]


@jax.jit
def wrap_hermitian_y(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny):
im_exp = expand_hermitian_y(im)
im_exp = wrap_nonhermitian(
im_exp, wrap_xmin - im_xmin, wrap_ymin - im_ymin, wrap_nx, wrap_ny
)
return contract_hermitian_y(im_exp)


# I am leaving this code here for posterity. It has a bug that I cannot find.
# It tries to be more clever instead of simply expanding the hermitian image to
# it's full shape, wrapping everything, and then contracting. -MRB
# @jax.jit
# def wrap_hermitian_x(im, im_xmin, im_ymin, wrap_xmin, wrap_ymin, wrap_nx, wrap_ny):
# def _body_j(j, vals):
# i, im = vals

# # first do zero or positive x freq
# im_y = i + im_ymin
# im_x = j + im_xmin
# wrap_y = (im_y - wrap_ymin) % wrap_ny + wrap_ymin
# wrap_x = (im_x - wrap_xmin) % wrap_nx + wrap_xmin
# wrap_yind = wrap_y - im_ymin
# wrap_xind = wrap_x - im_xmin
# im = jax.lax.cond(
# wrap_xind >= 0,
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond(
# jnp.abs(wrap_x - im_x) + jnp.abs(wrap_y - im_y) != 0,
# lambda im, wrap_yind, wrap_xind: im.at[wrap_yind, wrap_xind].add(im[i, j]),
# lambda im, wrap_yind, wrap_xind: im,
# im,
# wrap_yind,
# wrap_xind,
# ),
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im,
# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind,
# )

# # now do neg x freq
# im_y = -im_y
# im_x = -im_x
# wrap_y = (im_y - wrap_ymin) % wrap_ny + wrap_ymin
# wrap_x = (im_x - wrap_xmin) % wrap_nx + wrap_xmin
# wrap_yind = wrap_y - im_ymin
# wrap_xind = wrap_x - im_xmin
# im = jax.lax.cond(
# im_x != 0,
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond(
# wrap_xind >= 0,
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: jax.lax.cond(
# (jnp.abs(wrap_x - im_x) + jnp.abs(wrap_y - im_y)) != 0,
# lambda im, wrap_yind, wrap_xind: im.at[wrap_yind, wrap_xind].add(im[i, j].conjugate()),
# lambda im, wrap_yind, wrap_xind: im,
# im,
# wrap_yind,
# wrap_xind,
# ),
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im,
# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind,
# ),
# lambda wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind: im,
# wrap_x, im_x, wrap_y, im_y, im, wrap_yind, wrap_xind,
# )

# return [i, im]

# def _body_i(i, vals):
# im = vals
# _, im = jax.lax.fori_loop(0, im.shape[1], _body_j, [i, im])
# return im

# im = jax.lax.fori_loop(0, im.shape[0], _body_i, im)
# return im
5 changes: 3 additions & 2 deletions jax_galsim/gsobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from jax._src.numpy.util import _wraps

from jax_galsim.core.utils import is_equal_with_arrays
from jax_galsim.gsparams import GSParams
from jax_galsim.position import Position, PositionD, PositionI
from jax_galsim.utilities import parse_pos_args
Expand Down Expand Up @@ -178,7 +179,7 @@ def __neg__(self):
def __eq__(self, other):
return (self is other) or (
(type(other) is self.__class__)
and (self.tree_flatten() == other.tree_flatten())
and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten())
)

@_wraps(_galsim.GSObject.xValue)
Expand Down Expand Up @@ -771,7 +772,7 @@ def drawFFT_makeKImage(self, image):
with jax.ensure_compile_time_eval():
Nk = self.gsparams.maximum_fft_size
N = Nk
dk = 2.0 * np.pi / (N * image.scale)
dk = 2.0 * np.pi / (N * image.scale)
else:
# Start with what this profile thinks a good size would be given the image's pixel scale.
N = self.getGoodImageSize(image.scale)
Expand Down
Loading

0 comments on commit 7e8cbc1

Please sign in to comment.