-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_analysis.py
66 lines (47 loc) · 2.17 KB
/
run_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from test_RNN import fractions, days
from vizualizations import plot_data_set_composition, make_training_history_plot
from interpret_results import save_all_phase_vs_accuracy_plot, save_class_wise_phase_vs_accuracy_plot, merge_performance_tables
default_seed = 40
np.random.seed(default_seed)
def merge_sample_tables(model_dir):
df_train = pd.read_csv(f"{model_dir}/train_sample.csv", index_col=0)
df_train = df_train._append({'Class':'Total', 'Count':sum(df_train['Count'].to_numpy())}, ignore_index=True)
df_train.rename(columns={'Count': 'Train_count'}, inplace=True)
df_test = pd.read_csv(f"{model_dir}/test_sample.csv", index_col=0)
df_test = df_test._append({'Class':'Total', 'Count':sum(df_test['Count'].to_numpy())}, ignore_index=True)
df_test.rename(columns={'Count': 'Test_count'}, inplace=True)
df_val = pd.read_csv(f"{model_dir}/validation_sample.csv", index_col=0)
df_val = df_val._append({'Class':'Total', 'Count':sum(df_val['Count'].to_numpy())}, ignore_index=True)
df_val.rename(columns={'Count': 'val_count'}, inplace=True)
df_combined = df_train.merge(df_test, on='Class')
df_combined = df_combined.merge(df_val, on='Class')
df_combined.to_csv(f'{model_dir}/combined_sample.csv')
print(df_combined.to_latex(index=False))
def parse_args():
'''
Get commandline options
'''
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=Path, required=True, help='Directory containing best_mode.h5. Results will be stored in the same directory.')
args = parser.parse_args()
return args
def run_analysis(model_dir):
make_training_history_plot(model_dir)
save_all_phase_vs_accuracy_plot(model_dir, days=days)
plt.close()
save_class_wise_phase_vs_accuracy_plot(model_dir, days=days)
plt.close()
# plot the make up of all the data sets
plot_data_set_composition(model_dir)
plt.close()
merge_sample_tables(model_dir)
merge_performance_tables(model_dir)
if __name__=='__main__':
args = parse_args()
run_analysis(args.model_dir)