-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcustom_cnn.py
38 lines (31 loc) · 1.35 KB
/
custom_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch as th
import torch.nn as nn
import gym
class CustomCNN(BaseFeaturesExtractor):
"""
:param observation_space: (gym.Space)
:param features_dim: (int) Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
super(CustomCNN, self).__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
print(n_input_channels)
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=1, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))