forked from pabloswfly/genomcmcgan
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsymmetric.py
25 lines (21 loc) · 897 Bytes
/
symmetric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
class Symmetric(nn.Module):
"""Class for a symmetric layer for permutation invariant CNNs. This layer
collapses the dimension specified in the given axis using a summary statistic
"""
def __init__(self, function, axis, **kwargs):
self.function = function
self.axis = axis
super(Symmetric, self).__init__(**kwargs)
def forward(self, x):
if self.function == "sum":
out = torch.sum(x, dim=self.axis, keepdim=True)
elif self.function == "mean":
out = torch.mean(x, dim=self.axis, keepdim=True)
elif self.function == "min":
# torch.min and torch.max returns: (values, indices)
out = torch.min(x, dim=self.axis, keepdim=True)[0]
elif self.function == "max":
out = torch.max(x, dim=self.axis, keepdim=True)[0]
return out