-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Input the contacts in parallel from numpy (#2)
* Fix deprecation of numpy types in tests * add python contacts read function * split method to add contacts * split into functions * Use map to preprocess contacts * second try in parallel * Add test for large tree
- Loading branch information
Showing
9 changed files
with
312 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"N": 19531, | ||
"t_limit": 25, | ||
"mu": 0.01 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#!/usr/bin/env python | ||
# coding: utf-8 | ||
import sys | ||
sys.path.insert(0,"..") | ||
|
||
import numpy as np | ||
import json | ||
import sib | ||
|
||
with open("data/large_tree_pars.json") as f: | ||
params = json.load(f) | ||
|
||
cts_beliefs_f = np.load("data/large_tree_data.npz") | ||
|
||
cts = cts_beliefs_f["cts"] | ||
|
||
beliefs_all = cts_beliefs_f["beliefs"] | ||
|
||
obs_all = {} | ||
for k in cts_beliefs_f.files: | ||
if "obs_" in k: | ||
u=int(k.split("_")[-1]) | ||
#print(k, u) | ||
obs_all[u] = cts_beliefs_f[k] | ||
|
||
#close file | ||
cts_beliefs_f.close() | ||
|
||
sib_pars = sib.Params(prob_r=sib.Gamma(mu=params["mu"])) | ||
|
||
N = params["N"] | ||
t_limit = params["t_limit"] | ||
cts_sib = [(int(r["i"]),int(r["j"]),int(r["t"]),r["lam"]) for r in cts] | ||
|
||
tests = [sib.Test(s==0,s==1,s==2) for s in range(3)] | ||
def make_obs_sib(N, t_limit,obs, tests): | ||
obs_list_sib =[(i,-1,t) for t in [t_limit] for i in range(N) ] | ||
obs_list_sib.extend([(r["i"],tests[r["st"]],r["t"]) for r in obs]) | ||
|
||
obs_list_sib.sort(key=lambda x: x[-1]) | ||
|
||
return obs_list_sib | ||
|
||
callback = lambda t, err, fg: print(f"iter: {t:6}, err: {err:.5e} ", end="\r") | ||
|
||
for ii,obs in obs_all.items(): | ||
fg = sib.FactorGraph(params=sib_pars) | ||
beliefs = beliefs_all[ii] | ||
#print(f"Instance {ii}") | ||
#for c in cts_sib: | ||
# fg.append_contact(*c) | ||
fg.append_contacts_npy(cts["i"], cts["j"], cts["t"], cts["lam"]) | ||
obs_list_sib = make_obs_sib(N,t_limit, obs, tests) | ||
for o in obs_list_sib: | ||
fg.append_observation(*o) | ||
|
||
sib.iterate(fg,200,1e-20,callback=callback ) | ||
print("") | ||
s=0. | ||
for i in range(len(fg.nodes)): | ||
s+=np.abs(np.array(fg.nodes[i].bt)-beliefs[i][0]).sum() | ||
s+=np.abs(np.array(fg.nodes[i].bg)-beliefs[i][1]).sum() | ||
#fg.nodes[i].bg])) | ||
print(f"instance {ii}: {s:4.3e} {s < 1e-10}") | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,5 +81,4 @@ instance 45: True | |
instance 46: True | ||
instance 47: True | ||
instance 48: True | ||
instance 49: True | ||
|
||
instance 49: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
>>> from pathlib import Path | ||
>>> import sys | ||
>>> sys.path.insert(0,'test/') | ||
>>> import numpy as np | ||
>>> import sib | ||
>>> import data_load | ||
|
||
### LOAD DATA | ||
>>> folder_data = Path("test/data/tree_check/") | ||
>>> params,contacts,observ,epidem = data_load.load_exported_data(folder_data) | ||
>>> contacts = contacts[["i","j","t","lambda"]] | ||
>>> obs_all_df = [] | ||
>>> for obs in observ: | ||
... obs_df = data_load.convert_obs_to_df(obs) | ||
... obs_all_df.append(obs_df[["i","st","t"]]) | ||
>>> n_inst = len(observ) | ||
>>> print(f"Number of instances: {n_inst}") | ||
Number of instances: 50 | ||
|
||
### TEST RESULTS | ||
>>> beliefs = np.load(folder_data / "beliefs_tree.npz") | ||
>>> tests = [sib.Test(s==0,s==1,s==2) for s in range(3)] | ||
>>> sib_pars = sib.Params(prob_r=sib.Gamma(mu=params["mu"])) | ||
>>> cts = contacts.to_records(index=False) | ||
>>> for inst in range(n_inst): | ||
... obs = list(obs_all_df[inst].to_records(index=False)) | ||
... obs = [(i,tests[s],t) for (i,s,t) in obs] | ||
... fg = sib.FactorGraph(params=sib_pars) | ||
... fg.append_contacts_npy(cts["i"], cts["j"], cts["t"], cts["lambda"]) | ||
... for o in obs: | ||
... fg.append_observation(*o) | ||
... sib.iterate(fg,200,1e-20,callback=None) | ||
... s = 0.0 | ||
... for i in range(len(fg.nodes)): | ||
... s += sum(abs(beliefs[f"{inst}_{i}"][0]-np.array(fg.nodes[i].bt))) | ||
... print(f"instance {inst}: {s < 1e-10}") | ||
instance 0: True | ||
instance 1: True | ||
instance 2: True | ||
instance 3: True | ||
instance 4: True | ||
instance 5: True | ||
instance 6: True | ||
instance 7: True | ||
instance 8: True | ||
instance 9: True | ||
instance 10: True | ||
instance 11: True | ||
instance 12: True | ||
instance 13: True | ||
instance 14: True | ||
instance 15: True | ||
instance 16: True | ||
instance 17: True | ||
instance 18: True | ||
instance 19: True | ||
instance 20: True | ||
instance 21: True | ||
instance 22: True | ||
instance 23: True | ||
instance 24: True | ||
instance 25: True | ||
instance 26: True | ||
instance 27: True | ||
instance 28: True | ||
instance 29: True | ||
instance 30: True | ||
instance 31: True | ||
instance 32: True | ||
instance 33: True | ||
instance 34: True | ||
instance 35: True | ||
instance 36: True | ||
instance 37: True | ||
instance 38: True | ||
instance 39: True | ||
instance 40: True | ||
instance 41: True | ||
instance 42: True | ||
instance 43: True | ||
instance 44: True | ||
instance 45: True | ||
instance 46: True | ||
instance 47: True | ||
instance 48: True | ||
instance 49: True |