-
Notifications
You must be signed in to change notification settings - Fork 0
/
analysis.py
executable file
·81 lines (70 loc) · 2.59 KB
/
analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os.path as osp, os
import json
import pandas as pd
import joblib
import util
DEFAULT_GS_CONFIG = "sddpg/default.json"
def anaylsis(env_name, exp_name, gs_config=None, save_report=False):
print('Generating Report...')
groupby = False
if gs_config:
groupby = True
else:
gs_config = DEFAULT_GS_CONFIG
config = json.load(open(util.GS_CONFIG_DIR + gs_config))
params = list(config["grid"].keys())
data_dir = osp.join(util.LOG_DIR, env_name, exp_name)
data = []
def get_param(config, key):
for k in key.split('.'):
config = config[k]
return config
for exp in next(os.walk(data_dir))[1]:
try:
var = joblib.load("%s/%s/vars.pkl" % (data_dir, exp))
config = json.load(open("%s/%s/config.json" % (data_dir, exp)))
row = [var["max_ret"], exp]
for param in params:
for key in param.split(','):
row.append(get_param(config, key))
data.append(row)
except Exception as e:
print(e)
columns = []
if params:
for param in params:
for key in param.split(','):
columns.append(key)
df = pd.DataFrame(data, columns=["max_ret", "name"] + columns)
df = df.astype(str)
df["max_ret"] = pd.to_numeric(df["max_ret"])
if groupby:
res = df.groupby(columns)["max_ret"].agg(["mean", "std", "max", "min"])
res.columns = ["max_ret", "std", "max", "min"]
res = res.reset_index()
columns = list(res.columns)
columns = columns[-4:] + columns[:-4]
res = res[columns]
else:
res = df
res.sort_values("max_ret", ascending=False, inplace=True)
print(res.head())
if save_report:
if not osp.exists(util.REPORT_DIR):
os.makedirs(util.REPORT_DIR)
report_file = util.REPORT_DIR + env_name + "_" + exp_name + ".txt"
res.to_string(open(report_file, "w"))
print("report saved to " + report_file)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('env_name', type=str)
parser.add_argument('exp_name', type=str)
parser.add_argument('--gs_config', type=str, default=None)
parser.add_argument('--save_report', action="store_true")
args = parser.parse_args()
if args.gs_config is None:
input(
"No Grid Search Config Provided. Using sddpg/default.json.\n"
"Groupby params disabled. Confirm? or ctrl+C to abort.")
anaylsis(args.env_name, args.exp_name, args.gs_config, args.save_report)