From 36de0a6591d7661e0d5e26e67d623efa34d842be Mon Sep 17 00:00:00 2001 From: Eric Rescorla <141454109+ekr-cfa@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:44:55 -0800 Subject: [PATCH] Change |get_global_property()| to return Option<&T::Value>. Fixes #72 (#73) Returns None when either: - No global properties have been set - Global properties exist but the requested property has not been set. --- .../parameter-loading/incidence_report.rs | 5 +- .../parameter-loading/infection_manager.rs | 5 +- examples/parameter-loading/main.rs | 5 +- .../parameter-loading/transmission_manager.rs | 5 +- .../exposure_manager.rs | 15 ++++-- .../incidence_report.rs | 5 +- .../infection_manager.rs | 15 ++++-- examples/time-varying-infection/main.rs | 5 +- .../time-varying-infection/periodic_report.rs | 5 +- .../population_loader.rs | 10 +++- src/global_properties.rs | 48 ++++++++++++++----- 11 files changed, 96 insertions(+), 27 deletions(-) diff --git a/examples/parameter-loading/incidence_report.rs b/examples/parameter-loading/incidence_report.rs index 9afc67f..0ace9ba 100644 --- a/examples/parameter-loading/incidence_report.rs +++ b/examples/parameter-loading/incidence_report.rs @@ -32,7 +32,10 @@ fn handle_infection_status_change( } pub fn init(context: &mut Context) { - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); context .report_options() .directory(PathBuf::from(parameters.output_dir)); diff --git a/examples/parameter-loading/infection_manager.rs b/examples/parameter-loading/infection_manager.rs index b157dd1..86cc046 100644 --- a/examples/parameter-loading/infection_manager.rs +++ b/examples/parameter-loading/infection_manager.rs @@ -12,7 +12,10 @@ use crate::Parameters; define_rng!(InfectionRng); fn schedule_recovery(context: &mut Context, person_id: PersonId) { - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); let infection_duration = parameters.infection_duration; let recovery_time = context.get_current_time() + context.sample_distr(InfectionRng, Exp::new(1.0 / infection_duration).unwrap()); diff --git a/examples/parameter-loading/main.rs b/examples/parameter-loading/main.rs index a968a07..c1eae20 100644 --- a/examples/parameter-loading/main.rs +++ b/examples/parameter-loading/main.rs @@ -31,7 +31,10 @@ fn main() { match parameters_loader::init_parameters(&mut context, &file_path) { Ok(()) => { - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); context.init_random(parameters.seed); for _ in 0..parameters.population { diff --git a/examples/parameter-loading/transmission_manager.rs b/examples/parameter-loading/transmission_manager.rs index 4b7d1b9..5cac66e 100644 --- a/examples/parameter-loading/transmission_manager.rs +++ b/examples/parameter-loading/transmission_manager.rs @@ -17,7 +17,10 @@ fn attempt_infection(context: &mut Context) { context.get_person_id(context.sample_range(TransmissionRng, 0..population_size)); let person_status: InfectionStatus = context.get_person_property(person_to_infect, InfectionStatusType); - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); if matches!(person_status, InfectionStatus::S) { context.set_person_property(person_to_infect, InfectionStatusType, InfectionStatus::I); diff --git a/examples/time-varying-infection/exposure_manager.rs b/examples/time-varying-infection/exposure_manager.rs index 13a18c0..b4f30ae 100644 --- a/examples/time-varying-infection/exposure_manager.rs +++ b/examples/time-varying-infection/exposure_manager.rs @@ -39,7 +39,10 @@ fn inverse_sampling_infection(context: &mut Context) -> f64 { let s: f64 = context.sample_distr(ExposureRng, Exp1); // get the time by following the formula described above // first need to get the simulation's sin_shift - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); let sin_shift = parameters.foi_sin_shift; let foi = parameters.foi; let f = func!(move |t| foi_t(t, foi, sin_shift)); @@ -91,7 +94,10 @@ mod test { }; let mut context = Context::new(); context.set_global_property_value(Parameters, p_values); - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); context.init_random(parameters.seed); init(&mut context); context.add_person(); @@ -124,7 +130,10 @@ mod test { }; let mut context = Context::new(); context.set_global_property_value(Parameters, p_values); - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); context.init_random(parameters.seed); // empirical mean let mut sum = 0.0; diff --git a/examples/time-varying-infection/incidence_report.rs b/examples/time-varying-infection/incidence_report.rs index 94c9b5c..1e67b88 100644 --- a/examples/time-varying-infection/incidence_report.rs +++ b/examples/time-varying-infection/incidence_report.rs @@ -31,7 +31,10 @@ fn handle_infection_status_change( } pub fn init(context: &mut Context) { - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); context .report_options() .directory(PathBuf::from(parameters.output_dir)); diff --git a/examples/time-varying-infection/infection_manager.rs b/examples/time-varying-infection/infection_manager.rs index 27f2b50..4619ff1 100644 --- a/examples/time-varying-infection/infection_manager.rs +++ b/examples/time-varying-infection/infection_manager.rs @@ -16,7 +16,10 @@ fn recovery_cdf(context: &mut Context, time_spent_infected: f64) -> f64 { } fn n_eff_inv_infec(context: &mut Context) -> f64 { - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); // get number of infected people let mut n_infected = 0; for usize_id in 0..context.get_current_population() { @@ -61,7 +64,10 @@ fn handle_infection_status_change( context: &mut Context, event: PersonPropertyChangeEvent, ) { - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); if matches!(event.current, DiseaseStatus::I) { // recall resampling rate is sum of maximum foi rate and gamma // maximum foi rate is foi * 2 -- the 2 because foi is sin(t + c) + 1 @@ -112,7 +118,10 @@ mod test { let mut context = Context::new(); context.set_global_property_value(Parameters, p_values); - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); context.init_random(parameters.seed); init(&mut context); diff --git a/examples/time-varying-infection/main.rs b/examples/time-varying-infection/main.rs index 4ce8a34..bbef9a7 100644 --- a/examples/time-varying-infection/main.rs +++ b/examples/time-varying-infection/main.rs @@ -20,7 +20,10 @@ fn main() { match parameters_loader::init_parameters(&mut context, &file_path) { Ok(()) => { - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); context.init_random(parameters.seed); exposure_manager::init(&mut context); diff --git a/examples/time-varying-infection/periodic_report.rs b/examples/time-varying-infection/periodic_report.rs index a843166..fc1c850 100644 --- a/examples/time-varying-infection/periodic_report.rs +++ b/examples/time-varying-infection/periodic_report.rs @@ -50,7 +50,10 @@ fn count_people_and_send_report(context: &mut Context, report_period: f64) { } pub fn init(context: &mut Context) { - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); context .report_options() .directory(PathBuf::from(parameters.output_dir)); diff --git a/examples/time-varying-infection/population_loader.rs b/examples/time-varying-infection/population_loader.rs index b989dde..fe95bb7 100644 --- a/examples/time-varying-infection/population_loader.rs +++ b/examples/time-varying-infection/population_loader.rs @@ -18,7 +18,10 @@ define_person_property_with_default!(DiseaseStatusType, DiseaseStatus, DiseaseSt define_person_property!(InfectionTime, f64); pub fn init(context: &mut Context) { - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); for _ in 0..parameters.population { context.add_person(); } @@ -51,7 +54,10 @@ mod tests { }; let mut context = Context::new(); context.set_global_property_value(Parameters, p_values); - let parameters = context.get_global_property_value(Parameters).clone(); + let parameters = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); context.init_random(parameters.seed); init(&mut context); diff --git a/src/global_properties.rs b/src/global_properties.rs index 8943f72..9525ffe 100644 --- a/src/global_properties.rs +++ b/src/global_properties.rs @@ -52,7 +52,10 @@ pub trait ContextGlobalPropertiesExt { ); /// Return value of global property T - fn get_global_property_value(&self, _property: T) -> &T::Value; + fn get_global_property_value( + &self, + _property: T, + ) -> Option<&T::Value>; /// Given a file path for a valid json file, deserialize parameter values /// for a given struct T @@ -79,12 +82,14 @@ impl GlobalPropertiesDataContainer { .or_insert_with(|| Box::new(value)); } - fn get_global_property_value(&self) -> &T::Value { - let data_container = self - .global_property_container - .get(&TypeId::of::()) - .expect("Global property not initialized"); - data_container.downcast_ref::().unwrap() + #[must_use] + fn get_global_property_value(&self) -> Option<&T::Value> { + let data_container = self.global_property_container.get(&TypeId::of::()); + + match data_container { + Some(property) => Some(property.downcast_ref::().unwrap()), + None => None, + } } } @@ -99,9 +104,14 @@ impl ContextGlobalPropertiesExt for Context { } #[allow(unused_variables)] - fn get_global_property_value(&self, _property: T) -> &T::Value { - let data_container = self.get_data_container(GlobalPropertiesPlugin).unwrap(); - data_container.get_global_property_value::() + fn get_global_property_value( + &self, + _property: T, + ) -> Option<&T::Value> { + if let Some(data_container) = self.get_data_container(GlobalPropertiesPlugin) { + return data_container.get_global_property_value::(); + }; + None } fn load_parameters_from_json( @@ -139,10 +149,21 @@ mod test { }; let mut context = Context::new(); context.set_global_property_value(DiseaseParams, params.clone()); - let global_params = context.get_global_property_value(DiseaseParams).clone(); + let global_params = context + .get_global_property_value(DiseaseParams) + .unwrap() + .clone(); assert_eq!(global_params.days, params.days); assert_eq!(global_params.diseases, params.diseases); } + + #[test] + fn get_global_propert_missing() { + let context = Context::new(); + let global_params = context.get_global_property_value(DiseaseParams); + assert!(global_params.is_none()); + } + #[test] fn set_parameters() { let mut context = Context::new(); @@ -166,7 +187,10 @@ mod test { context.set_global_property_value(Parameters, params_json); - let params_read = context.get_global_property_value(Parameters).clone(); + let params_read = context + .get_global_property_value(Parameters) + .unwrap() + .clone(); assert_eq!(params_read.days, params.days); assert_eq!(params_read.diseases, params.diseases); }