Skip to content

Commit

Permalink
Change |get_global_property()| to return Option<&T::Value>. Fixes #72 (
Browse files Browse the repository at this point in the history
…#73)

Returns None when either:

- No global properties have been set
- Global properties exist but the requested property has not
  been set.
  • Loading branch information
ekr-cfa authored Nov 5, 2024
1 parent 5fe79e9 commit 36de0a6
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 27 deletions.
5 changes: 4 additions & 1 deletion examples/parameter-loading/incidence_report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ fn handle_infection_status_change(
}

pub fn init(context: &mut Context) {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context
.report_options()
.directory(PathBuf::from(parameters.output_dir));
Expand Down
5 changes: 4 additions & 1 deletion examples/parameter-loading/infection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use crate::Parameters;
define_rng!(InfectionRng);

fn schedule_recovery(context: &mut Context, person_id: PersonId) {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
let infection_duration = parameters.infection_duration;
let recovery_time = context.get_current_time()
+ context.sample_distr(InfectionRng, Exp::new(1.0 / infection_duration).unwrap());
Expand Down
5 changes: 4 additions & 1 deletion examples/parameter-loading/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ fn main() {

match parameters_loader::init_parameters(&mut context, &file_path) {
Ok(()) => {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context.init_random(parameters.seed);

for _ in 0..parameters.population {
Expand Down
5 changes: 4 additions & 1 deletion examples/parameter-loading/transmission_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ fn attempt_infection(context: &mut Context) {
context.get_person_id(context.sample_range(TransmissionRng, 0..population_size));
let person_status: InfectionStatus =
context.get_person_property(person_to_infect, InfectionStatusType);
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();

if matches!(person_status, InfectionStatus::S) {
context.set_person_property(person_to_infect, InfectionStatusType, InfectionStatus::I);
Expand Down
15 changes: 12 additions & 3 deletions examples/time-varying-infection/exposure_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ fn inverse_sampling_infection(context: &mut Context) -> f64 {
let s: f64 = context.sample_distr(ExposureRng, Exp1);
// get the time by following the formula described above
// first need to get the simulation's sin_shift
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
let sin_shift = parameters.foi_sin_shift;
let foi = parameters.foi;
let f = func!(move |t| foi_t(t, foi, sin_shift));
Expand Down Expand Up @@ -91,7 +94,10 @@ mod test {
};
let mut context = Context::new();
context.set_global_property_value(Parameters, p_values);
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context.init_random(parameters.seed);
init(&mut context);
context.add_person();
Expand Down Expand Up @@ -124,7 +130,10 @@ mod test {
};
let mut context = Context::new();
context.set_global_property_value(Parameters, p_values);
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context.init_random(parameters.seed);
// empirical mean
let mut sum = 0.0;
Expand Down
5 changes: 4 additions & 1 deletion examples/time-varying-infection/incidence_report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ fn handle_infection_status_change(
}

pub fn init(context: &mut Context) {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context
.report_options()
.directory(PathBuf::from(parameters.output_dir));
Expand Down
15 changes: 12 additions & 3 deletions examples/time-varying-infection/infection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ fn recovery_cdf(context: &mut Context, time_spent_infected: f64) -> f64 {
}

fn n_eff_inv_infec(context: &mut Context) -> f64 {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
// get number of infected people
let mut n_infected = 0;
for usize_id in 0..context.get_current_population() {
Expand Down Expand Up @@ -61,7 +64,10 @@ fn handle_infection_status_change(
context: &mut Context,
event: PersonPropertyChangeEvent<DiseaseStatusType>,
) {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
if matches!(event.current, DiseaseStatus::I) {
// recall resampling rate is sum of maximum foi rate and gamma
// maximum foi rate is foi * 2 -- the 2 because foi is sin(t + c) + 1
Expand Down Expand Up @@ -112,7 +118,10 @@ mod test {
let mut context = Context::new();

context.set_global_property_value(Parameters, p_values);
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context.init_random(parameters.seed);
init(&mut context);

Expand Down
5 changes: 4 additions & 1 deletion examples/time-varying-infection/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ fn main() {

match parameters_loader::init_parameters(&mut context, &file_path) {
Ok(()) => {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context.init_random(parameters.seed);

exposure_manager::init(&mut context);
Expand Down
5 changes: 4 additions & 1 deletion examples/time-varying-infection/periodic_report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ fn count_people_and_send_report(context: &mut Context, report_period: f64) {
}

pub fn init(context: &mut Context) {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context
.report_options()
.directory(PathBuf::from(parameters.output_dir));
Expand Down
10 changes: 8 additions & 2 deletions examples/time-varying-infection/population_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ define_person_property_with_default!(DiseaseStatusType, DiseaseStatus, DiseaseSt
define_person_property!(InfectionTime, f64);

pub fn init(context: &mut Context) {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
for _ in 0..parameters.population {
context.add_person();
}
Expand Down Expand Up @@ -51,7 +54,10 @@ mod tests {
};
let mut context = Context::new();
context.set_global_property_value(Parameters, p_values);
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context.init_random(parameters.seed);
init(&mut context);

Expand Down
48 changes: 36 additions & 12 deletions src/global_properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ pub trait ContextGlobalPropertiesExt {
);

/// Return value of global property T
fn get_global_property_value<T: GlobalProperty + 'static>(&self, _property: T) -> &T::Value;
fn get_global_property_value<T: GlobalProperty + 'static>(
&self,
_property: T,
) -> Option<&T::Value>;

/// Given a file path for a valid json file, deserialize parameter values
/// for a given struct T
Expand All @@ -79,12 +82,14 @@ impl GlobalPropertiesDataContainer {
.or_insert_with(|| Box::new(value));
}

fn get_global_property_value<T: GlobalProperty + 'static>(&self) -> &T::Value {
let data_container = self
.global_property_container
.get(&TypeId::of::<T>())
.expect("Global property not initialized");
data_container.downcast_ref::<T::Value>().unwrap()
#[must_use]
fn get_global_property_value<T: GlobalProperty + 'static>(&self) -> Option<&T::Value> {
let data_container = self.global_property_container.get(&TypeId::of::<T>());

match data_container {
Some(property) => Some(property.downcast_ref::<T::Value>().unwrap()),
None => None,
}
}
}

Expand All @@ -99,9 +104,14 @@ impl ContextGlobalPropertiesExt for Context {
}

#[allow(unused_variables)]
fn get_global_property_value<T: GlobalProperty + 'static>(&self, _property: T) -> &T::Value {
let data_container = self.get_data_container(GlobalPropertiesPlugin).unwrap();
data_container.get_global_property_value::<T>()
fn get_global_property_value<T: GlobalProperty + 'static>(
&self,
_property: T,
) -> Option<&T::Value> {
if let Some(data_container) = self.get_data_container(GlobalPropertiesPlugin) {
return data_container.get_global_property_value::<T>();
};
None
}

fn load_parameters_from_json<T: 'static + Debug + DeserializeOwned>(
Expand Down Expand Up @@ -139,10 +149,21 @@ mod test {
};
let mut context = Context::new();
context.set_global_property_value(DiseaseParams, params.clone());
let global_params = context.get_global_property_value(DiseaseParams).clone();
let global_params = context
.get_global_property_value(DiseaseParams)
.unwrap()
.clone();
assert_eq!(global_params.days, params.days);
assert_eq!(global_params.diseases, params.diseases);
}

#[test]
fn get_global_propert_missing() {
let context = Context::new();
let global_params = context.get_global_property_value(DiseaseParams);
assert!(global_params.is_none());
}

#[test]
fn set_parameters() {
let mut context = Context::new();
Expand All @@ -166,7 +187,10 @@ mod test {

context.set_global_property_value(Parameters, params_json);

let params_read = context.get_global_property_value(Parameters).clone();
let params_read = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
assert_eq!(params_read.days, params.days);
assert_eq!(params_read.diseases, params.diseases);
}
Expand Down

0 comments on commit 36de0a6

Please sign in to comment.