diff --git a/tests/scheduler/test_features.py b/tests/scheduler/test_features.py index 98204df5..4bbb5a6d 100644 --- a/tests/scheduler/test_features.py +++ b/tests/scheduler/test_features.py @@ -10,6 +10,52 @@ from rubin_scheduler.utils import survey_start_mjd +def make_observations_list(nobs=1): + observations_list = [] + for i in range(0, nobs): + observation = empty_observation() + observation["mjd"] = survey_start_mjd() + i * 30 / 60 / 60 / 24 + observation["RA"] = np.radians(30) + observation["dec"] = np.radians(-20) + observation["filter"] = "r" + observation["scheduler_note"] = "test" + observations_list.append(observation) + return observations_list + + +def make_observations_arrays(observations_list, nside=32): + # Turn list of observations (that should already have useful info) + # into observations_array plus observations_hpids_array. + observations_array = np.empty(len(observations_list), dtype=observations_list[0].dtype) + for i, obs in enumerate(observations_list): + observations_array[i] = obs + # Build observations_hpids_array. + # Find list of lists of healpixels + # (should match [indxs, indxs, indxs2] from above) + pointing2indx = HpInLsstFov(nside=nside) + list_of_hpids = pointing2indx(observations_array["RA"], observations_array["dec"]) + # Unravel list-of-lists (list_of_hpids) to match against observations + hpids = [] + big_array_indx = [] + for i, indxs in enumerate(list_of_hpids): + for indx in indxs: + hpids.append(indx) + big_array_indx.append(i) + hpids = np.array(hpids, dtype=[("hpid", int)]) + # Set up format / dtype for observations_hpids_array + names = list(observations_array.dtype.names) + types = [observations_array[name].dtype for name in names] + names.append(hpids.dtype.names[0]) + types.append(hpids["hpid"].dtype) + ndt = list(zip(names, types)) + observations_hpids_array = np.empty(hpids.size, dtype=ndt) + # Populate observations_hpid_array - big_array_indx points + # between index in observations_array and index in hpid + observations_hpids_array[list(observations_array.dtype.names)] = observations_array[big_array_indx] + observations_hpids_array[hpids.dtype.names[0]] = hpids + return observations_array, observations_hpids_array + + class TestFeatures(unittest.TestCase): def test_pair_in_night(self): pin = features.PairInNight(gap_min=25.0, gap_max=45.0) @@ -283,6 +329,156 @@ def test_NObservationsCurrentSeason(self): # in these cases with added requirements .. but will leave it # to the "restore" test in test_utils.py. + def test_NObsSurvey(self): + # Make some observations to count + observations_list = make_observations_list(2) + observations_list[0]["scheduler_note"] = "survey a" + observations_list[1]["scheduler_note"] = "survey b" + observations_array, observations_hpid_array = make_observations_arrays(observations_list) + # Count the observations matching any note + count_feature = features.NObsSurvey(note=None) + # ... it matters significantly that we pass obs[0] and not obs. + for obs in observations_list: + count_feature.add_observation(obs[0]) + self.assertTrue(count_feature.feature == 2) + # and count again using add_observations_array + count_feature = features.NObsSurvey(note=None) + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature == 2) + # Count using a note to match + # Count the observations matching specific note + count_feature = features.NObsSurvey(note="survey a") + for obs in observations_list: + count_feature.add_observation(obs[0]) + self.assertTrue(count_feature.feature == 1) + # and count again using add_observations_array + count_feature = features.NObsSurvey(note="survey a") + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature == 1) + # Count the observations matching subset of note + count_feature = features.NObsSurvey(note="survey") + for obs in observations_list: + count_feature.add_observation(obs[0]) + self.assertTrue(count_feature.feature == 2) + # and count again using add_observations_array + count_feature = features.NObsSurvey(note="survey") + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature == 2) + + def test_LastObservation(self): + # Make some observations to count + observations_list = make_observations_list(4) + observations_list[0]["mjd"] = survey_start_mjd() + observations_list[1]["mjd"] = survey_start_mjd() + 10 + observations_list[2]["mjd"] = survey_start_mjd() + 20 + observations_list[3]["mjd"] = survey_start_mjd() + 30 + observations_list[0]["scheduler_note"] = "survey a" + observations_list[1]["scheduler_note"] = "survey" + observations_list[2]["scheduler_note"] = "survey a" + observations_list[3]["scheduler_note"] = "survey b" + observations_array, observations_hpid_array = make_observations_arrays(observations_list) + # Observations matching any note + count_feature = features.LastObservation(scheduler_note=None) + for obs in observations_list: + count_feature.add_observation(obs[0]) + self.assertTrue(count_feature.feature["mjd"] == observations_list[-1]["mjd"]) + # and count again using add_observations_array + count_feature = features.LastObservation(scheduler_note=None) + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature["mjd"] == observations_list[-1]["mjd"]) + # Observations matching a specific note. + count_feature = features.LastObservation(scheduler_note="survey a") + for obs in observations_list: + count_feature.add_observation(obs[0]) + self.assertTrue(count_feature.feature["mjd"] == observations_list[-2]["mjd"]) + # and count again using add_observations_array + count_feature = features.LastObservation(scheduler_note="survey a") + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature["mjd"] == observations_list[-2]["mjd"]) + # Observations matching a subset of note. + count_feature = features.LastObservation(scheduler_note="survey") + for obs in observations_list: + count_feature.add_observation(obs[0]) + self.assertTrue(count_feature.feature["mjd"] == observations_list[-1]["mjd"]) + # and count again using add_observations_array + count_feature = features.LastObservation(scheduler_note="survey") + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature["mjd"] == observations_list[-1]["mjd"]) + + def test_NObservations(self): + # Make some observations to count + observations_list = make_observations_list(12) + indexes = [] + nside = 32 + pointing2hpindx = HpInLsstFov(nside=nside) + for i, obs in enumerate(observations_list): + obs["mjd"] = survey_start_mjd() + i + obs["rotSkyPos"] = 0 + if i < 6: + obs["filter"] = "r" + else: + obs["filter"] = "g" + if i % 2 == 0: + obs["RA"] = np.radians(30) + else: + obs["RA"] = np.radians(10) + if i % 3 == 0: + obs["scheduler_note"] = "survey a" + elif i % 3 == 1: + obs["scheduler_note"] = "survey b" + else: + obs["scheduler_note"] = "survey" + indexes.append(pointing2hpindx(obs["RA"], obs["dec"], rotSkyPos=obs["rotSkyPos"])) + observations_array, observations_hpid_array = make_observations_arrays(observations_list) + # Observations matching any note or filter + count_feature = features.NObservations(filtername=None, scheduler_note=None) + for obs, indx in zip(observations_list, indexes): + count_feature.add_observation(obs[0], indx) + self.assertTrue(count_feature.feature.max() == 6) + # and count again using add_observations_array + count_feature = features.NObservations(filtername=None, scheduler_note=None) + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature.max() == 6) + # Observations matching a specific note - or are partial matches. + count_feature = features.NObservations(scheduler_note="survey a") + for obs, indx in zip(observations_list, indexes): + count_feature.add_observation(obs[0], indx) + self.assertTrue(count_feature.feature.max() == 4) + # and count again using add_observations_array + # Observations matching a specific note. + count_feature = features.NObservations(scheduler_note="survey a") + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature.max() == 4) + # Observations matching a subset of note. + # It's not obvious that this is what this SHOULD do, and it's not + # used with "note" in the example /baseline scheduler. + count_feature = features.NObservations(scheduler_note="survey") + for obs, indx in zip(observations_list, indexes): + count_feature.add_observation(obs[0], indx) + self.assertTrue(count_feature.feature.max() == 2) + # and count again using add_observations_array + count_feature = features.NObservations(scheduler_note="survey") + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature.max() == 2) + # Observations matching any note but specified filter. + count_feature = features.NObservations(filtername="r", scheduler_note=None) + for obs, indx in zip(observations_list, indexes): + count_feature.add_observation(obs[0], indx) + self.assertTrue(count_feature.feature.max() == 3) + # and count again using add_observations_array + count_feature = features.NObservations(filtername="r", scheduler_note=None) + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature.max() == 3) + # Observations matching specific note and specified filter. + count_feature = features.NObservations(filtername="r", scheduler_note="survey") + for obs, indx in zip(observations_list, indexes): + count_feature.add_observation(obs[0], indx) + self.assertTrue(count_feature.feature.max() == 1) + # and count again using add_observations_array + count_feature = features.NObservations(filtername="r", scheduler_note="survey") + count_feature.add_observations_array(observations_array, observations_hpid=observations_hpid_array) + self.assertTrue(count_feature.feature.max() == 1) + if __name__ == "__main__": unittest.main()