forked from jbkjr/train-procgen-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_pets.py
79 lines (72 loc) · 2.11 KB
/
run_pets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
from helper_local import add_pets_args
from hyperparameter_optimization import run_pets_hyperparameters
from pets.pets import run_pets
def normal():
parser = argparse.ArgumentParser()
parser = add_pets_args(parser)
args = parser.parse_args()
run_pets(args)
if __name__ == '__main__':
hparams = {
"env_name": "cartpole_continuous",
"trial_length": 500,
"num_trials": 40,
"hid_size": 512,
"learning_rate": 0.000672,
"model_batch_size": 16,
"num_layers": 6,
"seed": 6033,
"use_wandb": True,
"wandb_tags": ["pgt2"],
"use_custom_reward_fn": False,
"num_epochs": 1000,
"overfit": True,
"drop_same": True,
"min_cart_mass": 1.0,
"max_cart_mass": 1.0,
"min_pole_mass": 0.1,
"max_pole_mass": 0.1,
"min_force_mag": 10.,
"max_force_mag": 10.,
}
hparams = {
"alpha": 0.1,
"detect_nan": False,
"deterministic": False,
"dyn_model": "pets.pets_models.GraphTransitionPets",
"elite_ratio": 0.1,
"ensemble_size": 5,
"env_name": "cartpole_continuous",
"exp_name": "",
"hid_size": 156,
"learning_rate": 0.000255,
"logdir": "logs/pets/cartpole_continuous/2024-08-05__02-43-29__seed_6033",
"model_batch_size": 22,
"num_checkpoints": 4,
"num_epochs": 50,
"num_iterations": 5,
"num_layers": 6,
"num_particles": 20,
"num_trials": 20,
"patience": 50,
"planning_horizon": 15,
"population_size": 500,
"render": False,
"replan_freq": 1,
"seed": 6033,
"trial_length": 500,
"use_wandb": True,
"validation_ratio": 0.05,
"wandb_tags": ["pgt1"],
# "1": "graph-transition",
"weight_decay": 0.00005,
"drop_same": True,
"min_cart_mass": 1.0,
"max_cart_mass": 1.0,
"min_pole_mass": 0.1,
"max_pole_mass": 0.1,
"min_force_mag": 10.,
"max_force_mag": 10.,
}
run_pets_hyperparameters(hparams)