From 8bf6291073376d475d7e072ed527d6ed6814761a Mon Sep 17 00:00:00 2001 From: Erfan Nourbakhsh Date: Tue, 7 Jan 2025 15:00:59 -0500 Subject: [PATCH 1/3] Add mid-level measurement driver task --- python/lsst/pipe/tasks/measurementDriver.py | 379 ++++++++++++++++++++ 1 file changed, 379 insertions(+) create mode 100644 python/lsst/pipe/tasks/measurementDriver.py diff --git a/python/lsst/pipe/tasks/measurementDriver.py b/python/lsst/pipe/tasks/measurementDriver.py new file mode 100644 index 000000000..fbcfcc713 --- /dev/null +++ b/python/lsst/pipe/tasks/measurementDriver.py @@ -0,0 +1,379 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +__all__ = ["MeasurementDriverConfig", "MeasurementDriverTask"] + +import logging + +import lsst.afw.image as afwImage +import lsst.afw.table as afwTable +import lsst.meas.algorithms as measAlgorithms +import lsst.meas.base as measBase +import lsst.meas.deblender as measDeblender +import lsst.meas.extensions.scarlet as scarlet +import lsst.pex.config as pexConfig +import lsst.pipe.base as pipeBase +import numpy as np + +logging.basicConfig(level=logging.INFO) + + +class MeasurementDriverConfig(pexConfig.Config): + """Configuration parameters for `MeasurementDriverTask`.""" + + # To generate catalog ids consistently across subtasks. + id_generator = measBase.DetectorVisitIdGeneratorConfig.make_field() + + detection = pexConfig.ConfigurableField( + target=measAlgorithms.SourceDetectionTask, + doc="Task to detect sources to return in the output catalog.", + ) + + deblender = pexConfig.ChoiceField[str]( + doc="The deblender to use.", + default="meas_deblender", + allowed={"meas_deblender": "Deblend using meas_deblender", "scarlet": "Deblend using scarlet"}, + ) + + deblend = pexConfig.ConfigurableField( + target=measDeblender.SourceDeblendTask, doc="Split blended sources into their components." + ) + + measurement = pexConfig.ConfigurableField( + target=measBase.SingleFrameMeasurementTask, + doc="Task to measure sources to return in the output catalog.", + ) + + def __setattr__(self, key, value): + """Intercept attribute setting to trigger setDefaults when relevant + fields change. + """ + super().__setattr__(key, value) + + # This is to ensure the deblend target is set correctly whenever the + # deblender is changed. This is required because `setDefaults` is not + # automatically invoked during reconfiguration. + if key == "deblender": + self.setDefaults() + + def validate(self): + super().validate() + + # Ensure the deblend target aligns with the selected deblender. + if self.deblender == "scarlet": + assert self.deblend.target == scarlet.ScarletDeblendTask + elif self.deblender == "meas_deblender": + assert self.deblend.target == measDeblender.SourceDeblendTask + elif self.deblender is not None: + raise ValueError(f"Invalid deblender value: {self.deblender}") + + def setDefaults(self): + super().setDefaults() + if self.deblender == "scarlet": + self.deblend.retarget(scarlet.ScarletDeblendTask) + elif self.deblender == "meas_deblender": + self.deblend.retarget(measDeblender.SourceDeblendTask) + + +class MeasurementDriverTask(pipeBase.Task): + """A mid-level driver for running detection, deblending (optional), and + measurement algorithms in one go. + + This driver simplifies the process of applying a small set of measurement + algorithms to images by abstracting away schema and table boilerplate. It + is particularly suited for simple use cases, such as processing images + without neighbor-noise-replacement or extensive configuration. + + Designed to streamline the measurement framework, this class integrates + detection, deblending (if enabled), and measurement into a single workflow. + + Parameters + ---------- + schema : `~lsst.afw.table.Schema` + Schema used to create the output `~lsst.afw.table.SourceCatalog`, + modified in place with fields that will be written by this task. + **kwargs : `dict` + Additional kwargs to pass to lsst.pipe.base.Task.__init__() + + Examples + -------- + Here is an example of how to use this class to run detection, deblending, + and measurement on a given exposure: + >>> from lsst.pipe.tasks.measurementDriver import MeasurementDriverTask + >>> import lsst.meas.extensions.shapeHSM # To register its plugins + >>> config = MeasurementDriverTask().ConfigClass() + >>> config.detection.thresholdValue = 5.5 + >>> config.deblender = "meas_deblender" + >>> config.deblend.tinyFootprintSize = 3 + >>> config.measurement.plugins.names |= [ + ... "base_SdssCentroid", + ... "base_SdssShape", + ... "ext_shapeHSM_HsmSourceMoments", + ... ] + >>> config.measurement.slots.psfFlux = None + >>> config.measurement.doReplaceWithNoise = False + >>> exposure = butler.get("deepCoadd", dataId=...) + >>> driver = MeasurementDriverTask(config=config) + >>> catalog = driver.run(exposure) + >>> catalog.writeFits("meas_catalog.fits") + """ + + ConfigClass = MeasurementDriverConfig + _DefaultName = "measurementDriver" + + def __init__(self, schema=None, **kwargs): + super().__init__(**kwargs) + + if schema is None: + # Create a minimal schema that will be extended by tasks. + self.schema = afwTable.SourceTable.makeMinimalSchema() + else: + self.schema = schema + + # Add coordinate error fields to the schema (this is to avoid errors + # such as: "Field with name 'coord_raErr' not found with type 'F'"). + afwTable.CoordKey.addErrorFields(self.schema) + + self.subtasks = ["detection", "deblend", "measurement"] + + def make_subtasks(self): + """Create subtasks based on the current configuration.""" + for name in self.subtasks: + self.makeSubtask(name, schema=self.schema) + + def run( + self, + image, + bands=None, + band=None, + mask=None, + variance=None, + psf=None, + wcs=None, + photo_calib=None, + id_generator=None, + ): + """Run detection, optional deblending, and measurement on a given + image. + + Parameters + ---------- + image: `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` or + `~lsst.afw.image.Image` or `np.ndarray` or + `~lsst.afw.image.MultibandExposure` or + `list` of `~lsst.afw.image.Exposure` + The image on which to detect, deblend and measure sources. If + provided as a multiband exposure, or a list of `Exposure` objects, + it can be taken advantage of by the 'scarlet' deblender. When using + a list of `Exposure` objects, the ``bands`` parameter must also be + provided. + bands: `str` or `list` of `str`, optional + The bands of the input image. Required if ``image`` is provided as + a list of `Exposure` objects. Example: ["g", "r", "i", "z", "y"] + or "grizy". + band: `str`, optional + The target band of the image to use for detection and measurement. + Required when ``image`` is provided as a `MultibandExposure`, or a + list of `Exposure` objects. + mask: `~lsst.afw.image.Mask`, optional + The mask for the input image. Only used if ``image`` is provided + as an afw `Image` or a numpy `ndarray`. + variance: `~lsst.afw.image.Image`, optional + The variance image for the input image. Only used if ``image`` is + provided as an afw `Image` or a numpy `ndarray`. + psf: `~lsst.afw.detection.Psf`, optional + The PSF model for the input image. Will be ignored if ``image`` is + provided as an `Exposure`, `MultibandExposure`, or a list of + `Exposure` objects. + wcs: `~lsst.afw.image.Wcs`, optional + The World Coordinate System (WCS) model for the input image. Will + be ignored if ``image`` is provided as an `Exposure`, + `MultibandExposure`, or a list of `Exposure` objects. + photo_calib : `~lsst.afw.image.PhotoCalib`, optional + Photometric calibration model for the input image. Will be ignored + if ``image`` is provided as an `Exposure`, `MultibandExposure`, or + a list of `Exposure` objects. + id_generator : `~lsst.meas.base.IdGenerator`, optional + Object that generates source IDs and provides random seeds. + + Returns + ------- + catalog : `~lsst.afw.table.SourceCatalog` + The source catalog with all requested measurements. + """ + + # Only make the `deblend` subtask if it is enabled. + if self.config.deblender is None: + self.subtasks.remove("deblend") + + # Validate the configuration before running the task. + self.config.validate() + + # This guarantees the `run` method picks up the current subtask config. + self.make_subtasks() + # N.B. subtasks must be created here to handle reconfigurations, such + # as retargeting the `deblend` subtask, because the `makeSubtask` + # method locks in its config just before creating the subtask. If the + # subtask was already made in __init__ using the initial config, it + # cannot be retargeted now because retargeting happens at the config + # level, not the subtask level. + + if id_generator is None: + id_generator = measBase.IdGenerator() + + if isinstance(image, afwImage.MultibandExposure) or isinstance(image, list): + if self.config.deblender != "scarlet": + self.log.debug( + "Supplied a multiband exposure, or a list of exposures, while the deblender is set to " + f"'{self.config.deblender}'. A single exposure corresponding to target `band` will be " + "used for everything." + ) + if band is None: + raise ValueError( + "The target `band` must be provided when using multiband exposures or a list of " + "exposures." + ) + if isinstance(image, list): + if not all(isinstance(im, afwImage.Exposure) for im in image): + raise ValueError("All elements in the `image` list must be `Exposure` objects.") + if bands is None: + raise ValueError( + "The `bands` parameter must be provided if `image` is a list of `Exposure` objects." + ) + if not isinstance(bands, (str, list)) or ( + isinstance(bands, list) and not all(isinstance(b, str) for b in bands) + ): + raise TypeError( + "The `bands` parameter must be a string or a list of strings if provided." + ) + if len(bands) != len(image): + raise ValueError( + "The number of bands must match the number of `Exposure` objects in the list." + ) + else: + if band is None: + band = "N/A" # Just a placeholder for single-band deblending + else: + self.log.warn("The target `band` is not required when the input image is not multiband.") + if bands is not None: + self.log.warn( + "The `bands` parameter will be ignored because the input image is not multiband." + ) + + if self.config.deblender == "scarlet": + if not isinstance(image, (afwImage.MultibandExposure, list, afwImage.Exposure)): + raise ValueError( + "The `image` parameter must be a `MultibandExposure`, a list of `Exposure` " + "objects, or a single `Exposure` when the deblender is set to 'scarlet'." + ) + if isinstance(image, afwImage.Exposure): + # N.B. scarlet is designed to leverage multiband information to + # differentiate overlapping sources based on their spectral and + # spatial profiles. However, it can also run on a single band + # and still give better results than 'meas_deblender'. + self.log.debug( + "Supplied a single-band exposure, while the deblender is set to 'scarlet'." + "Make sure it was intended." + ) + + # Start with some image conversions if needed. + if isinstance(image, np.ndarray): + image = afwImage.makeImageFromArray(image) + if isinstance(mask, np.ndarray): + mask = afwImage.makeMaskFromArray(mask) + if isinstance(variance, np.ndarray): + variance = afwImage.makeImageFromArray(variance) + if isinstance(image, afwImage.Image): + image = afwImage.makeMaskedImage(image, mask, variance) + + # Avoid type checker errors by being explicit from here on. + exposure: afwImage.Exposure + + # Make sure we have an `Exposure` object to work with (potentially + # along with a `MultiBandExposure` for scarlet deblending). + if isinstance(image, afwImage.Exposure): + exposure = image + elif isinstance(image, afwImage.MaskedImage): + exposure = afwImage.makeExposure(image, wcs) + if psf is not None: + exposure.setPsf(psf) + if photo_calib is not None: + exposure.setPhotoCalib(photo_calib) + elif isinstance(image, list): + # Construct a multiband exposure for scarlet deblending. + exposures = afwImage.MultibandExposure.fromExposures(bands, image) + # Select the exposure of the desired band, which will be used for + # detection and measurement. + exposure = exposures[band] + elif isinstance(image, afwImage.MultibandExposure): + exposures = image + exposure = exposures[band] + else: + raise TypeError(f"Unsupported image type: {type(image)}") + + # Create a source table into which detections will be placed. + table = afwTable.SourceTable.make(self.schema, id_generator.make_table_id_factory()) + + # Detect sources and get a source catalog. + self.log.info(f"Running detection on a {exposure.width}x{exposure.height} pixel image") + detections = self.detection.run(table, exposure) + catalog = detections.sources + + # Deblend sources into their components and update the catalog. + if self.config.deblender is None: + self.log.info("Deblending is disabled; skipping deblending") + else: + self.log.info( + f"Running deblending via '{self.config.deblender}' on {len(catalog)} detection footprints" + ) + if self.config.deblender == "meas_deblender": + self.deblend.run(exposure=exposure, sources=catalog) + elif self.config.deblender == "scarlet": + if not isinstance(image, (afwImage.MultibandExposure, list)): + # We need to have a multiband exposure to satisfy scarlet + # function's signature, even when using a single band. + exposures = afwImage.MultibandExposure.fromExposures([band], [exposure]) + catalog, model_data = self.deblend.run(mExposure=exposures, mergedSources=catalog) + # The footprints need to be updated for the subsequent + # measurement. + scarlet.io.updateCatalogFootprints( + modelData=model_data, + catalog=catalog, + band=band, + imageForRedistribution=exposure, + removeScarletData=True, + updateFluxColumns=True, + ) + + # The deblender may not produce a contiguous catalog; ensure contiguity + # for the subsequent task. + if not catalog.isContiguous(): + self.log.info("Catalog is not contiguous; making it contiguous") + catalog = catalog.copy(deep=True) + + # Measure requested quantities on sources. + self.measurement.run(catalog, exposure) + self.log.info( + f"Measured {len(catalog)} sources and stored them in the output " + f"catalog containing {catalog.schema.getFieldCount()} fields" + ) + + return catalog From 08f34c1e819553200a6ebb40053967afb53d9b6c Mon Sep 17 00:00:00 2001 From: Erfan Nourbakhsh Date: Tue, 21 Jan 2025 23:52:50 -0500 Subject: [PATCH 2/3] Address review comments round 1 (to be squashed) --- python/lsst/pipe/tasks/measurementDriver.py | 616 ++++++++++++-------- 1 file changed, 378 insertions(+), 238 deletions(-) diff --git a/python/lsst/pipe/tasks/measurementDriver.py b/python/lsst/pipe/tasks/measurementDriver.py index fbcfcc713..cf8fabe67 100644 --- a/python/lsst/pipe/tasks/measurementDriver.py +++ b/python/lsst/pipe/tasks/measurementDriver.py @@ -19,11 +19,19 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -__all__ = ["MeasurementDriverConfig", "MeasurementDriverTask"] +__all__ = [ + "SingleBandMeasurementDriverConfig", + "SingleBandMeasurementDriverTask", + "MultiBandMeasurementDriverConfig", + "MultiBandMeasurementDriverTask", +] import logging +import lsst.afw.detection as afwDetection +import lsst.afw.geom as afwGeom import lsst.afw.image as afwImage +import lsst.afw.math as afwMath import lsst.afw.table as afwTable import lsst.meas.algorithms as measAlgorithms import lsst.meas.base as measBase @@ -36,110 +44,102 @@ logging.basicConfig(level=logging.INFO) -class MeasurementDriverConfig(pexConfig.Config): - """Configuration parameters for `MeasurementDriverTask`.""" +class MeasurementDriverBaseConfig(pexConfig.Config): + """Base configuration for measurement driver tasks. - # To generate catalog ids consistently across subtasks. - id_generator = measBase.DetectorVisitIdGeneratorConfig.make_field() + This class provides foundational configuration for its subclasses to handle + single-band and multi-band data. It defines the detection, deblending, + and measurement subtasks, which are intended to be executed in sequence + by the driver tasks. + """ + + idGenerator = measBase.DetectorVisitIdGeneratorConfig.make_field( + doc="Configuration for generating catalog IDs from data IDs consistently across subtasks." + ) detection = pexConfig.ConfigurableField( - target=measAlgorithms.SourceDetectionTask, - doc="Task to detect sources to return in the output catalog.", + target=measAlgorithms.SourceDetectionTask, doc="Subtask to detect sources in the image." ) deblender = pexConfig.ChoiceField[str]( - doc="The deblender to use.", + doc="Which deblender to use?", default="meas_deblender", - allowed={"meas_deblender": "Deblend using meas_deblender", "scarlet": "Deblend using scarlet"}, + allowed={ + "meas_deblender": "Deblend using meas_deblender (only single-band)", + "scarlet": "Deblend using scarlet (single- or multi-band)", + }, ) deblend = pexConfig.ConfigurableField( - target=measDeblender.SourceDeblendTask, doc="Split blended sources into their components." + target=measDeblender.SourceDeblendTask, doc="Subtask to split blended sources into components." ) measurement = pexConfig.ConfigurableField( target=measBase.SingleFrameMeasurementTask, - doc="Task to measure sources to return in the output catalog.", + doc="Subtask to measure sources and populate the output catalog", ) def __setattr__(self, key, value): - """Intercept attribute setting to trigger setDefaults when relevant - fields change. - """ + """Intercept changes to 'deblender' and retarget subtask if needed.""" super().__setattr__(key, value) # This is to ensure the deblend target is set correctly whenever the # deblender is changed. This is required because `setDefaults` is not # automatically invoked during reconfiguration. if key == "deblender": - self.setDefaults() - - def validate(self): - super().validate() - - # Ensure the deblend target aligns with the selected deblender. - if self.deblender == "scarlet": - assert self.deblend.target == scarlet.ScarletDeblendTask - elif self.deblender == "meas_deblender": - assert self.deblend.target == measDeblender.SourceDeblendTask - elif self.deblender is not None: - raise ValueError(f"Invalid deblender value: {self.deblender}") + self._retargetDeblend() def setDefaults(self): super().setDefaults() + self._retargetDeblend() + + def _retargetDeblend(self): if self.deblender == "scarlet": self.deblend.retarget(scarlet.ScarletDeblendTask) elif self.deblender == "meas_deblender": self.deblend.retarget(measDeblender.SourceDeblendTask) + def validate(self): + super().validate() + targetMap = { + "scarlet": scarlet.ScarletDeblendTask, + "meas_deblender": measDeblender.SourceDeblendTask, + } + + # Ensure the deblend target aligns with the selected deblender. + if self.deblend.target != (expected := targetMap.get(self.deblender)): + raise ValueError( + f"Invalid target for '{self.deblender}': expected {expected}, got {self.deblend.target}" + ) -class MeasurementDriverTask(pipeBase.Task): - """A mid-level driver for running detection, deblending (optional), and - measurement algorithms in one go. + +class MeasurementDriverBaseTask(pipeBase.Task): + """Base class for the mid-level driver running detection, deblending + (optional), and measurement algorithms in one go. This driver simplifies the process of applying a small set of measurement algorithms to images by abstracting away schema and table boilerplate. It is particularly suited for simple use cases, such as processing images without neighbor-noise-replacement or extensive configuration. - Designed to streamline the measurement framework, this class integrates - detection, deblending (if enabled), and measurement into a single workflow. - Parameters ---------- - schema : `~lsst.afw.table.Schema` + schema : Schema used to create the output `~lsst.afw.table.SourceCatalog`, modified in place with fields that will be written by this task. - **kwargs : `dict` + **kwargs : Additional kwargs to pass to lsst.pipe.base.Task.__init__() - Examples - -------- - Here is an example of how to use this class to run detection, deblending, - and measurement on a given exposure: - >>> from lsst.pipe.tasks.measurementDriver import MeasurementDriverTask - >>> import lsst.meas.extensions.shapeHSM # To register its plugins - >>> config = MeasurementDriverTask().ConfigClass() - >>> config.detection.thresholdValue = 5.5 - >>> config.deblender = "meas_deblender" - >>> config.deblend.tinyFootprintSize = 3 - >>> config.measurement.plugins.names |= [ - ... "base_SdssCentroid", - ... "base_SdssShape", - ... "ext_shapeHSM_HsmSourceMoments", - ... ] - >>> config.measurement.slots.psfFlux = None - >>> config.measurement.doReplaceWithNoise = False - >>> exposure = butler.get("deepCoadd", dataId=...) - >>> driver = MeasurementDriverTask(config=config) - >>> catalog = driver.run(exposure) - >>> catalog.writeFits("meas_catalog.fits") + Notes + ----- + Subclasses (e.g. single-band vs multi-band) override how inputs are built + or validated, but rely on this base for the pipeline logic. """ - ConfigClass = MeasurementDriverConfig - _DefaultName = "measurementDriver" + ConfigClass = MeasurementDriverBaseConfig + _DefaultName = "measurementDriverBase" - def __init__(self, schema=None, **kwargs): + def __init__(self, schema: afwTable.Schema = None, **kwargs: dict): super().__init__(**kwargs) if schema is None: @@ -148,87 +148,46 @@ def __init__(self, schema=None, **kwargs): else: self.schema = schema - # Add coordinate error fields to the schema (this is to avoid errors - # such as: "Field with name 'coord_raErr' not found with type 'F'"). + # Add coordinate error fields to avoid missing field issues in the + # schema. afwTable.CoordKey.addErrorFields(self.schema) - self.subtasks = ["detection", "deblend", "measurement"] + # Standard subtasks to run in sequence. + self.subtaskNames = ["detection", "deblend", "measurement"] - def make_subtasks(self): - """Create subtasks based on the current configuration.""" - for name in self.subtasks: - self.makeSubtask(name, schema=self.schema) + def makeSubtasks(self): + """Construct subtasks based on the current configuration.""" + for name in self.subtaskNames: + if not hasattr(self, name): + self.makeSubtask(name, schema=self.schema) def run( - self, - image, - bands=None, - band=None, - mask=None, - variance=None, - psf=None, - wcs=None, - photo_calib=None, - id_generator=None, - ): + self, exposure: afwImage.Exposure, idGenerator: measBase.IdGenerator = None + ) -> afwTable.SourceCatalog: """Run detection, optional deblending, and measurement on a given image. Parameters ---------- - image: `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` or - `~lsst.afw.image.Image` or `np.ndarray` or - `~lsst.afw.image.MultibandExposure` or - `list` of `~lsst.afw.image.Exposure` - The image on which to detect, deblend and measure sources. If - provided as a multiband exposure, or a list of `Exposure` objects, - it can be taken advantage of by the 'scarlet' deblender. When using - a list of `Exposure` objects, the ``bands`` parameter must also be - provided. - bands: `str` or `list` of `str`, optional - The bands of the input image. Required if ``image`` is provided as - a list of `Exposure` objects. Example: ["g", "r", "i", "z", "y"] - or "grizy". - band: `str`, optional - The target band of the image to use for detection and measurement. - Required when ``image`` is provided as a `MultibandExposure`, or a - list of `Exposure` objects. - mask: `~lsst.afw.image.Mask`, optional - The mask for the input image. Only used if ``image`` is provided - as an afw `Image` or a numpy `ndarray`. - variance: `~lsst.afw.image.Image`, optional - The variance image for the input image. Only used if ``image`` is - provided as an afw `Image` or a numpy `ndarray`. - psf: `~lsst.afw.detection.Psf`, optional - The PSF model for the input image. Will be ignored if ``image`` is - provided as an `Exposure`, `MultibandExposure`, or a list of - `Exposure` objects. - wcs: `~lsst.afw.image.Wcs`, optional - The World Coordinate System (WCS) model for the input image. Will - be ignored if ``image`` is provided as an `Exposure`, - `MultibandExposure`, or a list of `Exposure` objects. - photo_calib : `~lsst.afw.image.PhotoCalib`, optional - Photometric calibration model for the input image. Will be ignored - if ``image`` is provided as an `Exposure`, `MultibandExposure`, or - a list of `Exposure` objects. - id_generator : `~lsst.meas.base.IdGenerator`, optional + exposure : + The exposure on which to detect, deblend and measure sources. + idGenerator : optional Object that generates source IDs and provides random seeds. Returns ------- - catalog : `~lsst.afw.table.SourceCatalog` + catalog : The source catalog with all requested measurements. """ - - # Only make the `deblend` subtask if it is enabled. + # Make the `deblend` subtask only if it is enabled. if self.config.deblender is None: self.subtasks.remove("deblend") - # Validate the configuration before running the task. + # Validate the configuration. self.config.validate() - # This guarantees the `run` method picks up the current subtask config. - self.make_subtasks() + # Ensure this method picks up the current subtask config. + self.makeSubtasks() # N.B. subtasks must be created here to handle reconfigurations, such # as retargeting the `deblend` subtask, because the `makeSubtask` # method locks in its config just before creating the subtask. If the @@ -236,65 +195,166 @@ def run( # cannot be retargeted now because retargeting happens at the config # level, not the subtask level. - if id_generator is None: - id_generator = measBase.IdGenerator() - - if isinstance(image, afwImage.MultibandExposure) or isinstance(image, list): - if self.config.deblender != "scarlet": - self.log.debug( - "Supplied a multiband exposure, or a list of exposures, while the deblender is set to " - f"'{self.config.deblender}'. A single exposure corresponding to target `band` will be " - "used for everything." - ) - if band is None: - raise ValueError( - "The target `band` must be provided when using multiband exposures or a list of " - "exposures." - ) - if isinstance(image, list): - if not all(isinstance(im, afwImage.Exposure) for im in image): - raise ValueError("All elements in the `image` list must be `Exposure` objects.") - if bands is None: - raise ValueError( - "The `bands` parameter must be provided if `image` is a list of `Exposure` objects." - ) - if not isinstance(bands, (str, list)) or ( - isinstance(bands, list) and not all(isinstance(b, str) for b in bands) - ): - raise TypeError( - "The `bands` parameter must be a string or a list of strings if provided." - ) - if len(bands) != len(image): - raise ValueError( - "The number of bands must match the number of `Exposure` objects in the list." - ) + if idGenerator is None: + idGenerator = measBase.IdGenerator() + + self.exposure = exposure + + # Create an empty source table with the known schema into which + # detections will be placed next. + self.catalog = afwTable.SourceTable.make(self.schema, idGenerator.make_table_id_factory()) + + # Step 1: Detect sources in the image and populate the catalog. + self._detectSources() + + # Step 2: If enabled, deblend detected sources and update the catalog. + if self.config.deblender: + self._deblendSources() else: - if band is None: - band = "N/A" # Just a placeholder for single-band deblending - else: - self.log.warn("The target `band` is not required when the input image is not multiband.") - if bands is not None: - self.log.warn( - "The `bands` parameter will be ignored because the input image is not multiband." - ) + self.log.info("Deblending is disabled; skipping deblending") + + # Step 3: Measure properties of detected/deblended sources. + self._measureSources() + + return self.catalog + + def _detectSources(self): + """Run the detection subtask to identify sources in the image.""" + self.log.info(f"Running detection on a {self.exposure.width}x{self.exposure.height} pixel exposure") + self.catalog = self.detection.run(self.catalog, self.exposure).sources + + def _deblendSources(self): + """Run the deblending subtask to separate blended sources.""" + self.log.info( + f"Deblending using '{self.config.deblender}' on {len(self.catalog)} detection footprints" + ) + if self.config.deblender == "meas_deblender": + self.deblend.run(exposure=self.exposure, sources=self.catalog) + elif self.config.deblender == "scarlet": + if not isinstance(self.exposure, afwImage.MultibandExposure): + # We need to have a multiband exposure to satisfy scarlet + # function's signature, even when using a single band. + self.band = "N/A" # Placeholder for single-band deblending + self.mExposure = afwImage.MultibandExposure.fromExposures([self.band], [self.exposure]) + self.catalog, modelData = self.deblend.run(mExposure=self.mExposure, mergedSources=self.catalog) + # The footprints need to be updated for the subsequent measurement. + scarlet.io.updateCatalogFootprints( + modelData=modelData, + catalog=self.catalog, + band=self.band, + imageForRedistribution=None, + removeScarletData=True, + updateFluxColumns=True, + ) + # The deblender may not produce a contiguous catalog; ensure contiguity + # for the subsequent task. + if not self.catalog.isContiguous(): + self.log.info("Catalog is not contiguous; making it contiguous") + self.catalog = self.catalog.copy(deep=True) + + def _measureSources(self): + """Run the measurement subtask to compute properties of sources.""" + isDeblended = "and deblended" if self.config.deblender else "(not deblended)" + self.log.info(f"Measuring {len(self.catalog)} detected {isDeblended} sources") + self.measurement.run(self.catalog, self.exposure) + self.log.info( + f"Measurement complete - output catalog has " f"{self.catalog.schema.getFieldCount()} fields" + ) + +class SingleBandMeasurementDriverConfig(MeasurementDriverBaseConfig): + """Configuration for single-band measurement driver tasks. + + No additional parameters specific to single-band processing is added. + """ + + pass + + +class SingleBandMeasurementDriverTask(MeasurementDriverBaseTask): + """Mid-level driver for processing single-band data. + + Provides an additional interface for handling raw image data that is + specific to single-band scenarios. + + Examples + -------- + Here is an example of how to use this class to run detection, deblending, + and measurement on a single-band exposure: + >>> from lsst.pipe.tasks.measurementDriver import ( + ... SingleBandMeasurementDriverConfig, + ... SingleBandMeasurementDriverTask, + ... ) + >>> import lsst.meas.extensions.shapeHSM # To register its plugins + >>> config = SingleBandMeasurementDriverConfig() + >>> config.detection.thresholdValue = 5.5 + >>> config.deblender = "meas_deblender" + >>> config.deblend.tinyFootprintSize = 3 + >>> config.measurement.plugins.names |= [ + ... "base_SdssCentroid", + ... "base_SdssShape", + ... "ext_shapeHSM_HsmSourceMoments", + ... ] + >>> config.measurement.slots.psfFlux = None + >>> config.measurement.doReplaceWithNoise = False + >>> exposure = butler.get("deepCoadd", dataId=...) + >>> driver = SingleBandMeasurementDriverTask(config=config) + >>> catalog = driver.run(exposure) + >>> catalog.writeFits("meas_catalog.fits") + """ + + _DefaultName = "singleBandMeasurementDriver" + ConfigClass = SingleBandMeasurementDriverConfig + + def run(self, *args, **kwargs): if self.config.deblender == "scarlet": - if not isinstance(image, (afwImage.MultibandExposure, list, afwImage.Exposure)): - raise ValueError( - "The `image` parameter must be a `MultibandExposure`, a list of `Exposure` " - "objects, or a single `Exposure` when the deblender is set to 'scarlet'." - ) - if isinstance(image, afwImage.Exposure): - # N.B. scarlet is designed to leverage multiband information to - # differentiate overlapping sources based on their spectral and - # spatial profiles. However, it can also run on a single band - # and still give better results than 'meas_deblender'. - self.log.debug( - "Supplied a single-band exposure, while the deblender is set to 'scarlet'." - "Make sure it was intended." - ) - - # Start with some image conversions if needed. + # N.B. scarlet is designed to leverage multiband information to + # differentiate overlapping sources based on their spectral and + # spatial profiles. However, it can also run on a single band and + # often give better results than 'meas_deblender'. + self.log.debug("Using 'scarlet' deblender for single-band processing; make sure it was intended") + return super().run(*args, **kwargs) + + def runFromImage( + self, + image: afwImage.MaskedImage | afwImage.Image | np.ndarray, + mask: afwImage.Mask | np.ndarray = None, + variance: afwImage.Image | np.ndarray = None, + wcs: afwGeom.SkyWcs = None, + psf: afwDetection.Psf | np.ndarray = None, + photoCalib: afwImage.PhotoCalib = None, + idGenerator: measBase.IdGenerator = None, + ) -> afwTable.SourceCatalog: + """Convert image data to an `Exposure`, then run it through the + configured subtasks. + + Parameters + ---------- + image : + Input image data. Will be converted into an `Exposure` before + processing. + mask : optional + Mask data for the image. Used if 'image' is a bare `array` or + `Image`. + variance : optional + Variance plane data for the image. + wcs : optional + World Coordinate System to associate with the exposure that will + be created from ``image``. + psf : optional + PSF model for the exposure. + photoCalib : optional + Photometric calibration model for the exposure. + idGenerator : optional + Generator for unique source IDs. + + Returns + ------- + catalog : + Final catalog of measured sources. + """ + # Convert raw image data into an Exposure + # exposure = self._makeExposureFromImage(image, mask, variance, wcs, psf, photoCalib) if isinstance(image, np.ndarray): image = afwImage.makeImageFromArray(image) if isinstance(mask, np.ndarray): @@ -304,76 +364,156 @@ def run( if isinstance(image, afwImage.Image): image = afwImage.makeMaskedImage(image, mask, variance) - # Avoid type checker errors by being explicit from here on. - exposure: afwImage.Exposure - - # Make sure we have an `Exposure` object to work with (potentially - # along with a `MultiBandExposure` for scarlet deblending). - if isinstance(image, afwImage.Exposure): - exposure = image - elif isinstance(image, afwImage.MaskedImage): + # By now, the input should already be - or have been converted to - a + # MaskedImage. + if isinstance(image, afwImage.MaskedImage): exposure = afwImage.makeExposure(image, wcs) - if psf is not None: - exposure.setPsf(psf) - if photo_calib is not None: - exposure.setPhotoCalib(photo_calib) - elif isinstance(image, list): - # Construct a multiband exposure for scarlet deblending. - exposures = afwImage.MultibandExposure.fromExposures(bands, image) - # Select the exposure of the desired band, which will be used for - # detection and measurement. - exposure = exposures[band] - elif isinstance(image, afwImage.MultibandExposure): - exposures = image - exposure = exposures[band] else: - raise TypeError(f"Unsupported image type: {type(image)}") + raise TypeError(f"Unsupported 'image' type: {type(image)}") - # Create a source table into which detections will be placed. - table = afwTable.SourceTable.make(self.schema, id_generator.make_table_id_factory()) + if psf is not None: + if isinstance(psf, np.ndarray): + # Create a FixedKernel using the array. + psf /= psf.sum() + kernel = afwMath.FixedKernel(afwImage.makeImageFromArray(psf)) + # Create a KernelPsf using the kernel. + psf = afwDetection.KernelPsf(kernel) + elif not isinstance(psf, afwDetection.Psf): + raise TypeError(f"Unsupported 'psf' type: {type(psf)}") + exposure.setPsf(psf) - # Detect sources and get a source catalog. - self.log.info(f"Running detection on a {exposure.width}x{exposure.height} pixel image") - detections = self.detection.run(table, exposure) - catalog = detections.sources + if photoCalib is not None: + exposure.setPhotoCalib(photoCalib) - # Deblend sources into their components and update the catalog. - if self.config.deblender is None: - self.log.info("Deblending is disabled; skipping deblending") - else: - self.log.info( - f"Running deblending via '{self.config.deblender}' on {len(catalog)} detection footprints" + return self.run(exposure, idGenerator=idGenerator) + + +class MultiBandMeasurementDriverConfig(MeasurementDriverBaseConfig): + """Configuration for multi-band measurement driver tasks. + + Adds a validation check to ensure the 'scarlet' deblender is used. + """ + + def validate(self): + super().validate() + if self.deblender != "scarlet": + raise ValueError( + f"Multi-band deblending requires the 'scarlet' deblender, but got '{self.deblender}'." ) - if self.config.deblender == "meas_deblender": - self.deblend.run(exposure=exposure, sources=catalog) - elif self.config.deblender == "scarlet": - if not isinstance(image, (afwImage.MultibandExposure, list)): - # We need to have a multiband exposure to satisfy scarlet - # function's signature, even when using a single band. - exposures = afwImage.MultibandExposure.fromExposures([band], [exposure]) - catalog, model_data = self.deblend.run(mExposure=exposures, mergedSources=catalog) - # The footprints need to be updated for the subsequent - # measurement. - scarlet.io.updateCatalogFootprints( - modelData=model_data, - catalog=catalog, - band=band, - imageForRedistribution=exposure, - removeScarletData=True, - updateFluxColumns=True, - ) - # The deblender may not produce a contiguous catalog; ensure contiguity - # for the subsequent task. - if not catalog.isContiguous(): - self.log.info("Catalog is not contiguous; making it contiguous") - catalog = catalog.copy(deep=True) - # Measure requested quantities on sources. - self.measurement.run(catalog, exposure) - self.log.info( - f"Measured {len(catalog)} sources and stored them in the output " - f"catalog containing {catalog.schema.getFieldCount()} fields" - ) +class MultiBandMeasurementDriverTask(MeasurementDriverBaseTask): + """Mid-level driver for processing multi-band data. + + Provides functionality for handling a list of single-band exposures in + addition to a multi-band exposure. - return catalog + Examples + -------- + Here is an example of how to use this class to run detection, deblending, + and measurement on a multi-band exposure: + >>> from lsst.afw.image import MultibandExposure + >>> from lsst.pipe.tasks.measurementDriver import ( + ... MultiBandMeasurementDriverConfig, + ... MultiBandMeasurementDriverTask, + ... ) + >>> import lsst.meas.extensions.shapeHSM # To register its plugins + >>> config = MultiBandMeasurementDriverConfig() + >>> config.detection.thresholdValue = 5.5 + >>> config.deblender = "scarlet" + >>> config.deblend.minSNR = 42.0 + >>> config.deblend.maxIter = 20 + >>> config.measurement.plugins.names |= [ + ... "base_SdssCentroid", + ... "base_SdssShape", + ... "ext_shapeHSM_HsmSourceMoments", + ... ] + >>> config.measurement.slots.psfFlux = None + >>> config.measurement.doReplaceWithNoise = False + >>> mExposure = MultibandExposure.fromButler( + ... butler, ["g", "r", "i"], "deepCoadd_calexp", ... + ... ) + >>> driver = MultiBandMeasurementDriverTask(config=config) + >>> catalog = driver.run(mExposure, "r") + >>> catalog.writeFits("meas_catalog.fits") + """ + + ConfigClass = MultiBandMeasurementDriverConfig + _DefaultName = "multiBandMeasurementDriver" + + def run( + self, + mExposure: afwImage.MultibandExposure | list[afwImage.Exposure], + band: str, + bands: list[str] | None = None, + idGenerator: measBase.IdGenerator = None, + ) -> afwTable.SourceCatalog: + """ + Process a multi-band exposure or a list of exposures. + + Parameters + ---------- + mExposure : + Multi-band data. May be a single `MultibandExposure` or a list of + exposures associated with different bands in which case ``bands`` + must be provided. + band : + Reference band to use for detection and measurement. + bands : optional + List of bands associated with the exposures in ``mExposure``. Only + required if ``mExposure`` is a list of single-band exposures. + idGenerator : optional + Generator for unique source IDs. + + Returns + ------- + catalog : + Catalog containing the measured sources. + """ + # Store the reference band for later use. + self.band = band + + # Convert list of exposures to a MultibandExposure if needed. Save the + # result as an instance attribute for later use. + self.mExposure = self._buildMultibandExposure(mExposure, bands) + + if self.band not in self.mExposure: + raise ValueError(f"Requested band '{band}' is not present in the multiband exposure.") + + # Use the reference band for detection and measurement. + exposure = self.mExposure[self.band] + self.log.info(f"Using '{self.band}' band as the reference band for detection and measurement") + + return super().run(exposure, idGenerator=idGenerator) + + def _buildMultibandExposure( + self, exposure: afwImage.MultibandExposure | list[afwImage.Exposure], bands: list[str] | None + ) -> afwImage.MultibandExposure: + """ + Convert a list of single-band exposures to a MultibandExposure if needed. + + Parameters + ---------- + exposure : + Input multi-band data. + bands : optional + List of bands associated with the exposures in ``exposure``. Only + required if ``exposure`` is a list of single-band exposures. + + Returns + ------- + mbExposure : + Converted multi-band exposure. + """ + if isinstance(exposure, afwImage.MultibandExposure): + if bands is not None: + self.log.warn("Ignoring 'bands' argument; using bands from the input MultibandExposure") + return exposure + elif isinstance(exposure, list): + if bands is None: + raise ValueError("List of bands must be provided if 'exposure' is a list") + if len(bands) != len(exposure): + raise ValueError("Number of bands and exposures must match.") + return afwImage.MultibandExposure.fromExposures(bands, exposure) + else: + raise TypeError("'exposure' must be a MultibandExposure or a list of single-band Exposures.") From aa831cecb8a692e82972686b2cce914da191a893 Mon Sep 17 00:00:00 2001 From: Erfan Nourbakhsh Date: Wed, 29 Jan 2025 00:31:22 -0500 Subject: [PATCH 3/3] Address review comments round 2 (to be squashed) --- python/lsst/pipe/tasks/measurementDriver.py | 579 +++++++++++++------- 1 file changed, 375 insertions(+), 204 deletions(-) diff --git a/python/lsst/pipe/tasks/measurementDriver.py b/python/lsst/pipe/tasks/measurementDriver.py index cf8fabe67..ee2e7c804 100644 --- a/python/lsst/pipe/tasks/measurementDriver.py +++ b/python/lsst/pipe/tasks/measurementDriver.py @@ -27,6 +27,7 @@ ] import logging +from abc import ABCMeta, abstractmethod import lsst.afw.detection as afwDetection import lsst.afw.geom as afwGeom @@ -37,96 +38,86 @@ import lsst.meas.base as measBase import lsst.meas.deblender as measDeblender import lsst.meas.extensions.scarlet as scarlet -import lsst.pex.config as pexConfig +from lsst.pex.config import Config, ConfigurableField, Field import lsst.pipe.base as pipeBase import numpy as np logging.basicConfig(level=logging.INFO) -class MeasurementDriverBaseConfig(pexConfig.Config): +class MeasurementDriverBaseConfig(Config): """Base configuration for measurement driver tasks. This class provides foundational configuration for its subclasses to handle single-band and multi-band data. It defines the detection, deblending, - and measurement subtasks, which are intended to be executed in sequence - by the driver tasks. + measurement, aperture correction, and catalog calculation subtasks, which + are intended to be executed in sequence by the driver tasks. """ - idGenerator = measBase.DetectorVisitIdGeneratorConfig.make_field( - doc="Configuration for generating catalog IDs from data IDs consistently across subtasks." + doScaleVariance = Field[bool](doc="Scale variance plane using empirical noise?", default=False) + + scaleVariance = ConfigurableField( + target=measAlgorithms.ScaleVarianceTask, doc="Subtask to rescale variance plane" ) - detection = pexConfig.ConfigurableField( + doDetect = Field[bool](doc="Run the source detection algorithm?", default=True) + + detection = ConfigurableField( target=measAlgorithms.SourceDetectionTask, doc="Subtask to detect sources in the image." ) - deblender = pexConfig.ChoiceField[str]( - doc="Which deblender to use?", - default="meas_deblender", - allowed={ - "meas_deblender": "Deblend using meas_deblender (only single-band)", - "scarlet": "Deblend using scarlet (single- or multi-band)", - }, - ) + doDeblend = Field[bool](doc="Run the source deblending algorithm?", default=True) + # N.B. The 'deblend' configurable field should be defined in subclasses. - deblend = pexConfig.ConfigurableField( - target=measDeblender.SourceDeblendTask, doc="Subtask to split blended sources into components." - ) + doMeasure = Field[bool](doc="Run the source measurement algorithm?", default=True) - measurement = pexConfig.ConfigurableField( + measurement = ConfigurableField( target=measBase.SingleFrameMeasurementTask, doc="Subtask to measure sources and populate the output catalog", ) - def __setattr__(self, key, value): - """Intercept changes to 'deblender' and retarget subtask if needed.""" - super().__setattr__(key, value) - - # This is to ensure the deblend target is set correctly whenever the - # deblender is changed. This is required because `setDefaults` is not - # automatically invoked during reconfiguration. - if key == "deblender": - self._retargetDeblend() - - def setDefaults(self): - super().setDefaults() - self._retargetDeblend() - - def _retargetDeblend(self): - if self.deblender == "scarlet": - self.deblend.retarget(scarlet.ScarletDeblendTask) - elif self.deblender == "meas_deblender": - self.deblend.retarget(measDeblender.SourceDeblendTask) - - def validate(self): - super().validate() - targetMap = { - "scarlet": scarlet.ScarletDeblendTask, - "meas_deblender": measDeblender.SourceDeblendTask, - } - - # Ensure the deblend target aligns with the selected deblender. - if self.deblend.target != (expected := targetMap.get(self.deblender)): - raise ValueError( - f"Invalid target for '{self.deblender}': expected {expected}, got {self.deblend.target}" - ) + psfCache = Field[int](doc="Size of psfCache", default=100) + checkUnitsParseStrict = Field[str]( + doc="Strictness of Astropy unit compatibility check, can be 'raise', 'warn' or 'silent'", + default="raise", + ) -class MeasurementDriverBaseTask(pipeBase.Task): - """Base class for the mid-level driver running detection, deblending - (optional), and measurement algorithms in one go. + doApCorr = Field[bool]( + doc="Apply aperture corrections? If yes, your image must have an aperture correction map", + default=False, + ) + + applyApCorr = ConfigurableField( + doc="Subtask to apply aperture corrections", + target=measBase.ApplyApCorrTask, + ) + + doRunCatalogCalculation = Field[bool](doc="Run catalogCalculation task?", default=False) + + catalogCalculation = ConfigurableField( + target=measBase.CatalogCalculationTask, doc="Subtask to run catalogCalculation plugins on catalog" + ) + + +class MeasurementDriverBaseTask(pipeBase.Task, metaclass=ABCMeta): + """Base class for the mid-level driver running detection, deblending, + measurement algorithms, apperture correction, and catalog calculation in + one go. This driver simplifies the process of applying a small set of measurement - algorithms to images by abstracting away schema and table boilerplate. It - is particularly suited for simple use cases, such as processing images - without neighbor-noise-replacement or extensive configuration. + algorithms to images by abstracting away Schema and table boilerplate. + Also, users don't need to Butlerize their input data. It is particularly + suited for simple use cases, such as processing images without + neighbor-noise-replacement or extensive configuration. Parameters ---------- schema : Schema used to create the output `~lsst.afw.table.SourceCatalog`, modified in place with fields that will be written by this task. + peakSchema : + Schema of Footprint Peaks that will be passed to the deblender. **kwargs : Additional kwargs to pass to lsst.pipe.base.Task.__init__() @@ -138,157 +129,249 @@ class MeasurementDriverBaseTask(pipeBase.Task): ConfigClass = MeasurementDriverBaseConfig _DefaultName = "measurementDriverBase" + _Deblender = "" - def __init__(self, schema: afwTable.Schema = None, **kwargs: dict): + def __init__(self, schema: afwTable.Schema = None, peakSchema: afwTable.Schema = None, **kwargs: dict): super().__init__(**kwargs) - if schema is None: - # Create a minimal schema that will be extended by tasks. - self.schema = afwTable.SourceTable.makeMinimalSchema() + # Schema for the output catalog. + self.schema = schema + + # Schema for deblender peaks. + self.peakSchema = peakSchema + + # Placeholders for subclasses to populate. + self.scaleVariance: measAlgorithms.ScaleVarianceTask + self.detection: measAlgorithms.SourceDetectionTask + self.deblend: measDeblender.SourceDeblendTask | scarlet.ScarletDeblendTask + self.measure: measBase.SingleFrameMeasurementTask + self.applyApCorr: measBase.ApplyApCorrTask + self.catalogCalculation: measBase.CatalogCalculationTask + self.exposure: afwImage.Exposure + self.catalog: afwTable.SourceCatalog + self.idGenerator: measBase.IdGenerator + + def _initializeSchema(self): + """Initialize the Schema to be used for constructing the subtasks. + + Might seem a bit clunky, but this workaround is necessary to ensure + that the Schema is consistent across all subtasks. + """ + if self.catalog is None: + if self.schema is None: + # Create a minimal Schema that will be extended by tasks. + self.schema = afwTable.SourceTable.makeMinimalSchema() + + # Add coordinate error fields to avoid missing field issues. + afwTable.CoordKey.addErrorFields(self.schema) else: - self.schema = schema + # Since a catalog is provided, use its Schema as the base. + catalogSchema = self.catalog.schema - # Add coordinate error fields to avoid missing field issues in the - # schema. - afwTable.CoordKey.addErrorFields(self.schema) + # Create a SchemaMapper that maps from catalogSchema to a new one + # it will create. + self.mapper = afwTable.SchemaMapper(catalogSchema) - # Standard subtasks to run in sequence. - self.subtaskNames = ["detection", "deblend", "measurement"] + # Add everything from catalogSchema to output Schema. + self.mapper.addMinimalSchema(catalogSchema, True) - def makeSubtasks(self): - """Construct subtasks based on the current configuration.""" - for name in self.subtaskNames: - if not hasattr(self, name): - self.makeSubtask(name, schema=self.schema) + # Get the output Schema from the SchemaMapper and assign it as the + # Schema to be used for constructing the subtasks. + self.schema = self.mapper.getOutputSchema() - def run( - self, exposure: afwImage.Exposure, idGenerator: measBase.IdGenerator = None - ) -> afwTable.SourceCatalog: - """Run detection, optional deblending, and measurement on a given - image. + def _makeSubtasks(self): + """Construct subtasks based on the configuration and the Schema.""" + if self.config.doScaleVariance and not hasattr(self, "scaleVariance"): + self.makeSubtask("scaleVariance") - Parameters - ---------- - exposure : - The exposure on which to detect, deblend and measure sources. - idGenerator : optional - Object that generates source IDs and provides random seeds. + if self.config.doDetect and not hasattr(self, "detection"): + self.makeSubtask("detection", schema=self.schema) - Returns - ------- - catalog : - The source catalog with all requested measurements. - """ - # Make the `deblend` subtask only if it is enabled. - if self.config.deblender is None: - self.subtasks.remove("deblend") + if self.config.doDeblend and not hasattr(self, "deblend"): + self.makeSubtask("deblend", schema=self.schema, peakSchema=self.peakSchema) - # Validate the configuration. - self.config.validate() + if self.config.doMeasure and not hasattr(self, "measurement"): + self.makeSubtask("measurement", schema=self.schema) - # Ensure this method picks up the current subtask config. - self.makeSubtasks() - # N.B. subtasks must be created here to handle reconfigurations, such - # as retargeting the `deblend` subtask, because the `makeSubtask` - # method locks in its config just before creating the subtask. If the - # subtask was already made in __init__ using the initial config, it - # cannot be retargeted now because retargeting happens at the config - # level, not the subtask level. + if self.config.doApCorr and not hasattr(self, "applyApCorr"): + self.makeSubtask("applyApCorr", schema=self.schema) - if idGenerator is None: - idGenerator = measBase.IdGenerator() + if self.config.doRunCatalogCalculation and not hasattr(self, "catalogCalculation"): + self.makeSubtask("catalogCalculation", schema=self.schema) - self.exposure = exposure + # Check that all units in the Schema are valid Astropy unit strings. + self.schema.checkUnits(parse_strict=self.config.checkUnitsParseStrict) - # Create an empty source table with the known schema into which - # detections will be placed next. - self.catalog = afwTable.SourceTable.make(self.schema, idGenerator.make_table_id_factory()) + def _updateCatalogSchema(self): + """Update the Schema of the provided catalog to incorporate changes + made by the configured subtasks. + """ + # Create an empty catalog with the Schema required by the subtasks that + # are configured to run. + newCatalog = afwTable.SourceCatalog(self.schema) + + # Transfer all records from the original catalog to the new catalog, + # using the SchemaMapper to copy values. + newCatalog.extend(self.catalog, mapper=self.mapper) - # Step 1: Detect sources in the image and populate the catalog. - self._detectSources() + # Replace the original catalog with the updated one, preserving the + # records while applying the updated Schema. + self.catalog = newCatalog - # Step 2: If enabled, deblend detected sources and update the catalog. - if self.config.deblender: + @abstractmethod + def run(self) -> afwTable.SourceCatalog: + """Abstract method to run detection, deblending, measurement, aperture + correction, and catalog calculation on a given exposure. + + Returns + ------- + catalog : + The source catalog with all requested measurements. + """ + + # Set up the Schema before creating subtasks. + self._initializeSchema() + + # Create subtasks, passing the same Schema to each subtask's + # constructor if need be. + self._makeSubtasks() + + # Adjust the catalog Schema to align with changes made by the subtasks. + if self.catalog is not None: + self._updateCatalogSchema() + + # Generate catalog IDs consistently across subtasks. + if self.idGenerator is None: + self.idGenerator = measBase.IdGenerator() + + # Set psfcache. + self.exposure.getPsf().setCacheCapacity(self.config.psfCache) + + # Scale variance plane. + if self.config.doScaleVariance: + varScale = self.scaleVariance.run(self.exposure.maskedImage) + self.exposure.getMetadata().add("VARIANCE_SCALE", varScale) + + if self.config.doDetect: + if self.catalog is None: + # Create an empty source table with the known Schema into which + # detected sources will be placed next. + self.table = afwTable.SourceTable.make(self.schema, self.idGenerator.make_table_id_factory()) + else: + raise RuntimeError( + "An input catalog was given to bypass detection, but detection is still on." + ) + else: + if self.catalog is None: + raise RuntimeError("Cannot run without detection if no catalog is provided.") + else: + self.log.info("Using detections from provided catalog; skipping detection") + + # Detect sources in the image and populate the catalog. + if self.config.doDetect: + self._detectSources() + + # Deblend detected sources and update the catalog. + if self.config.doDeblend: + self.log.info(f"Deblending using '{self._Deblender}' on {len(self.catalog)} detection footprints") self._deblendSources() + # The deblender may not produce a contiguous catalog; ensure + # contiguity for the subsequent task. + if not self.catalog.isContiguous(): + self.log.info("Catalog is not contiguous; making it contiguous") + self.catalog = self.catalog.copy(deep=True) else: self.log.info("Deblending is disabled; skipping deblending") - # Step 3: Measure properties of detected/deblended sources. - self._measureSources() + # Measure properties of detected/deblended sources. + if self.config.doMeasure: + self._measureSources() + + # Apply aperture corrections to the catalog. + if self.config.doApCorr: + self._applyApCorr() + + # Ensure contiguity again. + if not self.catalog.isContiguous(): + self.catalog = self.catalog.copy(deep=True) + + # Run catalogCalculation on the catalog. + if self.config.doRunCatalogCalculation: + self._runCatalogCalculation() + + self.log.info( + f"Run complete; output catalog has {self.catalog.schema.getFieldCount()} " + f"fields and {len(self.catalog)} records" + ) return self.catalog def _detectSources(self): """Run the detection subtask to identify sources in the image.""" self.log.info(f"Running detection on a {self.exposure.width}x{self.exposure.height} pixel exposure") - self.catalog = self.detection.run(self.catalog, self.exposure).sources + self.catalog = self.detection.run(self.table, self.exposure).sources + @abstractmethod def _deblendSources(self): - """Run the deblending subtask to separate blended sources.""" - self.log.info( - f"Deblending using '{self.config.deblender}' on {len(self.catalog)} detection footprints" - ) - if self.config.deblender == "meas_deblender": - self.deblend.run(exposure=self.exposure, sources=self.catalog) - elif self.config.deblender == "scarlet": - if not isinstance(self.exposure, afwImage.MultibandExposure): - # We need to have a multiband exposure to satisfy scarlet - # function's signature, even when using a single band. - self.band = "N/A" # Placeholder for single-band deblending - self.mExposure = afwImage.MultibandExposure.fromExposures([self.band], [self.exposure]) - self.catalog, modelData = self.deblend.run(mExposure=self.mExposure, mergedSources=self.catalog) - # The footprints need to be updated for the subsequent measurement. - scarlet.io.updateCatalogFootprints( - modelData=modelData, - catalog=self.catalog, - band=self.band, - imageForRedistribution=None, - removeScarletData=True, - updateFluxColumns=True, - ) - # The deblender may not produce a contiguous catalog; ensure contiguity - # for the subsequent task. - if not self.catalog.isContiguous(): - self.log.info("Catalog is not contiguous; making it contiguous") - self.catalog = self.catalog.copy(deep=True) + """Run the deblending subtask to separate blended sources. Subclasses + must implement this method to handle task-specific deblending logic. + """ + raise NotImplementedError(f"{self.__class__.__name__} has not implemented '_deblendSources'.") def _measureSources(self): """Run the measurement subtask to compute properties of sources.""" - isDeblended = "and deblended" if self.config.deblender else "(not deblended)" - self.log.info(f"Measuring {len(self.catalog)} detected {isDeblended} sources") - self.measurement.run(self.catalog, self.exposure) - self.log.info( - f"Measurement complete - output catalog has " f"{self.catalog.schema.getFieldCount()} fields" + deblendedInfo = "and deblended" if self.config.doDeblend else "(not deblended)" + self.log.info(f"Measuring {len(self.catalog)} detected {deblendedInfo} sources") + self.measurement.run( + measCat=self.catalog, exposure=self.exposure, exposureId=self.idGenerator.catalog_id ) + def _applyApCorr(self): + """Apply aperture corrections to the catalog.""" + apCorrMap = self.exposure.getInfo().getApCorrMap() + if apCorrMap is None: + self.log.warning( + "Image does not have valid aperture correction map for catalog id " + f"{self.idGenerator.catalog_id}; skipping aperture correction" + ) + else: + self.log.info("Applying aperture corrections to the catalog") + self.applyApCorr.run(catalog=self.catalog, apCorrMap=apCorrMap) -class SingleBandMeasurementDriverConfig(MeasurementDriverBaseConfig): - """Configuration for single-band measurement driver tasks. + def _runCatalogCalculation(self): + """Run the catalogCalculation subtask to compute properties of sources.""" + self.log.info(f"Running catalogCalculation on {len(self.catalog)} sources") + self.catalogCalculation.run(self.catalog) - No additional parameters specific to single-band processing is added. - """ - pass +class SingleBandMeasurementDriverConfig(MeasurementDriverBaseConfig): + """Configuration for the single-band measurement driver task.""" + + deblend = ConfigurableField(target=measDeblender.SourceDeblendTask, doc="Deblender for single-band data.") class SingleBandMeasurementDriverTask(MeasurementDriverBaseTask): """Mid-level driver for processing single-band data. Provides an additional interface for handling raw image data that is - specific to single-band scenarios. + specific to single-band processing. Examples -------- - Here is an example of how to use this class to run detection, deblending, - and measurement on a single-band exposure: + Here is an example of how to use this class to run variance scaling, + detection, deblending, and measurement on a single-band exposure: >>> from lsst.pipe.tasks.measurementDriver import ( ... SingleBandMeasurementDriverConfig, ... SingleBandMeasurementDriverTask, ... ) >>> import lsst.meas.extensions.shapeHSM # To register its plugins >>> config = SingleBandMeasurementDriverConfig() + >>> config.doScaleVariance = True + >>> config.doDetect = True + >>> config.doDeblend = True + >>> config.doMeasure = True + >>> config.scaleVariance.background.binSize = 64 >>> config.detection.thresholdValue = 5.5 - >>> config.deblender = "meas_deblender" >>> config.deblend.tinyFootprintSize = 3 >>> config.measurement.plugins.names |= [ ... "base_SdssCentroid", @@ -303,17 +386,38 @@ class SingleBandMeasurementDriverTask(MeasurementDriverBaseTask): >>> catalog.writeFits("meas_catalog.fits") """ - _DefaultName = "singleBandMeasurementDriver" ConfigClass = SingleBandMeasurementDriverConfig + _DefaultName = "singleBandMeasurementDriver" + _Deblender = "meas_deblender" - def run(self, *args, **kwargs): - if self.config.deblender == "scarlet": - # N.B. scarlet is designed to leverage multiband information to - # differentiate overlapping sources based on their spectral and - # spatial profiles. However, it can also run on a single band and - # often give better results than 'meas_deblender'. - self.log.debug("Using 'scarlet' deblender for single-band processing; make sure it was intended") - return super().run(*args, **kwargs) + def run( + self, + exposure: afwImage.Exposure, + catalog: afwTable.SourceCatalog = None, + idGenerator: measBase.IdGenerator = None, + ) -> afwTable.SourceCatalog: + """Process a single-band exposure. + + Parameters + ---------- + exposure : + The exposure on which to detect, deblend, and measure sources. + catalog : optional + Catalog to be extended by the driver task. If not provided, a new + catalog will be created either from the user-provided Schema or a + minimal Schema. It will then be populated with detected sources. + idGenerator : optional + Object that generates source IDs and provides random seeds. + + Returns + ------- + catalog : + Catalog containing the measured sources. + """ + self.exposure = exposure + self.catalog = catalog + self.idGenerator = idGenerator + return super().run() def runFromImage( self, @@ -334,7 +438,7 @@ def runFromImage( Input image data. Will be converted into an `Exposure` before processing. mask : optional - Mask data for the image. Used if 'image' is a bare `array` or + Mask data for the image. Used if ``image`` is a bare `array` or `Image`. variance : optional Variance plane data for the image. @@ -353,8 +457,7 @@ def runFromImage( catalog : Final catalog of measured sources. """ - # Convert raw image data into an Exposure - # exposure = self._makeExposureFromImage(image, mask, variance, wcs, psf, photoCalib) + # Convert raw image data into an Exposure. if isinstance(image, np.ndarray): image = afwImage.makeImageFromArray(image) if isinstance(mask, np.ndarray): @@ -387,31 +490,40 @@ def runFromImage( return self.run(exposure, idGenerator=idGenerator) + def _deblendSources(self): + self.deblend.run(exposure=self.exposure, sources=self.catalog) + class MultiBandMeasurementDriverConfig(MeasurementDriverBaseConfig): - """Configuration for multi-band measurement driver tasks. + """Configuration for the multi-band measurement driver task.""" - Adds a validation check to ensure the 'scarlet' deblender is used. - """ + deblend = ConfigurableField( + target=scarlet.ScarletDeblendTask, doc="Scarlet deblender for multi-band data." + ) - def validate(self): - super().validate() - if self.deblender != "scarlet": - raise ValueError( - f"Multi-band deblending requires the 'scarlet' deblender, but got '{self.deblender}'." - ) + doConserveFlux = Field[bool]( + doc="Whether to use the deblender models as templates to re-distribute the flux from " + "the 'exposure' (True), or to perform measurements on the deblender model footprints.", + default=False, + ) + + doStripHeavyFootprints = Field[bool]( + doc="Whether to strip heavy footprints from the output catalog before saving to disk. " + "This is usually done when using scarlet models to save disk space.", + default=True, + ) class MultiBandMeasurementDriverTask(MeasurementDriverBaseTask): """Mid-level driver for processing multi-band data. - Provides functionality for handling a list of single-band exposures in - addition to a multi-band exposure. + Provides functionality for handling a singe-band exposure and a list of + single-band exposures in addition to a standard multi-band exposure. Examples -------- - Here is an example of how to use this class to run detection, deblending, - and measurement on a multi-band exposure: + Here is an example of how to use this class to run variance scaling, + detection, deblending, and measurement on a multi-band exposure: >>> from lsst.afw.image import MultibandExposure >>> from lsst.pipe.tasks.measurementDriver import ( ... MultiBandMeasurementDriverConfig, @@ -419,8 +531,12 @@ class MultiBandMeasurementDriverTask(MeasurementDriverBaseTask): ... ) >>> import lsst.meas.extensions.shapeHSM # To register its plugins >>> config = MultiBandMeasurementDriverConfig() + >>> config.doScaleVariance = True + >>> config.doDetect = True + >>> config.doDeblend = True + >>> config.doMeasure = True + >>> config.scaleVariance.background.binSize = 64 >>> config.detection.thresholdValue = 5.5 - >>> config.deblender = "scarlet" >>> config.deblend.minSNR = 42.0 >>> config.deblend.maxIter = 20 >>> config.measurement.plugins.names |= [ @@ -440,12 +556,14 @@ class MultiBandMeasurementDriverTask(MeasurementDriverBaseTask): ConfigClass = MultiBandMeasurementDriverConfig _DefaultName = "multiBandMeasurementDriver" + _Deblender = "scarlet" def run( self, mExposure: afwImage.MultibandExposure | list[afwImage.Exposure], - band: str, + band: str | None = None, bands: list[str] | None = None, + catalog: afwTable.SourceCatalog = None, idGenerator: measBase.IdGenerator = None, ) -> afwTable.SourceCatalog: """ @@ -457,11 +575,14 @@ def run( Multi-band data. May be a single `MultibandExposure` or a list of exposures associated with different bands in which case ``bands`` must be provided. - band : + band : optional Reference band to use for detection and measurement. bands : optional List of bands associated with the exposures in ``mExposure``. Only required if ``mExposure`` is a list of single-band exposures. + catalog : optional + Catalog to be extended by the driver task. If not provided, a new + catalog will be created and populated. idGenerator : optional Generator for unique source IDs. @@ -470,50 +591,100 @@ def run( catalog : Catalog containing the measured sources. """ + + # Basic sanity checks to ensure the inputs are consistent. + if (band is None) != (bands is None): + raise ValueError("'band' and 'bands' must be provided together or not at all.") + if band is not None and bands is not None: + if band not in bands: + raise ValueError(f"Reference band '{band}' is not in the list of bands: {bands}") + # Store the reference band for later use. self.band = band - # Convert list of exposures to a MultibandExposure if needed. Save the - # result as an instance attribute for later use. + # Convert mExposure to a MultibandExposure object if not already in + # that form. Save the result as an instance attribute for later use. self.mExposure = self._buildMultibandExposure(mExposure, bands) if self.band not in self.mExposure: raise ValueError(f"Requested band '{band}' is not present in the multiband exposure.") - # Use the reference band for detection and measurement. - exposure = self.mExposure[self.band] - self.log.info(f"Using '{self.band}' band as the reference band for detection and measurement") + # We use a reference band for band-specific tasks like detection and + # measurement. + self.exposure = self.mExposure[self.band] + self.log.info(f"Using '{self.band}' band as the reference band for band-specific tasks") - return super().run(exposure, idGenerator=idGenerator) + self.catalog = catalog + self.idGenerator = idGenerator + + return super().run() + + def _deblendSources(self): + self.catalog, modelData = self.deblend.run(mExposure=self.mExposure, mergedSources=self.catalog) + + # The footprints need to be updated for the subsequent measurement. + if self.config.doConserveFlux: + imageForRedistribution = self.exposure + else: + imageForRedistribution = None + scarlet.io.updateCatalogFootprints( + modelData=modelData, + catalog=self.catalog, + band=self.band, + imageForRedistribution=imageForRedistribution, + removeScarletData=True, + updateFluxColumns=True, + ) + + # Strip HeavyFootprints to save space on disk. + if self.config.doStripHeavyFootprints: + sources = self.catalog + for source in sources[sources["parent"] != 0]: + source.setFootprint(None) def _buildMultibandExposure( - self, exposure: afwImage.MultibandExposure | list[afwImage.Exposure], bands: list[str] | None + self, mExposure: afwImage.MultibandExposure | list[afwImage.Exposure], bands: list[str] | None ) -> afwImage.MultibandExposure: - """ - Convert a list of single-band exposures to a MultibandExposure if needed. + """Convert a single-band exposure or a list of single-band exposures to + a `MultibandExposure` if not already of that type. + + No conversion is done if the input is already a `MultibandExposure`. Parameters ---------- - exposure : + mExposure : Input multi-band data. bands : optional - List of bands associated with the exposures in ``exposure``. Only - required if ``exposure`` is a list of single-band exposures. + List of bands associated with the exposures in ``mExposure``. Only + required if ``mExposure`` is a list of single-band exposures. Returns ------- - mbExposure : + mExposure : Converted multi-band exposure. """ - if isinstance(exposure, afwImage.MultibandExposure): + if isinstance(mExposure, afwImage.MultibandExposure): if bands is not None: - self.log.warn("Ignoring 'bands' argument; using bands from the input MultibandExposure") - return exposure - elif isinstance(exposure, list): + self.log.warn("Ignoring 'bands' argument; using bands from the input `MultibandExposure`") + return mExposure + elif isinstance(mExposure, list): if bands is None: - raise ValueError("List of bands must be provided if 'exposure' is a list") - if len(bands) != len(exposure): + raise ValueError("List of bands must be provided if 'mExposure' is a list") + if len(bands) != len(mExposure): raise ValueError("Number of bands and exposures must match.") - return afwImage.MultibandExposure.fromExposures(bands, exposure) + return afwImage.MultibandExposure.fromExposures(bands, mExposure) + elif isinstance(mExposure, afwImage.Exposure): + # N.B. Scarlet is designed to leverage multiband information to + # differentiate overlapping sources based on their spectral and + # spatial profiles. However, it can also run on a single band and + # often give better results than 'meas_deblender'. + self.log.debug("Using 'scarlet' deblender for single-band processing; make sure it was intended!") + if self.band is None: + self.band = "N/A" # Placeholder for single-band deblending + if bands is None: + bands = [self.band] + # We need to have a multiband exposure to satisfy scarlet + # function's signature, even when using a single band. + return afwImage.MultibandExposure.fromExposures(bands, [mExposure]) else: - raise TypeError("'exposure' must be a MultibandExposure or a list of single-band Exposures.") + raise TypeError(f"Unsupported 'mExposure' type: {type(mExposure)}")