diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index 642c452a..77fcecff 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -82,8 +82,6 @@ def load_and_filter_results(self, search, config): logger.info(f"Chunk Min. Likelihood = {results[-1].lh}") trj_batch = [] - psi_batch = [] - phi_batch = [] for i, trj in enumerate(results): # Stop as soon as we hit a result below our limit, because anything after # that is not guarrenteed to be valid due to potential on-GPU filtering. @@ -93,14 +91,15 @@ def load_and_filter_results(self, search, config): if trj.lh < max_lh: trj_batch.append(trj) - psi_batch.append(search.get_psi_curves(trj)) - phi_batch.append(search.get_phi_curves(trj)) total_count += 1 batch_size = len(trj_batch) logger.info(f"Extracted batch of {batch_size} results for total of {total_count}") if batch_size > 0: + psi_batch = search.get_psi_curves(trj_batch) + phi_batch = search.get_phi_curves(trj_batch) + result_batch = Results.from_trajectories(trj_batch, track_filtered=do_tracking) result_batch.add_psi_phi_data(psi_batch, phi_batch) diff --git a/src/kbmod/search/image_stack.cpp b/src/kbmod/search/image_stack.cpp index 6b89db3f..ea64ce36 100644 --- a/src/kbmod/search/image_stack.cpp +++ b/src/kbmod/search/image_stack.cpp @@ -39,9 +39,9 @@ void ImageStack::set_single_image(int index, LayeredImage& img, bool force_move) assert_sizes_equal(img.get_height(), height, "ImageStack image height"); if (force_move) { - images[index] = img; + images[index] = img; } else { - images[index] = std::move(img); + images[index] = std::move(img); } } @@ -186,8 +186,8 @@ static void image_stack_bindings(py::module& m) { .def("get_single_image", &is::get_single_image, py::return_value_policy::reference_internal, pydocs::DOC_ImageStack_get_single_image) .def("set_single_image", &is::set_single_image, py::arg("index"), py::arg("img"), - py::arg("force_move")=false, pydocs::DOC_ImageStack_set_single_image) - .def("append_image", &is::append_image, py::arg("img"), py::arg("force_move")=false, + py::arg("force_move") = false, pydocs::DOC_ImageStack_set_single_image) + .def("append_image", &is::append_image, py::arg("img"), py::arg("force_move") = false, pydocs::DOC_ImageStack_append_image) .def("get_obstime", &is::get_obstime, pydocs::DOC_ImageStack_get_obstime) .def("get_zeroed_time", &is::get_zeroed_time, pydocs::DOC_ImageStack_get_zeroed_time) diff --git a/src/kbmod/search/image_stack.h b/src/kbmod/search/image_stack.h index f27e3b77..20c59abc 100644 --- a/src/kbmod/search/image_stack.h +++ b/src/kbmod/search/image_stack.h @@ -36,8 +36,8 @@ class ImageStack { // Functions for setting or appending a single LayeredImage. If force_move is true, // then the code uses move semantics and destroys the input object. - void set_single_image(int index, LayeredImage& img, bool force_move=false); - void append_image(LayeredImage& img, bool force_move=false); + void set_single_image(int index, LayeredImage& img, bool force_move = false); + void append_image(LayeredImage& img, bool force_move = false); // Functions for getting or using times. double get_obstime(int index) const; diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index adb9946a..576fadee 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -130,12 +130,12 @@ static const auto DOC_StackSearch_get_psi_curves = R"doc( Parameters ---------- - trj : `kb.Trajectory` - The input trajectory. + trj : `kb.Trajectory` or `list` of `kb.Trajectory` + The input trajectory or trajectories. Returns ------- - result : `list` of `float` + result : `list` of `float` or `list` of `list` of `float` The psi values at each time step with NO_DATA replaced by 0.0. )doc"; @@ -144,12 +144,12 @@ static const auto DOC_StackSearch_get_phi_curves = R"doc( Parameters ---------- - trj : `kb.Trajectory` - The input trajectory. + trj : `kb.Trajectory` or `list` of `kb.Trajectory` + The input trajectory or trajectories. Returns ------- - result : `list` of `float` + result : `list` of `float` or `list` of `list` of `float` The phi values at each time step with NO_DATA replaced by 0.0. )doc"; diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index 951ec8ca..0123b9dc 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -237,7 +237,7 @@ uint64_t StackSearch::compute_max_results() { return num_search_pixels * params.results_per_pixel; } -std::vector StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool extract_psi) { +std::vector StackSearch::extract_psi_or_phi_curve(const Trajectory& trj, bool extract_psi) { prepare_psi_phi(); const unsigned int num_times = stack.img_count(); @@ -256,17 +256,33 @@ std::vector StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool e return result; } -std::vector StackSearch::get_psi_curves(Trajectory& trj) { +std::vector > StackSearch::get_psi_curves(const std::vector& trajectories) { + std::vector > all_results; + for (const auto& trj : trajectories) { + all_results.push_back(extract_psi_or_phi_curve(trj, true)); + } + return all_results; +} + +std::vector StackSearch::get_psi_curves(const Trajectory& trj) { return extract_psi_or_phi_curve(trj, true); } -std::vector StackSearch::get_phi_curves(Trajectory& trj) { +std::vector > StackSearch::get_phi_curves(const std::vector& trajectories) { + std::vector > all_results; + for (const auto& trj : trajectories) { + all_results.push_back(extract_psi_or_phi_curve(trj, false)); + } + return all_results; +} + +std::vector StackSearch::get_phi_curves(const Trajectory& trj) { return extract_psi_or_phi_curve(trj, false); } std::vector StackSearch::get_results(uint64_t start, uint64_t count) { - rs_logger->debug("Reading results [" + std::to_string(start) + ", " + - std::to_string(start + count) + ")"); + rs_logger->debug("Reading results [" + std::to_string(start) + ", " + std::to_string(start + count) + + ")"); return results.get_batch(start, count); } @@ -306,10 +322,14 @@ static void stack_search_bindings(py::module& m) { .def("get_imagestack", &ks::get_imagestack, py::return_value_policy::reference_internal, pydocs::DOC_StackSearch_get_imagestack) // For testings - .def("get_psi_curves", (std::vector(ks::*)(tj&)) & ks::get_psi_curves, + .def("get_psi_curves", (std::vector(ks::*)(const tj&)) & ks::get_psi_curves, pydocs::DOC_StackSearch_get_psi_curves) - .def("get_phi_curves", (std::vector(ks::*)(tj&)) & ks::get_phi_curves, + .def("get_phi_curves", (std::vector(ks::*)(const tj&)) & ks::get_phi_curves, pydocs::DOC_StackSearch_get_phi_curves) + .def("get_psi_curves", + (std::vector >(ks::*)(const std::vector&)) & ks::get_psi_curves) + .def("get_phi_curves", + (std::vector >(ks::*)(const std::vector&)) & ks::get_phi_curves) .def("prepare_psi_phi", &ks::prepare_psi_phi, pydocs::DOC_StackSearch_prepare_psi_phi) .def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi) .def("get_number_total_results", &ks::get_number_total_results, diff --git a/src/kbmod/search/stack_search.h b/src/kbmod/search/stack_search.h index 43b0bac0..7a629bea 100644 --- a/src/kbmod/search/stack_search.h +++ b/src/kbmod/search/stack_search.h @@ -60,8 +60,10 @@ class StackSearch { std::vector get_results(uint64_t start, uint64_t count); // Getters for the Psi and Phi data. - std::vector get_psi_curves(Trajectory& t); - std::vector get_phi_curves(Trajectory& t); + std::vector get_psi_curves(const Trajectory& t); + std::vector get_phi_curves(const Trajectory& t); + std::vector > get_psi_curves(const std::vector& trajectories); + std::vector > get_phi_curves(const std::vector& trajectories); // Helper functions for computing Psi and Phi void prepare_psi_phi(); @@ -73,7 +75,7 @@ class StackSearch { virtual ~StackSearch(){}; protected: - std::vector extract_psi_or_phi_curve(Trajectory& trj, bool extract_psi); + std::vector extract_psi_or_phi_curve(const Trajectory& trj, bool extract_psi); // Core data and search parameters. Note the StackSearch does not own // the ImageStack and it must exist for the duration of the object's life. diff --git a/tests/test_search.py b/tests/test_search.py index 40837640..4072e8eb 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -103,106 +103,6 @@ def setUp(self): self.max_angle, ) - def test_set_get_results(self): - results = self.search.get_results(0, 10) - self.assertEqual(len(results), 0) - - trjs = [Trajectory(i, i, 0.0, 0.0) for i in range(10)] - self.search.set_results(trjs) - - # Check that we extract them all. - results = self.search.get_results(0, 10) - self.assertEqual(len(results), 10) - for i in range(10): - self.assertEqual(results[i].x, i) - - # Check that we can run past the end of the results. - results = self.search.get_results(0, 100) - self.assertEqual(len(results), 10) - - # Check that we can pull a subset. - results = self.search.get_results(2, 2) - self.assertEqual(len(results), 2) - self.assertEqual(results[0].x, 2) - self.assertEqual(results[1].x, 3) - - # Check that we can pull a subset aligned with the end. - results = self.search.get_results(8, 2) - self.assertEqual(len(results), 2) - self.assertEqual(results[0].x, 8) - self.assertEqual(results[1].x, 9) - - # Check invalid settings - self.assertRaises(RuntimeError, self.search.get_results, 0, 0) - - def test_load_and_filter_results_lh(self): - time_list = [i / self.img_count for i in range(self.img_count)] - fake_ds = FakeDataSet( - self.dim_x, - self.dim_y, - time_list, - noise_level=1.0, - psf_val=0.5, - use_seed=True, - ) - - # Create fake result trajectories with given initial likelihoods. The 1st is - # filtered by max likelihood. The two final ones are filtered by min likelihood. - trjs = [ - Trajectory(10, 10, 0, 0, 500.0, 9000.0, self.img_count), - Trajectory(20, 20, 0, 0, 110.0, 110.0, self.img_count), - Trajectory(30, 30, 0, 0, 100.0, 100.0, self.img_count), - Trajectory(40, 40, 0, 0, 50.0, 50.0, self.img_count), - Trajectory(41, 41, 0, 0, 50.0, 50.0, self.img_count), - Trajectory(42, 42, 0, 0, 50.0, 50.0, self.img_count), - Trajectory(43, 43, 0, 0, 50.0, 50.0, self.img_count), - Trajectory(50, 50, 0, 0, 1.0, 2.0, self.img_count), - Trajectory(60, 60, 0, 0, 1.0, 1.0, self.img_count), - ] - for trj in trjs: - fake_ds.insert_object(trj) - - # Create the stack search and insert the fake results. - search = StackSearch(fake_ds.stack) - search.set_results(trjs) - - # Do the loading and filtering. - config = SearchConfiguration() - overrides = { - "clip_negative": False, - "chunk_size": 500000, - "lh_level": 10.0, - "max_lh": 1000.0, - "num_cores": 1, - "num_obs": 5, - "sigmaG_lims": [25, 75], - } - config.set_multiple(overrides) - - runner = SearchRunner() - results = runner.load_and_filter_results(search, config) - - # Only two of the middle results should pass the filtering. - self.assertEqual(len(results), 6) - self.assertEqual(results["y"][0], 20) - self.assertEqual(results["y"][1], 30) - self.assertEqual(results["y"][2], 40) - self.assertEqual(results["y"][3], 41) - self.assertEqual(results["y"][4], 42) - self.assertEqual(results["y"][5], 43) - - # Rerun the search with a small chunk_size to make sure we still - # find everything. - overrides["chunk_size"] = 2 - results = runner.load_and_filter_results(search, config) - self.assertEqual(len(results), 6) - self.assertEqual(results["y"][0], 20) - self.assertEqual(results["y"][1], 30) - self.assertEqual(results["y"][2], 40) - self.assertEqual(results["y"][3], 41) - self.assertEqual(results["y"][4], 42) - self.assertEqual(results["y"][5], 43) - @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") def test_evaluate_single_trajectory(self): test_trj = Trajectory( diff --git a/tests/test_stack_search_results.py b/tests/test_stack_search_results.py new file mode 100644 index 00000000..1a8de0ed --- /dev/null +++ b/tests/test_stack_search_results.py @@ -0,0 +1,147 @@ +import unittest + +import numpy as np + +from kbmod.configuration import SearchConfiguration +from kbmod.fake_data.fake_data_creator import create_fake_times, FakeDataSet +from kbmod.run_search import SearchRunner +from kbmod.search import StackSearch, Trajectory + + +class test_search(unittest.TestCase): + def setUp(self): + self.num_times = 10 + self.width = 256 + self.height = 256 + self.num_objs = 5 + + self.times = create_fake_times(self.num_times, obs_per_day=3) + self.fake_ds = FakeDataSet(self.width, self.height, self.times) + for _ in range(self.num_objs): + self.fake_ds.insert_random_object(500) + + self.search = StackSearch(self.fake_ds.stack) + self.fake_trjs = self.fake_ds.trajectories + + def test_set_get_results(self): + results = self.search.get_results(0, 10) + self.assertEqual(len(results), 0) + + trjs = [Trajectory(i, i, 0.0, 0.0) for i in range(10)] + self.search.set_results(trjs) + + # Check that we extract them all. + results = self.search.get_results(0, 10) + self.assertEqual(len(results), 10) + for i in range(10): + self.assertEqual(results[i].x, i) + + # Check that we can run past the end of the results. + results = self.search.get_results(0, 100) + self.assertEqual(len(results), 10) + + # Check that we can pull a subset. + results = self.search.get_results(2, 2) + self.assertEqual(len(results), 2) + self.assertEqual(results[0].x, 2) + self.assertEqual(results[1].x, 3) + + # Check that we can pull a subset aligned with the end. + results = self.search.get_results(8, 2) + self.assertEqual(len(results), 2) + self.assertEqual(results[0].x, 8) + self.assertEqual(results[1].x, 9) + + # Check invalid settings + self.assertRaises(RuntimeError, self.search.get_results, 0, 0) + + def test_psi_phi_curves(self): + psi_curves = np.array(self.search.get_psi_curves(self.fake_trjs)) + self.assertEqual(psi_curves.shape[0], self.num_objs) + self.assertEqual(psi_curves.shape[1], self.num_times) + self.assertTrue(np.all(psi_curves > 0.0)) + + phi_curves = np.array(self.search.get_phi_curves(self.fake_trjs)) + self.assertEqual(phi_curves.shape[0], self.num_objs) + self.assertEqual(phi_curves.shape[1], self.num_times) + self.assertTrue(np.all(phi_curves > 0.0)) + + # Check that the batch getters give the same results as the iterative ones. + for i in range(self.num_objs): + current_psi = self.search.get_psi_curves(self.fake_trjs[i]) + self.assertTrue(np.allclose(psi_curves[i], current_psi)) + + current_phi = self.search.get_phi_curves(self.fake_trjs[i]) + self.assertTrue(np.allclose(phi_curves[i], current_phi)) + + def test_load_and_filter_results_lh(self): + time_list = [i / self.num_times for i in range(self.num_times)] + fake_ds = FakeDataSet( + self.width, + self.height, + time_list, + noise_level=1.0, + psf_val=0.5, + use_seed=True, + ) + + # Create fake result trajectories with given initial likelihoods. The 1st is + # filtered by max likelihood. The two final ones are filtered by min likelihood. + trjs = [ + Trajectory(10, 10, 0, 0, 500.0, 9000.0, self.num_times), + Trajectory(20, 20, 0, 0, 110.0, 110.0, self.num_times), + Trajectory(30, 30, 0, 0, 100.0, 100.0, self.num_times), + Trajectory(40, 40, 0, 0, 50.0, 50.0, self.num_times), + Trajectory(41, 41, 0, 0, 50.0, 50.0, self.num_times), + Trajectory(42, 42, 0, 0, 50.0, 50.0, self.num_times), + Trajectory(43, 43, 0, 0, 50.0, 50.0, self.num_times), + Trajectory(50, 50, 0, 0, 1.0, 2.0, self.num_times), + Trajectory(60, 60, 0, 0, 1.0, 1.0, self.num_times), + ] + for trj in trjs: + fake_ds.insert_object(trj) + + # Create the stack search and insert the fake results. + search = StackSearch(fake_ds.stack) + search.set_results(trjs) + + # Do the loading and filtering. + config = SearchConfiguration() + overrides = { + "clip_negative": False, + "chunk_size": 500000, + "lh_level": 10.0, + "max_lh": 1000.0, + "num_cores": 1, + "num_obs": 5, + "sigmaG_lims": [25, 75], + } + config.set_multiple(overrides) + + runner = SearchRunner() + results = runner.load_and_filter_results(search, config) + + # Only two of the middle results should pass the filtering. + self.assertEqual(len(results), 6) + self.assertEqual(results["y"][0], 20) + self.assertEqual(results["y"][1], 30) + self.assertEqual(results["y"][2], 40) + self.assertEqual(results["y"][3], 41) + self.assertEqual(results["y"][4], 42) + self.assertEqual(results["y"][5], 43) + + # Rerun the search with a small chunk_size to make sure we still + # find everything. + overrides["chunk_size"] = 2 + results = runner.load_and_filter_results(search, config) + self.assertEqual(len(results), 6) + self.assertEqual(results["y"][0], 20) + self.assertEqual(results["y"][1], 30) + self.assertEqual(results["y"][2], 40) + self.assertEqual(results["y"][3], 41) + self.assertEqual(results["y"][4], 42) + self.assertEqual(results["y"][5], 43) + + +if __name__ == "__main__": + unittest.main()