Skip to content

Commit

Permalink
clean up fine tune code
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Dec 5, 2017
1 parent b6a9171 commit db33426
Showing 1 changed file with 39 additions and 41 deletions.
80 changes: 39 additions & 41 deletions fine_tune/success_fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,6 @@ def train_model(network, criterion, optimizer, scheduler, trainLoader, valLoader
plt.figure()
plt.xlabel('epoch')
plt.ylabel('loss scores')
# axes = plt.gca()
# axes.set_ylim([1.7,2.0])
plt.plot(xrange(n_epochs), train_loss_arr)
plt.plot(xrange(n_epochs), valid_loss_arr)

Expand All @@ -270,45 +268,45 @@ def train_model(network, criterion, optimizer, scheduler, trainLoader, valLoader
plt.savefig(model_name+'/acctop5.png')


# vgg16 = models.vgg16(pretrained = True)
# for param in vgg16.parameters():
# param.requires_grad = False
# vgg16.classifier = nn.Sequential(
# nn.Linear(25088, 4096),
# nn.ReLU(),
# nn.Dropout(0.5),
# nn.Linear(4096, 4096),
# nn.ReLU(),
# nn.Dropout(0.5),
# nn.Linear(4096, len(idx_to_info))
# )

# optimizer = optim.SGD(vgg16.classifier.parameters(), lr = 0.001)
# scheduler = LambdaLR(optimizer, lambda e: 1 if e < 200/2 else 0.1)
# criterion = nn.CrossEntropyLoss()

# print 'VGG16 models loaded'
# train_model(vgg16, criterion, optimizer, scheduler, trainLoader, valLoader, n_epochs = 200, model_name='vgg16')


# resnet18 = models.resnet18(pretrained=True)
# for param in resnet18.parameters():
# param.requires_grad = False
# resnet18.fc = nn.Sequential(
# nn.Linear(512, 512),
# nn.ReLU(),
# nn.Dropout(0.5),
# nn.Linear(512, 512),
# nn.ReLU(),
# nn.Dropout(0.5),
# nn.Linear(512, len(idx_to_info))
# )
# optimizer = optim.SGD(resnet18.fc.parameters(), lr = 0.001)
# scheduler = LambdaLR(optimizer, lambda e: 1 if e < 200/2 else 0.1)
# criterion = nn.CrossEntropyLoss()

# print 'resnet18 models loaded'
# train_model(resnet18, criterion, optimizer, scheduler, trainLoader, valLoader, n_epochs = 200, model_name='resnet18')
vgg16 = models.vgg16(pretrained = True)
for param in vgg16.parameters():
param.requires_grad = False
vgg16.classifier = nn.Sequential(
nn.Linear(25088, 4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, len(idx_to_info))
)

optimizer = optim.SGD(vgg16.classifier.parameters(), lr = 0.001)
scheduler = LambdaLR(optimizer, lambda e: 1 if e < 200/2 else 0.1)
criterion = nn.CrossEntropyLoss()

print 'VGG16 models loaded'
train_model(vgg16, criterion, optimizer, scheduler, trainLoader, valLoader, n_epochs = 200, model_name='vgg16')


resnet18 = models.resnet18(pretrained=True)
for param in resnet18.parameters():
param.requires_grad = False
resnet18.fc = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, len(idx_to_info))
)
optimizer = optim.SGD(resnet18.fc.parameters(), lr = 0.001)
scheduler = LambdaLR(optimizer, lambda e: 1 if e < 200/2 else 0.1)
criterion = nn.CrossEntropyLoss()

print 'resnet18 models loaded'
train_model(resnet18, criterion, optimizer, scheduler, trainLoader, valLoader, n_epochs = 200, model_name='resnet18')

resnet152 = models.resnet152(pretrained=True)
for param in resnet152.parameters():
Expand Down

0 comments on commit db33426

Please sign in to comment.