-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
40 lines (36 loc) · 1.52 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2017-10-25 20:42 zq <zq@mclab>
#
# Distributed under terms of the MIT license.
"""
Find difference of two similar image through learning in pytorch.
"""
import torch.nn as nn
import torchvision
import torch
from torchvision import datasets, models, transforms
class DiffNetwork(nn.Module):
def __init__(self):
super(DiffNetwork, self).__init__()
# 16x16 and 512 channels
self.resnet18 = nn.Sequential(*list(models.resnet18(pretrained=True).children())[:-2])
self.regression = nn.Sequential(
# To 14x14
nn.Conv2d(1024, 512, kernel_size=4, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 35, kernel_size=3, stride=1, padding=1))
def forward(self, inputa, inputb):
outputa = self.resnet18(inputa)
outputb = self.resnet18(inputb)
concated_fea = torch.cat([outputa, outputb], dim=1) # [batch_size, 1024, 16, 16]
output = self.regression(concated_fea)
return output