Skip to content

Latest commit

 

History

History
232 lines (161 loc) · 7.73 KB

README.md

File metadata and controls

232 lines (161 loc) · 7.73 KB

Real-Time Video Super-Resolution on Mobile

Overview

[Challenge Website] [Workshop Website]

This repository provides the implementation of the baseline model, Mobile RRN, for the Real-Time Video Super-Resolution Challenge in Mobile AI (MAI) Workshop @ CVPR 2022 & Advances in Image Manipulation (AIM) Workshop @ ECCV 2022. Mobile RRN is a recurrent network for video super-resolution to run on mobile. And it is modified from RRN with reducing channels and not using previous output information.

Contents


Requirements

[back]


Dataset preparation

  • Download REDS dataset and extract it into data folder.

    The REDS dataset folder should contain three subfolders: train/, val/ and test/. Please find the download links to above files in MAI'22 Real-Time Video Super-Resolution Challenge website (registration needed).

[back]


Training and Validation

Configuration

Before training and testing, please make sure the fields in config.yml is properly set.

log_dir: snapshot -> The directory which records logs and checkpoints. 

dataset:
    dataloader_settings: -> The setting of different splits dataloader.
        train:
            batch_size: 4
            drop_remainder: True
            shuffle: True
            num_parallel_calls: 6
        val:
            batch_size: 1
    data_dir: data/ -> The directory of REDS dataset.
    degradation: sharp_bicubic -> The degradation of images.
    train_frame_num: 10 -> The number of image frame(s) for per training step.
    test_frame_num: 100 -> The number of image frame(s) for per testing step.
    crop_size: 64 -> The height and width of cropped patch.

model:
    path: model/mobile_rrn.py -> The path of model file.
    name: MobileRRN -> The name of model class.

learner:
    general:
        total_steps: 1500000 -> The number of training steps.
        log_train_info_steps: 100 -> The frequency of logging training info.
        keep_ckpt_steps: 10000 -> The frequency of saving checkpoint.
        valid_steps: 100000 -> The frequency of validation.

    optimizer: -> Define the module name and setting of optimizer
        name: Adam
        beta_1: 0.9
        beta_2: 0.999

    lr_scheduler: -> Define the module name and setting of learning rate scheduler
        name: ExponentialDecay
        initial_learning_rate: 0.0001
        decay_steps: 1000000
        decay_rate: 0.1
        staircase: True

    saver:
        restore_ckpt: null -> The path to checkpoint where would be restored from.

Training

To train the model, use the following command:

python run.py --process train --config_path config.yml

The main arguments are as follows:

process :   Process type should be train or test.
config_path :   Path of yml config file of the application.

After training, the checkpoints will be produced in log_dir.

Validation

To valid the model, use the following command:

python run.py --process test --config_path config.yml

After testing, the output images will be produced in log_dir/output.

[back]


Testing

To generate testing outputs, use the following command:

python generate_output.py --model_path model/mobile_rrn.py --model_name MobileRRN --ckpt_path snapshot/ckpt-* --data_dir REDS/test/test_sharp_bicubic/X4/ --output_dir results

The main arguments are as follows:

model_path :   Path of model file.
model_name :   Name of model class.
ckpt_path :   Path of checkpoint.
data_dir :   Directory of testing frames in REDS dataset.
output_dir :   Directory for saving output images.

[back]


Convert to tflite

To convert the keras model to tflite, use the following command:

python convert.py --model_path model/mobile_rrn.py --model_name MobileRRN --input_shapes 1,320,180,6:1,320,180,16 --ckpt_path snapshot/mobile_rrn_16/ckpt-* --output_tflite model.tflite

The main arguments are as follows:

model_path :   Path of model file.
model_name :   Name of model class.
input_shape :   Series of the input shapes split by `:`.
ckpt_path :   Path of checkpoint.
output_tflite :   Path of output tflite.

[back]


TFLite inference on Mobile

We provide two ways to evaluate the mobile performance of your TFLite models:

  • AI benchmark: An app allowing you to load your model and run it locally on your own Android devices with various acceleration options (e.g. CPU, GPU, APU, etc.).
  • TFLite Neuron Delegate: You can build MediaTek's neuron delegate runner by yourself.

[back]


Folder structure

│
├── data/ -> The directory places the REDS dataset
├── dataset/
|   ├── dataset_builder.py -> Builds the dataset loader.
|   ├── reds.py -> Define the class of REDS dataset.
|   └── transform.py -> Define the transform functions for augmentation.
├── learner/
|   ├── learner.py -> Define the learner for training and testing.
|   ├── metric.py -> Implement the metric functions.
|   └── saver.py -> Define the saver to save and load checkpoints.
├── model/
|   └── mobile_rrn.py -> Define Mobile RRN architecture.
├── snapshot/ -> The directory which records logs and checkpoints. 
├── util/
|   ├── common_util.py -> Define the utility functions for general purpose.
|   ├── constant_util.py -> Global constant definitions.
|   ├── logger.py -> Define logging utilities.
|   └── plugin.py -> Define plugin utilities.
├── config.yml -> Configuration yaml file.
├── convert.py -> Convert keras model to tflite.
└── run.py -> Generic main function for VSR experiments.

[back]


Model Optimization

To make your model run faster on device, please fulfill the preference of network operations as much as possible to leverage the great power of AI accelerator. You may also find some optimization hint from our paper: Deploying Image Deblurring across Mobile Devices: A Perspective of Quality and Latency

[back]


Reference

Revisiting Temporal Modeling for Video Super-resolution (RRN) [Github] [Paper]

[back]


License

Mediatek License: Mediatek Apache License 2.0

[back]