Skip to content

Commit

Permalink
Refactor dipoleFitTask to give better centroids
Browse files Browse the repository at this point in the history
* Enforce plugin order, so that DipoleFit can fall back on SdssCentroid
for non-dipoles.
* DipoleFitPlugin can run at "centroid" order, because it simultaneously fits
centroids and fluxes.
* Switch centroid slot after SdssCentroid, so that the "best" centroid
comes from DipoleFit (even if it's just copied over).
* Switch dipole centroid field names to better match centroid slot
convention (foo_x/foo_y, for plugin foo).
* Rename DipoleFitTask default name to match Task name convention.

Cleanup tests to pass with the better centroids:
* Loosen flux tolerance.
* Remove test of "unphysical" sources that relied on bad centroiding
pushing the sky sources off the image.
* Remove tests that relied on old centroider behavior when measuring
on unmerged footprints.
  • Loading branch information
parejkoj committed Mar 21, 2024
1 parent 62593c7 commit 2489c08
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 89 deletions.
78 changes: 45 additions & 33 deletions python/lsst/ip/diffim/dipoleFitTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import lsst.afw.image as afwImage
import lsst.meas.base as measBase
import lsst.afw.table as afwTable
import lsst.afw.detection as afwDet
import lsst.geom as geom
import lsst.pex.exceptions as pexExcept
Expand Down Expand Up @@ -104,13 +103,12 @@ class DipoleFitTaskConfig(measBase.SingleFrameMeasurementConfig):
def setDefaults(self):
measBase.SingleFrameMeasurementConfig.setDefaults(self)

# This task also runs DipoleFitPlugin directly in DipoleFitTask, which
# writes outputs to "ip_diffim_DipoleFit" entries.
self.plugins.names = ["base_CircularApertureFlux",
self.plugins.names = ["base_SdssCentroid",
"ip_diffim_DipoleFit",
"base_CircularApertureFlux",
"base_PixelFlags",
"base_SkyCoord",
"base_PsfFlux",
"base_SdssCentroid",
"base_SdssShape",
"base_GaussianFlux",
]
Expand All @@ -121,7 +119,8 @@ def setDefaults(self):
self.slots.modelFlux = None
self.slots.gaussianFlux = None
self.slots.shape = "base_SdssShape"
self.slots.centroid = "ip_diffim_NaiveDipoleCentroid"
# This will be switched to "ip_diffim_DipoleFit" as this task runs.
self.slots.centroid = "base_SdssCentroid"
self.doReplaceWithNoise = False


Expand All @@ -135,21 +134,28 @@ class DipoleFitTask(measBase.SingleFrameMeasurementTask):
"""

ConfigClass = DipoleFitTaskConfig
_DefaultName = "ip_diffim_DipoleFit"
_DefaultName = "dipoleFit"

def __init__(self, schema, algMetadata=None, **kwargs):

measBase.SingleFrameMeasurementTask.__init__(self, schema, algMetadata, **kwargs)

dpFitPluginConfig = self.config.plugins['ip_diffim_DipoleFit']

self.dipoleFitter = DipoleFitPlugin(dpFitPluginConfig, name=self._DefaultName,
schema=schema, metadata=algMetadata,
logName=self.log.name)
super().__init__(schema, algMetadata, **kwargs)

# Enforce a specific plugin order, so that DipoleFit can fall back on
# SdssCentroid for non-dipoles
self.plugins_pre = self.plugins.copy()
self.plugins_post = self.plugins.copy()
self.plugins_pre.clear()
self.plugins_pre["base_SdssCentroid"] = self.plugins["base_SdssCentroid"]
self.plugins_post.pop("base_SdssCentroid")
self.dipoleFit = self.plugins_post.pop("ip_diffim_DipoleFit")
del self.plugins

@timeMethod
def run(self, sources, exposure, posExp=None, negExp=None, **kwargs):
"""Run dipole measurement and classification
"""Run dipole measurement and classification.
Run SdssCentroid first, then switch the centroid slot, then DipoleFit
then the rest; DipoleFit will fall back on SdssCentroid for sources
not containing positive+negative peaks.
Parameters
----------
Expand All @@ -168,14 +174,15 @@ def run(self, sources, exposure, posExp=None, negExp=None, **kwargs):
**kwargs
Additional keyword arguments for `lsst.meas.base.sfm.SingleFrameMeasurementTask`.
"""
self.plugins = self.plugins_pre
super().run(sources, exposure, **kwargs)

measBase.SingleFrameMeasurementTask.run(self, sources, exposure, **kwargs)

if not sources:
return

sources.schema.getAliasMap().set("slot_Centroid", "ip_diffim_DipoleFit")
for source in sources:
self.dipoleFitter.measure(source, exposure, posExp, negExp)
self.dipoleFit.measureDipoles(source, exposure, posExp, negExp)

self.plugins = self.plugins_post
super().run(sources, exposure, **kwargs)


class DipoleModel:
Expand Down Expand Up @@ -981,12 +988,10 @@ class DipoleFitPlugin(measBase.SingleFramePlugin):

@classmethod
def getExecutionOrder(cls):
"""Set execution order to `FLUX_ORDER`.
This includes algorithms that require both `getShape()` and `getCentroid()`,
in addition to a Footprint and its Peaks.
"""This algorithm simultaneously fits the centroid and flux, and does
not require any previous centroid fit.
"""
return cls.FLUX_ORDER
return cls.CENTROID_ORDER

def __init__(self, config, name, schema, metadata, logName=None):
if logName is None:
Expand All @@ -1010,15 +1015,15 @@ def _setupSchema(self, config, name, schema, metadata):
self.fluxKey = measBase.FluxResultKey.addFields(schema, name, doc)

self.posCentroidKey = measBase.CentroidResultKey.addFields(schema,
schema.join(name, "pos", "centroid"),
schema.join(name, "pos"),
"Dipole positive lobe centroid position.",
measBase.UncertaintyEnum.NO_UNCERTAINTY)
self.negCentroidKey = measBase.CentroidResultKey.addFields(schema,
schema.join(name, "neg", "centroid"),
schema.join(name, "neg"),
"Dipole negative lobe centroid position.",
measBase.UncertaintyEnum.NO_UNCERTAINTY)
self.centroidKey = measBase.CentroidResultKey.addFields(schema,
schema.join(name, "centroid"),
name,
"Dipole centroid position.",
measBase.UncertaintyEnum.NO_UNCERTAINTY)

Expand Down Expand Up @@ -1058,7 +1063,12 @@ def _setupSchema(self, config, name, schema, metadata):
schema.join(name, "flag", "edge"), type="Flag",
doc="Flag set when dipole is too close to edge of image")

def measure(self, measRecord, exposure, posExp=None, negExp=None):
def measure(self, *args):
"""No op: the real work of this task is done in `measureDipoles`.
"""
pass

def measureDipoles(self, measRecord, exposure, posExp=None, negExp=None):
"""Perform the non-linear least squares minimization on the putative dipole source.
Parameters
Expand Down Expand Up @@ -1105,7 +1115,9 @@ def measure(self, measRecord, exposure, posExp=None, negExp=None):
measRecord.set(self.classificationAttemptedFlagKey, False)
self.fail(measRecord, measBase.MeasurementError('not a dipole', self.FAILURE_NOT_DIPOLE))
if not self.config.fitAllDiaSources:
return result
measRecord[self.centroidKey.getX()] = measRecord["base_SdssCentroid_x"]
measRecord[self.centroidKey.getY()] = measRecord["base_SdssCentroid_y"]
return

try:
alg = self.DipoleFitAlgorithmClass(exposure, posImage=posExp, negImage=negExp)
Expand All @@ -1125,7 +1137,7 @@ def measure(self, measRecord, exposure, posExp=None, negExp=None):
if result is None:
measRecord.set(self.classificationFlagKey, False)
measRecord.set(self.classificationAttemptedFlagKey, False)
return result
return

self.log.debug("Dipole fit result: %d %s", measRecord.getId(), str(result))

Expand Down
53 changes: 5 additions & 48 deletions tests/test_detectAndMeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DetectAndMeasureTestBase:

def _check_diaSource(self, refSources, diaSource, refIds=None,
matchDistance=1., scale=1., usePsfFlux=True,
rtol=0.021, atol=None):
rtol=0.025, atol=None):
"""Match a diaSource with a source in a reference catalog
and compare properties.
Expand Down Expand Up @@ -227,17 +227,6 @@ def test_remove_unphysical(self):
bbox = difference.getBBox()
difference.maskedImage -= matchedTemplate.maskedImage

# Configure the detection Task, and do not remove unphysical sources
detectionTask = self._setup_detection(doForcedMeasurement=False, doSkySources=True, nSkySources=20,
badSourceFlags=[])

# Run detection and check the results
diaSources = detectionTask.run(science, matchedTemplate, difference).diaSources
badDiaSrcNoRemove = ~bbox.contains(diaSources.getX(), diaSources.getY())
nBadNoRemove = np.count_nonzero(badDiaSrcNoRemove)
# Verify that unphysical sources exist
self.assertGreater(nBadNoRemove, 0)

# Configure the detection Task, and remove unphysical sources
detectionTask = self._setup_detection(doForcedMeasurement=False, doSkySources=True, nSkySources=20,
badSourceFlags=["base_PixelFlags_flag_offimage", ])
Expand Down Expand Up @@ -351,26 +340,11 @@ def test_detect_dipoles(self):
matchedTemplate.mask.array[...] = np.roll(matchedTemplate.mask.array[...], offset, axis=0)
difference.maskedImage -= matchedTemplate.maskedImage[science.getBBox()]

# Configure the detection Task
detectionTask = self._setup_detection(doMerge=False)

# Run detection and check the results
detectionTask = self._setup_detection(doMerge=True)
output = detectionTask.run(science, matchedTemplate, difference)
self.assertIn(dipoleFlag, output.diaSources.schema.getNames())
nSourcesDet = len(sources)
self.assertEqual(len(output.diaSources), 2*nSourcesDet)
self.assertEqual(len(output.diaSources), len(sources))
refIds = []
# The diaSource check should fail if we don't merge positive and negative footprints
for diaSource in output.diaSources:
with self.assertRaises(AssertionError):
self._check_diaSource(sources, diaSource, refIds=refIds, scale=0,
atol=np.sqrt(fluxRange*fluxLevel))

detectionTask2 = self._setup_detection(doMerge=True)
output2 = detectionTask2.run(science, matchedTemplate, difference)
self.assertEqual(len(output2.diaSources), nSourcesDet)
refIds = []
for diaSource in output2.diaSources:
if diaSource[dipoleFlag]:
self._check_diaSource(sources, diaSource, refIds=refIds, scale=0,
rtol=0.05, atol=None, usePsfFlux=False)
Expand Down Expand Up @@ -720,28 +694,11 @@ def test_detect_dipoles(self):
scienceKernel = science.psf.getKernel()
score = subtractTask._convolveExposure(difference, scienceKernel, subtractTask.convolutionControl)

# Configure the detection Task
detectionTask = self._setup_detection(doMerge=False)

# Run detection and check the results
detectionTask = self._setup_detection(doMerge=True)
output = detectionTask.run(science, matchedTemplate, difference, score)
self.assertIn(dipoleFlag, output.diaSources.schema.getNames())
nSourcesDet = len(sources)
# Since we did not merge the dipoles, each source should result in
# both a positive and a negative diaSource
self.assertEqual(len(output.diaSources), 2*nSourcesDet)
self.assertEqual(len(output.diaSources), len(sources))
refIds = []
# The diaSource check should fail if we don't merge positive and negative footprints
for diaSource in output.diaSources:
with self.assertRaises(AssertionError):
self._check_diaSource(sources, diaSource, refIds=refIds, scale=0,
atol=np.sqrt(fluxRange*fluxLevel))

detectionTask2 = self._setup_detection(doMerge=True)
output2 = detectionTask2.run(science, matchedTemplate, difference, score)
self.assertEqual(len(output2.diaSources), nSourcesDet)
refIds = []
for diaSource in output2.diaSources:
if diaSource[dipoleFlag]:
self._check_diaSource(sources, diaSource, refIds=refIds, scale=0,
rtol=0.05, atol=None, usePsfFlux=False)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_dipoleFitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,13 @@ def _checkTaskOutput(self, dipoleTestImage, sources, rtol=None):
self.assertFloatsAlmostEqual((result['ip_diffim_DipoleFit_pos_instFlux']
+ abs(result['ip_diffim_DipoleFit_neg_instFlux']))/2.,
dipoleTestImage.flux[i], rtol=rtol)
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_pos_centroid_x'],
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_pos_x'],
dipoleTestImage.xc[i] + offsets[i], rtol=rtol)
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_pos_centroid_y'],
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_pos_y'],
dipoleTestImage.yc[i] + offsets[i], rtol=rtol)
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_neg_centroid_x'],
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_neg_x'],
dipoleTestImage.xc[i] - offsets[i], rtol=rtol)
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_neg_centroid_y'],
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_neg_y'],
dipoleTestImage.yc[i] - offsets[i], rtol=rtol)
# Note this is dependent on the noise (variance) being realistic in the image.
# otherwise it throws off the chi2 estimate, which is used for classification:
Expand All @@ -183,16 +183,16 @@ def _checkTaskOutput(self, dipoleTestImage, sources, rtol=None):
(result2['ip_diffim_PsfDipoleFlux_pos_instFlux']
+ abs(result2['ip_diffim_PsfDipoleFlux_neg_instFlux']))/2.,
rtol=rtol)
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_pos_centroid_x'],
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_pos_x'],
result2['ip_diffim_PsfDipoleFlux_pos_centroid_x'],
rtol=rtol)
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_pos_centroid_y'],
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_pos_y'],
result2['ip_diffim_PsfDipoleFlux_pos_centroid_y'],
rtol=rtol)
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_neg_centroid_x'],
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_neg_x'],
result2['ip_diffim_PsfDipoleFlux_neg_centroid_x'],
rtol=rtol)
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_neg_centroid_y'],
self.assertFloatsAlmostEqual(result['ip_diffim_DipoleFit_neg_y'],
result2['ip_diffim_PsfDipoleFlux_neg_centroid_y'],
rtol=rtol)

Expand Down

0 comments on commit 2489c08

Please sign in to comment.