-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDiffReact.py
127 lines (102 loc) · 2.74 KB
/
DiffReact.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from mpi4py.futures import MPIPoolExecutor
import pickle
from configs import Config
from systems import FHN_PDE, DiffReact
from solver import SolverRK, SolverScipy
from parareal import PararealLight
import numpy as np
import os
import sys
import warnings
from scipy.linalg import LinAlgWarning
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=LinAlgWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
# print(sys.argv, sys.argv[-3:])
N, mdl, dx = sys.argv[-3:]
N = int(N)
dx = int(dx)
d_y = dx
T = 20
if dx == 19:
mul = 1
G = 'RK1'
F = 'RK4'
mulF = 1800e2
assert N == 64
elif dx == 28:
mul = 1
G = 'RK1'
F = 'RK4'
mulF = 1200e2
assert N == 128
elif dx == 41:
mul = 1
G = 'RK1'
F = 'RK4'
mulF = 800e2
assert N == 256
elif dx == 77:
mul = 1
G = 'RK4'
F = 'RK8'
mulF = 300e2
assert N == 512
elif dx == 113:
mul = 2
G = 'RK4'
F = 'RK8'
mulF = 400e2
assert N == 512
elif dx == 164:
mul = 4
G = 'RK4'
F = 'RK8'
mulF = 500e2
assert N == 512
elif dx == 235:
mul = 8
G = 'RK4'
F = 'RK8'
mulF = 600e2
assert N == 512
else:
raise Exception('Invalid dx val')
ode = DiffReact(d_x=dx, use_jax=False, normalization='-11')
_f = ode.get_vector_field()
def f(t, u):
return _f(t, u)
if __name__ == '__main__':
avail_work = int(os.getenv('SLURM_NTASKS'))
workers = avail_work - 1
print('Total workes', workers)
pool = MPIPoolExecutor(workers)
print(N, mdl, dx, ode.name)
########## CHANGE THIS #########
dir_name = 'DiffReactScal'
################################
name = f'{dir_name}_{dx}_{N}_{mdl}'
assert workers >= N
# generate folder
if not os.path.exists(dir_name):
os.mkdir(dir_name)
########## CHANGE THIS #########
s = SolverScipy(f, mul, mulF, G, F, verbose=False, use_jax=False)
p = PararealLight(ode, s, tspan=(0, T), N=N)
#####################################
# run the code, storing intermediates in custom folder
if mdl == 'para':
res = p.run(pool=pool, parall='mpi', light=True)
elif mdl == 'elm':
res = p.run(model='elm', degree=1, m=3, pool=pool, parall='mpi', light=True)
elif mdl == 'nngp':
res = p.run(model='nngp', pool=pool, parall='mpi', light=True,
nn=20)
elif mdl == 'gp':
res = p.run(model='gpjax', pool=pool, parall='mpi', light=True)
else:
raise Exception('Unknown model type', mdl)
res['timings'].pop('by_iter')
print(res['timings'])
# dump the final result
p.store(name=name, path=dir_name)