diff --git a/python/lsst/ip/diffim/detectAndMeasure.py b/python/lsst/ip/diffim/detectAndMeasure.py index a9b71bb2..f035ea4a 100644 --- a/python/lsst/ip/diffim/detectAndMeasure.py +++ b/python/lsst/ip/diffim/detectAndMeasure.py @@ -141,6 +141,12 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig, target=SkyObjectsTask, doc="Generate sky sources", ) + badSourceFlags = lsst.pex.config.ListField( + dtype=str, + doc="Sources with any of these flags set are removed before writing the output catalog.", + default=("base_PixelFlags_flag_offimage", + ), + ) idGenerator = DetectorVisitIdGeneratorConfig.make_field() def setDefaults(self): @@ -222,6 +228,10 @@ def __init__(self, **kwargs): self.makeSubtask("skySources") self.skySourceKey = self.schema.addField("sky_source", type="Flag", doc="Sky objects.") + # Check that the schema and config are consistent + for flag in self.config.badSourceFlags: + if flag not in self.schema: + raise pipeBase.InvalidQuantumError("Field %s not in schema" % flag) # initialize InitOutputs self.outputSchema = afwTable.SourceCatalog(self.schema) self.outputSchema.getTable().setMetadata(self.algMetadata) @@ -353,17 +363,18 @@ def processResults(self, science, matchedTemplate, difference, sources, table, fpSet = positiveFootprints fpSet.merge(negativeFootprints, self.config.growFootprint, self.config.growFootprint, False) - diaSources = afwTable.SourceCatalog(table) - fpSet.makeSources(diaSources) - self.log.info("Merging detections into %d sources", len(diaSources)) + initialDiaSources = afwTable.SourceCatalog(table) + fpSet.makeSources(initialDiaSources) + self.log.info("Merging detections into %d sources", len(initialDiaSources)) else: - diaSources = sources - self.metadata.add("nMergedDiaSources", len(diaSources)) + initialDiaSources = sources + self.metadata.add("nMergedDiaSources", len(initialDiaSources)) if self.config.doSkySources: - self.addSkySources(diaSources, difference.mask, difference.info.id) + self.addSkySources(initialDiaSources, difference.mask, difference.info.id) - self.measureDiaSources(diaSources, science, difference, matchedTemplate) + self.measureDiaSources(initialDiaSources, science, difference, matchedTemplate) + diaSources = self._removeBadSources(initialDiaSources) if self.config.doForcedMeasurement: self.measureForcedSources(diaSources, science, difference.getWcs()) @@ -376,6 +387,32 @@ def processResults(self, science, matchedTemplate, difference, sources, table, return measurementResults + def _removeBadSources(self, diaSources): + """Remove bad diaSources from the catalog. + + Parameters + ---------- + diaSources : `lsst.afw.table.SourceCatalog` + The catalog of detected sources. + + Returns + ------- + diaSources : `lsst.afw.table.SourceCatalog` + The updated catalog of detected sources, with any source that has a + flag in ``config.badSourceFlags`` set removed. + """ + nBadTotal = 0 + selector = np.ones(len(diaSources), dtype=bool) + for flag in self.config.badSourceFlags: + flags = diaSources[flag] + nBad = np.count_nonzero(flags) + if nBad > 0: + self.log.info("Found and removed %d unphysical sources with flag %s.", nBad, flag) + selector &= ~flags + nBadTotal += nBad + self.metadata.add("nRemovedBadFlaggedSources", nBadTotal) + return diaSources[selector].copy(deep=True) + def addSkySources(self, diaSources, mask, seed): """Add sources in empty regions of the difference image for measuring the background. diff --git a/tests/test_detectAndMeasure.py b/tests/test_detectAndMeasure.py index 71bd6a68..7fc3138f 100644 --- a/tests/test_detectAndMeasure.py +++ b/tests/test_detectAndMeasure.py @@ -25,6 +25,7 @@ import lsst.geom from lsst.ip.diffim import detectAndMeasure, subtractImages from lsst.ip.diffim.utils import makeTestImage +from lsst.pipe.base import InvalidQuantumError import lsst.utils.tests @@ -92,20 +93,17 @@ def _check_values(self, values, minValue=None, maxValue=None): if maxValue is not None: self.assertTrue(np.all(values <= maxValue)) - def _setup_detection(self, doApCorr=False, doMerge=False, - doSkySources=False, doForcedMeasurement=False): + def _setup_detection(self, doSkySources=False, nSkySources=5, **kwargs): """Setup and configure the detection and measurement PipelineTask. Parameters ---------- - doApCorr : `bool`, optional - Run subtask to apply aperture corrections. - doMerge : `bool`, optional - Merge positive and negative diaSources. doSkySources : `bool`, optional Generate sky sources. - doForcedMeasurement : `bool`, optional - Force photometer diaSource locations on PVI. + nSkySources : `int`, optional + The number of sky sources to add in isolated background regions. + **kwargs + Any additional config parameters to set. Returns ------- @@ -113,12 +111,10 @@ def _setup_detection(self, doApCorr=False, doMerge=False, The configured Task to use for detection and measurement. """ config = self.detectionTask.ConfigClass() - config.doApCorr = doApCorr - config.doMerge = doMerge config.doSkySources = doSkySources - config.doForcedMeasurement = doForcedMeasurement if doSkySources: - config.skySources.nSources = 5 + config.skySources.nSources = nSkySources + config.update(**kwargs) return self.detectionTask(config=config) @@ -189,6 +185,68 @@ def test_measurements_finite(self): self._check_values(output.diaSources.getY(), minValue=0, maxValue=ySize) self._check_values(output.diaSources.getPsfInstFlux()) + def test_raise_config_schema_mismatch(self): + """Check that sources with specified flags are removed from the catalog. + """ + # Configure the detection Task, and and set a config that is not in the schema + with self.assertRaises(InvalidQuantumError): + self._setup_detection(badSourceFlags=["Bogus_flag_42"]) + + def test_remove_unphysical(self): + """Check that sources with specified flags are removed from the catalog. + """ + # Set up the simulated images + noiseLevel = 1. + staticSeed = 1 + xSize = 256 + ySize = 256 + kwargs = {"psfSize": 2.4, "xSize": xSize, "ySize": ySize} + science, sources = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel, noiseSeed=6, + nSrc=1, **kwargs) + matchedTemplate, _ = makeTestImage(seed=staticSeed, noiseLevel=noiseLevel/4, noiseSeed=7, + nSrc=1, **kwargs) + difference = science.clone() + bbox = difference.getBBox() + difference.maskedImage -= matchedTemplate.maskedImage + + # Configure the detection Task, and do not remove unphysical sources + detectionTask = self._setup_detection(doForcedMeasurement=False, doSkySources=True, nSkySources=20, + badSourceFlags=[]) + + # Run detection and check the results + diaSources = detectionTask.run(science, matchedTemplate, difference).diaSources + badDiaSrcNoRemove = ~bbox.contains(diaSources.getX(), diaSources.getY()) + nBadNoRemove = np.count_nonzero(badDiaSrcNoRemove) + # Verify that unphysical sources exist + self.assertGreater(nBadNoRemove, 0) + + # Configure the detection Task, and remove unphysical sources + detectionTask = self._setup_detection(doForcedMeasurement=False, doSkySources=True, nSkySources=20, + badSourceFlags=["base_PixelFlags_flag_offimage", ]) + + # Run detection and check the results + diaSources = detectionTask.run(science, matchedTemplate, difference).diaSources + badDiaSrcDoRemove = ~bbox.contains(diaSources.getX(), diaSources.getY()) + nBadDoRemove = np.count_nonzero(badDiaSrcDoRemove) + # Verify that all sources are physical + self.assertEqual(nBadDoRemove, 0) + # Set a few centroids outside the image bounding box + nSetBad = 5 + for src in diaSources[0: nSetBad]: + src["slot_Centroid_x"] += xSize + src["slot_Centroid_y"] += ySize + src["base_PixelFlags_flag_offimage"] = True + # Verify that these sources are outside the image + badDiaSrc = ~bbox.contains(diaSources.getX(), diaSources.getY()) + nBad = np.count_nonzero(badDiaSrc) + self.assertEqual(nBad, nSetBad) + diaSourcesNoBad = detectionTask._removeBadSources(diaSources) + badDiaSrcNoBad = ~bbox.contains(diaSourcesNoBad.getX(), diaSourcesNoBad.getY()) + + # Verify that no sources outside the image bounding box remain + self.assertEqual(np.count_nonzero(badDiaSrcNoBad), 0) + self.assertEqual(len(diaSourcesNoBad), len(diaSources) - nSetBad) + def test_detect_transients(self): """Run detection on a difference image containing transients. """ @@ -202,7 +260,7 @@ def test_detect_transients(self): matchedTemplate, _ = makeTestImage(noiseLevel=noiseLevel/4, noiseSeed=7, **kwargs) # Configure the detection Task - detectionTask = self._setup_detection() + detectionTask = self._setup_detection(doMerge=False) kwargs["seed"] = transientSeed kwargs["nSrc"] = 10 kwargs["fluxLevel"] = 1000 @@ -254,7 +312,7 @@ def test_detect_dipoles(self): difference.maskedImage -= matchedTemplate.maskedImage[science.getBBox()] # Configure the detection Task - detectionTask = self._setup_detection() + detectionTask = self._setup_detection(doMerge=False) # Run detection and check the results output = detectionTask.run(science, matchedTemplate, difference) @@ -462,7 +520,7 @@ def test_detect_transients(self): subtractTask = subtractImages.AlardLuptonPreconvolveSubtractTask() # Configure the detection Task - detectionTask = self._setup_detection() + detectionTask = self._setup_detection(doMerge=False) kwargs["seed"] = transientSeed kwargs["nSrc"] = 10 kwargs["fluxLevel"] = 1000 @@ -532,7 +590,7 @@ def test_detect_dipoles(self): score = subtractTask._convolveExposure(difference, scienceKernel, subtractTask.convolutionControl) # Configure the detection Task - detectionTask = self._setup_detection() + detectionTask = self._setup_detection(doMerge=False) # Run detection and check the results output = detectionTask.run(science, matchedTemplate, difference, score)