Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp the interface for timestamps #351

Merged
merged 3 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,6 @@ def load_images(
print(f"Loaded {len(images)} images")
stack = kb.ImageStack(images)

# Create a list of visit times and visit times shifted to 0.0.
min_time = min(visit_times)
zero_shifted = [(t - min_time) for t in visit_times]
stack.set_times(zero_shifted)
print("Times set", flush=True)

return (stack, wcs_list, visit_times)


Expand Down
53 changes: 22 additions & 31 deletions src/kbmod/search/image_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@ namespace py = pybind11;
namespace search {
ImageStack::ImageStack(const std::vector<std::string>& filenames, const std::vector<PSF>& psfs) {
verbose = true;
reset_images();
images = std::vector<LayeredImage>();
load_images(filenames, psfs);
extract_image_times();
set_time_origin();

global_mask = RawImage(get_width(), get_height());
global_mask.set_all_pix(0.0);
}

ImageStack::ImageStack(const std::vector<LayeredImage>& imgs) {
verbose = true;
images = imgs;
extract_image_times();
set_time_origin();

global_mask = RawImage(get_width(), get_height());
global_mask.set_all_pix(0.0);
}
Expand All @@ -42,39 +40,32 @@ namespace search {
if (verbose) std::cout << "\n";
}

void ImageStack::extract_image_times() {
// Load image times
image_times = std::vector<float>();
for (auto& i : images) {
image_times.push_back(float(i.get_obstime()));
}
}

void ImageStack::set_time_origin() {
// Set beginning time to 0.0
double initial_time = image_times[0];
for (auto& t : image_times) t = t - initial_time;
}

LayeredImage& ImageStack::get_single_image(int index) {
if (index < 0 || index > images.size()) throw std::out_of_range("ImageStack index out of bounds.");
return images[index];
}

void ImageStack::set_single_image(int index, LayeredImage& img) {
float ImageStack::get_obstime(int index) const {
if (index < 0 || index > images.size()) throw std::out_of_range("ImageStack index out of bounds.");
images[index] = img;
return images[index].get_obstime();
}

void ImageStack::set_times(const std::vector<float>& times) {
if (times.size() != img_count())
throw std::runtime_error("List of times provided does not match the number of images!");
image_times = times;
set_time_origin();
float ImageStack::get_zeroed_time(int index) const {
if (index < 0 || index > images.size()) throw std::out_of_range("ImageStack index out of bounds.");
return images[index].get_obstime() - images[0].get_obstime();
}

void ImageStack::reset_images() { images = std::vector<LayeredImage>(); }

std::vector<float> ImageStack::build_zeroed_times() const {
std::vector<float> zeroed_times = std::vector<float>();
if (images.size() > 0) {
float t0 = images[0].get_obstime();
for (auto& i : images) {
zeroed_times.push_back(i.get_obstime() - t0);
}
}
return zeroed_times;
}

void ImageStack::convolve_psf() {
for (auto& i : images) i.convolve_psf();
}
Expand Down Expand Up @@ -141,9 +132,9 @@ namespace search {
.def("get_single_image", &is::get_single_image,
py::return_value_policy::reference_internal,
pydocs::DOC_ImageStack_get_single_image)
.def("set_single_image", &is::set_single_image, pydocs::DOC_ImageStack_set_single_image)
.def("get_times", &is::get_times, pydocs::DOC_ImageStack_get_times)
.def("set_times", &is::set_times, pydocs::DOC_ImageStack_set_times )
.def("get_obstime", &is::get_obstime, pydocs::DOC_ImageStack_get_obstime)
.def("get_zeroed_time", &is::get_zeroed_time, pydocs::DOC_ImageStack_get_zeroed_time)
.def("build_zeroed_times", &is::build_zeroed_times, pydocs::DOC_ImageStack_build_zeroed_times)
.def("img_count", &is::img_count, pydocs::DOC_ImageStack_img_count)
.def("apply_mask_flags", &is::apply_mask_flags, pydocs::DOC_ImageStack_apply_mask_flags)
.def("apply_mask_threshold", &is::apply_mask_threshold, pydocs::DOC_ImageStack_apply_mask_threshold)
Expand Down
13 changes: 4 additions & 9 deletions src/kbmod/search/image_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ namespace search {
unsigned get_height() const { return images.size() > 0 ? images[0].get_height() : 0; }
unsigned get_npixels() const { return images.size() > 0 ? images[0].get_npixels() : 0; }
std::vector<LayeredImage>& get_images() { return images; }
const std::vector<float>& get_times() const { return image_times; }
float* get_timesDataRef() { return image_times.data(); }
LayeredImage& get_single_image(int index);

// Simple setters.
void set_times(const std::vector<float>& times);
void reset_images();
void set_single_image(int index, LayeredImage& img);
// Functions for getting times.
float get_obstime(int index) const;
float get_zeroed_time(int index) const;
std::vector<float> build_zeroed_times() const; // Linear cost.

// Apply makes to all the images.
void apply_global_mask(int flags, int threshold);
Expand All @@ -49,12 +47,9 @@ namespace search {

private:
void load_images(const std::vector<std::string>& filenames, const std::vector<PSF>& psfs);
void extract_image_times();
void set_time_origin();
void create_global_mask(int flags, int threshold);
std::vector<LayeredImage> images;
RawImage global_mask;
std::vector<float> image_times;
bool verbose;
};

Expand Down
26 changes: 14 additions & 12 deletions src/kbmod/search/pydocs/image_stack_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,32 @@ namespace pydocs {
static const auto DOC_ImageStack = R"doc(
todo
)doc";

static const auto DOC_ImageStack_get_images = R"doc(
todo
)doc";

static const auto DOC_ImageStack_get_single_image = R"doc(
todo
static const auto DOC_ImageStack_img_count = R"doc(
Returns the number of images in the stack.
)doc";

static const auto DOC_ImageStack_set_single_image = R"doc(
todo
static const auto DOC_ImageStack_get_single_image = R"doc(
Returns a single LayeredImage for a given index.
)doc";

static const auto DOC_ImageStack_get_times = R"doc(
todo
static const auto DOC_ImageStack_get_obstime = R"doc(
Returns a single image's observation time in MJD.
)doc";

static const auto DOC_ImageStack_set_times = R"doc(
todo
static const auto DOC_ImageStack_get_zeroed_time = R"doc(
Returns a single image's observation time relative to that
of the first image.
)doc";

static const auto DOC_ImageStack_img_count = R"doc(
todo
)doc";
static const auto DOC_ImageStack_build_zeroed_times = R"doc(
Construct an array of time differentials between each image
in the stack and the first image.
")doc";

static const auto DOC_ImageStack_apply_mask_flags = R"doc(
todo
Expand Down
10 changes: 6 additions & 4 deletions src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ namespace search {
end_timer();

// Create a data stucture for the per-image data.
std::vector<float> image_times = stack.build_zeroed_times();
PerImageData img_data;
img_data.num_images = stack.img_count();
img_data.image_times = stack.get_timesDataRef();
img_data.image_times = image_times.data();
if (params.use_corr) img_data.bary_corrs = &bary_corrs[0];

// Compute the encoding parameters for psi and phi if needed.
Expand Down Expand Up @@ -428,9 +429,10 @@ namespace search {
const int height = stack.get_height();

// Create a data stucture for the per-image data.
std::vector<float> image_times = stack.build_zeroed_times();
PerImageData img_data;
img_data.num_images = num_images;
img_data.image_times = stack.get_timesDataRef();
img_data.image_times = image_times.data();

// Allocate space for the results.
const int num_trajectories = t_array.size();
Expand Down Expand Up @@ -479,7 +481,7 @@ namespace search {
}

PixelPos StackSearch::get_trajectory_position(const Trajectory& t, int i) const {
float time = stack.get_times()[i];
float time = stack.get_zeroed_time(i);
if (use_corr) {
return {t.x + time * t.vx + bary_corrs[i].dx + t.x * bary_corrs[i].dxdx + t.y * bary_corrs[i].dxdy,
t.y + time * t.vy + bary_corrs[i].dy + t.x * bary_corrs[i].dydx +
Expand Down Expand Up @@ -513,7 +515,7 @@ namespace search {
int img_size = imgs.size();
std::vector<float> lightcurve;
lightcurve.reserve(img_size);
const std::vector<float>& times = stack.get_times();
std::vector<float> times = stack.build_zeroed_times();
for (int i = 0; i < img_size; ++i) {
/* Do not use get_pixel_interp(), because results from create_curves must
* be able to recover the same likelihoods as the ones reported by the
Expand Down
6 changes: 3 additions & 3 deletions tests/test_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def test_apply_stamp_filter(self):
int(self.img_count / 2),
)

mjds = np.array(stack.get_times())
kb_post_process = PostProcess(self.config, mjds)
zeroed_times = np.array(stack.build_zeroed_times())
kb_post_process = PostProcess(self.config, zeroed_times)

keep = kb_post_process.load_and_filter_results(
search,
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_clustering(self):
cluster_params["y_size"] = self.dim_y
cluster_params["vel_lims"] = [self.min_vel, self.max_vel]
cluster_params["ang_lims"] = [self.min_angle, self.max_angle]
cluster_params["mjd"] = np.array(self.stack.get_times())
cluster_params["mjd"] = np.array(self.stack.build_zeroed_times())

trjs = [
self._make_trajectory(10, 11, 1, 2, 100.0),
Expand Down
19 changes: 8 additions & 11 deletions tests/test_image_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self):
60, # dim_y = 60 pixels,
2.0, # noise_level
4.0, # variance
2.0 * i, # time
2.0 * i + 1.0, # time
self.p[i],
)

Expand All @@ -38,27 +38,24 @@ def test_create(self):
def test_access(self):
# Test we can access an individual image.
img = self.im_stack.get_single_image(1)
self.assertEqual(img.get_obstime(), 2.0)
self.assertEqual(img.get_obstime(), 3.0)
self.assertEqual(img.get_name(), "layered_test_1")

# Test an out of bounds access.
with self.assertRaises(IndexError):
img = self.im_stack.get_single_image(self.num_images + 1)

def test_times(self):
times = self.im_stack.get_times()
# Check that we can access specific times.
self.assertEqual(self.im_stack.get_obstime(1), 3.0)
self.assertEqual(self.im_stack.get_zeroed_time(1), 2.0)

# Check that we can build the full zeroed times list.
times = self.im_stack.build_zeroed_times()
self.assertEqual(len(times), self.num_images)
for i in range(self.num_images):
self.assertEqual(times[i], 2.0 * i)

new_times = [3.0 * i for i in range(self.num_images)]
self.im_stack.set_times(new_times)

times2 = self.im_stack.get_times()
self.assertEqual(len(times2), self.num_images)
for i in range(self.num_images):
self.assertEqual(times2[i], 3.0 * i)

def test_apply_mask(self):
# Nothing is initially masked.
for i in range(self.num_images):
Expand Down
16 changes: 8 additions & 8 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_sci_viz_stamps(self):
sci_stamps = self.search.get_stamps(self.trj, 2)
self.assertEqual(len(sci_stamps), self.imCount)

times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for i in range(self.imCount):
self.assertEqual(sci_stamps[i].get_width(), 5)
self.assertEqual(sci_stamps[i].get_height(), 5)
Expand All @@ -278,7 +278,7 @@ def test_stacked_sci(self):
self.assertEqual(sci.get_height(), 5)

# Compute the true stacked pixel for the middle of the track.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
sum_middle = 0.0
for i in range(self.imCount):
t = times[i]
Expand Down Expand Up @@ -309,7 +309,7 @@ def test_median_stamps_trj(self):
self.assertEqual(medianStamps1.get_height(), 5)

# Compute the true median pixel for the middle of the track.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
pix_values0 = []
pix_values1 = []
for i in range(self.imCount):
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_mean_stamps_trj(self):
self.assertEqual(meanStamp1.get_height(), 5)

# Compute the true median pixel for the middle of the track.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
pix_sum0 = 0.0
pix_sum1 = 0.0
pix_count0 = 0.0
Expand Down Expand Up @@ -605,7 +605,7 @@ def test_coadd_cpu(self):
self.assertEqual(medianStamps[0].get_height(), 2 * params.radius + 1)

# Compute the true summed and mean pixels for all of the pixels in the stamp.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for stamp_x in range(2 * params.radius + 1):
for stamp_y in range(2 * params.radius + 1):
x_offset = stamp_x - params.radius
Expand Down Expand Up @@ -656,7 +656,7 @@ def test_coadd_gpu(self):
self.assertEqual(medianStamps[0].get_height(), 2 * params.radius + 1)

# Compute the true summed and mean pixels for all of the pixels in the stamp.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for stamp_x in range(2 * params.radius + 1):
for stamp_y in range(2 * params.radius + 1):
x_offset = stamp_x - params.radius
Expand Down Expand Up @@ -702,7 +702,7 @@ def test_coadd_cpu_use_inds(self):
meanStamps = self.search.get_coadded_stamps([self.trj, self.trj], inds, params, False)

# Compute the true summed and mean pixels for all of the pixels in the stamp.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for stamp_x in range(2 * params.radius + 1):
for stamp_y in range(2 * params.radius + 1):
x_offset = stamp_x - params.radius
Expand Down Expand Up @@ -751,7 +751,7 @@ def test_coadd_gpu_use_inds(self):
meanStamps = self.search.get_coadded_stamps([self.trj, self.trj], inds, params, True)

# Compute the true summed and mean pixels for all of the pixels in the stamp.
times = self.stack.get_times()
times = self.stack.build_zeroed_times()
for stamp_x in range(2 * params.radius + 1):
for stamp_y in range(2 * params.radius + 1):
x_offset = stamp_x - params.radius
Expand Down