Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize forward step for large populations #10

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
96 changes: 79 additions & 17 deletions python/pandemic_simulator/environment/pandemic_sim.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Confidential, Copyright 2020, Sony Corporation of America, All rights reserved.

from collections import defaultdict, OrderedDict
from itertools import product as cartesianproduct, combinations
from collections import defaultdict
from typing import DefaultDict, Dict, List, Optional, Sequence, cast
from functools import lru_cache

import numpy as np
from orderedset import OrderedSet

from .interfaces import ContactRate, ContactTracer, PandemicRegulation, PandemicSimState, PandemicTesting, \
PandemicTestResult, \
Expand Down Expand Up @@ -58,8 +57,8 @@ def __init__(self,
boolean in PandemicSimState is set to True.
:param numpy_rng: Random number generator.
"""
self._id_to_person = OrderedDict({p.id: p for p in persons})
self._id_to_location = OrderedDict({loc.id: loc for loc in locations})
self._id_to_person = {p.id: p for p in persons}
self._id_to_location = {loc.id: loc for loc in locations}
self._infection_model = infection_model
self._pandemic_testing = pandemic_testing
self._registry = registry
Expand Down Expand Up @@ -89,7 +88,7 @@ def __init__(self,
infection_above_threshold=False
)

def _compute_contacts(self, location: Location) -> OrderedSet:
def _compute_contacts(self, location: Location) -> dict:
assignees = location.state.assignees_in_location
visitors = location.state.visitors_in_location
cr = location.state.contact_rate
Expand All @@ -101,28 +100,42 @@ def _compute_contacts(self, location: Location) -> OrderedSet:
(cr.min_assignees_visitors, cr.fraction_assignees_visitors),
(cr.min_visitors, cr.fraction_visitors)]

contacts: OrderedSet = OrderedSet()

contacts_x: List = list()
contacts_y: List = list()
for grp, cst in zip(groups, constraints):
grp1, grp2 = grp
minimum, fraction = cst

possible_contacts = list(combinations(grp1, 2) if grp1 == grp2 else cartesianproduct(grp1, grp2))
num_possible_contacts = len(possible_contacts)
possible_contacts_x = []
possible_contacts_y = []
num_possible_contacts = 0

num_possible_contacts = n_choose_k(len(grp1), 2) if grp1 == grp2 else len(grp1) * len(grp2)

if len(possible_contacts) == 0:
if num_possible_contacts == 0:
continue

fraction_sample = min(1., max(0., self._numpy_rng.normal(fraction, 1e-2)))
real_fraction = max(minimum, int(fraction_sample * num_possible_contacts))

# we are using an orderedset, it's repeatable
contact_idx = self._numpy_rng.randint(0, num_possible_contacts, real_fraction)
contacts.update([possible_contacts[idx] for idx in contact_idx])

return contacts

def _compute_infection_probabilities(self, contacts: OrderedSet) -> None:
if grp1 == grp2:
possible_contacts_x, possible_contacts_y = comb2_reduced(np.asarray(grp1), contact_idx)
else:
possible_contacts_x, possible_contacts_y = prod_reduced(np.asarray(grp1), np.asarray(grp2), contact_idx)
contacts_x = np.concatenate((contacts_x, possible_contacts_x))
contacts_y = np.concatenate((contacts_y, possible_contacts_y))

# Stuff the contact pairs into a dictionary/Set, removing duplicates from repeats in contact_idx
r = dict()
for i, c in enumerate(contacts_x):
r[contacts_x[i], contacts_y[i]] = 0
return r

def _compute_infection_probabilities(self, contacts: dict) -> None:
if len(contacts) < 1:
return
infectious_states = {InfectionSummary.INFECTED, InfectionSummary.CRITICAL}

for c in contacts:
Expand Down Expand Up @@ -199,7 +212,6 @@ def step(self) -> None:
# update person contacts
for location in self._id_to_location.values():
contacts = self._compute_contacts(location)

if self._contact_tracer:
self._contact_tracer.add_contacts(contacts)

Expand Down Expand Up @@ -322,3 +334,53 @@ def reset(self) -> None:
regulation_stage=0,
infection_above_threshold=False,
)

def person_update(self, person: Person) -> None:
person.step(self._state.sim_time, self._contact_tracer)
return


def prod_reduced(a: np.array, b: np.array, idx: list) -> tuple:
"""
return AxB only at desired indices
"""
return a[idx // b.size], b[idx % b.size]


def comb2_reduced(l: np.array, idx: list) -> tuple:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you improve the documentation on the arguments and what is happening inside the function?

"""
Compute combinations of 2 on a subarray of input l given by indicies in input idx
Uses the upper triangular matrix on input l to generate 2-combination coordinates which are extracted from
l as a vector

:param: l, base array
:param: idx, index values of l that define a subvector for which combinations of 2 will be computed
:return: actual combinations of 2 in l given idx
"""
triu = np.triu_indices(l.size, 1)
return l[triu[0][idx]], l[triu[1][idx]]


@lru_cache(maxsize=None)
def n_choose_k(n: int, k: int) -> int:
"""
Calulate the number of combinations in N choose K
When K is 0 or 1, the answer is returned directly. When K > 1, iterate to compute factoral to compute
nCk formula = n! / (k! (n-k)! by using m as an accumulator

:return: number of ways to choose k from n
"""
m = 0
if k == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment on these conditions?

m = 1
if k == 1:
m = n
if k >= 2:
num, dem, op1, op2 = 1, 1, k, n
while(op1 >= 1):
num *= op2
dem *= op1
op1 -= 1
op2 -= 1
m = num//dem
return m