-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathrun_agent.py
45 lines (36 loc) · 1.74 KB
/
run_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from argparse import ArgumentParser
import pickle
import aicrowd_gym
import minerl
from openai_vpt.agent import MineRLAgent
def main(model, weights, env, n_episodes=3, max_steps=int(1e9), show=False):
# Using aicrowd_gym is important! Your submission will not work otherwise
env = aicrowd_gym.make(env)
agent_parameters = pickle.load(open(model, "rb"))
policy_kwargs = agent_parameters["model"]["args"]["net"]["args"]
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
agent = MineRLAgent(env, policy_kwargs=policy_kwargs, pi_head_kwargs=pi_head_kwargs)
agent.load_weights(weights)
for _ in range(n_episodes):
obs = env.reset()
for _ in range(max_steps):
action = agent.get_action(obs)
# ESC is not part of the predictions model.
# For baselines, we just set it to zero.
# We leave proper execution as an exercise for the participants :)
action["ESC"] = 0
obs, _, done, _ = env.step(action)
if show:
env.render()
if done:
break
env.close()
if __name__ == "__main__":
parser = ArgumentParser("Run pretrained models on MineRL environment")
parser.add_argument("--weights", type=str, required=True, help="Path to the '.weights' file to be loaded.")
parser.add_argument("--model", type=str, required=True, help="Path to the '.model' file to be loaded.")
parser.add_argument("--env", type=str, required=True)
parser.add_argument("--show", action="store_true", help="Render the environment.")
args = parser.parse_args()
main(args.model, args.weights, args.env, show=args.show)