-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
executable file
·86 lines (78 loc) · 2.82 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
from matplotlib import pyplot as plt
plt.switch_backend('agg')
DATA_DIR = "data/"
CONFIG_DIR = "config/"
ENV_CONFIG_DIR = CONFIG_DIR + "envs/"
MODEL_CONFIG_DIR = CONFIG_DIR + "models/"
GS_CONFIG_DIR = CONFIG_DIR + "gs/"
SETTING_DIR = DATA_DIR + "setting/"
LOG_DIR= DATA_DIR + "log/"
REPORT_DIR = DATA_DIR + "report/"
def plot_actions(actions, act_high, fig_file):
actions = np.array(actions)
if len(actions.shape) == 2:
_, n_params = actions.shape
plt.figure(figsize=(n_params*8, 5))
for j in range(n_params):
plt.subplot(1, n_params, j+1)
plt.hist(actions[:, j].flatten(), bins=100, range=(0, act_high[j]))
plt.title("param %d" % j)
else:
n_test, _, n_params = actions.shape
plt.figure(figsize=(n_params*8, (n_test+1)*5))
for j in range(n_params):
plt.subplot(n_test+1, n_params, j+1)
plt.hist(actions[:, :, j].flatten(), bins=100, range=(0, act_high[j]))
plt.title("total of param %d" % j)
for i in range(n_test):
for j in range(n_params):
plt.subplot(n_test+1, n_params, (i+1)*n_params+j+1)
plt.hist(actions[i, :, j], bins=100, range=(0, act_high[j]))
plt.title("test %d param %d" % (i, j))
plt.savefig(fig_file)
plt.close()
def plot_adv(act, adv, fig_file):
_, n_params = act.shape
m = 4
idx = abs(adv - np.mean(adv)) < m * np.std(adv)
act = act[idx]
adv = adv[idx]
plt.figure(figsize=(n_params*8, 5))
for i in range(n_params):
nbins = 20
act_i = act[:, i]
n, _ = np.histogram(act_i, bins=nbins)
sy, _ = np.histogram(act_i, bins=nbins, weights=adv)
sy2, _ = np.histogram(act_i, bins=nbins, weights=adv*adv)
mean = sy / n
std = np.sqrt(sy2/n - mean*mean)
plt.subplot(1, n_params, i+1)
plt.plot(act_i, adv, 'bo', zorder=-1)
plt.errorbar((_[1:] + _[:-1])/2, mean, yerr=std, fmt='r-', capsize=4)
plt.title("param %d" % i)
plt.savefig(fig_file)
plt.close()
def plot_seq_actions(actions, act_high, fig_file):
actions = np.array(actions)
_, n_params = actions.shape
plt.figure(figsize=(n_params*8, 5))
for j in range(n_params):
plt.subplot(1, n_params, j+1)
plt.plot(actions[:, j])
plt.title("param %d" % j)
plt.savefig(fig_file)
plt.close()
def plot_seq_obs_and_actions(obs, actions, act_high, fig_file):
obs = np.array(obs)
actions = np.array(actions)
_, n_params = actions.shape
plt.figure(figsize=(n_params*8, 5))
for j in range(n_params):
ax1 = plt.subplot(1, n_params, j+1)
ax1.plot(actions[:, j], 'b-')
ax2 = ax1.twinx()
ax2.plot(obs[:, 0], 'r-')
plt.title("param %d" % j)
plt.savefig(fig_file)
plt.close("all")