Skip to content

Commit

Permalink
Move catalog_clean to BaseEnhancedTable
Browse files Browse the repository at this point in the history
The catalog_search functionality really makes more sense as a method.
Since it seems potentially useful to all of the subclasses of
BaseEnhancedTable, I put it there.
  • Loading branch information
mwcraig committed Nov 24, 2023
1 parent 3930535 commit cf5191e
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 142 deletions.
58 changes: 58 additions & 0 deletions stellarphot/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

from astropy import units as u
from astropy.coordinates import EarthLocation, SkyCoord
from astropy.table import Column, QTable, Table
Expand Down Expand Up @@ -305,6 +307,62 @@ def _update_passbands(self):
mask = self['passband'] == orig_pb
self['passband'][mask] = aavso_pb

def clean(self, remove_rows_with_mask=True, **other_restrictions):
"""
Return a catalog with only the rows that meet the criteria specified.
Parameters
----------
catalog : `astropy.table.Table`
Table of catalog information. There are no restrictions on the columns.
remove_rows_with_mask : bool, optional
If ``True``, remove rows in which one or more of the values is masked.
other_restrictions: dict, optional
Key/value pairs in which the key is the name of a column in the
catalog and the value is the criteria that values in that column
must satisfy to be kept in the cleaned catalog. The criteria must be
simple, beginning with a comparison operator and including a value.
See Examples below.
Returns
-------
same type as object whose method was called
Table with filtered data
"""
comparisons = {
'<': np.less,
'=': np.equal,
'>': np.greater,
'<=': np.less_equal,
'>=': np.greater_equal,
'!=': np.not_equal
}

recognized_comparison_ops = '|'.join(comparisons.keys())
keepers = np.ones([len(self)], dtype=bool)

if remove_rows_with_mask and self.has_masked_values:
for c in self.columns:
keepers &= ~self[c].mask

for column, restriction in other_restrictions.items():
criteria_re = re.compile(r'({})([-+a-zA-Z0-9]+)'.format(recognized_comparison_ops))
results = criteria_re.match(restriction)
if not results:
raise ValueError("Criteria {}{} not "
"understood.".format(column, restriction))
comparison_func = comparisons[results.group(1)]
comparison_value = results.group(2)
new_keepers = comparison_func(self[column],
float(comparison_value))
keepers = keepers & new_keepers

return self[keepers]


class PhotometryData(BaseEnhancedTable):
"""
Expand Down
108 changes: 108 additions & 0 deletions stellarphot/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,114 @@ def test_base_enhanced_table_from_existing_table():
assert len(test_base2['dec']) == 1


def test_base_enhanced_table_clean():
# Check that the clean method exists
test_base = BaseEnhancedTable(table_description=test_descript, input_data=testdata)
# Add a row so that we can clean something
test_base_two = test_base.copy()
test_base_two.add_row(test_base[0])
test_base_two['ra'][1] = - test_base_two['ra'][1]
test_cleaned = test_base_two.clean(ra='>0.0')
assert len(test_cleaned) == 1
assert test_cleaned == test_base


def a_table(masked=False):
test_table = Table([(1, 2, 3), (1, -1, -1)], names=('a', 'b'),
masked=masked)
test_table = BaseEnhancedTable(table_description={'a': None, 'b': None}, input_data=test_table)
return test_table


def test_bet_clean_criteria_none_removed():
"""
If all rows satisfy the criteria, none should be removed.
"""
inp = a_table()
criteria = {'a': '>0'}
out = inp.clean(**criteria)
assert len(out) == len(inp)
assert (out == inp).all()


@pytest.mark.parametrize("condition",
['>0', '=1', '!=-1', '>=1'])
def test_bet_clean_criteria_some_removed(condition):
"""
Try a few filters which remove the second row and check that it is
removed.
"""
inp = a_table()
criteria = {'b': condition}
out = inp.clean(**criteria)
assert len(out) == 1
assert (out[0] == inp[0]).all()


@pytest.mark.parametrize("criteria,error_msg", [
({'a': '5'}, "not understood"),
({'a': '<foo'}, "could not convert string")])
def test_clean_bad_criteria(criteria, error_msg):
"""
Make sure the appropriate error is raised when bad criteria are used.
"""
inp = a_table(masked=False)

with pytest.raises(ValueError, match=error_msg):
inp.clean(**criteria)


@pytest.mark.parametrize("clean_masked",
[False, True])
def test_clean_masked_handled_correctly(clean_masked):
inp = a_table(masked=True)
# Mask negative values
inp['b'].mask = inp['b'] < 0
out = inp.clean(remove_rows_with_mask=clean_masked)
if clean_masked:
assert len(out) == 1
assert (np.array(out[0]) == np.array(inp[0])).all()
else:
assert len(out) == len(inp)
assert (out == inp).all()


def test_clean_masked_and_criteria():
"""
Check whether removing masked rows and using a criteria work
together.
"""
inp = a_table(masked=True)
# Mask the first row.
inp['b'].mask = inp['b'] > 0

inp_copy = inp.copy()
# This should remove the third row.
criteria = {'a': '<=2'}

out = inp.clean(remove_rows_with_mask=True, **criteria)

# Is only one row left?
assert len(out) == 1

# Is the row that is left the same as the second row of the input?
assert (np.array(out[0]) == np.array(inp[1])).all()

# Is the input table unchanged?
assert (inp == inp_copy).all()


def test_clean_criteria_none_removed():
"""
If all rows satisfy the criteria, none should be removed.
"""
inp = a_table()
criteria = {'a': '>0'}
out = inp.clean(**criteria)
assert len(out) == len(inp)
assert (out == inp).all()


def test_base_enhanced_table_missing_column():
# Should raise exception because the RA data is missing from input data
testdata_nora = testdata.copy()
Expand Down
61 changes: 0 additions & 61 deletions stellarphot/utils/catalog_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
__all__ = [
'in_frame',
'catalog_search',
'catalog_clean',
'find_apass_stars',
'find_known_variables',
]
Expand Down Expand Up @@ -123,66 +122,6 @@ def catalog_search(frame_wcs_or_center, shape, desired_catalog,
return cat[in_fov]


def catalog_clean(catalog, remove_rows_with_mask=True,
**other_restrictions):
"""
Return a catalog with only the rows that meet the criteria specified.
Parameters
----------
catalog : `astropy.table.Table`
Table of catalog information. There are no restrictions on the columns.
remove_rows_with_mask : bool, optional
If ``True``, remove rows in which one or more of the values is masked.
other_restrictions: dict, optional
Key/value pairs in which the key is the name of a column in the
catalog and the value is the criteria that values in that column
must satisfy to be kept in the cleaned catalog. The criteria must be
simple, beginning with a comparison operator and including a value.
See Examples below.
Returns
-------
`astropy.table.Table`
Table of catalog information for stars in the field of view.
"""

comparisons = {
'<': np.less,
'=': np.equal,
'>': np.greater,
'<=': np.less_equal,
'>=': np.greater_equal,
'!=': np.not_equal
}

recognized_comparison_ops = '|'.join(comparisons.keys())
keepers = np.ones([len(catalog)], dtype=bool)

if remove_rows_with_mask and catalog.masked:
for c in catalog.columns:
keepers &= ~catalog[c].mask

for column, restriction in other_restrictions.items():
criteria_re = re.compile(r'({})([-+a-zA-Z0-9]+)'.format(recognized_comparison_ops))
results = criteria_re.match(restriction)
if not results:
raise ValueError("Criteria {}{} not "
"understood.".format(column, restriction))
comparison_func = comparisons[results.group(1)]
comparison_value = results.group(2)
new_keepers = comparison_func(catalog[column],
float(comparison_value))
keepers = keepers & new_keepers

return catalog[keepers]


def find_apass_stars(image_or_center,
radius=1,
max_mag_error=0.05,
Expand Down
83 changes: 2 additions & 81 deletions stellarphot/utils/tests/test_catalog_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from astropy.wcs import WCS
from astropy.nddata import CCDData

from ..catalog_search import catalog_clean, in_frame, \
from ..catalog_search import in_frame, \
catalog_search, find_known_variables, \
find_apass_stars, filter_catalog
find_apass_stars
from ...tests.make_wcs import make_wcs

CCD_SHAPE = [2048, 3073]
Expand All @@ -22,85 +22,6 @@ def a_table(masked=False):
return test_table


def test_clean_criteria_none_removed():
"""
If all rows satisfy the criteria, none should be removed.
"""
inp = a_table()
criteria = {'a': '>0'}
out = catalog_clean(inp, **criteria)
assert len(out) == len(inp)
assert (out == inp).all()


@pytest.mark.parametrize("condition",
['>0', '=1', '!=-1', '>=1'])
def test_clean_criteria_some_removed(condition):
"""
Try a few filters which remove the second row and check that it is
removed.
"""
inp = a_table()
criteria = {'b': condition}
out = catalog_clean(inp, **criteria)
assert len(out) == 1
assert (out[0] == inp[0]).all()


@pytest.mark.parametrize("clean_masked",
[False, True])
def test_clean_masked_handled_correctly(clean_masked):
inp = a_table(masked=True)
# Mask negative values
inp['b'].mask = inp['b'] < 0
out = catalog_clean(inp, remove_rows_with_mask=clean_masked)
if clean_masked:
assert len(out) == 1
assert (np.array(out[0]) == np.array(inp[0])).all()
else:
assert len(out) == len(inp)
assert (out == inp).all()


def test_clean_masked_and_criteria():
"""
Check whether removing masked rows and using a criteria work
together.
"""
inp = a_table(masked=True)
# Mask the first row.
inp['b'].mask = inp['b'] > 0

inp_copy = inp.copy()
# This should remove the third row.
criteria = {'a': '<=2'}

out = catalog_clean(inp, remove_rows_with_mask=True, **criteria)

# Is only one row left?
assert len(out) == 1

# Is the row that is left the same as the second row of the input?
assert (np.array(out[0]) == np.array(inp[1])).all()

# Is the input table unchanged?
assert (inp == inp_copy).all()


@pytest.mark.parametrize("criteria,error_msg", [
({'a': '5'}, "not understood"),
({'a': '<foo'}, "could not convert string")])
def test_clean_bad_criteria(criteria, error_msg):
"""
Make sure the appropriate error is raised when bad criteria are used.
"""
inp = a_table(masked=False)

with pytest.raises(ValueError) as e:
catalog_clean(inp, **criteria)
assert error_msg in str(e.value)


def test_in_frame():
# This wcs has the identity matrix as the coordinate transform
wcs = make_wcs()
Expand Down

0 comments on commit cf5191e

Please sign in to comment.