-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_lstm_without_env.py
92 lines (86 loc) · 3.41 KB
/
train_lstm_without_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from os.path import join, exists
import torch
import numpy as np
from stable_baselines3.common.utils import get_device
from mav_baselines.torch.recurrent_ppo.policies import MultiInputLstmPolicy
from mav_baselines.torch.recurrent_ppo.ppo_recurrent import RecurrentPPO
def configure_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
def parser():
parser = argparse.ArgumentParser()
parser.add_argument("--train", type=int, default=1, help="Train the policy or evaluate the policy")
parser.add_argument("--retrain", type=int, default=1, help="if retrain")
parser.add_argument("--trial", type=int, default=1, help="PPO trial number")
parser.add_argument("--iter", type=int, default=100, help="PPO iter number")
parser.add_argument("--dir", type=str, default="./saved/lstm_dataset",
help="Where to place rollouts")
parser.add_argument("--recon", nargs='+', type=int, default=[0, 0, 1],
help="past now future")
parser.add_argument("--lstm_exp", type=str, default="LSTM", help="lstm experiment name")
return parser
def main():
args = parser().parse_args()
configure_random_seed(92)
rsg_root = os.path.dirname(os.path.abspath(__file__))
log_dir = rsg_root + "/saved"
vae_logdir = os.environ["AVOIDBENCH_PATH"] + "/../mavrl/exp_vae/"
vae_file = join(vae_logdir, 'vae_64_new', 'best.tar')
assert exists(vae_file), "No trained VAE in the logdir..."
state_vae = torch.load(vae_file)
print("Loading VAE at epoch {} "
"with test error {}".format(state_vae['epoch'], state_vae['precision']))
device = get_device("auto")
weight = os.environ["AVOIDBENCH_PATH"] + "/../mavrl/saved/RecurrentPPO_{0}/Policy/iter_{1:05d}.pth".format(args.trial, args.iter)
saved_variables = torch.load(weight, map_location=device)
# Create policy object
saved_variables["data"]['only_lstm_training'] = True
policy = MultiInputLstmPolicy(features_dim=64,
reconstruction_members=args.recon,
reconstruction_steps=10,
**saved_variables["data"])
policy.action_net = torch.nn.Sequential(policy.action_net, torch.nn.Tanh())
policy.load_state_dict(saved_variables["state_dict"], strict=False)
policy.to(device)
print(args.recon)
model = RecurrentPPO(
tensorboard_log=log_dir,
policy=policy,
policy_kwargs=dict(
activation_fn=torch.nn.ReLU,
net_arch=[dict(pi=[256, 256], vf=[512, 512])],
),
use_tanh_act=True,
gae_lambda=0.95,
gamma=0.99,
n_steps=1000,
n_seq=1,
ent_coef=0.0,
vf_coef=0.5,
max_grad_norm=0.5,
lstm_layer=1,
batch_size=500,
n_epochs=2000,
clip_range=0.2,
use_sde=False, # don't use (gSDE), doesn't work
retrain=args.retrain,
verbose=1,
only_lstm_training=True,
state_vae=state_vae,
states_dim=0,
reconstruction_members=args.recon,
reconstruction_steps=10,
train_lstm_without_env=True,
lstm_dataset_path=args.dir,
lstm_weight_saved_path=args.lstm_exp,
)
if args.train:
model.train_lstm_from_dataset()
else:
model.test_lstm_seperate()
if __name__ == "__main__":
main()