Skip to content

Commit

Permalink
Change |get_global_property()| to return Option<&T::Value>. Fixes #72
Browse files Browse the repository at this point in the history
Returns None when either:

- No global properties have been set
- Global properties exist but the requested property has not
  been set.
  • Loading branch information
ekr-cfa committed Nov 4, 2024
1 parent 208945f commit ec887af
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 16 deletions.
5 changes: 4 additions & 1 deletion examples/parameter-loading/incidence_report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ fn handle_infection_status_change(
}

pub fn init(context: &mut Context) {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context
.report_options()
.directory(PathBuf::from(parameters.output_dir));
Expand Down
5 changes: 4 additions & 1 deletion examples/parameter-loading/infection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use crate::Parameters;
define_rng!(InfectionRng);

fn schedule_recovery(context: &mut Context, person_id: PersonId) {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
let infection_duration = parameters.infection_duration;
let recovery_time = context.get_current_time()
+ context.sample_distr(InfectionRng, Exp::new(1.0 / infection_duration).unwrap());
Expand Down
5 changes: 4 additions & 1 deletion examples/parameter-loading/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ fn main() {

match parameters_loader::init_parameters(&mut context, &file_path) {
Ok(()) => {
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
context.init_random(parameters.seed);

for _ in 0..parameters.population {
Expand Down
5 changes: 4 additions & 1 deletion examples/parameter-loading/transmission_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ fn attempt_infection(context: &mut Context) {
context.get_person_id(context.sample_range(TransmissionRng, 0..population_size));
let person_status: InfectionStatus =
context.get_person_property(person_to_infect, InfectionStatusType);
let parameters = context.get_global_property_value(Parameters).clone();
let parameters = context
.get_global_property_value(Parameters)
.unwrap()
.clone();

if matches!(person_status, InfectionStatus::S) {
context.set_person_property(person_to_infect, InfectionStatusType, InfectionStatus::I);
Expand Down
48 changes: 36 additions & 12 deletions src/global_properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ pub trait ContextGlobalPropertiesExt {
);

/// Return value of global property T
fn get_global_property_value<T: GlobalProperty + 'static>(&self, _property: T) -> &T::Value;
fn get_global_property_value<T: GlobalProperty + 'static>(
&self,
_property: T,
) -> Option<&T::Value>;

/// Given a file path for a valid json file, deserialize parameter values
/// for a given struct T
Expand All @@ -79,12 +82,14 @@ impl GlobalPropertiesDataContainer {
.or_insert_with(|| Box::new(value));
}

fn get_global_property_value<T: GlobalProperty + 'static>(&self) -> &T::Value {
let data_container = self
.global_property_container
.get(&TypeId::of::<T>())
.expect("Global property not initialized");
data_container.downcast_ref::<T::Value>().unwrap()
#[must_use]
fn get_global_property_value<T: GlobalProperty + 'static>(&self) -> Option<&T::Value> {
let data_container = self.global_property_container.get(&TypeId::of::<T>());

match data_container {
Some(property) => Some(property.downcast_ref::<T::Value>().unwrap()),
None => None,
}
}
}

Expand All @@ -99,9 +104,14 @@ impl ContextGlobalPropertiesExt for Context {
}

#[allow(unused_variables)]
fn get_global_property_value<T: GlobalProperty + 'static>(&self, _property: T) -> &T::Value {
let data_container = self.get_data_container(GlobalPropertiesPlugin).unwrap();
data_container.get_global_property_value::<T>()
fn get_global_property_value<T: GlobalProperty + 'static>(
&self,
_property: T,
) -> Option<&T::Value> {
if let Some(data_container) = self.get_data_container(GlobalPropertiesPlugin) {
return data_container.get_global_property_value::<T>();
};
None
}

fn load_parameters_from_json<T: 'static + Debug + DeserializeOwned>(
Expand Down Expand Up @@ -139,10 +149,21 @@ mod test {
};
let mut context = Context::new();
context.set_global_property_value(DiseaseParams, params.clone());
let global_params = context.get_global_property_value(DiseaseParams).clone();
let global_params = context
.get_global_property_value(DiseaseParams)
.unwrap()
.clone();
assert_eq!(global_params.days, params.days);
assert_eq!(global_params.diseases, params.diseases);
}

#[test]
fn get_global_propert_missing() {
let context = Context::new();
let global_params = context.get_global_property_value(DiseaseParams);
assert!(global_params.is_none());
}

#[test]
fn set_parameters() {
let mut context = Context::new();
Expand All @@ -166,7 +187,10 @@ mod test {

context.set_global_property_value(Parameters, params_json);

let params_read = context.get_global_property_value(Parameters).clone();
let params_read = context
.get_global_property_value(Parameters)
.unwrap()
.clone();
assert_eq!(params_read.days, params.days);
assert_eq!(params_read.diseases, params.diseases);
}
Expand Down

0 comments on commit ec887af

Please sign in to comment.