forked from btglr/RL_for_dynamic_scheduling
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
68 lines (59 loc) · 2.12 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import numpy as np
import time
from env import DAGEnv
import heft
import string
import matplotlib.pyplot as plt
import pandas as pd
import torch
from model_old import Net, SimpleNet, SimpleNet2, ResNetG, SimpleNetMax, SimpleNetW, SimpleNetWSage, ModelHeterogene
import pickle as pkl
import torch
from collections import namedtuple
if __name__ == '__main__':
model = torch.load('model_examples/cholesky_n=8_nGPU=2_nCPU=2/model_window=0.pth')
model.eval()
w_list = []
n_node_list = []
tile_list = []
ready_node_list = []
num_node_observation = []
mean_time = []
env_type = 'chol'
nGPU = 2
window = 0
for n in [2, 4, 6, 8, 10, 12]:
p_input = np.array([1] * nGPU + [0] * (4 - nGPU))
env = DAGEnv(n, p_input, window, env_type=env_type)
print(env.is_homogene)
print("|V|: ", len(env.task_data.x))
observation = env.reset()
print(observation.keys())
print(observation['graph'].x.shape)
done = False
time_step = 0
total_time = 0
while (not done) :
start_time = time.time()
with torch.no_grad():
policy, value = model(observation)
action_raw = policy.argmax().detach().cpu().numpy()
ready_nodes = observation['ready'].squeeze(1).to(torch.bool)
action = -1 if action_raw == policy.shape[-1] - 1 else \
observation['node_num'][ready_nodes][action_raw].detach().numpy()[0]
observation, reward, done, info = env.step(action)
cur_time = time.time() - start_time
total_time += cur_time
time_step += 1
w_list.append(window)
n_node_list.append(env.num_nodes)
tile_list.append(n)
mean_time.append(cur_time)
print('n_node:', env.num_nodes)
print(total_time/time_step)
execution_time = pd.DataFrame({'w': w_list,
'n_node': n_node_list,
'tiles': tile_list,
'time': mean_time})
execution_time.to_csv("results/time.csv")