Skip to content

Commit

Permalink
Update componets to calculate and use custom paf for preterm
Browse files Browse the repository at this point in the history
  • Loading branch information
albrja committed Mar 6, 2025
1 parent 6b750e3 commit 1e3dd2a
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 18 deletions.
1 change: 0 additions & 1 deletion src/vivarium_gates_mncnh/components/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pandas as pd
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event

from vivarium_gates_mncnh.constants import data_values
from vivarium_gates_mncnh.constants.data_keys import NO_CPAP_RISK
Expand Down
3 changes: 1 addition & 2 deletions src/vivarium_gates_mncnh/components/lbwsg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from vivarium.framework.engine import Builder
from vivarium.framework.lookup import LookupTable
from vivarium.framework.population import SimulantData
from vivarium.framework.resource import Resource
from vivarium.framework.values import Pipeline
from vivarium_public_health.risks.data_transformations import (
get_exposure_post_processor,
Expand Down Expand Up @@ -91,7 +90,7 @@ def initialization_requirements(self) -> dict[str, list[str]]:

def setup(self, builder: Builder) -> None:
# Paf pipeline needs to be registered before the super setup is called
self.acmr_paf = builder.value.register_value_producer(
self.paf = builder.value.register_value_producer(
f"lbwsg_paf_on_{self.target.name}.{self.target.measure}",
source=self.lookup_tables["population_attributable_fraction"],
component=self,
Expand Down
33 changes: 29 additions & 4 deletions src/vivarium_gates_mncnh/components/neonatal_causes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.lookup import LookupTable
from vivarium.framework.values import Pipeline

from vivarium_gates_mncnh.constants import data_keys
from vivarium_gates_mncnh.constants.data_values import (
Expand Down Expand Up @@ -36,25 +38,29 @@ def __init__(self, neonatal_cause: str) -> None:
#####################

def setup(self, builder: Builder) -> None:
self.acmr_paf = builder.value.get_value(PIPELINES.ACMR_PAF)
# This is the ACMR PAF pipeline. For preterm we will get the custom preterm PAF
self.paf = self.get_paf(self, builder)
# Register csmr pipeline
self.csmr = builder.value.register_value_producer(
f"cause.{self.neonatal_cause}.cause_specific_mortality_rate",
source=self.get_normalized_csmr,
component=self,
required_resources=[PIPELINES.ACMR_PAF],
required_resources=[self.paf],
)
builder.value.register_value_modifier(
"death_in_age_group_probability",
modifier=self.modify_death_in_age_group_probability,
component=self,
required_resources=[PIPELINES.ACMR_PAF, PIPELINES.DEATH_IN_AGE_GROUP_PROBABILITY],
required_resources=[self.paf],
)

##################
# Helper methods #
##################

def get_paf(self, builder: Builder) -> Pipeline:
return builder.value.get_value(PIPELINES.ACMR_PAF)

def load_csmr(self, builder: Builder) -> pd.DataFrame:
csmr = builder.data.load(f"cause.{self.neonatal_cause}.cause_specific_mortality_rate")
csmr = csmr.rename(columns=CHILD_LOOKUP_COLUMN_MAPPER)
Expand All @@ -64,7 +70,7 @@ def get_normalized_csmr(self, index: pd.Index) -> pd.Series:
# CSMR = CSMR * (1-PAF) * RR
# NOTE: There is LBWSG RR on this pipeline
raw_csmr = self.lookup_tables["csmr"](index)
normalizing_constant = 1 - self.acmr_paf(index)
normalizing_constant = 1 - self.paf(index)
normalized_csmr = raw_csmr * normalizing_constant

return normalized_csmr
Expand All @@ -80,6 +86,25 @@ def modify_death_in_age_group_probability(


class PretermBirth(NeonatalCause):
@property
def configuration_defaults(self) -> dict:
return {
self.name: {
"data_sources": {
"csmr": self.load_csmr,
"paf": self.load_paf,
}
}
}

def load_paf(self, builder: Builder) -> pd.DataFrame:
paf = builder.data.load(data_keys.PRETERM_BIRTH.PAF)
paf = paf.rename(columns=CHILD_LOOKUP_COLUMN_MAPPER)
return paf

def get_paf(self, _: Builder) -> LookupTable:
return self.lookup_tables["paf"]

def get_normalized_csmr(self, index: pd.Index) -> pd.Series:
pop = self.population_view.get(index)
ga_greater_than_37 = pop[COLUMNS.GESTATIONAL_AGE] >= 37
Expand Down
1 change: 1 addition & 0 deletions src/vivarium_gates_mncnh/constants/data_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def log_name(self):
class __NeonatalPretermBirth(NamedTuple):
# Keys that will be loaded into the artifact. must have a colon type declaration
CSMR: str = "cause.neonatal_preterm_birth.cause_specific_mortality_rate"
PAF: str = "cause.neonatal_preterm_birth.population_attributable_fraction"

@property
def name(self):
Expand Down
18 changes: 10 additions & 8 deletions src/vivarium_gates_mncnh/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_data(
data_keys.LBWSG.EXPOSURE: load_lbwsg_exposure,
data_keys.LBWSG.RELATIVE_RISK: load_lbwsg_rr,
data_keys.LBWSG.RELATIVE_RISK_INTERPOLATOR: load_lbwsg_interpolated_rr,
data_keys.LBWSG.PAF: load_lbwsg_paf,
data_keys.LBWSG.PAF: load_paf_data,
data_keys.ANC.ESTIMATE: load_anc_proportion,
data_keys.MATERNAL_SEPSIS.RAW_INCIDENCE_RATE: load_standard_data,
data_keys.MATERNAL_SEPSIS.CSMR: load_standard_data,
Expand All @@ -81,6 +81,7 @@ def get_data(
data_keys.OBSTRUCTED_LABOR.CSMR: load_standard_data,
data_keys.OBSTRUCTED_LABOR.YLD_RATE: load_maternal_disorder_yld_rate,
data_keys.PRETERM_BIRTH.CSMR: load_standard_data,
data_keys.PRETERM_BIRTH.PAF: load_paf_data,
data_keys.NEONATAL_SEPSIS.CSMR: load_standard_data,
data_keys.NEONATAL_ENCEPHALOPATHY.CSMR: load_standard_data,
data_keys.NO_CPAP_RISK.P_RDS: load_p_rds,
Expand Down Expand Up @@ -377,11 +378,15 @@ def make_interpolator(log_rr_for_age_sex_draw: pd.Series) -> RectBivariateSpline
return log_rr_interpolator


def load_lbwsg_paf(
def load_paf_data(
key: str, location: str, years: Optional[Union[int, str, list[int]]]
) -> pd.DataFrame:
if key != data_keys.LBWSG.PAF:
raise ValueError(f"Unrecognized key {key}")
if key == data_keys.LBWSG.PAF:
filename = (
"calculated_lbwsg_paf_on_cause.all_causes.cause_specific_mortality_rate.parquet"
)
else:
filename = "calculated_lbwsg_paf_on_cause.all_causes.cause_specific_mortality_rate_preterm.parquet"

location_mapper = {
"Ethiopia": "ethiopia",
Expand All @@ -391,10 +396,7 @@ def load_lbwsg_paf(

output_dir = paths.PAF_DIR / location_mapper[location]

df = pd.read_parquet(
output_dir
/ "calculated_lbwsg_paf_on_cause.all_causes.cause_specific_mortality_rate.parquet"
)
df = pd.read_parquet(output_dir / filename)
if "input_draw" in df.columns:
df = df.assign(input_draw="draw_" + df.input_draw.astype(str))
else:
Expand Down
7 changes: 4 additions & 3 deletions src/vivarium_gates_mncnh/model_specifications/model_spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ components:
- AgelessPopulation("population.scaling_factor")
- Pregnancy()
- ChildrenBirthExposure()
- Intrapartum()
# - Intrapartum()
- ResultsStratifier()
- BirthObserver()
- AntenatalCare()
Expand All @@ -33,7 +33,7 @@ components:
- NeonatalCause('neonatal_sepsis_and_other_neonatal_infections')
- NeonatalCause('neonatal_encephalopathy_due_to_birth_asphyxia_and_trauma')
- NeonatalMortality()
- NoCPAPRisk()
# - NoCPAPRisk()
# Add model observers below here
- ANCObserver()
- MaternalDisordersBurdenObserver()
Expand All @@ -43,7 +43,8 @@ components:
configuration:
input_data:
input_draw_number: 0
artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/cpap/ethiopia.hdf"
artifact_path: /home/albrja/scratch/artifacts/ethiopia.hdf
# artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/cpap/ethiopia.hdf"
interpolation:
order: 0
extrapolate: True
Expand Down

0 comments on commit 1e3dd2a

Please sign in to comment.