Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Psi/Phi curve generation out of Python loop #712

Merged
merged 8 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def load_and_filter_results(self, search, config):
logger.info(f"Chunk Min. Likelihood = {results[-1].lh}")

trj_batch = []
psi_batch = []
phi_batch = []
for i, trj in enumerate(results):
# Stop as soon as we hit a result below our limit, because anything after
# that is not guarrenteed to be valid due to potential on-GPU filtering.
Expand All @@ -93,14 +91,15 @@ def load_and_filter_results(self, search, config):

if trj.lh < max_lh:
trj_batch.append(trj)
psi_batch.append(search.get_psi_curves(trj))
phi_batch.append(search.get_phi_curves(trj))
total_count += 1

batch_size = len(trj_batch)
logger.info(f"Extracted batch of {batch_size} results for total of {total_count}")

if batch_size > 0:
psi_batch = search.get_psi_curves(trj_batch)
phi_batch = search.get_phi_curves(trj_batch)

result_batch = Results.from_trajectories(trj_batch, track_filtered=do_tracking)
result_batch.add_psi_phi_data(psi_batch, phi_batch)

Expand Down
8 changes: 4 additions & 4 deletions src/kbmod/search/image_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ void ImageStack::set_single_image(int index, LayeredImage& img, bool force_move)
assert_sizes_equal(img.get_height(), height, "ImageStack image height");

if (force_move) {
images[index] = img;
images[index] = img;
} else {
images[index] = std::move(img);
images[index] = std::move(img);
}
}

Expand Down Expand Up @@ -186,8 +186,8 @@ static void image_stack_bindings(py::module& m) {
.def("get_single_image", &is::get_single_image, py::return_value_policy::reference_internal,
pydocs::DOC_ImageStack_get_single_image)
.def("set_single_image", &is::set_single_image, py::arg("index"), py::arg("img"),
py::arg("force_move")=false, pydocs::DOC_ImageStack_set_single_image)
.def("append_image", &is::append_image, py::arg("img"), py::arg("force_move")=false,
py::arg("force_move") = false, pydocs::DOC_ImageStack_set_single_image)
.def("append_image", &is::append_image, py::arg("img"), py::arg("force_move") = false,
pydocs::DOC_ImageStack_append_image)
.def("get_obstime", &is::get_obstime, pydocs::DOC_ImageStack_get_obstime)
.def("get_zeroed_time", &is::get_zeroed_time, pydocs::DOC_ImageStack_get_zeroed_time)
Expand Down
4 changes: 2 additions & 2 deletions src/kbmod/search/image_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class ImageStack {

// Functions for setting or appending a single LayeredImage. If force_move is true,
// then the code uses move semantics and destroys the input object.
void set_single_image(int index, LayeredImage& img, bool force_move=false);
void append_image(LayeredImage& img, bool force_move=false);
void set_single_image(int index, LayeredImage& img, bool force_move = false);
void append_image(LayeredImage& img, bool force_move = false);

// Functions for getting or using times.
double get_obstime(int index) const;
Expand Down
12 changes: 6 additions & 6 deletions src/kbmod/search/pydocs/stack_search_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ static const auto DOC_StackSearch_get_psi_curves = R"doc(

Parameters
----------
trj : `kb.Trajectory`
The input trajectory.
trj : `kb.Trajectory` or `list` of `kb.Trajectory`
The input trajectory or trajectories.

Returns
-------
result : `list` of `float`
result : `list` of `float` or `list` of `list` of `float`
The psi values at each time step with NO_DATA replaced by 0.0.
)doc";

Expand All @@ -144,12 +144,12 @@ static const auto DOC_StackSearch_get_phi_curves = R"doc(

Parameters
----------
trj : `kb.Trajectory`
The input trajectory.
trj : `kb.Trajectory` or `list` of `kb.Trajectory`
The input trajectory or trajectories.

Returns
-------
result : `list` of `float`
result : `list` of `float` or `list` of `list` of `float`
The phi values at each time step with NO_DATA replaced by 0.0.
)doc";

Expand Down
24 changes: 22 additions & 2 deletions src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,33 @@ std::vector<float> StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool e
return result;
}

std::vector<std::vector<float> > StackSearch::get_psi_curves(std::vector<Trajectory>& trajectories) {
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::vector<float> > all_results;
for (auto& trj : trajectories) {
all_results.push_back(extract_psi_or_phi_curve(trj, true));
}
return all_results;
}

std::vector<float> StackSearch::get_psi_curves(Trajectory& trj) {
return extract_psi_or_phi_curve(trj, true);
}

std::vector<std::vector<float> > StackSearch::get_phi_curves(std::vector<Trajectory>& trajectories) {
std::vector<std::vector<float> > all_results;
for (auto& trj : trajectories) {
all_results.push_back(extract_psi_or_phi_curve(trj, false));
}
return all_results;
}

std::vector<float> StackSearch::get_phi_curves(Trajectory& trj) {
return extract_psi_or_phi_curve(trj, false);
}

std::vector<Trajectory> StackSearch::get_results(uint64_t start, uint64_t count) {
rs_logger->debug("Reading results [" + std::to_string(start) + ", " +
std::to_string(start + count) + ")");
rs_logger->debug("Reading results [" + std::to_string(start) + ", " + std::to_string(start + count) +
")");
return results.get_batch(start, count);
}

Expand Down Expand Up @@ -310,6 +326,10 @@ static void stack_search_bindings(py::module& m) {
pydocs::DOC_StackSearch_get_psi_curves)
.def("get_phi_curves", (std::vector<float>(ks::*)(tj&)) & ks::get_phi_curves,
pydocs::DOC_StackSearch_get_phi_curves)
.def("get_psi_curves",
(std::vector<std::vector<float> >(ks::*)(std::vector<tj>&)) & ks::get_psi_curves)
.def("get_phi_curves",
(std::vector<std::vector<float> >(ks::*)(std::vector<tj>&)) & ks::get_phi_curves)
.def("prepare_psi_phi", &ks::prepare_psi_phi, pydocs::DOC_StackSearch_prepare_psi_phi)
.def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi)
.def("get_number_total_results", &ks::get_number_total_results,
Expand Down
2 changes: 2 additions & 0 deletions src/kbmod/search/stack_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class StackSearch {
// Getters for the Psi and Phi data.
std::vector<float> get_psi_curves(Trajectory& t);
std::vector<float> get_phi_curves(Trajectory& t);
std::vector<std::vector<float> > get_psi_curves(std::vector<Trajectory>& trajectories);
std::vector<std::vector<float> > get_phi_curves(std::vector<Trajectory>& trajectories);

// Helper functions for computing Psi and Phi
void prepare_psi_phi();
Expand Down
100 changes: 0 additions & 100 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,106 +103,6 @@ def setUp(self):
self.max_angle,
)

def test_set_get_results(self):
results = self.search.get_results(0, 10)
self.assertEqual(len(results), 0)

trjs = [Trajectory(i, i, 0.0, 0.0) for i in range(10)]
self.search.set_results(trjs)

# Check that we extract them all.
results = self.search.get_results(0, 10)
self.assertEqual(len(results), 10)
for i in range(10):
self.assertEqual(results[i].x, i)

# Check that we can run past the end of the results.
results = self.search.get_results(0, 100)
self.assertEqual(len(results), 10)

# Check that we can pull a subset.
results = self.search.get_results(2, 2)
self.assertEqual(len(results), 2)
self.assertEqual(results[0].x, 2)
self.assertEqual(results[1].x, 3)

# Check that we can pull a subset aligned with the end.
results = self.search.get_results(8, 2)
self.assertEqual(len(results), 2)
self.assertEqual(results[0].x, 8)
self.assertEqual(results[1].x, 9)

# Check invalid settings
self.assertRaises(RuntimeError, self.search.get_results, 0, 0)

def test_load_and_filter_results_lh(self):
time_list = [i / self.img_count for i in range(self.img_count)]
fake_ds = FakeDataSet(
self.dim_x,
self.dim_y,
time_list,
noise_level=1.0,
psf_val=0.5,
use_seed=True,
)

# Create fake result trajectories with given initial likelihoods. The 1st is
# filtered by max likelihood. The two final ones are filtered by min likelihood.
trjs = [
Trajectory(10, 10, 0, 0, 500.0, 9000.0, self.img_count),
Trajectory(20, 20, 0, 0, 110.0, 110.0, self.img_count),
Trajectory(30, 30, 0, 0, 100.0, 100.0, self.img_count),
Trajectory(40, 40, 0, 0, 50.0, 50.0, self.img_count),
Trajectory(41, 41, 0, 0, 50.0, 50.0, self.img_count),
Trajectory(42, 42, 0, 0, 50.0, 50.0, self.img_count),
Trajectory(43, 43, 0, 0, 50.0, 50.0, self.img_count),
Trajectory(50, 50, 0, 0, 1.0, 2.0, self.img_count),
Trajectory(60, 60, 0, 0, 1.0, 1.0, self.img_count),
]
for trj in trjs:
fake_ds.insert_object(trj)

# Create the stack search and insert the fake results.
search = StackSearch(fake_ds.stack)
search.set_results(trjs)

# Do the loading and filtering.
config = SearchConfiguration()
overrides = {
"clip_negative": False,
"chunk_size": 500000,
"lh_level": 10.0,
"max_lh": 1000.0,
"num_cores": 1,
"num_obs": 5,
"sigmaG_lims": [25, 75],
}
config.set_multiple(overrides)

runner = SearchRunner()
results = runner.load_and_filter_results(search, config)

# Only two of the middle results should pass the filtering.
self.assertEqual(len(results), 6)
self.assertEqual(results["y"][0], 20)
self.assertEqual(results["y"][1], 30)
self.assertEqual(results["y"][2], 40)
self.assertEqual(results["y"][3], 41)
self.assertEqual(results["y"][4], 42)
self.assertEqual(results["y"][5], 43)

# Rerun the search with a small chunk_size to make sure we still
# find everything.
overrides["chunk_size"] = 2
results = runner.load_and_filter_results(search, config)
self.assertEqual(len(results), 6)
self.assertEqual(results["y"][0], 20)
self.assertEqual(results["y"][1], 30)
self.assertEqual(results["y"][2], 40)
self.assertEqual(results["y"][3], 41)
self.assertEqual(results["y"][4], 42)
self.assertEqual(results["y"][5], 43)

@unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)")
def test_evaluate_single_trajectory(self):
test_trj = Trajectory(
Expand Down
Loading
Loading