diff --git a/src/pyuvdata/data/1061316296_nearfield.uvh5 b/src/pyuvdata/data/1061316296_nearfield.uvh5 new file mode 100644 index 0000000000..0465a434bc Binary files /dev/null and b/src/pyuvdata/data/1061316296_nearfield.uvh5 differ diff --git a/src/pyuvdata/utils/phase_center_catalog.py b/src/pyuvdata/utils/phase_center_catalog.py index dd972e9be2..bcc7308e40 100644 --- a/src/pyuvdata/utils/phase_center_catalog.py +++ b/src/pyuvdata/utils/phase_center_catalog.py @@ -7,7 +7,7 @@ from . import RADIAN_TOL -allowed_cat_types = ["sidereal", "ephem", "unprojected", "driftscan"] +allowed_cat_types = ["sidereal", "ephem", "unprojected", "driftscan", "near_field"] def look_in_catalog( @@ -610,6 +610,8 @@ def generate_phase_center_cat_entry( "driftscan" (fixed az/el position), "unprojected" (no w-projection, equivalent to the old `phase_type` == "drift"). + "near-field" (equivalent to sidereal with the addition + of near-field corrections) cat_lon : float or ndarray Value of the longitudinal coordinate (e.g., RA, Az, l) in radians of the phase center. No default unless `cat_type="unprojected"`, in which case the @@ -684,7 +686,7 @@ def generate_phase_center_cat_entry( if not isinstance(cat_name, str): raise ValueError("cat_name must be a string.") - # We currently only have 4 supported types -- make sure the user supplied + # We currently only have 5 supported types -- make sure the user supplied # one of those if cat_type not in allowed_cat_types: raise ValueError(f"cat_type must be one of {allowed_cat_types}.") diff --git a/src/pyuvdata/utils/phasing.py b/src/pyuvdata/utils/phasing.py index 930ce8a4ba..5a9f3a330c 100644 --- a/src/pyuvdata/utils/phasing.py +++ b/src/pyuvdata/utils/phasing.py @@ -7,12 +7,13 @@ import erfa import numpy as np from astropy import units -from astropy.coordinates import Angle, Distance, EarthLocation, SkyCoord +from astropy.coordinates import AltAz, Angle, Distance, EarthLocation, SkyCoord from astropy.time import Time from astropy.utils import iers from . import _phasing from .times import get_lst_for_time +from .tools import _get_autocorrelations_mask, _nants_to_nblts, _ntimes_to_nblts try: from lunarsky import MoonLocation, SkyCoord as LunarSkyCoord, Time as LTime @@ -2085,6 +2086,8 @@ def calc_app_coords( "ephem" (RA/Dec that moves with time), "driftscan" (fixed az/el position), "unprojected" (alias for "driftscan" with (Az, Alt) = (0 deg, 90 deg)). + "near_field" (equivalent to sidereal, with the addition of + near-field corrections) time_array : float or ndarray of float or Time object Times for which the apparent coordinates are to be calculated, in UTC JD. If more than a single element, must be the same shape as lon_coord and @@ -2175,7 +2178,7 @@ def calc_app_coords( else: unique_lst = lst_array[unique_mask] - if coord_type == "sidereal": + if coord_type == "sidereal" or coord_type == "near_field": # If the coordinates are not in the ICRS frame, go ahead and transform them now if coord_frame != "icrs": icrs_ra, icrs_dec = transform_sidereal_coords( @@ -2571,3 +2574,120 @@ def uvw_track_generator( "lst": lst_array, "site_loc": site_loc, } + + +def _get_focus_xyz(uvd, focus, ra, dec): + """ + Return the x,y,z coordinates of the focal point. + + The focal point corresponds to the location of + the near-field object of interest in the ENU + frame centered on the median position of the + antennas. + + Parameters + ---------- + uvd : UVData object + UVData object + focus : float + Focal distance of the array (km) + ra : float + Right ascension of the focal point ie phase center (deg; shape (Ntimes,)) + dec : float + Declination of the focal point ie phase center (deg; shape (Ntimes,)) + + Returns + ------- + x, y, z: ndarray, ndarray, ndarray + ENU-frame coordinates of the focal point (meters) (shape (Ntimes,)) + """ + # Obtain timesteps + timesteps = Time(np.unique(uvd.time_array), format="jd") + + # Initialize sky-based coordinates using right ascension and declination + obj = SkyCoord(ra * units.deg, dec * units.deg) + + # The center of the ENU frame should be located at the median position of the array + loc = uvd.telescope.location.itrs.cartesian.xyz.value + antpos = uvd.telescope.antenna_positions + loc + x, y, z = np.median(antpos, axis=0) + + # Initialize EarthLocation object centred on the telescope + telescope = EarthLocation(x, y, z, unit=units.m) + + # Convert sky object to an AltAz frame centered on the telescope + obj = obj.transform_to(AltAz(obstime=timesteps, location=telescope)) + + # Obtain altitude and azimuth + theta, phi = obj.alt.to(units.rad), obj.az.to(units.rad) + + # Obtain x,y,z ENU coordinates + x = focus * 1e3 * np.cos(theta) * np.sin(phi) + y = focus * 1e3 * np.cos(theta) * np.cos(phi) + z = focus * 1e3 * np.sin(theta) + + return x, y, z + + +def _get_delay(uvd, focus_x, focus_y, focus_z): + """ + Calculate near-field phase/delay along the Nblts axis. + + Parameters + ---------- + uvd : UVData object + UVData object + focus_x, focus_y, focus_z : ndarray, ndarray, ndarray + ENU-frame coordinates of focal point (Each of shape (Ntimes,)) + + Returns + ------- + phi : ndarray + The phase correction to apply to each visibility along the Nblts axis + new_w : ndarray + The calculated near-field delay (or w-term) for each visibility along + the Nblts axis + """ + # Get indices to convert between Nants and Nblts + ind1, ind2 = _nants_to_nblts(uvd) + + # Antenna positions in ENU frame + antpos = uvd.telescope.get_enu_antpos() - np.median( + uvd.telescope.get_enu_antpos(), axis=0 + ) + + # Get tile positions for each baseline + tile1 = antpos[ind1] # Shape (Nblts, 3) + tile2 = antpos[ind2] # Shape (Nblts, 3) + + # Focus points have shape (Ntimes,); convert to shape (Nblts,) + t_inds = _ntimes_to_nblts(uvd) + (focus_x, focus_y, focus_z) = (focus_x[t_inds], focus_y[t_inds], focus_z[t_inds]) + + # Calculate distance from antennas to focal point + # for each visibility along the Nblts axis + r1 = np.sqrt( + (tile1[:, 0] - focus_x) ** 2 + + (tile1[:, 1] - focus_y) ** 2 + + (tile1[:, 2] - focus_z) ** 2 + ) + r2 = np.sqrt( + (tile2[:, 0] - focus_x) ** 2 + + (tile2[:, 1] - focus_y) ** 2 + + (tile2[:, 2] - focus_z) ** 2 + ) + + # Get the uvw array along the Nblts axis; select only the w's + old_w = uvd.uvw_array[:, -1] + + # Calculate near-field delay + new_w = r1 - r2 + phi = new_w - old_w + + # Remove autocorrelations + mask = _get_autocorrelations_mask(uvd) + + new_w = new_w * mask + old_w * (1 - mask) + phi = phi * mask + + return phi, new_w # Each of shape (Nblts,) diff --git a/src/pyuvdata/utils/tools.py b/src/pyuvdata/utils/tools.py index 29dadfc95c..4aae5d758c 100644 --- a/src/pyuvdata/utils/tools.py +++ b/src/pyuvdata/utils/tools.py @@ -368,3 +368,98 @@ def _sorted_unique_difference(obj1, obj2=None): List containing the difference in unique entries between obj1 and obj2. """ return sorted(set(obj1)) if obj2 is None else sorted(set(obj1).difference(obj2)) + + +def _nants_to_nblts(uvd): + """ + Obtain indices to convert (Nants,) to (Nblts,). + + Parameters + ---------- + uvd : UVData object + + Returns + ------- + ind1, ind2 : ndarray, ndarray + index pairs to compose (Nblts,) shaped arrays for each + baseline from an (Nants,) shaped array + """ + ants = uvd.telescope.antenna_numbers + + ant1 = uvd.ant_1_array + ant2 = uvd.ant_2_array + + ind1 = [] + ind2 = [] + + for i in ant1: + ind1.append(np.where(ants == i)[0][0]) + for i in ant2: + ind2.append(np.where(ants == i)[0][0]) + + return np.asarray(ind1), np.asarray(ind2) + + +def _ntimes_to_nblts(uvd): + """ + Obtain indices to convert (Ntimes,) to (Nblts,). + + Parameters + ---------- + uvd : UVData object + UVData object + + Returns + ------- + inds : ndarray + Indices that, when applied to an array of shape (Ntimes,), + correctly convert it to shape (Nblts,) + """ + unique_t = np.unique(uvd.time_array) + t = uvd.time_array + + inds = [] + for i in t: + inds.append(np.where(unique_t == i)[0][0]) + + return np.asarray(inds) + + +def _get_autocorrelations_mask(uvd): + """ + Get a (Nblts,) shaped array that masks autocorrelations. + + Parameters + ---------- + uvd : UVData object + UVData object + + Returns + ------- + mask : ndarray + array of shape (Nblts,) of 1's and 0's, + where 0 indicates an autocorrelation + """ + # Get indices along the Nblts axis corresponding to autocorrelations + autos = [] + for i in uvd.telescope.antenna_numbers: + num = uvd.antpair2ind(i, ant2=i) + + if isinstance(num, slice): + step = num.step if num.step is not None else 1 + inds = list(range(num.start, num.stop, step)) + autos.append(inds) + + # Flatten it to obtain the 1D array of autocorrelation indices + autos = np.asarray(autos).flatten() + + # Initialize mask of ones (1 = not an autocorrelation) + mask = np.ones_like(uvd.baseline_array) + + # Populate with zeros (0 = is an autocorrelation) + if ( + len(autos) > 0 + ): # Protect against the case where the uvd is already free of autos + mask[autos] = 0 + + return mask diff --git a/src/pyuvdata/uvdata/miriad.py b/src/pyuvdata/uvdata/miriad.py index af3ab3369e..5fac226bb4 100644 --- a/src/pyuvdata/uvdata/miriad.py +++ b/src/pyuvdata/uvdata/miriad.py @@ -1643,6 +1643,14 @@ def write_miriad( "Only ITRS telescope locations are supported in Miriad files." ) + if any( + entry.get("cat_type") == "near_field" + for entry in self.phase_center_catalog.values() + ): + raise NotImplementedError( + "Writing near-field phased data to miriad format is not yet supported." + ) + # change time_array and lst_array to mark beginning of integration, # per Miriad format miriad_time_array = self.time_array - self.integration_time / (24 * 3600.0) / 2 diff --git a/src/pyuvdata/uvdata/ms.py b/src/pyuvdata/uvdata/ms.py index 6cf298ffdb..b7349208ba 100644 --- a/src/pyuvdata/uvdata/ms.py +++ b/src/pyuvdata/uvdata/ms.py @@ -107,6 +107,15 @@ def write_ms( if not casa_present: # pragma: no cover raise ImportError(no_casa_message) from casa_error + if any( + entry.get("cat_type") == "near_field" + for entry in self.phase_center_catalog.values() + ): + raise NotImplementedError( + "Writing near-field phased data to Measurement Set format " + + "is not yet supported." + ) + if run_check: self.check( check_extra=check_extra, diff --git a/src/pyuvdata/uvdata/uvdata.py b/src/pyuvdata/uvdata/uvdata.py index 40204bab7c..c302bf8188 100644 --- a/src/pyuvdata/uvdata/uvdata.py +++ b/src/pyuvdata/uvdata/uvdata.py @@ -24,6 +24,7 @@ from ..telescopes import known_telescopes from ..utils import phasing as phs_utils from ..utils.io import hdf5 as hdf5_utils +from ..utils.phasing import _get_delay, _get_focus_xyz from ..uvbase import UVBase from .initializers import new_uvdata @@ -4623,11 +4624,65 @@ def _phase_dict_helper( elif (key == "cat_id") and (phase_dict[key] is not None): # If this is the cat_id, make it an int phase_dict[key] = int(phase_dict[key]) - elif not ((phase_dict[key] is None) or isinstance(phase_dict[key], str)): - phase_dict[key] = float(phase_dict[key]) - return phase_dict + def _apply_near_field_corrections(self, focus, ra, dec): + """ + Apply near-field corrections by focusing the array to the specified focal point. + + Parameters + ---------- + focus : astropy.units.Quantity object + Focal point of the array + ra : ndarray + Right ascension of the focal point ie phase center (rad; shape (Ntimes,)) + dec : ndarray + Declination of the focal point ie phase center (rad; shape (Ntimes,)) + + Returns + ------- + None (performs operations inplace) + """ + # ------- Parameters that are independent of frequency -------- + + # Obtain focal distance in km + focus = focus.to(units.km).value + + # Convert ra, dec from radians to degrees + ra, dec = np.degrees(ra), np.degrees(dec) + + # Calculate the x, y, z coordinates of the focal point + # in ENU frame for each vis along Nblts axis + focus_x, focus_y, focus_z = _get_focus_xyz(self, focus, ra, dec) + + # Calculate near-field correction at the specified timestep + # for each vis along Nblts axis + phi, new_w = _get_delay(self, focus_x, focus_y, focus_z) + + # Update old w with new w + self.uvw_array[:, -1] = new_w + + # ---------------- Frequency-dependent calculations --------------------- + + # Calculate wavelength associate with each frequency + wavelengths = 299792458 / self.freq_array + + # Reshape the phi and wavelengths arrays in order to + # be able to broadcast them together + phi = np.reshape(phi, (phi.size, 1)) # (Nblts, 1) + wavelengths = np.reshape(wavelengths, (1, wavelengths.size)) # (1, Nfreqs) + + # Calculate phase corrections at all frequencies + # -- produces shape (Nblts, Nfreqs) + phase_corrections = np.exp(-2j * np.pi * phi / wavelengths) + + # Set data at each polarization (Npols = 4 usually) + for pol in self.polarization_array: + prev = np.reshape(self.get_data(pol), (self.Nblts, self.Nfreqs, 1)) + corr = np.reshape(phase_corrections, (self.Nblts, self.Nfreqs, 1)) + + self.set_data(corr * prev, pol) + def phase( self, *, @@ -4688,7 +4743,9 @@ def phase( cat_type : str Type of phase center to be added. Must be one of: "sidereal" (fixed RA/Dec), "ephem" (RA/Dec that moves with time), - "driftscan" (fixed az/el position). Default is "sidereal". + "driftscan" (fixed az/el position), "near_field" (first applies far-field + phasing assuming sidereal phase center, then applies near-field + corrections to the specified dist). Default is "sidereal". ephem_times : ndarray of float Only used when `cat_type="ephem"`. Describes the time for which the values of `cat_lon` and `cat_lat` are caclulated, in units of JD. Shape is (Npts,). @@ -4698,10 +4755,14 @@ def phase( pm_dec : float Proper motion in Dec, in units of mas/year. Only used for sidereal phase centers. - dist : float or ndarray of float - Distance of the source, in units of pc. Only used for sidereal and ephem - phase centers. Expected to be a float for sidereal phase - centers, and an ndarray of floats of shape (Npts,) for ephem phase centers. + dist : float or ndarray of float or astropy.units.Quantity object. + Distance to the source. Used for sidereal and ephem phase centers, + and for applying near-field corrections. If passed either as a float + (for sidereal phase centers) or as an ndarray of floats of shape (Npts,) + (for ephem phase centers), will be interpreted in units of parsec for all + cat_types except near_field; in the latter case it will be interpreted + in meters. Alternatively, an astropy.units.Quantity object may be passed + instead, in which case the units will be infered automatically. vrad : float or ndarray of float Radial velocity of the source, in units of km/s. Only used for sidereal and ephem phase centers. Expected to be a float for sidereal phase @@ -4745,6 +4806,23 @@ def phase( # Before moving forward with the heavy calculations, we need to do some # basic housekeeping to make sure that we've got the coordinate data that # we need in order to proceed. + if dist is not None: + if isinstance(dist, units.Quantity): + dist_qt = copy.deepcopy(dist) + else: + if cat_type == "near_field": + dist_qt = dist * units.m + else: + dist_qt = dist * units.parsec + + dist = dist_qt.to( + units.parsec + ).value # phase_dict internally stores in parsecs + elif dist is None and cat_type == "near_field": + raise ValueError( + "dist parameter must be specified for cat_type 'near_field'" + ) + phase_dict = self._phase_dict_helper( lon=lon, lat=lat, @@ -4899,6 +4977,12 @@ def phase( if cleanup_old_sources: self._clear_unused_phase_centers() + # Lastly, apply near-field corrections if specified + if cat_type == "near_field": + self._apply_near_field_corrections( + focus=dist_qt, ra=phase_dict["cat_lon"], dec=phase_dict["cat_lat"] + ) + def phase_to_time( self, time, *, phase_frame="icrs", use_ant_pos=True, select_mask=None ): diff --git a/src/pyuvdata/uvdata/uvfits.py b/src/pyuvdata/uvdata/uvfits.py index e98e930722..c7d4c85b98 100644 --- a/src/pyuvdata/uvdata/uvfits.py +++ b/src/pyuvdata/uvdata/uvfits.py @@ -819,6 +819,14 @@ def write_uvfits( fix_autos: bool = False, ): """Write data to a uvfits file.""" + if any( + entry.get("cat_type") == "near_field" + for entry in self.phase_center_catalog.values() + ): + raise NotImplementedError( + "Writing near-field phased data to uvfits format is not yet supported." + ) + if run_check: self.check( check_extra=check_extra, diff --git a/tests/uvdata/test_uvdata.py b/tests/uvdata/test_uvdata.py index e5999bc5bf..614cda7860 100644 --- a/tests/uvdata/test_uvdata.py +++ b/tests/uvdata/test_uvdata.py @@ -885,6 +885,16 @@ def test_generic_read(): }, False, ), + ( + { + "cat_name": "near_field_test", + "ra": 0.4, + "dec": -0.3, + "cat_type": "near_field", + "dist": 10 * units.km, + }, + False, + ), ], ) def test_phase_unphase_hera(hera_uvh5, phase_kwargs, partial): @@ -917,18 +927,26 @@ def test_phase_unphase_hera(hera_uvh5, phase_kwargs, partial): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") -def test_phase_unphase_hera_one_bl(hera_uvh5): +@pytest.mark.parametrize("cat_type", ["sidereal", "near_field"]) +def test_phase_unphase_hera_one_bl(hera_uvh5, cat_type): uv_raw = hera_uvh5 # check that phase + unphase work with one baseline uv_raw_small = uv_raw.select(blt_inds=[0], inplace=False) uv_phase_small = uv_raw_small.copy() - uv_phase_small.phase(lon=Angle("23h").rad, lat=Angle("15d").rad, cat_name="foo") + uv_phase_small.phase( + lon=Angle("23h").rad, + lat=Angle("15d").rad, + cat_name="foo", + cat_type=cat_type, + dist=5000, + ) uv_phase_small.unproject_phase(cat_name="zenith") assert uv_raw_small == uv_phase_small @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") -def test_phase_unphase_hera_antpos(hera_uvh5): +@pytest.mark.parametrize("cat_type", ["sidereal", "near_field"]) +def test_phase_unphase_hera_antpos(hera_uvh5, cat_type): uv_phase = hera_uvh5.copy() uv_raw = hera_uvh5 # check that they match if you phase & unphase using antenna locations @@ -957,9 +975,19 @@ def test_phase_unphase_hera_antpos(hera_uvh5): uv_raw_new = uv_raw.copy() uv_raw_new.uvw_array = uvw_calc - uv_phase.phase(ra=0.0, dec=0.0, epoch="J2000", cat_name="foo", use_ant_pos=True) + uv_phase.phase( + ra=0.0, + dec=0.0, + epoch="J2000", + cat_name="foo", + use_ant_pos=True, + cat_type=cat_type, + dist=7000, + ) uv_phase2 = uv_raw_new.copy() - uv_phase2.phase(ra=0.0, dec=0.0, epoch="J2000", cat_name="foo") + uv_phase2.phase( + ra=0.0, dec=0.0, epoch="J2000", cat_name="foo", cat_type=cat_type, dist=7000 + ) # The uvw's only agree to ~1mm. should they be better? np.testing.assert_allclose( @@ -10124,7 +10152,7 @@ def test_print_object_multi(carma_miriad): ValueError, re.escape( "If set, cat_type must be one of ['sidereal', 'ephem', 'unprojected', " - "'driftscan']" + "'driftscan', 'near_field']" ), ], ], @@ -12479,3 +12507,86 @@ def test_pol_convention_warnings(hera_uvh5): ValueError, match="pol_convention is set but the data is uncalibrated" ): hera_uvh5.check() + + +@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") +def test_near_field_corrections(): + uvfits_raw = os.path.join(DATA_PATH, "1061316296.uvfits") + uvfits_corr = os.path.join(DATA_PATH, "1061316296_nearfield.uvh5") + + uvd_raw = UVData() + uvd_corr = UVData() + + uvd_raw.read(uvfits_raw) + uvd_corr.read(uvfits_corr) + + uvd_raw.phase( + ra=np.radians(30), + dec=np.radians(-20), + cat_name="foo", + dist=10000, + cat_type="near_field", + ) + + # History and filename don't matter here + uvd_raw.history = uvd_corr.history + uvd_raw.filename = uvd_corr.filename + + assert uvd_raw == uvd_corr + + +def test_near_field_err(): + uvfits_sample = os.path.join(DATA_PATH, "1061316296.uvfits") + + uvd = UVData() + uvd.read(uvfits_sample) + + with pytest.raises( + ValueError, match="dist parameter must be specified for cat_type 'near_field'" + ): + uvd.phase( + ra=np.radians(30), + dec=np.radians(-20), + cat_name="foo", + cat_type="near_field", + ) + + +@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") +@pytest.mark.parametrize( + "file_format, import_check, error_message", + [ + ( + "miriad", + lambda: pytest.importorskip("pyuvdata.uvdata._miriad"), + "Writing near-field phased data to miriad format is not yet supported.", + ), + ( + "ms", + lambda: pytest.importorskip("casacore"), + "Writing near-field phased data to Measurement Set format " + + "is not yet supported.", + ), + ( + "uvfits", + None, + "Writing near-field phased data to uvfits format is not yet supported.", + ), + ], +) +def test_write_near_field_err(file_format, import_check, error_message): + uvh5_sample = os.path.join(DATA_PATH, "1061316296_nearfield.uvh5") + + uvd = UVData() + uvd.read(uvh5_sample) + + if import_check: + import_check() + + with pytest.raises(NotImplementedError, match=error_message): + if file_format == "miriad": + uvd.write_miriad("test_path") + elif file_format == "ms": + uvd.write_ms("test_path") + elif file_format == "uvfits": + uvd.write_uvfits("test_path") diff --git a/tests/uvdata/test_uvh5.py b/tests/uvdata/test_uvh5.py index b3950080eb..9b04999a84 100644 --- a/tests/uvdata/test_uvh5.py +++ b/tests/uvdata/test_uvh5.py @@ -142,7 +142,8 @@ def make_old_shapes(filename): @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") -def test_read_miriad_write_uvh5_read_uvh5(paper_miriad, tmp_path): +@pytest.mark.parametrize("cat_type", ["sidereal", "near_field"]) +def test_read_miriad_write_uvh5_read_uvh5(paper_miriad, tmp_path, cat_type): """ Test a miriad file round trip. """ @@ -165,13 +166,21 @@ def test_read_miriad_write_uvh5_read_uvh5(paper_miriad, tmp_path): assert uv_in == uv_out - # also test round-tripping phased data + # also test round-tripping phased data using phase_to_time uv_in.phase_to_time(Time(np.mean(uv_in.time_array), format="jd")) uv_in.write_uvh5(testfile, clobber=True) uv_out.read_uvh5(testfile) assert uv_in == uv_out + # finally, test round-tripping phased data + # with and without applying near-field corrections + uv_in.phase(cat_name="foo", cat_type=cat_type, ra=0.6, dec=-0.1, dist=4000) + uv_in.write_uvh5(testfile, clobber=True) + uv_out.read_uvh5(testfile) + + assert uv_in == uv_out + # clean up os.remove(testfile) @@ -251,6 +260,45 @@ def test_read_uvfits_write_uvh5_read_uvh5( return +@pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") +@pytest.mark.parametrize("cat_type", ["sidereal", "near_field"]) +def test_read_uvfits_phase_write_uvh5_read_uvh5(casa_uvfits, tmp_path, cat_type): + """ + Test a uvfits file round trip after applying phasing. + """ + uv_in = casa_uvfits + + uv_out = UVData() + fname = "outtest_uvfits_phased.uvh5" + testfile = str(tmp_path / fname) + + uv_in.phase(cat_name="foo", cat_type=cat_type, ra=0.7, dec=0.2, dist=5000) + + uv_in.write_uvh5(testfile, clobber=True) + uv_out.read(testfile) + + # make sure filenames are what we expect + assert uv_in.filename == ["day2_TDEM0003_10s_norx_1src_1spw.uvfits"] + assert uv_out.filename == [fname] + uv_in.filename = uv_out.filename + + assert uv_in == uv_out + + # clean up + os.remove(testfile) + + # also test writing double-precision data_array + uv_in.data_array = uv_in.data_array.astype(np.complex128) + uv_in.write_uvh5(testfile, clobber=True) + uv_out.read(testfile) + assert uv_in == uv_out + + # clean up + os.remove(testfile) + + return + + @pytest.mark.filterwarnings("ignore:The uvw_array does not match the expected values") def test_read_uvh5_errors(tmp_path, casa_uvfits): """