Skip to content

Commit

Permalink
Merge pull request #453 from HERA-Team/update-for-uvdata3
Browse files Browse the repository at this point in the history
Update for uvdata3
  • Loading branch information
bhazelton authored Feb 14, 2024
2 parents 3e71720 + 6728a34 commit cb86b39
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 123 deletions.
9 changes: 6 additions & 3 deletions hera_qm/metrics_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,8 @@ def read_a_priori_chan_flags(a_priori_flags_yaml, freqs=None):
Numpy array of integer a priori channel index flags.
'''
apcf = []
apf = yaml.safe_load(open(a_priori_flags_yaml, 'r'))
with open(a_priori_flags_yaml, 'r') as fl:
apf = yaml.safe_load(fl)

# Load channel flags
if 'channel_flags' in apf:
Expand Down Expand Up @@ -1067,7 +1068,8 @@ def read_a_priori_int_flags(a_priori_flags_yaml, times=None, lsts=None):
Numpy array of integer a priori integration index flags.
'''
apif = []
apf = yaml.safe_load(open(a_priori_flags_yaml, 'r'))
with open(a_priori_flags_yaml, 'r') as fl:
apf = yaml.safe_load(fl)

# Load integration flags
if 'integration_flags' in apf:
Expand Down Expand Up @@ -1157,7 +1159,8 @@ def read_a_priori_ant_flags(a_priori_flags_yaml, ant_indices_only=False, by_ant_
if ant_indices_only and by_ant_pol:
raise ValueError("ant_indices_only and by_ant_pol can't both be True.")
apaf = []
apf = yaml.safe_load(open(a_priori_flags_yaml, 'r'))
with open(a_priori_flags_yaml, 'r') as fl:
apf = yaml.safe_load(fl)

# Load antenna flags
if 'ex_ants' in apf:
Expand Down
18 changes: 13 additions & 5 deletions hera_qm/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from pyuvdata import UVData
from pyuvdata import UVCal
import pyuvdata.utils as uvutils
import warnings
from pathlib import Path

pytestmark = pytest.mark.filterwarnings(
"ignore:The uvw_array does not match the expected values given the antenna positions.",
Expand Down Expand Up @@ -414,8 +416,14 @@ def test_apply_yaml_flags_errors():
pytest.raises(NotImplementedError, utils.apply_yaml_flags, 'uvdata', test_flag_jds)

# check warning for negative integrations
for warn_yaml in ['a_priori_flags_maximum_channels.yaml', 'a_priori_flags_maximum_integrations.yaml',
'a_priori_flags_negative_channels.yaml', 'a_priori_flags_negative_integrations.yaml']:
yaml_path = os.path.join(DATA_PATH, warn_yaml)
with pytest.warns(None) as record:
utils.apply_yaml_flags(uvc, yaml_path, unflag_first=True)
with pytest.warns(UserWarning, match='Flagged channels were provided that exceed the maximum channel index'):
utils.apply_yaml_flags(uvc, Path(DATA_PATH) / "a_priori_flags_maximum_channels.yaml", unflag_first=True)

with pytest.warns(UserWarning, match='Flagged channels were provided with a negative channel index'):
utils.apply_yaml_flags(uvc, Path(DATA_PATH) / "a_priori_flags_negative_channels.yaml", unflag_first=True)

with pytest.warns(UserWarning, match='Flagged integrations were provided that exceed the maximum integration index'):
utils.apply_yaml_flags(uvc, Path(DATA_PATH) / "a_priori_flags_maximum_integrations.yaml", unflag_first=True)

with pytest.warns(UserWarning, match='Flagged integrations were provided with a negative integration index'):
utils.apply_yaml_flags(uvc, Path(DATA_PATH) / "a_priori_flags_negative_integrations.yaml", unflag_first=True)
3 changes: 2 additions & 1 deletion hera_qm/tests/test_vis_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def test_sequential_diff():
assert np.isclose(uvd_diff.data_array, uvd_diff2.data_array, atol=1e-5).all()

# test flag propagation
uvd.flag_array[uvd.antpair2ind(89, 96, ordered=False)[:1]] = True
inds = np.arange(uvd.flag_array.shape[0])[uvd.antpair2ind(89, 96, ordered=False)]
uvd.flag_array[inds[0]] = True
uvd_diff = vis_metrics.sequential_diff(uvd, axis=(0,), pad=False)
assert uvd_diff.get_flags(89, 96)[0].all()
assert not uvd_diff.get_flags(89, 96)[1:].any()
Expand Down
155 changes: 42 additions & 113 deletions hera_qm/tests/test_xrfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import glob
import hera_qm.ant_class as ant_class
from hera_cal import io

from pyuvdata.tests import check_warnings
from astropy.utils.exceptions import AstropyUserWarning
import warnings

test_d_file = os.path.join(DATA_PATH, 'zen.2457698.40355.xx.HH.uvcAA')
test_uvfits_file = os.path.join(DATA_PATH, 'zen.2457698.40355.xx.HH.uvcAA.uvfits')
Expand Down Expand Up @@ -1396,10 +1398,7 @@ def test_xrfi_run_step(tmpdir):

def test_xrfi_run_yaml_flags(tmpdir):
# test xrfi_run with yaml pre-flagging.
mess1 = ['This object is already a waterfall']
messages = 8 * mess1
cat1 = [UserWarning]
categories = 8 * cat1

# Spoof a couple files to use as extra inputs (xrfi_run needs two cal files and two data-like files)
tmp_path = tmpdir.strpath
fake_obs = 'zen.2457698.40355.HH'
Expand Down Expand Up @@ -1467,16 +1466,21 @@ def test_xrfi_run_yaml_flags(tmpdir):

# now test apriori flag file.
# test for different integrations modes (lsts, jds, integrations)
msg = 'This object is already a waterfall'
for test_flag in [a_priori_flag_integrations, a_priori_flag_jds, a_priori_flag_lsts]:
with pytest.warns(None) as record:
with check_warnings(UserWarning, match=msg, nwarnings=8):
# TODO: check whether this warning is expected
warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in subtract")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Degrees of freedom <= 0 for slice")

warnings.filterwarnings("ignore", category=AstropyUserWarning)

xrfi.xrfi_run(ocal_file, acal_file, model_file, raw_dfile,
a_priori_flag_yaml=test_flag, history='Just a test', kt_size=3, throw_away_edges=False)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8
a_priori_flag_yaml=test_flag, history='Just a test',
kt_size=3, throw_away_edges=False)


for ext, label in ext_labels.items():
# by default, only cross median filter / mean filter is not performed.
if not ext in['cross_metrics1', 'cross_flags1']:
Expand Down Expand Up @@ -1505,10 +1509,7 @@ def test_xrfi_run_yaml_flags(tmpdir):
def test_xrfi_run(tmpdir):
# The warnings are because we use UVFlag.to_waterfall() on the total chisquareds
# This doesn't hurt anything, and lets us streamline the pipe
mess1 = ['This object is already a waterfall']
messages = 8 * mess1
cat1 = [UserWarning]
categories = 8 * cat1

# Spoof a couple files to use as extra inputs (xrfi_run needs two cal files and two data-like files)
tmp_path = tmpdir.strpath
fake_obs = 'zen.2457698.40355.HH'
Expand All @@ -1522,14 +1523,8 @@ def test_xrfi_run(tmpdir):
shutil.copyfile(test_uvh5_file, model_file)

# check warnings
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_file, acal_file, model_file, raw_dfile, history='Just a test', kt_size=3)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8

outdir = os.path.join(tmp_path, 'zen.2457698.40355.xrfi')
ext_labels = {'ag_flags1': 'Abscal gains, median filter. Flags.',
Expand Down Expand Up @@ -1591,15 +1586,9 @@ def test_xrfi_run(tmpdir):
# now really do everything.
uvf_list1 = []
uvf_list1_names = []
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_file, acal_file, model_file, raw_dfile,
history='Just a test', kt_size=3, cross_median_filter=True)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8

for ext, label in ext_labels.items():
out = os.path.join(outdir, '.'.join([fake_obs, ext, 'h5']))
Expand All @@ -1616,15 +1605,9 @@ def test_xrfi_run(tmpdir):
# now do partial i/o and check equality of outputs.
uvf_list2 = []
uvf_list2_names = []
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_file, acal_file, model_file, raw_dfile, Nwf_per_load=1,
history='Just a test', kt_size=3, cross_median_filter=True)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8

for ext, label in ext_labels.items():
out = os.path.join(outdir, '.'.join([fake_obs, ext, 'h5']))
Expand Down Expand Up @@ -1829,10 +1812,7 @@ def test_xrfi_run_edgeflag(tmpdir):
# first try out a single file.
# The warnings are because we use UVFlag.to_waterfall() on the total chisquareds
# This doesn't hurt anything, and lets us streamline the pipe
mess1 = ['This object is already a waterfall']
messages = 8 * mess1
cat1 = [UserWarning]
categories = 8 * cat1

# Spoof a couple files to use as extra inputs (xrfi_run needs two cal files and two data-like files)
tmp_path = tmpdir.strpath
fake_obs = 'zen.2457698.40355.HH'
Expand All @@ -1845,14 +1825,9 @@ def test_xrfi_run_edgeflag(tmpdir):
model_file = os.path.join(tmp_path, fake_obs + '.omni_vis.uvh5')
shutil.copyfile(test_uvh5_file, model_file)
# check warnings
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_file, acal_file, model_file, raw_dfile, history='Just a test', kt_size=2)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8

outdir = os.path.join(tmp_path, 'zen.2457698.40355.xrfi')
ext_labels = {'ag_flags1': 'Abscal gains, median filter. Flags.',
'ag_flags2': 'Abscal gains, mean filter. Flags.',
Expand Down Expand Up @@ -1937,14 +1912,8 @@ def test_xrfi_run_edgeflag(tmpdir):
model_file = os.path.join(tmp_path, fo + '.omni_vis.uvh5')
shutil.copyfile(uvf, model_file)
model_files.append(model_file)
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_files, acal_files, model_files, raw_dfiles, history='Just a test', kt_size=1)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8
flags2 = sorted(glob.glob(tmp_path + '/*.xrfi/*.HH.flags2.h5'))
assert len(flags2) == 3
uvf = UVFlag(flags2)
Expand All @@ -1960,10 +1929,7 @@ def test_xrfi_run_multifile(tmpdir):
# test xrfi_run with multiple files
# The warnings are because we use UVFlag.to_waterfall() on the total chisquareds
# This doesn't hurt anything, and lets us streamline the pipe
mess1 = ['This object is already a waterfall']
messages = 8 * mess1
cat1 = [UserWarning]
categories = 8 * cat1

# Spoof a couple files to use as extra inputs (xrfi_run needs two cal files and two data-like files)
tmp_path = tmpdir.strpath
fake_obses = ['zen.2457698.40355191.HH', 'zen.2457698.40367619.HH', 'zen.2457698.40380046.HH']
Expand All @@ -1986,15 +1952,9 @@ def test_xrfi_run_multifile(tmpdir):
model_files.append(model_file)

# check warnings
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_files, acal_files, model_files, raw_dfiles,
history='Just a test', kt_size=3, cross_median_filter=True)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8
ext_labels = {'ag_flags1': 'Abscal gains, median filter. Flags.',
'ag_flags2': 'Abscal gains, mean filter. Flags.',
'ag_metrics1': 'Abscal gains, median filter.',
Expand Down Expand Up @@ -2053,12 +2013,11 @@ def test_xrfi_run_multifile(tmpdir):
if ext == 'flags2.h5':
assert np.all(uvf.flag_array)
# check warnings
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_files, acal_files, model_files, raw_dfiles,
history='Just a test', kt_size=3, cross_median_filter=True,
throw_away_edges=False, clobber=True)
assert len(record) >= len(messages)
n_matched_warnings = 0

# check that the number of outdirs is 1
outdirs = sorted(glob.glob(tmp_path + '/*.xrfi'))
assert len(outdirs) == 3
Expand All @@ -2079,10 +2038,6 @@ def test_xrfi_run_multifile(tmpdir):
def test_day_threshold_run(tmpdir):
# The warnings are because we use UVFlag.to_waterfall() on the total chisquareds
# This doesn't hurt anything, and lets us streamline the pipe
mess1 = ['This object is already a waterfall']
messages = 8 * mess1
cat1 = [UserWarning]
categories = 8 * cat1
# Spoof the files - run xrfi_run twice on spoofed files.
tmp_path = tmpdir.strpath
fake_obses = ['zen.2457698.40355.HH', 'zen.2457698.41101.HH']
Expand All @@ -2098,14 +2053,8 @@ def test_day_threshold_run(tmpdir):
shutil.copyfile(test_uvh5_file, model_file)

# check warnings
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_file, acal_file, model_file, raw_dfile, history='Just a test', kt_size=3, throw_away_edges=False)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8

# Need to adjust time arrays when duplicating files
uvd = UVData.from_file(data_files[0], use_future_array_shapes=True)
Expand All @@ -2127,15 +2076,17 @@ def test_day_threshold_run(tmpdir):
shutil.copyfile(test_flag_integrations, a_priori_flag_integrations)

# check warnings
with pytest.warns(None) as record:
with check_warnings(
UserWarning,
match="This object is already a waterfall",
nwarnings=8
):
# TODO: these three warnings should be checked.
warnings.filterwarnings("ignore", category=AstropyUserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Degrees of freedom")
xrfi.xrfi_run(ocal_file, acal_file, model_file, data_files[1], history='Just a test', kt_size=3, clobber=True,
throw_away_edges=False, a_priori_flag_yaml=a_priori_flag_integrations, a_priori_ants_only=True)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8

xrfi.day_threshold_run(data_files, history='just a test', a_priori_flag_yaml=a_priori_flag_integrations)
types = ['og', 'ox', 'ag', 'ax', 'v', 'cross', 'auto', 'omnical_chi_sq_renormed',
Expand All @@ -2157,10 +2108,6 @@ def test_day_threshold_run(tmpdir):
def test_day_threshold_run_yaml(tmpdir):
# The warnings are because we use UVFlag.to_waterfall() on the total chisquareds
# This doesn't hurt anything, and lets us streamline the pipe
mess1 = ['This object is already a waterfall']
messages = 8 * mess1
cat1 = [UserWarning]
categories = 8 * cat1
# Spoof the files - run xrfi_run twice on spoofed files.
tmp_path = tmpdir.strpath
fake_obses = ['zen.2457698.40355.HH', 'zen.2457698.41101.HH']
Expand All @@ -2176,14 +2123,8 @@ def test_day_threshold_run_yaml(tmpdir):
shutil.copyfile(test_uvh5_file, model_file)

# check warnings
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_file, acal_file, model_file, raw_dfile, history='Just a test', kt_size=3, throw_away_edges=False)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8

# Need to adjust time arrays when duplicating files
uvd = UVData.from_file(data_files[0], use_future_array_shapes=True)
Expand All @@ -2203,14 +2144,8 @@ def test_day_threshold_run_yaml(tmpdir):
uvc.write_calfits(acal_file)

# check warnings
with pytest.warns(None) as record:
with check_warnings(UserWarning, match="This object is already a waterfall", nwarnings=8):
xrfi.xrfi_run(ocal_file, acal_file, model_file, data_files[1], history='Just a test', kt_size=3, clobber=True, throw_away_edges=False)
assert len(record) >= len(messages)
n_matched_warnings = 0
for i in range(len(record)):
if mess1[0] in str(record[i].message) and cat1[0] == record[i].category:
n_matched_warnings += 1
assert n_matched_warnings == 8

xrfi.day_threshold_run(data_files, history='just a test')
types = ['og', 'ox', 'ag', 'ax', 'v', 'cross', 'auto', 'omnical_chi_sq_renormed',
Expand Down Expand Up @@ -2407,16 +2342,10 @@ def test_xrfi_h1c_run():
kt_size=3)

# catch no provided data file for flagging
with pytest.warns(None) as record:
with check_warnings(UserWarning, match=['indata is None']*100 + ['K1 value 8']*91):
xrfi.xrfi_h1c_run(None, **{'filename': test_d_file, 'history': 'Just a test.',
'model_file': test_d_file, 'model_file_format': 'miriad',
'xrfi_path': xrfi_path})
assert len(record) >= 191
n_matched_warnings = 0
for i in range(len(record)):
if 'indata is None' in str(record[i].message) or 'K1 value 8' in str(record[i].message):
n_matched_warnings += 1
assert n_matched_warnings == 191


def test_xrfi_h1c_run_no_indata():
Expand Down
Loading

0 comments on commit cb86b39

Please sign in to comment.