diff --git a/mnist_project/src/model.py b/mnist_project/src/model.py index cd00054..a0ec4f0 100644 --- a/mnist_project/src/model.py +++ b/mnist_project/src/model.py @@ -4,17 +4,17 @@ class MNISTModel(nn.Module): def __init__(self): super(MNISTModel, self).__init__() - self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) - self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) - self.fc1 = nn.Linear(32 * 7 * 7, 128) - self.fc2 = nn.Linear(128, 10) + self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1) + self.fc1 = nn.Linear(16 * 7 * 7, 64) + self.fc2 = nn.Linear(64, 10) self.pool = nn.MaxPool2d(2, 2) self.relu = nn.ReLU() def forward(self, x): x = self.pool(self.relu(self.conv1(x))) # 14x14 x = self.pool(self.relu(self.conv2(x))) # 7x7 - x = x.view(-1, 32 * 7 * 7) + x = x.view(-1, 16 * 7 * 7) x = self.relu(self.fc1(x)) x = self.fc2(x) return x \ No newline at end of file