forked from szq0214/MEAL-V2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit c848006
Showing
20 changed files
with
1,582 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# MEAL-V2 | ||
|
||
This is the official pytorch implementation of our paper: | ||
["MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracyon ImageNet without Tricks"]() by | ||
[Zhiqiang Shen](http://zhiqiangshen.com/) and [Marios Savvides](https://www.cmu-biometrics.org/) from Carnegie Mellon University. | ||
|
||
<div align=center> | ||
<img width=70% src="https://user-images.githubusercontent.com/3794909/92182326-6f78c400-ee19-11ea-80e4-2d6e4d73ce82.png"/> | ||
</div> | ||
|
||
In this paper, we introduce a simple yet effective approach that can boost the vanilla ResNet-50 to 80%+ Top-1 accuracy on ImageNet without any tricks. Generally, ourmethod is based on the recently proposed [MEAL](https://arxiv.org/abs/1812.02425), i.e.,ensemble knowledge distillation via discriminators. We further simplify it through 1) adopting the similarity loss anddiscriminator only on the final outputs and 2) using the av-erage of softmax probabilities from all teacher ensemblesas the stronger supervision for distillation. One crucial perspective of our method is that the one-hot/hard label shouldnot be used in the distillation process. We show that such asimple framework can achieve state-of-the-art results with-out involving any commonly-used tricks, such as 1) archi-tecture modification; 2) outside training data beyond Im-ageNet; 3) autoaug/randaug; 4) cosine learning rate; 5) mixup/cutmix training; 6) label smoothing; etc. | ||
|
||
## Citation | ||
|
||
If you find our code is helpful for your research, please cite: | ||
|
||
@article{shen2020mealv2, | ||
title={MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracyon ImageNet without Tricks}, | ||
author={Shen, Zhiqiang and Savvides, Marios}, | ||
journal={arXiv preprint arXiv:}, | ||
year={2020} | ||
} | ||
|
||
## Preparation | ||
|
||
### 1. Requirements: | ||
This repo is tested with: | ||
|
||
* Python 3.6 | ||
|
||
* CUDA 10.2 | ||
|
||
* PyTorch 1.6.0 | ||
|
||
* torchvision 0.7.0 | ||
|
||
* timm 0.2.1 | ||
(pip install timm) | ||
|
||
But it should be runnable with other PyTorch versions. | ||
|
||
### 2. Data: | ||
* Download ImageNet dataset following https://github.com/pytorch/examples/tree/master/imagenet#requirements. | ||
|
||
## Results & Models | ||
|
||
We provide pre-trained models with different trainings, we report in the table training/validation resolution, #parameters, Top-1 and Top-5 accuracy on ImageNet validation set: | ||
|
||
| Models | Resolution| #Parameters | Top-1/Top-5 | Trained models | | ||
| :---: | :-: | :-: | :------:| :------: | | ||
| [MEAL-V1 w/ ResNet50](https://arxiv.org/abs/1812.02425) | 224 | 25.6M |**78.21/94.01** | [GitHub](https://github.com/AaronHeee/MEAL#imagenet-model) | | ||
| MEAL-V2 w/ ResNet50 | 224 | 25.6M | **80.67/95.09** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0NGENlMK0pYVDQM?e=GkwZ93) | | ||
| MEAL-V2 w/ ResNet50| 380 | 25.6M | **81.72/95.81** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0T9nodVNdnklHNt?e=7oJGIy) | | ||
| MEAL-V2 + CutMix w/ ResNet50| 224 | 25.6M | **80.98/95.35** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0cIf5IqpBX6nl1U?e=Fig91M) | | ||
| MEAL-V2 w/ MobileNet V3-Small 0.75| 224 | 2.04M | **67.60/87.23** | [Download (8.3M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0nIq1jZo36dpN7Q?e=ODcoAN) | | ||
| MEAL-V2 w/ MobileNet V3-Small 1.0| 224 | 2.54M | **69.65/88.71** | [Download (10.3M)](https://1drv.ms/u/s!AtMVZxJ8MfxCiz9v7QqUmvQOLmTS?e=9nCWMa) | | ||
| MEAL-V2 w/ MobileNet V3-Large 1.0 | 224 | 5.48M | **76.92/93.32** | [Download (22.1M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0Ciwz-q-P2jwtXR?e=OebKAr) | | ||
| MEAL-V2 w/ EfficientNet-B0| 224 | 5.29M | **78.29/93.95** | [Download (21.5M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0XZLUEB3uYq3eBe?e=FJV9K1) | | ||
|
||
|
||
## Training & Testing | ||
### 1. Training: | ||
* To train a model, run script/train.sh with the desired model architecture and the path to the ImageNet dataset, for example: | ||
|
||
```shell | ||
# 224 x 224 ResNet-50 | ||
python train.py --save MEAL_V2_resnet50_224 --batch-size 512 -j 48 --model resnet50 --epochs 180 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] | ||
``` | ||
|
||
``` | ||
# 380 x 380 ResNet-50 | ||
python train.py --save MEAL_V2_resnet50_380 --batch-size 512 -j 48 --model resnet50 --image-size 380 --teacher-model tf_efficientnet_b4_ns,tf_efficientnet_b4 --imagenet [imagenet-folder with train and val folders] | ||
``` | ||
|
||
``` | ||
# 224 x 224 MobileNet V3-Small 0.75 | ||
python train.py --save MEAL_V2_mobilenetv3_small_075 --batch-size 512 -j 48 --model tf_mobilenetv3_small_075 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] | ||
``` | ||
|
||
``` | ||
# 224 x 224 MobileNet V3-Small 1.0 | ||
python train.py --save MEAL_V2_mobilenetv3_small_100 --batch-size 512 -j 48 --model tf_mobilenetv3_small_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] | ||
``` | ||
|
||
``` | ||
# 224 x 224 MobileNet V3-Large 1.0 | ||
python train.py --save MEAL_V2_mobilenetv3_large_100 --batch-size 512 -j 48 --model tf_mobilenetv3_large_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] | ||
``` | ||
|
||
``` | ||
# 224 x 224 EfficientNet-B0 | ||
python train.py --save MEAL_V2_efficientnet_b0 --batch-size 512 -j 48 --model tf_efficientnet_b0 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders] | ||
``` | ||
*Please reduce the ``--batch-size`` if you get ''out of memory'' error. We also notice that more training epochs can slightly improve the performance.* | ||
|
||
* To resume training a model, run script/resume_train.sh with the desired model architecture, starting number of training epoch and the path to the ImageNet dataset: | ||
* | ||
```shell | ||
sh script/resume_train.sh | ||
``` | ||
|
||
### 2. Testing: | ||
|
||
* To test a model, run inference.py with the desired model architecture, model path, resolution and the path to the ImageNet dataset: | ||
|
||
```shell | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 python inference.py -a resnet50 --res 224 --resume MODEL_PATH -e [imagenet-folder with train and val folders] | ||
``` | ||
|
||
change ``--res`` with other image resolution [224/380] and ``-a`` with other model architecture [tf\_mobilenetv3\_small\_100; tf\_mobilenetv3\_large\_100; tf\_efficientnet\_b0] to test other trained models. | ||
|
||
## Contact | ||
|
||
Zhiqiang Shen, CMU (zhiqians at andrew.cmu.edu) | ||
|
||
Any comments or suggestions are welcome! |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
__author__ = "Hessam Bagherinezhad <[email protected]>" | ||
|
||
from torch import nn | ||
from torch.nn.modules import loss | ||
|
||
|
||
class DataParallel(nn.DataParallel): | ||
"""An extension of nn.DataParallel. | ||
The only extensions are: | ||
1) If an attribute is missing in an object of this class, it will look | ||
for it in the wrapped module. This is useful for getting `LR_REGIME` | ||
of the wrapped module for example. | ||
2) state_dict() of this class calls the wrapped module's state_dict(), | ||
hence the weights can be transferred from a data parallel wrapped | ||
module to a single gpu module. | ||
""" | ||
|
||
|
||
def __getattr__(self, name): | ||
# If attribute doesn't exist in the DataParallel object this method will | ||
# be called. Here we first ask the super class to get the attribute, if | ||
# couldn't find it, we ask the underlying module that is wrapped by this | ||
# DataParallel to get the attribute. | ||
try: | ||
return super().__getattr__(name) | ||
except AttributeError: | ||
underlying_module = super().__getattr__('module') | ||
return getattr(underlying_module, name) | ||
|
||
def state_dict(self, *args, **kwargs): | ||
return self.module.state_dict(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
__author__ = "Hessam Bagherinezhad <[email protected]>" | ||
|
||
# modified by "Zhiqiang Shen <[email protected]>" | ||
|
||
import torch | ||
from torch.nn import functional as F | ||
from torch.nn.modules import loss | ||
|
||
|
||
class KLLoss(loss._Loss): | ||
"""The KL-Divergence loss for the model and soft labels output. | ||
output must be a pair of (model_output, soft_labels), both NxC tensors. | ||
The rows of soft_labels must all add up to one (probability scores); | ||
however, model_output must be the pre-softmax output of the network.""" | ||
|
||
def forward(self, output, target): | ||
if not self.training: | ||
# Loss is normal cross entropy loss between the model output and the | ||
# target. | ||
return F.cross_entropy(output, target) | ||
|
||
assert type(output) == tuple and len(output) == 2 and output[0].size() == \ | ||
output[1].size(), "output must a pair of tensors of same size." | ||
|
||
# Target is ignored at training time. Loss is defined as KL divergence | ||
# between the model output and the soft labels. | ||
model_output, soft_labels = output | ||
if soft_labels.requires_grad: | ||
raise ValueError("soft labels should not require gradients.") | ||
|
||
model_output_log_prob = F.log_softmax(model_output, dim=1) | ||
del model_output | ||
|
||
# Loss is -dot(model_output_log_prob, soft_labels). Prepare tensors | ||
# for batch matrix multiplicatio | ||
soft_labels = soft_labels.unsqueeze(1) | ||
model_output_log_prob = model_output_log_prob.unsqueeze(2) | ||
|
||
# Compute the loss, and average for the batch. | ||
cross_entropy_loss = -torch.bmm(soft_labels, model_output_log_prob) | ||
cross_entropy_loss = cross_entropy_loss.mean() | ||
# Return a pair of (loss_output, model_output). Model output will be | ||
# used for top-1 and top-5 evaluation. | ||
model_output_log_prob = model_output_log_prob.squeeze(2) | ||
return (cross_entropy_loss, model_output_log_prob) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
import random | ||
import numpy as np | ||
|
||
|
||
class ModelDistillationWrapper(nn.Module): | ||
"""Convenient wrapper class to train a model with soft label .""" | ||
|
||
def __init__(self, model, teacher): | ||
super().__init__() | ||
self.model = model | ||
self.teachers_0 = teacher | ||
self.combine = True | ||
|
||
# Since we don't want to back-prop through the teacher network, | ||
# make the parameters of the teacher network not require gradients. This | ||
# saves some GPU memory. | ||
|
||
for model in self.teachers_0: | ||
for param in model.parameters(): | ||
param.requires_grad = False | ||
|
||
self.false = False | ||
|
||
@property | ||
def LR_REGIME(self): | ||
# Training with soft label does not change learing rate regime. | ||
# Return's wrapped model lr regime. | ||
return self.model.LR_REGIME | ||
|
||
def state_dict(self): | ||
return self.model.state_dict() | ||
|
||
def forward(self, input, before=False): | ||
if self.training: | ||
if len(self.teachers_0) == 3 and self.combine == False: | ||
index = [0,1,1,2,2] | ||
idx = random.randint(0, 4) | ||
soft_labels_ = self.teachers_0[index[idx]](input) | ||
soft_labels = F.softmax(soft_labels_, dim=1) | ||
|
||
elif self.combine: | ||
soft_labels_ = [ torch.unsqueeze(self.teachers_0[idx](input), dim=2) for idx in range(len(self.teachers_0))] | ||
soft_labels_softmax = [F.softmax(i, dim=1) for i in soft_labels_] | ||
soft_labels_ = torch.cat(soft_labels_, dim=2).mean(dim=2) | ||
soft_labels = torch.cat(soft_labels_softmax, dim=2).mean(dim=2) | ||
|
||
else: | ||
idx = random.randint(0, len(self.teachers_0)-1) | ||
soft_labels_ = self.teachers_0[idx](input) | ||
soft_labels = F.softmax(soft_labels_, dim=1) | ||
|
||
# soft_labels = F.softmax(soft_labels_, dim=1) | ||
model_output = self.model(input) | ||
|
||
if before: | ||
return (model_output, soft_labels, soft_labels_) | ||
|
||
return (model_output, soft_labels) | ||
|
||
else: | ||
return self.model(input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
"""Dataset class for loading imagenet data.""" | ||
|
||
import os | ||
|
||
from torch.utils import data as data_utils | ||
from torchvision import datasets as torch_datasets | ||
from torchvision import transforms | ||
|
||
|
||
def get_train_loader(imagenet_path, batch_size, num_workers, image_size): | ||
train_dataset = ImageNet(imagenet_path, image_size, is_train=True) | ||
return data_utils.DataLoader( | ||
train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, | ||
num_workers=num_workers) | ||
|
||
|
||
def get_val_loader(imagenet_path, batch_size, num_workers, image_size): | ||
val_dataset = ImageNet(imagenet_path, image_size, is_train=False) | ||
return data_utils.DataLoader( | ||
val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, | ||
num_workers=num_workers) | ||
|
||
|
||
class ImageNet(torch_datasets.ImageFolder): | ||
"""Dataset class for ImageNet dataset. | ||
Arguments: | ||
root_dir (str): Path to the dataset root directory, which must contain | ||
train/ and val/ directories. | ||
is_train (bool): Whether to read training or validation images. | ||
""" | ||
MEAN = [0.485, 0.456, 0.406] | ||
STD = [0.229, 0.224, 0.225] | ||
|
||
def __init__(self, root_dir, im_size, is_train): | ||
if is_train: | ||
root_dir = os.path.join(root_dir, 'train') | ||
transform = transforms.Compose([ | ||
transforms.RandomResizedCrop(im_size), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize(ImageNet.MEAN, ImageNet.STD), | ||
]) | ||
else: | ||
root_dir = os.path.join(root_dir, 'val') | ||
transform = transforms.Compose([ | ||
transforms.Resize(int(256/224*im_size)), | ||
transforms.CenterCrop(im_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize(ImageNet.MEAN, ImageNet.STD), | ||
]) | ||
super().__init__(root_dir, transform=transform) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.