-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathbehavioural_cloning.py
180 lines (148 loc) · 6.85 KB
/
behavioural_cloning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Basic behavioural cloning
# Note: this uses gradient accumulation in batches of ones
# to perform training.
# This will fit inside even smaller GPUs (tested on 8GB one),
# but is slow.
from argparse import ArgumentParser
import pickle
import time
import gym
import minerl
import torch as th
import numpy as np
from openai_vpt.agent import PI_HEAD_KWARGS, MineRLAgent
from data_loader import DataLoader
from openai_vpt.lib.tree_util import tree_map
# Originally this code was designed for a small dataset of ~20 demonstrations per task.
# The settings might not be the best for the full BASALT dataset (thousands of demonstrations).
# Use this flag to switch between the two settings
USING_FULL_DATASET = True
EPOCHS = 1 if USING_FULL_DATASET else 2
# Needs to be <= number of videos
BATCH_SIZE = 64 if USING_FULL_DATASET else 16
# Ideally more than batch size to create
# variation in datasets (otherwise, you will
# get a bunch of consecutive samples)
# Decrease this (and batch_size) if you run out of memory
N_WORKERS = 100 if USING_FULL_DATASET else 20
DEVICE = "cuda"
LOSS_REPORT_RATE = 100
# Tuned with bit of trial and error
LEARNING_RATE = 0.000181
# OpenAI VPT BC weight decay
# WEIGHT_DECAY = 0.039428
WEIGHT_DECAY = 0.0
# KL loss to the original model was not used in OpenAI VPT
KL_LOSS_WEIGHT = 1.0
MAX_GRAD_NORM = 5.0
MAX_BATCHES = 2000 if USING_FULL_DATASET else int(1e9)
def load_model_parameters(path_to_model_file):
agent_parameters = pickle.load(open(path_to_model_file, "rb"))
policy_kwargs = agent_parameters["model"]["args"]["net"]["args"]
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
return policy_kwargs, pi_head_kwargs
def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights):
agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model)
# To create model with the right environment.
# All basalt environments have the same settings, so any of them works here
env = gym.make("MineRLBasaltFindCave-v0")
agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs)
agent.load_weights(in_weights)
# Create a copy which will have the original parameters
original_agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs)
original_agent.load_weights(in_weights)
env.close()
policy = agent.policy
original_policy = original_agent.policy
# Freeze most params if using small dataset
for param in policy.parameters():
param.requires_grad = False
# Unfreeze final layers
trainable_parameters = []
for param in policy.net.lastlayer.parameters():
param.requires_grad = True
trainable_parameters.append(param)
for param in policy.pi_head.parameters():
param.requires_grad = True
trainable_parameters.append(param)
# Parameters taken from the OpenAI VPT paper
optimizer = th.optim.Adam(
trainable_parameters,
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY
)
data_loader = DataLoader(
dataset_dir=data_dir,
n_workers=N_WORKERS,
batch_size=BATCH_SIZE,
n_epochs=EPOCHS,
)
start_time = time.time()
# Keep track of the hidden state per episode/trajectory.
# DataLoader provides unique id for each episode, which will
# be different even for the same trajectory when it is loaded
# up again
episode_hidden_states = {}
dummy_first = th.from_numpy(np.array((False,))).to(DEVICE)
loss_sum = 0
for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(data_loader):
batch_loss = 0
for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id):
if image is None and action is None:
# A work-item was done. Remove hidden state
if episode_id in episode_hidden_states:
removed_hidden_state = episode_hidden_states.pop(episode_id)
del removed_hidden_state
continue
agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True)
if agent_action is None:
# Action was null
continue
agent_obs = agent._env_obs_to_agent({"pov": image})
if episode_id not in episode_hidden_states:
episode_hidden_states[episode_id] = policy.initial_state(1)
agent_state = episode_hidden_states[episode_id]
pi_distribution, _, new_agent_state = policy.get_output_for_observation(
agent_obs,
agent_state,
dummy_first
)
with th.no_grad():
original_pi_distribution, _, _ = original_policy.get_output_for_observation(
agent_obs,
agent_state,
dummy_first
)
log_prob = policy.get_logprob_of_action(pi_distribution, agent_action)
kl_div = policy.get_kl_of_action_dists(pi_distribution, original_pi_distribution)
# Make sure we do not try to backprop through sequence
# (fails with current accumulation)
new_agent_state = tree_map(lambda x: x.detach(), new_agent_state)
episode_hidden_states[episode_id] = new_agent_state
# Finally, update the agent to increase the probability of the
# taken action.
# Remember to take mean over batch losses
loss = (-log_prob + KL_LOSS_WEIGHT * kl_div) / BATCH_SIZE
batch_loss += loss.item()
loss.backward()
th.nn.utils.clip_grad_norm_(trainable_parameters, MAX_GRAD_NORM)
optimizer.step()
optimizer.zero_grad()
loss_sum += batch_loss
if batch_i % LOSS_REPORT_RATE == 0:
time_since_start = time.time() - start_time
print(f"Time: {time_since_start:.2f}, Batches: {batch_i}, Avrg loss: {loss_sum / LOSS_REPORT_RATE:.4f}")
loss_sum = 0
if batch_i > MAX_BATCHES:
break
state_dict = policy.state_dict()
th.save(state_dict, out_weights)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--data-dir", type=str, required=True, help="Path to the directory containing recordings to be trained on")
parser.add_argument("--in-model", required=True, type=str, help="Path to the .model file to be finetuned")
parser.add_argument("--in-weights", required=True, type=str, help="Path to the .weights file to be finetuned")
parser.add_argument("--out-weights", required=True, type=str, help="Path where finetuned weights will be saved")
args = parser.parse_args()
behavioural_cloning_train(args.data_dir, args.in_model, args.in_weights, args.out_weights)