diff --git a/tests/test_defineVisits.py b/tests/test_defineVisits.py index f0e13a40..e09281fe 100644 --- a/tests/test_defineVisits.py +++ b/tests/test_defineVisits.py @@ -28,7 +28,7 @@ import lsst.daf.butler.tests as butlerTests from lsst.daf.butler import DataCoordinate, DimensionRecord, SerializedDimensionRecord -from lsst.obs.base import DefineVisitsTask +from lsst.obs.base import DefineVisitsConfig, DefineVisitsTask from lsst.obs.base.instrument_tests import DummyCam from lsst.utils.iteration import ensure_iterable @@ -36,10 +36,10 @@ DATADIR = os.path.join(TESTDIR, "data", "visits") -class DefineVisitsTestCase(unittest.TestCase): - """Test visit definition.""" +class DefineVisitsBase: + """General set up that can be shared.""" - def setUp(self): + def setUpExposures(self): """Create a new butler for each test since we are changing dimension records. """ @@ -47,7 +47,7 @@ def setUp(self): self.creatorButler = butlerTests.makeTestRepo(self.root, []) self.butler = butlerTests.makeTestCollection(self.creatorButler, uniqueId=self.id()) - self.config = DefineVisitsTask.ConfigClass() + self.config = self.get_config() self.task = DefineVisitsTask(config=self.config, butler=self.butler) # Need to register the instrument. @@ -66,10 +66,42 @@ def setUp(self): simple = SerializedDimensionRecord.model_validate_json(fh.read()) self.records[i] = DimensionRecord.from_simple(simple, registry=self.butler.registry) + def define_visits( + self, + exposures: list[DimensionRecord | list[DimensionRecord]], + incremental: bool, + ) -> None: + for records in exposures: + records = list(ensure_iterable(records)) + if "group" in self.butler.dimensions["exposure"].implied: + # This is a group + day_obs universe. + for rec in records: + self.butler.registry.syncDimensionData( + "group", dict(instrument=rec.instrument, name=rec.group) + ) + self.butler.registry.syncDimensionData( + "day_obs", dict(instrument=rec.instrument, id=rec.day_obs) + ) + + self.butler.registry.insertDimensionData("exposure", *records) + # Include all records so far in definition. + dataIds = list(self.butler.registry.queryDataIds("exposure", instrument="DummyCam")) + self.task.run(dataIds, incremental=incremental) + + +class DefineVisitsTestCase(unittest.TestCase, DefineVisitsBase): + """Test visit definition.""" + + def setUp(self): + self.setUpExposures() + def tearDown(self): if self.root is not None: shutil.rmtree(self.root, ignore_errors=True) + def get_config(self) -> DefineVisitsConfig: + return DefineVisitsTask.ConfigClass() + def assertVisits(self): """Check that the visits were registered as expected.""" visits = list(self.butler.registry.queryDimensionRecords("visit")) @@ -94,26 +126,6 @@ def assertVisits(self): }, ) - def define_visits( - self, exposures: list[DimensionRecord | list[DimensionRecord]], incremental: bool - ) -> None: - for records in exposures: - records = list(ensure_iterable(records)) - if "group" in self.butler.dimensions["exposure"].implied: - # This is a group + day_obs universe. - for rec in records: - self.butler.registry.syncDimensionData( - "group", dict(instrument=rec.instrument, name=rec.group) - ) - self.butler.registry.syncDimensionData( - "day_obs", dict(instrument=rec.instrument, id=rec.day_obs) - ) - - self.butler.registry.insertDimensionData("exposure", *records) - # Include all records so far in definition. - dataIds = list(self.butler.registry.queryDataIds("exposure", instrument="DummyCam")) - self.task.run(dataIds, incremental=incremental) - def test_defineVisits(self): # Test visit definition with all the records. self.define_visits([list(self.records.values())], incremental=False) # list inside a list @@ -174,5 +186,46 @@ def testPickleTask(self): self.assertEqual(self.task.universe, copy.universe) +class DefineVisitsGroupingTestCase(unittest.TestCase, DefineVisitsBase): + """Test visit grouping by group metadata.""" + + def setUp(self): + self.setUpExposures() + + def tearDown(self): + if self.root is not None: + shutil.rmtree(self.root, ignore_errors=True) + + def get_config(self) -> DefineVisitsConfig: + config = DefineVisitsTask.ConfigClass() + config.groupExposures.name = "by-group-metadata" + return config + + def test_defineVisits(self): + # Test visit definition with all the records. + self.define_visits([list(self.records.values())], incremental=False) # list inside a list + self.assertVisits() + + def assertVisits(self): + """Check that the visits were registered as expected.""" + visits = list(self.butler.registry.queryDimensionRecords("visit")) + self.assertEqual(len(visits), 2) + self.assertEqual({visit.id for visit in visits}, {2291434132550000, 2291434871810000}) + + # Ensure that the definitions are correct (ignoring order). + defmap = defaultdict(set) + definitions = list(self.butler.registry.queryDimensionRecords("visit_definition")) + for defn in definitions: + defmap[defn.visit].add(defn.exposure) + + self.assertEqual( + dict(defmap), + { + 2291434132550000: {2022040500347}, + 2291434871810000: {2022040500348, 2022040500349}, + }, + ) + + if __name__ == "__main__": unittest.main()