Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
szq0214 committed Sep 17, 2020
0 parents commit c848006
Show file tree
Hide file tree
Showing 20 changed files with 1,582 additions and 0 deletions.
117 changes: 117 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# MEAL-V2

This is the official pytorch implementation of our paper:
["MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracyon ImageNet without Tricks"]() by
[Zhiqiang Shen](http://zhiqiangshen.com/) and [Marios Savvides](https://www.cmu-biometrics.org/) from Carnegie Mellon University.

<div align=center>
<img width=70% src="https://user-images.githubusercontent.com/3794909/92182326-6f78c400-ee19-11ea-80e4-2d6e4d73ce82.png"/>
</div>

In this paper, we introduce a simple yet effective approach that can boost the vanilla ResNet-50 to 80%+ Top-1 accuracy on ImageNet without any tricks. Generally, ourmethod is based on the recently proposed [MEAL](https://arxiv.org/abs/1812.02425), i.e.,ensemble knowledge distillation via discriminators. We further simplify it through 1) adopting the similarity loss anddiscriminator only on the final outputs and 2) using the av-erage of softmax probabilities from all teacher ensemblesas the stronger supervision for distillation. One crucial perspective of our method is that the one-hot/hard label shouldnot be used in the distillation process. We show that such asimple framework can achieve state-of-the-art results with-out involving any commonly-used tricks, such as 1) archi-tecture modification; 2) outside training data beyond Im-ageNet; 3) autoaug/randaug; 4) cosine learning rate; 5) mixup/cutmix training; 6) label smoothing; etc.

## Citation

If you find our code is helpful for your research, please cite:

@article{shen2020mealv2,
title={MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracyon ImageNet without Tricks},
author={Shen, Zhiqiang and Savvides, Marios},
journal={arXiv preprint arXiv:},
year={2020}
}

## Preparation

### 1. Requirements:
This repo is tested with:

* Python 3.6

* CUDA 10.2

* PyTorch 1.6.0

* torchvision 0.7.0

* timm 0.2.1
(pip install timm)

But it should be runnable with other PyTorch versions.

### 2. Data:
* Download ImageNet dataset following https://github.com/pytorch/examples/tree/master/imagenet#requirements.

## Results & Models

We provide pre-trained models with different trainings, we report in the table training/validation resolution, #parameters, Top-1 and Top-5 accuracy on ImageNet validation set:

| Models | Resolution| #Parameters | Top-1/Top-5 | Trained models |
| :---: | :-: | :-: | :------:| :------: |
| [MEAL-V1 w/ ResNet50](https://arxiv.org/abs/1812.02425) | 224 | 25.6M |**78.21/94.01** | [GitHub](https://github.com/AaronHeee/MEAL#imagenet-model) |
| MEAL-V2 w/ ResNet50 | 224 | 25.6M | **80.67/95.09** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0NGENlMK0pYVDQM?e=GkwZ93) |
| MEAL-V2 w/ ResNet50| 380 | 25.6M | **81.72/95.81** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0T9nodVNdnklHNt?e=7oJGIy) |
| MEAL-V2 + CutMix w/ ResNet50| 224 | 25.6M | **80.98/95.35** | [Download (102.6M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0cIf5IqpBX6nl1U?e=Fig91M) |
| MEAL-V2 w/ MobileNet V3-Small 0.75| 224 | 2.04M | **67.60/87.23** | [Download (8.3M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0nIq1jZo36dpN7Q?e=ODcoAN) |
| MEAL-V2 w/ MobileNet V3-Small 1.0| 224 | 2.54M | **69.65/88.71** | [Download (10.3M)](https://1drv.ms/u/s!AtMVZxJ8MfxCiz9v7QqUmvQOLmTS?e=9nCWMa) |
| MEAL-V2 w/ MobileNet V3-Large 1.0 | 224 | 5.48M | **76.92/93.32** | [Download (22.1M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0Ciwz-q-P2jwtXR?e=OebKAr) |
| MEAL-V2 w/ EfficientNet-B0| 224 | 5.29M | **78.29/93.95** | [Download (21.5M)](https://1drv.ms/u/s!AtMVZxJ8MfxCi0XZLUEB3uYq3eBe?e=FJV9K1) |


## Training & Testing
### 1. Training:
* To train a model, run script/train.sh with the desired model architecture and the path to the ImageNet dataset, for example:

```shell
# 224 x 224 ResNet-50
python train.py --save MEAL_V2_resnet50_224 --batch-size 512 -j 48 --model resnet50 --epochs 180 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
```

```
# 380 x 380 ResNet-50
python train.py --save MEAL_V2_resnet50_380 --batch-size 512 -j 48 --model resnet50 --image-size 380 --teacher-model tf_efficientnet_b4_ns,tf_efficientnet_b4 --imagenet [imagenet-folder with train and val folders]
```

```
# 224 x 224 MobileNet V3-Small 0.75
python train.py --save MEAL_V2_mobilenetv3_small_075 --batch-size 512 -j 48 --model tf_mobilenetv3_small_075 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
```

```
# 224 x 224 MobileNet V3-Small 1.0
python train.py --save MEAL_V2_mobilenetv3_small_100 --batch-size 512 -j 48 --model tf_mobilenetv3_small_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
```

```
# 224 x 224 MobileNet V3-Large 1.0
python train.py --save MEAL_V2_mobilenetv3_large_100 --batch-size 512 -j 48 --model tf_mobilenetv3_large_100 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
```

```
# 224 x 224 EfficientNet-B0
python train.py --save MEAL_V2_efficientnet_b0 --batch-size 512 -j 48 --model tf_efficientnet_b0 --teacher-model gluon_senet154,gluon_resnet152_v1s --imagenet [imagenet-folder with train and val folders]
```
*Please reduce the ``--batch-size`` if you get ''out of memory'' error. We also notice that more training epochs can slightly improve the performance.*

* To resume training a model, run script/resume_train.sh with the desired model architecture, starting number of training epoch and the path to the ImageNet dataset:
*
```shell
sh script/resume_train.sh
```

### 2. Testing:

* To test a model, run inference.py with the desired model architecture, model path, resolution and the path to the ImageNet dataset:

```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 python inference.py -a resnet50 --res 224 --resume MODEL_PATH -e [imagenet-folder with train and val folders]
```

change ``--res`` with other image resolution [224/380] and ``-a`` with other model architecture [tf\_mobilenetv3\_small\_100; tf\_mobilenetv3\_large\_100; tf\_efficientnet\_b0] to test other trained models.

## Contact

Zhiqiang Shen, CMU (zhiqians at andrew.cmu.edu)

Any comments or suggestions are welcome!
Empty file added extensions/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions extensions/data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
__author__ = "Hessam Bagherinezhad <[email protected]>"

from torch import nn
from torch.nn.modules import loss


class DataParallel(nn.DataParallel):
"""An extension of nn.DataParallel.
The only extensions are:
1) If an attribute is missing in an object of this class, it will look
for it in the wrapped module. This is useful for getting `LR_REGIME`
of the wrapped module for example.
2) state_dict() of this class calls the wrapped module's state_dict(),
hence the weights can be transferred from a data parallel wrapped
module to a single gpu module.
"""


def __getattr__(self, name):
# If attribute doesn't exist in the DataParallel object this method will
# be called. Here we first ask the super class to get the attribute, if
# couldn't find it, we ask the underlying module that is wrapped by this
# DataParallel to get the attribute.
try:
return super().__getattr__(name)
except AttributeError:
underlying_module = super().__getattr__('module')
return getattr(underlying_module, name)

def state_dict(self, *args, **kwargs):
return self.module.state_dict(*args, **kwargs)
46 changes: 46 additions & 0 deletions extensions/kd_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
__author__ = "Hessam Bagherinezhad <[email protected]>"

# modified by "Zhiqiang Shen <[email protected]>"

import torch
from torch.nn import functional as F
from torch.nn.modules import loss


class KLLoss(loss._Loss):
"""The KL-Divergence loss for the model and soft labels output.
output must be a pair of (model_output, soft_labels), both NxC tensors.
The rows of soft_labels must all add up to one (probability scores);
however, model_output must be the pre-softmax output of the network."""

def forward(self, output, target):
if not self.training:
# Loss is normal cross entropy loss between the model output and the
# target.
return F.cross_entropy(output, target)

assert type(output) == tuple and len(output) == 2 and output[0].size() == \
output[1].size(), "output must a pair of tensors of same size."

# Target is ignored at training time. Loss is defined as KL divergence
# between the model output and the soft labels.
model_output, soft_labels = output
if soft_labels.requires_grad:
raise ValueError("soft labels should not require gradients.")

model_output_log_prob = F.log_softmax(model_output, dim=1)
del model_output

# Loss is -dot(model_output_log_prob, soft_labels). Prepare tensors
# for batch matrix multiplicatio
soft_labels = soft_labels.unsqueeze(1)
model_output_log_prob = model_output_log_prob.unsqueeze(2)

# Compute the loss, and average for the batch.
cross_entropy_loss = -torch.bmm(soft_labels, model_output_log_prob)
cross_entropy_loss = cross_entropy_loss.mean()
# Return a pair of (loss_output, model_output). Model output will be
# used for top-1 and top-5 evaluation.
model_output_log_prob = model_output_log_prob.squeeze(2)
return (cross_entropy_loss, model_output_log_prob)
65 changes: 65 additions & 0 deletions extensions/teacher_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from torch import nn
from torch.nn import functional as F

import random
import numpy as np


class ModelDistillationWrapper(nn.Module):
"""Convenient wrapper class to train a model with soft label ."""

def __init__(self, model, teacher):
super().__init__()
self.model = model
self.teachers_0 = teacher
self.combine = True

# Since we don't want to back-prop through the teacher network,
# make the parameters of the teacher network not require gradients. This
# saves some GPU memory.

for model in self.teachers_0:
for param in model.parameters():
param.requires_grad = False

self.false = False

@property
def LR_REGIME(self):
# Training with soft label does not change learing rate regime.
# Return's wrapped model lr regime.
return self.model.LR_REGIME

def state_dict(self):
return self.model.state_dict()

def forward(self, input, before=False):
if self.training:
if len(self.teachers_0) == 3 and self.combine == False:
index = [0,1,1,2,2]
idx = random.randint(0, 4)
soft_labels_ = self.teachers_0[index[idx]](input)
soft_labels = F.softmax(soft_labels_, dim=1)

elif self.combine:
soft_labels_ = [ torch.unsqueeze(self.teachers_0[idx](input), dim=2) for idx in range(len(self.teachers_0))]
soft_labels_softmax = [F.softmax(i, dim=1) for i in soft_labels_]
soft_labels_ = torch.cat(soft_labels_, dim=2).mean(dim=2)
soft_labels = torch.cat(soft_labels_softmax, dim=2).mean(dim=2)

else:
idx = random.randint(0, len(self.teachers_0)-1)
soft_labels_ = self.teachers_0[idx](input)
soft_labels = F.softmax(soft_labels_, dim=1)

# soft_labels = F.softmax(soft_labels_, dim=1)
model_output = self.model(input)

if before:
return (model_output, soft_labels, soft_labels_)

return (model_output, soft_labels)

else:
return self.model(input)
52 changes: 52 additions & 0 deletions imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Dataset class for loading imagenet data."""

import os

from torch.utils import data as data_utils
from torchvision import datasets as torch_datasets
from torchvision import transforms


def get_train_loader(imagenet_path, batch_size, num_workers, image_size):
train_dataset = ImageNet(imagenet_path, image_size, is_train=True)
return data_utils.DataLoader(
train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True,
num_workers=num_workers)


def get_val_loader(imagenet_path, batch_size, num_workers, image_size):
val_dataset = ImageNet(imagenet_path, image_size, is_train=False)
return data_utils.DataLoader(
val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True,
num_workers=num_workers)


class ImageNet(torch_datasets.ImageFolder):
"""Dataset class for ImageNet dataset.
Arguments:
root_dir (str): Path to the dataset root directory, which must contain
train/ and val/ directories.
is_train (bool): Whether to read training or validation images.
"""
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

def __init__(self, root_dir, im_size, is_train):
if is_train:
root_dir = os.path.join(root_dir, 'train')
transform = transforms.Compose([
transforms.RandomResizedCrop(im_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(ImageNet.MEAN, ImageNet.STD),
])
else:
root_dir = os.path.join(root_dir, 'val')
transform = transforms.Compose([
transforms.Resize(int(256/224*im_size)),
transforms.CenterCrop(im_size),
transforms.ToTensor(),
transforms.Normalize(ImageNet.MEAN, ImageNet.STD),
])
super().__init__(root_dir, transform=transform)
Binary file added images/comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit c848006

Please sign in to comment.