-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuffer.py
98 lines (79 loc) · 4.43 KB
/
buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import torch
import warnings
class ReplayBuffer:
def __init__(self, state_shape, action_shape, max_num_seq=int(2**12), seq_len=30,
batch_size=32, stored_on_gpu=False, obs_uint8=False):
super(ReplayBuffer, self).__init__()
self.seq_len = seq_len
self.state_shape = state_shape
self.action_shape = action_shape
self.batch_size = batch_size
self.obs_uint8 = obs_uint8
self.stored_on_gpu = stored_on_gpu
if self.stored_on_gpu:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.state_dtype = torch.float32 if not obs_uint8 else torch.uint8
self.length = torch.zeros((max_num_seq,), dtype=torch.int, device=self.device)
self.masks = torch.zeros((max_num_seq, seq_len,), dtype=torch.float32, device=self.device)
self.actions = torch.zeros((max_num_seq, seq_len, *action_shape), dtype=torch.float32, device=self.device)
self.states = torch.zeros((max_num_seq, seq_len + 1, *state_shape), dtype=self.state_dtype, device=self.device)
self.rewards = torch.zeros((max_num_seq, seq_len,), dtype=torch.float32, device=self.device)
self.dones = torch.zeros((max_num_seq, seq_len), dtype=torch.float32, device=self.device)
self.tail, self.size = 0, 0
self.max_num_seq = max_num_seq
self.sum_steps = 0
self.min_length = 0
self.max_length = 0
# Allocate GPU space for sampling batch data
self.length_b = torch.zeros((batch_size,), dtype=torch.int, device=self.device)
self.masks_b = torch.zeros((batch_size, seq_len,), dtype=bool, device=self.device)
self.actions_b = torch.zeros((batch_size, seq_len, *self.action_shape), dtype=torch.float32, device=self.device)
self.states_b = torch.zeros((batch_size, seq_len + 1, *self.state_shape), dtype=torch.float32, device=self.device)
self.rewards_b = torch.zeros((batch_size, seq_len,), dtype=torch.float32, device=self.device)
self.dones_b = torch.zeros((batch_size, seq_len), dtype=torch.float32, device=self.device)
def append_episode(self, states, actions, r, done, length):
if length < 1:
warnings.warn("Episode length < 1, not recorded!")
# shape: trajectory_length x data shape
if self.obs_uint8 and (states.dtype != np.uint8):
states = np.clip(states, 0, 1)
states = (255 * states).astype(np.uint8)
if self.length[self.tail] != 0:
self.sum_steps -= self.length[self.tail]
self.length[self.tail] = length
self.sum_steps += length
self.min_length = min(self.min_length, length)
self.max_length = max(self.max_length, length)
self.length[self.tail] = length
self.masks[self.tail][:length] = 1
self.masks[self.tail][length:] = 0
self.dones[self.tail][:length] = torch.from_numpy(done[:length]).to(device=self.device)
self.dones[self.tail][length:] = 1
self.states[self.tail][:length + 1] = torch.from_numpy(states[:length + 1]).to(device=self.device)
self.states[self.tail][length + 1:] = 0
self.actions[self.tail][:length] = torch.from_numpy(actions[:length]).to(device=self.device)
self.actions[self.tail][length:] = 0
self.rewards[self.tail][:length] = torch.from_numpy(r[:length]).to(device=self.device)
self.rewards[self.tail][length:] = 0
self.tail = (self.tail + 1) % self.max_num_seq
self.size = min(self.size + 1, self.max_num_seq)
def sample_batch(self):
sampled_episodes = torch.from_numpy(np.random.choice(self.size, [self.batch_size])).to(torch.int64)
self.masks_b.fill_(0)
self.actions_b.fill_(0)
self.states_b.fill_(0)
self.rewards_b.fill_(0)
self.dones_b.fill_(0)
self.length_b[:] = self.length[sampled_episodes]
self.actions_b[:] = self.actions[sampled_episodes]
self.rewards_b[:] = self.rewards[sampled_episodes]
self.dones_b[:] = self.dones[sampled_episodes]
self.masks_b[:] = self.masks[sampled_episodes]
if self.obs_uint8:
self.states_b[:] = (self.states[sampled_episodes].to(torch.float32)) / 255
else:
self.states_b[:] = self.states[sampled_episodes]
return self.states_b, self.actions_b, self.rewards_b, self.dones_b, self.masks_b, self.length_b