-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathlayers.py
51 lines (40 loc) · 1.4 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Copyright (c) 2020-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import dgl.function as fn
import torch
import torch.nn as nn
EOS = 1e-10
class GCNConv_dense(nn.Module):
def __init__(self, input_size, output_size):
super(GCNConv_dense, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def init_para(self):
self.linear.reset_parameters()
def forward(self, input, A, sparse=False):
hidden = self.linear(input)
if sparse:
output = torch.sparse.mm(A, hidden)
else:
output = torch.matmul(A, hidden)
return output
class GCNConv_dgl(nn.Module):
def __init__(self, input_size, output_size):
super(GCNConv_dgl, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x, g):
with g.local_scope():
g.ndata['h'] = self.linear(x)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum(msg='m', out='h'))
return g.ndata['h']
class Diag(nn.Module):
def __init__(self, input_size):
super(Diag, self).__init__()
self.W = nn.Parameter(torch.ones(input_size))
self.input_size = input_size
def forward(self, input):
hidden = input @ torch.diag(self.W)
return hidden