-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_CarRacing_MDN_RNN_dataset.py
80 lines (66 loc) · 2.95 KB
/
generate_CarRacing_MDN_RNN_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import shutil
import time
import gym
import numpy as np
import torch
import imageio
from arguments import get_args
from utils.env_util import adjust_action, adjust_obs
class CarRacing_MDN_RNN_Dataset_Generator:
def __init__(self, vae_dataset_path, weight_path, data_save_path, device):
super().__init__()
self.vae_dataset_path = vae_dataset_path
self.data_save_path = data_save_path
self.vae = torch.load(weight_path, map_location='cuda:0')
self.device = device
self.prefix = ''
filenames = os.path.abspath(__file__).split('\\')
for f in filenames[:-1]:
self.prefix += f+'\\'
self.prefix = os.path.join(self.prefix, data_save_path)
self.save_path = os.path.join(self.prefix, 'mdn_rnn')
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
def parsing_data(self):
if not os.path.exists(self.vae_dataset_path):
print('vae数据集目录不存在!')
return
names = os.listdir(self.vae_dataset_path)
files = [os.path.join(self.vae_dataset_path, n) for n in names]
obss = []
actions = []
for fs in files:
obs = np.load(fs)['obs']
obs = torch.tensor(obs, dtype=torch.float).permute(0, 3, 1, 2).to(self.device)
action = np.load(fs)['action']
action = torch.tensor(action, dtype=torch.float).to(self.device)
obss.append(obs)
actions.append(action)
return obss, actions
def get_data(self):
obss, actions = self.parsing_data()
init_mus = []
init_logvars = []
for i in range(len(obss)):
mu, logvar = self.vae.encoder(obss[i])
init_mu = mu[0]
init_logvar = logvar[0]
mu = mu.detach().cpu().numpy()
logvar = logvar.detach().cpu().numpy()
obs = obss[i].detach().cpu().numpy()
action = actions[i].detach().cpu().numpy()
init_mus.append(init_mu.detach().cpu().numpy())
init_logvars.append(init_logvar.detach().cpu().numpy())
save_file_name = os.path.join(self.save_path, f'epoch_{i + 1}.npz')
np.savez_compressed(save_file_name, mu=mu, logvar=logvar, obs=obs, action=action)
save_file_name = os.path.join(self.save_path, f'init.npz')
init_mus = np.array(init_mus)
init_logvars = np.array(init_logvars)
np.savez_compressed(save_file_name, init_mu=init_mus, init_logvar=init_logvars)
if __name__ == '__main__':
args = get_args()
weight_path = r'F:\our_code\RL\world_model\weights\vae_train\vae_2024.10.23.11.35.31\5000.pt'
vae_dataset_path = r'F:\our_code\RL\world_model\data\CarRacing\random'
generator = CarRacing_MDN_RNN_Dataset_Generator(vae_dataset_path, weight_path, args.data_save_path, args.device)
generator.get_data()