-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathNetwork.py
114 lines (89 loc) · 4.8 KB
/
Network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import math
import numpy as np
import torch
from torch import nn, Tensor
from torch.distributions import Categorical
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 100):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, positions: Tensor) -> Tensor:
return self.pe[positions]
class Actor(nn.Module):
def __init__(self, pos_encoder):
super(Actor, self).__init__()
self.activation = nn.Tanh()
self.project = nn.Linear(4, 8)
nn.init.xavier_uniform_(self.project.weight, gain=1.0)
nn.init.constant_(self.project.bias, 0)
self.pos_encoder = pos_encoder
self.embedding_fixed = nn.Embedding(2, 1)
self.embedding_legal_op = nn.Embedding(2, 1)
self.tokens_start_end = nn.Embedding(3, 4)
# self.conv_transform = nn.Conv1d(5, 1, 1)
# nn.init.kaiming_normal_(self.conv_transform.weight, mode="fan_out", nonlinearity="relu")
# nn.init.constant_(self.conv_transform.bias, 0)
self.enc1 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
norm_first=True)
self.enc2 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
norm_first=True)
self.final_tmp = nn.Sequential(
layer_init_tanh(nn.Linear(8, 32)),
nn.Tanh(),
layer_init_tanh(nn.Linear(32, 1), std=0.01)
)
self.no_op = nn.Sequential(
layer_init_tanh(nn.Linear(8, 32)),
nn.Tanh(),
layer_init_tanh(nn.Linear(32, 1), std=0.01)
)
def forward(self, obs, attention_interval_mask, job_resource, mask, indexes_inter, tokens_start_end):
embedded_obs = torch.cat((self.embedding_fixed(obs[:, :, :, 0].long()), obs[:, :, :, 1:3],
self.embedding_legal_op(obs[:, :, :, 3].long())), dim=3)
non_zero_tokens = tokens_start_end != 0
t = tokens_start_end[non_zero_tokens].long()
embedded_obs[non_zero_tokens] = self.tokens_start_end(t)
pos_encoder = self.pos_encoder(indexes_inter.long())
pos_encoder[non_zero_tokens] = 0
obs = self.project(embedded_obs) + pos_encoder
transformed_obs = obs.view(-1, obs.shape[2], obs.shape[3])
attention_interval_mask = attention_interval_mask.view(-1, attention_interval_mask.shape[-1])
transformed_obs = self.enc1(transformed_obs, src_key_padding_mask=attention_interval_mask == 1)
transformed_obs = transformed_obs.view(obs.shape)
obs = transformed_obs.mean(dim=2)
job_resource = job_resource[:, :-1, :-1] == 0
obs_action = self.enc2(obs, src_mask=job_resource) + obs
logits = torch.cat((self.final_tmp(obs_action).squeeze(2), self.no_op(obs_action).mean(dim=1)), dim=1)
return logits.masked_fill(mask == 0, -3.4028234663852886e+38)
class Agent(nn.Module):
def __init__(self):
super(Agent, self).__init__()
self.pos_encoder = PositionalEncoding(8)
self.actor = Actor(self.pos_encoder)
def forward(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end,
action=None):
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
probs = Categorical(logits=logits)
if action is None:
probabilities = probs.probs
actions = torch.multinomial(probabilities, probabilities.shape[1])
return actions, torch.log(probabilities), probs.entropy()
else:
return logits, probs.log_prob(action), probs.entropy()
def get_action_only(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
probs = Categorical(logits=logits)
return probs.sample()
def get_logits_only(self,data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
return logits
def layer_init_tanh(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, bias_const)
return layer