Skip to content

Commit

Permalink
Check for person properties (#40)
Browse files Browse the repository at this point in the history
* check properties

* check that schemas are equal

* don't call it 'field'

* properties is a set
  • Loading branch information
swo authored Dec 19, 2024
1 parent d436c06 commit 239a728
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 16 deletions.
22 changes: 21 additions & 1 deletion ringvax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@


class Simulation:
PROPERTIES = {
"infector",
"generation",
"t_exposed",
"t_infectious",
"t_recovered",
"infection_rate",
"detected",
"detect_method",
"t_detected",
"infection_times",
}

def __init__(self, params: dict[str, Any], seed: Optional[int] = None):
self.params = params
self.seed = seed
Expand All @@ -15,14 +28,21 @@ def __init__(self, params: dict[str, Any], seed: Optional[int] = None):
def create_person(self) -> str:
"""Add a new person to the data"""
id = str(len(self.infections))
self.infections[id] = {}
self.infections[id] = {x: None for x in self.PROPERTIES}
return id

def update_person(self, id: str, content: dict[str, Any]) -> None:
bad_properties = set(content.keys()) - set(self.PROPERTIES)
if len(bad_properties) > 0:
raise RuntimeError(f"Properties not in schema: {bad_properties}")

self.infections[id] |= content

def get_person_property(self, id: str, property: str) -> Any:
"""Get a property of a person"""
if property not in self.PROPERTIES:
raise RuntimeError(f"Property '{property}' not in schema")

if id not in self.infections:
raise RuntimeError(f"No person with {id=}")
elif property not in self.infections[id]:
Expand Down
2 changes: 2 additions & 0 deletions ringvax/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
An infection as a polars schema
"""

assert set(infection_schema.keys()) == Simulation.PROPERTIES


def prepare_for_df(infection: dict) -> dict:
"""
Expand Down
48 changes: 33 additions & 15 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def test_generate_disease_history_nonzero(rng):
}


def test_simulate(rng):
params = {
@pytest.fixture
def base_params():
return {
"n_generations": 4,
"latent_duration": 1.0,
"infectious_duration": 3.0,
Expand All @@ -96,23 +97,40 @@ def test_simulate(rng):
"active_detection_delay": 2.0,
"max_infections": 100,
}
s = ringvax.Simulation(params=params, seed=rng)


def test_simulate(rng, base_params):
s = ringvax.Simulation(params=base_params, seed=rng)
s.run()
assert len(s.infections) == 19


def test_simulate_max_infections(rng):
params = {
"n_generations": 4,
"latent_duration": 1.0,
"infectious_duration": 3.0,
"infection_rate": 1.0,
"p_passive_detect": 0.5,
"passive_detection_delay": 2.0,
"p_active_detect": 0.15,
"active_detection_delay": 2.0,
"max_infections": 10,
}
def test_simulate_max_infections(rng, base_params):
params = base_params
params["max_infections"] = 10
s = ringvax.Simulation(params=params, seed=rng)
s.run()
assert len(s.infections) == 10


def test_simulate_set_field(rng, base_params):
s = ringvax.Simulation(params=base_params, seed=rng)
id = s.create_person()
s.update_person(id, {"generation": 0})
assert s.get_person_property(id, "generation") == 0


def test_simulate_error_on_bad_get_property(rng, base_params):
s = ringvax.Simulation(params=base_params, seed=rng)
id = s.create_person()

with pytest.raises(RuntimeError, match="foo"):
s.get_person_property(id, "foo")


def test_simulate_error_on_bad_update_property(rng, base_params):
s = ringvax.Simulation(params=base_params, seed=rng)
id = s.create_person()

with pytest.raises(RuntimeError, match="foo"):
s.update_person(id, {"foo": 0})

0 comments on commit 239a728

Please sign in to comment.