Skip to content

Commit

Permalink
Add test for group-by-metadata visit definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
timj committed Jan 31, 2024
1 parent d5ed208 commit a147b68
Showing 1 changed file with 78 additions and 25 deletions.
103 changes: 78 additions & 25 deletions tests/test_defineVisits.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,26 @@

import lsst.daf.butler.tests as butlerTests
from lsst.daf.butler import DataCoordinate, DimensionRecord, SerializedDimensionRecord
from lsst.obs.base import DefineVisitsTask
from lsst.obs.base import DefineVisitsConfig, DefineVisitsTask
from lsst.obs.base.instrument_tests import DummyCam
from lsst.utils.iteration import ensure_iterable

TESTDIR = os.path.dirname(__file__)
DATADIR = os.path.join(TESTDIR, "data", "visits")


class DefineVisitsTestCase(unittest.TestCase):
"""Test visit definition."""
class DefineVisitsBase:
"""General set up that can be shared."""

def setUp(self):
def setUpExposures(self):
"""Create a new butler for each test since we are changing dimension
records.
"""
self.root = tempfile.mkdtemp(dir=TESTDIR)
self.creatorButler = butlerTests.makeTestRepo(self.root, [])
self.butler = butlerTests.makeTestCollection(self.creatorButler, uniqueId=self.id())

self.config = DefineVisitsTask.ConfigClass()
self.config = self.get_config()
self.task = DefineVisitsTask(config=self.config, butler=self.butler)

# Need to register the instrument.
Expand All @@ -66,10 +66,42 @@ def setUp(self):
simple = SerializedDimensionRecord.model_validate_json(fh.read())
self.records[i] = DimensionRecord.from_simple(simple, registry=self.butler.registry)

def define_visits(
self,
exposures: list[DimensionRecord | list[DimensionRecord]],
incremental: bool,
) -> None:
for records in exposures:
records = list(ensure_iterable(records))
if "group" in self.butler.dimensions["exposure"].implied:
# This is a group + day_obs universe.
for rec in records:
self.butler.registry.syncDimensionData(
"group", dict(instrument=rec.instrument, name=rec.group)
)
self.butler.registry.syncDimensionData(
"day_obs", dict(instrument=rec.instrument, id=rec.day_obs)
)

self.butler.registry.insertDimensionData("exposure", *records)
# Include all records so far in definition.
dataIds = list(self.butler.registry.queryDataIds("exposure", instrument="DummyCam"))
self.task.run(dataIds, incremental=incremental)


class DefineVisitsTestCase(unittest.TestCase, DefineVisitsBase):
"""Test visit definition."""

def setUp(self):
self.setUpExposures()

def tearDown(self):
if self.root is not None:
shutil.rmtree(self.root, ignore_errors=True)

def get_config(self) -> DefineVisitsConfig:
return DefineVisitsTask.ConfigClass()

def assertVisits(self):
"""Check that the visits were registered as expected."""
visits = list(self.butler.registry.queryDimensionRecords("visit"))
Expand All @@ -94,26 +126,6 @@ def assertVisits(self):
},
)

def define_visits(
self, exposures: list[DimensionRecord | list[DimensionRecord]], incremental: bool
) -> None:
for records in exposures:
records = list(ensure_iterable(records))
if "group" in self.butler.dimensions["exposure"].implied:
# This is a group + day_obs universe.
for rec in records:
self.butler.registry.syncDimensionData(
"group", dict(instrument=rec.instrument, name=rec.group)
)
self.butler.registry.syncDimensionData(
"day_obs", dict(instrument=rec.instrument, id=rec.day_obs)
)

self.butler.registry.insertDimensionData("exposure", *records)
# Include all records so far in definition.
dataIds = list(self.butler.registry.queryDataIds("exposure", instrument="DummyCam"))
self.task.run(dataIds, incremental=incremental)

def test_defineVisits(self):
# Test visit definition with all the records.
self.define_visits([list(self.records.values())], incremental=False) # list inside a list
Expand Down Expand Up @@ -174,5 +186,46 @@ def testPickleTask(self):
self.assertEqual(self.task.universe, copy.universe)


class DefineVisitsGroupingTestCase(unittest.TestCase, DefineVisitsBase):
"""Test visit grouping by group metadata."""

def setUp(self):
self.setUpExposures()

def tearDown(self):
if self.root is not None:
shutil.rmtree(self.root, ignore_errors=True)

def get_config(self) -> DefineVisitsConfig:
config = DefineVisitsTask.ConfigClass()
config.groupExposures.name = "by-group-metadata"
return config

def test_defineVisits(self):
# Test visit definition with all the records.
self.define_visits([list(self.records.values())], incremental=False) # list inside a list
self.assertVisits()

def assertVisits(self):
"""Check that the visits were registered as expected."""
visits = list(self.butler.registry.queryDimensionRecords("visit"))
self.assertEqual(len(visits), 2)
self.assertEqual({visit.id for visit in visits}, {2291434132550000, 2291434871810000})

# Ensure that the definitions are correct (ignoring order).
defmap = defaultdict(set)
definitions = list(self.butler.registry.queryDimensionRecords("visit_definition"))
for defn in definitions:
defmap[defn.visit].add(defn.exposure)

self.assertEqual(
dict(defmap),
{
2291434132550000: {2022040500347},
2291434871810000: {2022040500348, 2022040500349},
},
)


if __name__ == "__main__":
unittest.main()

0 comments on commit a147b68

Please sign in to comment.