Skip to content

Commit

Permalink
updates to MIRI LRS resample wcs
Browse files Browse the repository at this point in the history
  • Loading branch information
jemorrison committed Feb 11, 2025
1 parent 46c1b40 commit 8adf05e
Showing 1 changed file with 28 additions and 115 deletions.
143 changes: 28 additions & 115 deletions jwst/resample/resample_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,22 +523,19 @@ def build_interpolated_output_wcs(self, input_models):
all_dec_slit = []
xstop = 0

#all_wcs = [m.meta.wcs for m in input_models]
all_wcs = [m.meta.wcs for m in input_models]

Check warning on line 526 in jwst/resample/resample_spec.py

View check run for this annotation

Codecov / codecov/patch

jwst/resample/resample_spec.py#L526

Added line #L526 was not covered by tests


print('********in build interpolated output')
for im, model in enumerate(input_models):
wcs = model.meta.wcs
bbox = wcs.bounding_box
grid = wcstools.grid_from_bounding_box(bbox)
ra, dec, lam = np.array(wcs(*grid))
print('shape of ra and dec', ra.shape, dec.shape, lam.shape)

# Handle vertical (MIRI) or horizontal (NIRSpec) dispersion. The
# following 2 variables are 0 or 1, i.e. zero-indexed in x,y WCS order
spectral_axis = find_dispersion_axis(model)
spatial_axis = spectral_axis ^ 1

print('spectral_axis', spectral_axis, spatial_axis)
# Compute the wavelength array, trimming NaNs from the ends
# In many cases, a whole slice is NaNs, so ignore those warnings
with warnings.catch_warnings():
Expand All @@ -552,7 +549,7 @@ def build_interpolated_output_wcs(self, input_models):
# sampling.

# Steps to do this for first input model:
# 1. find the middle of the spectrum in wavelength
# 1. Find the middle of the spectrum in wavelength
# 2. Pull out the ra and dec at the center of the slit.
# 3. Find the mean ra,dec and the center of the slit this will
# represent the tangent point
Expand All @@ -562,14 +559,9 @@ def build_interpolated_output_wcs(self, input_models):
# the spatial scale of the output wcs
if im == 0:
all_wavelength = np.append(all_wavelength, wavelength_array)
print("* bbox ", bbox)
print(spectral_axis)
print(bbox[0][1], bbox[0][0])
print(bbox[1][0], bbox[1][1])
# find the center ra and dec for this slit at central wavelength
lam_center_index = int((bbox[spectral_axis][1] -
bbox[spectral_axis][0]) / 2)
print('lam center index',lam_center_index)
if spatial_axis == 0:
# MIRI LRS, the WCS x axis is spatial
ra_slice = ra[lam_center_index, :]
Expand All @@ -582,9 +574,6 @@ def build_interpolated_output_wcs(self, input_models):
ra_center_pt = np.nanmean(wrap_ra(ra_slice))
dec_center_pt = np.nanmean(dec_slice)

print('**** ra_center_pt',ra_center_pt)
print('**** dec_center_pt',dec_center_pt)

# convert ra and dec to tangent projection
tan = Pix2Sky_TAN()
native2celestial = RotateNative2Celestial(ra_center_pt, dec_center_pt, 180)
Expand All @@ -595,20 +584,18 @@ def build_interpolated_output_wcs(self, input_models):
warnings.simplefilter("ignore", RuntimeWarning) # was ignore. need to make more specific
# at this center of slit find x,y tangent projection - x_tan, y_tan
x_tan, y_tan = undist2sky1.inverse(ra, dec)
print('ra and dec shape', ra.shape, dec.shape)
print('x_tan', x_tan.shape)

# pull out data from center
if spectral_axis == 0: # MIRI LRS, the WCS x axis is spatial
if spectral_axis == 0:

Check warning on line 589 in jwst/resample/resample_spec.py

View check run for this annotation

Codecov / codecov/patch

jwst/resample/resample_spec.py#L589

Added line #L589 was not covered by tests
x_tan_array = x_tan.T[lam_center_index]
y_tan_array = y_tan.T[lam_center_index]
else:
else: # MIRI LRS, the WCS x axis is spatial
x_tan_array = x_tan[lam_center_index]
y_tan_array = y_tan[lam_center_index]

x_tan_array = x_tan_array[~np.isnan(x_tan_array)]
y_tan_array = y_tan_array[~np.isnan(y_tan_array)]
print(x_tan_array.shape)
print(y_tan_array.shape)

# estimate the spatial sampling
fitter = LinearLSQFitter()
fit_model = Linear1D()
Expand All @@ -620,8 +607,6 @@ def build_interpolated_output_wcs(self, input_models):
pix_to_xtan = fitter(fit_model, x_idx, x_tan_array)
pix_to_ytan = fitter(fit_model, y_idx, y_tan_array)

print('ystop y_idx', ystop, y_idx)
print('xstop y_idx', xstop, x_idx)
# append all ra and dec values to use later to find min and max
# ra and dec
ra_use = ra[~np.isnan(ra)].flatten()
Expand Down Expand Up @@ -692,11 +677,10 @@ def build_interpolated_output_wcs(self, input_models):
# Account for vertical or horizontal dispersion on detector
mapping.inverse = Mapping((2, 1) if spatial_axis else (1, 2))

print('** swap_xy', swap_xy)
# The final transform
# redefine the ra, dec center tangent point to include all data

# check if all_ra crosses 0 degrees - this makes it hard to
# Check if all_ra crosses 0 degrees - this makes it hard to
# define the min and max ra correctly
all_ra = wrap_ra(all_ra)
ra_min = np.amin(all_ra)
Expand All @@ -707,111 +691,39 @@ def build_interpolated_output_wcs(self, input_models):
dec_max = np.amax(all_dec)
dec_center_final = (dec_max + dec_min) / 2.0

print('ra',ra_min, ra_max)
print('dec',dec_min, dec_max)
print(ra_center_final, dec_center_final)
# define transforms
tan = Pix2Sky_TAN()
if len(input_models) == 1: # single model use ra_center_pt to be consistent
# with how resample was done before
ra_center_final = ra_center_pt
dec_center_final = dec_center_pt

native2celestial = RotateNative2Celestial(ra_center_final, dec_center_final, 180)
undist2sky = tan | native2celestial

# && try this
# at this center of slit find x,y tangent projection - x_tan, y_tan
x_tan, y_tan = undist2sky1.inverse(all_ra, all_dec)
print('x_tan', x_tan.shape)
# pull out data from center
if spectral_axis == 0: # MIRI LRS, the WCS x axis is spatial
x_tan_array = x_tan.T[lam_center_index]
y_tan_array = y_tan.T[lam_center_index]
else:
x_tan_array = x_tan[lam_center_index]
y_tan_array = y_tan[lam_center_index]

x_tan_array = x_tan_array[~np.isnan(x_tan_array)]
y_tan_array = y_tan_array[~np.isnan(y_tan_array)]
print(x_tan_array.shape)
print(y_tan_array.shape)
# estimate the spatial sampling
fitter = LinearLSQFitter()
fit_model = Linear1D()

xstop = x_tan_array.shape[0] * self.pscale_ratio
x_idx = np.linspace(0, xstop, x_tan_array.shape[0], endpoint=False)
ystop = y_tan_array.shape[0] * self.pscale_ratio
y_idx = np.linspace(0, ystop, y_tan_array.shape[0], endpoint=False)
pix_to_xtan = fitter(fit_model, x_idx, x_tan_array)
pix_to_ytan = fitter(fit_model, y_idx, y_tan_array)

## Use all the wcs
min_tan_x, max_tan_x, min_tan_y, max_tan_y = self._max_spatial_extent(

Check warning on line 704 in jwst/resample/resample_spec.py

View check run for this annotation

Codecov / codecov/patch

jwst/resample/resample_spec.py#L704

Added line #L704 was not covered by tests
all_wcs, undist2sky.inverse)
diff_y = np.abs(max_tan_y - min_tan_y)
diff_x = np.abs(max_tan_x - min_tan_x)
pix_to_tan_slope_y = np.abs(pix_to_ytan.slope)
slope_sign_y = np.sign(pix_to_ytan.slope)
pix_to_tan_slope_x = np.abs(pix_to_xtan.slope)
slope_sign_x = np.sign(pix_to_xtan.slope)

Check warning on line 711 in jwst/resample/resample_spec.py

View check run for this annotation

Codecov / codecov/patch

jwst/resample/resample_spec.py#L706-L711

Added lines #L706 - L711 were not covered by tests

# find the spatial size of the output - same in x,y
if swap_xy:
_, x_tan_all = undist2sky.inverse(all_ra, all_dec)
pix_to_tan_slope = pix_to_ytan.slope
ny = int(np.ceil(diff_y / pix_to_tan_slope_y)) + 1

Check warning on line 714 in jwst/resample/resample_spec.py

View check run for this annotation

Codecov / codecov/patch

jwst/resample/resample_spec.py#L714

Added line #L714 was not covered by tests
else:
x_tan_all, _ = undist2sky.inverse(all_ra, all_dec)
pix_to_tan_slope = pix_to_xtan.slope

x_min = np.amin(x_tan_all)
x_max = np.amax(x_tan_all)
x_size = int(np.ceil((x_max - x_min) / np.absolute(pix_to_tan_slope)))
print('pix_to_xtan', pix_to_ytan.intercept, pix_to_xtan.intercept)
print('x min max', x_min, x_max, x_size)

#if swap_xy:
# pix_to_ytan.intercept = -0.5 * (x_size - 1) * pix_to_ytan.slope
# pix_to_ytan.intercept = -0.5 * (44 - 1) * pix_to_ytan.slope
#else:
# pix_to_xtan.intercept = -0.5 * (x_size - 1) * pix_to_xtan.slope


print('pix_to_xtan', pix_to_ytan.intercept, pix_to_xtan.intercept)

## pulled from nirspec code
#min_tan_x, max_tan_x, min_tan_y, max_tan_y = self._max_spatial_extent(
# all_wcs, undist2sky.inverse)
#diff_y = np.abs(max_tan_y - min_tan_y)
#diff_x = np.abs(max_tan_x - min_tan_x)
#print('*** diff x y ', diff_x, diff_y)
#pix_to_tan_slope_y = np.abs(pix_to_ytan.slope)
#slope_sign_y = np.sign(pix_to_ytan.slope)
#pix_to_tan_slope_x = np.abs(pix_to_xtan.slope)
#slope_sign_x = np.sign(pix_to_xtan.slope)
#print(pix_to_tan_slope_y, slope_sign_y)
#print(pix_to_tan_slope_x, slope_sign_x)

#if swap_xy:
# ny = int(np.ceil(diff_y / pix_to_tan_slope_y)) + 1
#else:
# ny = int(np.ceil(diff_x / pix_to_tan_slope_x)) + 1
#print('** ny ***', ny)
#offset_y = ny/2 * pix_to_tan_slope_y - diff_y/2
#offset_x = ny/2 * pix_to_tan_slope_x - diff_x/2

#if slope_sign_y > 0:
# zero_value_y = min_tan_y
#if swap_xy:
# zero_value_y = max_tan_y

#if slope_sign_x > 0:
# zero_value_x = min_tan_x
#else:
# zero_value_x = max_tan_x
#zero_value_y = 0
#zero_value_x = 0
#print('*****', zero_value_y, zero_value_x)

#pix_to_ytan.intercept = zero_value_y - slope_sign_y * offset_y
#pix_to_xtan.intercept = zero_value_x - slope_sign_x * offset_x
##
ny = int(np.ceil(diff_x / pix_to_tan_slope_x)) + 1

Check warning on line 716 in jwst/resample/resample_spec.py

View check run for this annotation

Codecov / codecov/patch

jwst/resample/resample_spec.py#L716

Added line #L716 was not covered by tests

offset_y = (ny)/2 * pix_to_tan_slope_y
offset_x = (ny)/2 * pix_to_tan_slope_x
pix_to_ytan.intercept = - slope_sign_y * offset_y
pix_to_xtan.intercept = - slope_sign_x * offset_x

Check warning on line 721 in jwst/resample/resample_spec.py

View check run for this annotation

Codecov / codecov/patch

jwst/resample/resample_spec.py#L718-L721

Added lines #L718 - L721 were not covered by tests

# single model: use size of x_tan_array
# to be consistent with method before
if len(input_models) == 1:
x_size = int(np.ceil(xstop))
ny = int(np.ceil(xstop))

Check warning on line 726 in jwst/resample/resample_spec.py

View check run for this annotation

Codecov / codecov/patch

jwst/resample/resample_spec.py#L726

Added line #L726 was not covered by tests

# define the output wcs
transform = mapping | (pix_to_xtan & pix_to_ytan | undist2sky) & pix_to_wavelength
Expand All @@ -828,10 +740,11 @@ def build_interpolated_output_wcs(self, input_models):

output_wcs = WCS(pipeline)


# compute the output array size in WCS axes order, i.e. (x, y)
output_array_size = [0, 0]
output_array_size[spectral_axis] = int(np.ceil(len(wavelength_array)))
output_array_size[spatial_axis] = x_size
output_array_size[spatial_axis] = ny

Check warning on line 747 in jwst/resample/resample_spec.py

View check run for this annotation

Codecov / codecov/patch

jwst/resample/resample_spec.py#L747

Added line #L747 was not covered by tests

# turn the size into a numpy shape in (y, x) order
output_wcs.array_shape = output_array_size[::-1]
Expand Down

0 comments on commit 8adf05e

Please sign in to comment.