From f46faaec43df379ec2584c6c198d58fc8025c38e Mon Sep 17 00:00:00 2001 From: chaofengc Date: Mon, 9 Oct 2023 14:31:33 +0800 Subject: [PATCH] feat: :triangular_flag_on_post: add datasets PIQ2023, GFIQA and metric `topiq_nr-face`. --- README.md | 1 + docs/Dataset_Preparation.md | 5 +- docs/ModelCard.md | 87 +++---- docs/index.rst | 36 +++ options/default_dataset_opt.yml | 14 ++ .../train/TOPIQ/train_TOPIQ_res50_gfiqa.yml | 116 +++++++++ pyiqa/archs/topiq_arch.py | 13 +- pyiqa/data/base_iqa_dataset.py | 6 +- pyiqa/data/general_nr_dataset.py | 2 +- pyiqa/data/piq_dataset.py | 71 ++++++ pyiqa/default_model_configs.py | 11 + pyiqa/models/sr_model.py | 231 ------------------ pyiqa/train.py | 3 +- scripts/process_gfiqa.py | 42 ++++ scripts/process_piq.py | 58 +++++ 15 files changed, 411 insertions(+), 285 deletions(-) create mode 100644 options/train/TOPIQ/train_TOPIQ_res50_gfiqa.yml create mode 100644 pyiqa/data/piq_dataset.py delete mode 100644 pyiqa/models/sr_model.py create mode 100644 scripts/process_gfiqa.py create mode 100644 scripts/process_piq.py diff --git a/README.md b/README.md index 85b219f..4e5b93a 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ This is a image quality assessment toolbox with **pure python and pytorch**. We --- ### :triangular_flag_on_post: Updates/Changelog +- **Oct 09, 2023**. Add datasets: [PIQ2023](https://github.com/DXOMARK-Research/PIQ2023), [GFIQA](http://database.mmsp-kn.de/gfiqa-20k-database.html). Add metric `topiq_nr-face`. - **Aug 15, 2023**. Add `st-lpips` and `laion_aes`. Refer to official repo at [ShiftTolerant-LPIPS](https://github.com/abhijay9/ShiftTolerant-LPIPS) and [improved-aesthetic-predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor) - **Aug 05, 2023**. Add our work [TOPIQ](https://arxiv.org/abs/2308.03060) with remarkable performance on almost all benchmarks via efficient Resnet50 backbone. Use it with `topiq_fr, topiq_nr, topiq_iaa` for Full-Reference, No-Reference and Aesthetic assessment respectively. - **March 30, 2023**. Add [URanker](https://github.com/RQ-Wu/UnderwaterRanker) for IQA of under water images. diff --git a/docs/Dataset_Preparation.md b/docs/Dataset_Preparation.md index f1f6eb2..7c3ee83 100644 --- a/docs/Dataset_Preparation.md +++ b/docs/Dataset_Preparation.md @@ -16,11 +16,12 @@ | PieAPP | *2AFC* | AVA | *Aesthetic* | | KADID-10k | | KonIQ-10k(++) | | | LIVEM | | LIVEChallange | | -| LIVE | | | | -| TID2013 | | | | +| LIVE | | [PIQ2023](https://github.com/DXOMARK-Research/PIQ2023)| Portrait dataset | +| TID2013 | | [GFIQA](http://database.mmsp-kn.de/gfiqa-20k-database.html)| Face IQA Dataset | | TID2008 | | | | | CSIQ | | | | +Please see more details at [Awesome Image Quality Assessment](https://github.com/chaofengc/Awesome-Image-Quality-Assessment) ## Resources diff --git a/docs/ModelCard.md b/docs/ModelCard.md index 98acfc7..0de6d14 100644 --- a/docs/ModelCard.md +++ b/docs/ModelCard.md @@ -2,49 +2,53 @@ ## General FR/NR Methods - - -
+List all model names with: +``` +import pyiqa +print(pyiqa.list_models()) +``` -| FR Method | Backward | -| ------------------------ | ------------------ | -| AHIQ | :white_check_mark: | -| PieAPP | :white_check_mark: | -| LPIPS | :white_check_mark: | -| DISTS | :white_check_mark: | -| WaDIQaM | :white_check_mark: | -| CKDN[1](#fn1) | :white_check_mark: | -| FSIM | :white_check_mark: | -| SSIM | :white_check_mark: | -| MS-SSIM | :white_check_mark: | -| CW-SSIM | :white_check_mark: | -| PSNR | :white_check_mark: | -| VIF | :white_check_mark: | -| GMSD | :white_check_mark: | -| NLPD | :white_check_mark: | -| VSI | :white_check_mark: | -| MAD | :white_check_mark: | +| FR Method | Model names | Description +| ------------------------ | ------------------ | ------------ | +| TOPIQ | `topiq_fr`, `topiq_fr-pipal` | Proposed in [this paper](https://arxiv.org/abs/2308.03060) | +| AHIQ | `ahiq` | +| PieAPP | `pieapp` | +| LPIPS | `lpips`, `lpips-vgg`, `stlpips`, `stlpips-vgg` | +| DISTS | `dists` | +| WaDIQaM | | *No pretrain models* | +| CKDN[1](#fn1) | `ckdn` | +| FSIM | `fsim` | +| SSIM | `ssim`, `ssimc` | Gray input (y channel), color input +| MS-SSIM | `ms_ssim` | +| CW-SSIM | `cw_ssim` | +| PSNR | `psnr`, `psnry` | Color input, gray input (y channel) +| VIF | `vif` | +| GMSD | `gmsd` | +| NLPD | `nlpd` | +| VSI | `vsi` | +| MAD | `mad` | - - -| NR Method | Backward | -| ---------------------------- | ------------------------ | -| FID | :heavy_multiplication_x: | -| CLIPIQA(+) | :white_check_mark: | -| MANIQA | :white_check_mark: | -| MUSIQ | :white_check_mark: | -| DBCNN | :white_check_mark: | -| PaQ-2-PiQ | :white_check_mark: | -| HyperIQA | :white_check_mark: | -| NIMA | :white_check_mark: | -| WaDIQaM | :white_check_mark: | -| CNNIQA | :white_check_mark: | -| NRQM(Ma)[2](#fn2) | :heavy_multiplication_x: | -| PI(Perceptual Index) | :heavy_multiplication_x: | -| BRISQUE | :white_check_mark: | -| ILNIQE | :white_check_mark: | -| NIQE | :white_check_mark: | -
+| NR Method | Model names | Description | +| ---------------------------- | ------------------------ | ------ | +| TOPIQ | `topiq_nr`, `topiq_nr-flive`, `topiq_nr-spaq` | [TOPIQ](https://arxiv.org/abs/2308.03060) with different datasets, `koniq` by default | +| TReS | `tres`, `tres-koniq`, `tres-flive` | TReS with different datasets, `koniq` by default | +| FID | `fid` | Statistic distance between two datasets | +| CLIPIQA(+) | `clipiqa`, `clipiqa+`, `clipiqa+_vitL14_512`,`clipiqa+_rn50_512` | CLIPIQA(+) with different backbone, RN50 by default | +| MANIQA | `maniqa`, `maniqa-kadid`, `maniqa-koniq`, `maniqa-pipal` |MUSIQ with different datasets, `koniq` by default | +| MUSIQ | `musiq`, `musiq-koniq`, `musiq-spaq`, `musiq-paq2piq`, `musiq-ava` | MUSIQ with different datasets, `koniq` by default | +| DBCNN | `dbcnn` | +| PaQ-2-PiQ | `paq2piq` | +| HyperIQA | `hyperiqa` | +| NIMA | `nima`, `nima-vgg16-ava` | Aesthetic metric trained with AVA dataset | +| WaDIQaM | | *No pretrain models* +| CNNIQA | `cnniqa` | +| NRQM(Ma)[2](#fn2) | `nrqm` | No backward | +| PI(Perceptual Index) | `pi` | No backward | +| BRISQUE | `brisque` | No backward | +| ILNIQE | `ilniqe` | No backward | +| NIQE | `niqe` | No backward | + [1] This method use distorted image as reference. Please refer to the paper for details.
[2] Currently, only naive random forest regression is implemented and **does not** support backward. @@ -53,6 +57,7 @@ | Task | Method | Description | | -------------- | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Face IQA | `topiq_face` | TOPIQ model trained with face IQA dataset (GFIQA) | | Underwater IQA | URanker | A ranking-based underwater image quality assessment (UIQA) method, AAAI2023, [Arxiv](https://arxiv.org/abs/2208.06857), [Github](https://github.com/RQ-Wu/UnderwaterRanker) | ## Outputs of Different Metrics diff --git a/docs/index.rst b/docs/index.rst index f8d987d..d06a0d3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,6 +6,11 @@ Welcome to pyiqa's documentation! ================================= +``pyiqa`` is a image quality assessment toolbox **with pure python and pytorch**. We provide reimplementation of many mainstream full reference (FR) and no reference (NR) metrics (results are calibrated with official matlab scripts if exist). **With GPU acceleration, most of our implementations are much faster than Matlab.** + +Basic Information +------------------------- + .. toctree:: :maxdepth: 1 @@ -14,6 +19,9 @@ Welcome to pyiqa's documentation! ModelCard benchmark +API Tools and References +------------------------- + .. toctree:: :maxdepth: 2 @@ -21,6 +29,34 @@ Welcome to pyiqa's documentation! metrics_implement training_tools Dataset_Preparation + + +Citation +================================== + +If you find our codes helpful to your research, please consider to use the following citation: +:: + + @misc{pyiqa, + title={{IQA-PyTorch}: PyTorch Toolbox for Image Quality Assessment}, + author={Chaofeng Chen and Jiadi Mo}, + year={2022}, + howpublished = "[Online]. Available: \url{https://github.com/chaofengc/IQA-PyTorch}" + } + + +Please also consider to cite our new work **TOPIQ** if it is useful to you: +:: + + @misc{chen2023topiq, + title={TOPIQ: A Top-down Approach from Semantics to Distortions for Image Quality Assessment}, + author={Chaofeng Chen and Jiadi Mo and Jingwen Hou and Haoning Wu and Liang Liao and Wenxiu Sun and Qiong Yan and Weisi Lin}, + year={2023}, + eprint={2308.03060}, + archivePrefix={arXiv}, + primaryClass={cs.CV} + } + Indices and tables ================== diff --git a/options/default_dataset_opt.yml b/options/default_dataset_opt.yml index e1ef2f4..aab59f6 100644 --- a/options/default_dataset_opt.yml +++ b/options/default_dataset_opt.yml @@ -139,3 +139,17 @@ bapps: type: BAPPSDataset dataroot_target: './datasets/PerceptualSimilarity/dataset' meta_info_file: './datasets/meta_info/meta_info_BAPPSDataset.csv' + +piq: + name: PIQ2023 + type: PIQDataset + dataroot_target: ./datasets/PIQ + meta_info_file: ./datasets/meta_info/meta_info_PIQDataset.csv + split_index: 1 + +gfiqa: + name: GFIQA + type: GeneralNRDataset + dataroot_target: ./datasets/GFIQA/image + meta_info_file: ./datasets/meta_info/meta_info_GFIQADataset.csv + split_file: ./datasets/meta_info/gfiqa_seed123.pkl diff --git a/options/train/TOPIQ/train_TOPIQ_res50_gfiqa.yml b/options/train/TOPIQ/train_TOPIQ_res50_gfiqa.yml new file mode 100644 index 0000000..fe11e41 --- /dev/null +++ b/options/train/TOPIQ/train_TOPIQ_res50_gfiqa.yml @@ -0,0 +1,116 @@ +name: 002_CFANet_Res50_gfiqa +# name: debug_model +model_type: GeneralIQAModel +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 123 + +define: &img_size_oneside 512 +define: &img_size [*img_size_oneside, *img_size_oneside] + +define: &backbone resnet50 + +define: &train_batch_size 16 +define: &test_batch_size 1 + +# dataset and data loader settings +datasets: + train: + name: GFIQA + type: GeneralNRDataset + dataroot_target: ./datasets/GFIQA/image + meta_info_file: ./datasets/meta_info/meta_info_GFIQADataset.csv + split_file: ./datasets/meta_info/gfiqa_seed123.pkl + + augment: + hflip: true + img_range: 1 + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: *train_batch_size + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: GFIQA + type: GeneralNRDataset + dataroot_target: ./datasets/GFIQA/image + meta_info_file: ./datasets/meta_info/meta_info_GFIQADataset.csv + split_file: ./datasets/meta_info/gfiqa_seed123.pkl + + +# network structures +network: + type: CFANet + use_ref: false + pretrained: false + num_crop: 1 + num_attn_layers: 1 + crop_size: *img_size + semantic_model_name: *backbone + block_pool: weighted_avg + +# path +path: + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim: + type: AdamW + lr: !!float 3e-5 + weight_decay: !!float 1e-5 + + scheduler: + type: CosineAnnealingLR + T_max: 50 + eta_min: 0 + # type: StepLR + # step_size: !!float 1e9 + # gamma: 1.0 + + total_iter: 20000 + total_epoch: 200 + warmup_iter: -1 # no warm up + + # losses + mos_loss_opt: + type: MSELoss + loss_weight: !!float 1.0 + + metric_loss_opt: + type: NiNLoss + loss_weight: !!float 1.0 + +# validation settings +val: + val_freq: !!float 800 + save_img: false + pbar: true + + key_metric: srcc # if this metric improve, update all metrics. If not specified, each best metric results will be updated separately + metrics: + srcc: + type: calculate_srcc + + plcc: + type: calculate_plcc + +# logging settings +logger: + print_freq: 100 + save_latest_freq: !!float 500 + log_imgs_freq: 1000 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 + +find_unused_parameters: True diff --git a/pyiqa/archs/topiq_arch.py b/pyiqa/archs/topiq_arch.py index f92e002..d68b1c2 100644 --- a/pyiqa/archs/topiq_arch.py +++ b/pyiqa/archs/topiq_arch.py @@ -32,6 +32,7 @@ 'cfanet_nr_spaq_res50': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/cfanet_nr_spaq_res50-a7f799ac.pth', 'cfanet_iaa_ava_res50': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/cfanet_iaa_ava_res50-3cd62bb3.pth', 'cfanet_iaa_ava_swin': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/cfanet_iaa_ava_swin-393b41b4.pth', + 'topiq_nr_gfiqa_res50': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/topiq_nr_gfiqa_res50-d76bf1ae.pth', } @@ -186,7 +187,7 @@ def __init__(self, pretrained_model_path=None, out_act=False, block_pool='weighted_avg', - iaa_img_size=384, + test_img_size=None, default_mean=IMAGENET_DEFAULT_MEAN, default_std=IMAGENET_DEFAULT_STD, ): @@ -202,7 +203,7 @@ def __init__(self, self.num_class = num_class self.block_pool = block_pool - self.iaa_img_size = iaa_img_size + self.test_img_size = test_img_size # ============================================================= # define semantic backbone network @@ -357,12 +358,12 @@ def dist_func(self, x, y, eps=1e-12): def forward_cross_attention(self, x, y=None): - # resize image when testing IAA model + # resize image when testing if not self.training: - if self.model_name == 'cfanet_iaa_ava_res50': - x = TF.resize(x, self.iaa_img_size, antialias=True) # keep aspect ratio for CNN backbone - elif self.model_name == 'cfanet_iaa_ava_swin': + if self.model_name == 'cfanet_iaa_ava_swin': x = TF.resize(x, [384, 384], antialias=True) # swin require square inputs + elif self.test_img_size is not None: + x = TF.resize(x, self.test_img_size, antialias=True) x = self.preprocess(x) if self.use_ref: diff --git a/pyiqa/data/base_iqa_dataset.py b/pyiqa/data/base_iqa_dataset.py index e33ddcb..44bb099 100644 --- a/pyiqa/data/base_iqa_dataset.py +++ b/pyiqa/data/base_iqa_dataset.py @@ -61,12 +61,14 @@ def mos_normalize(self, opt): def normalize(mos_label): mos_label = (mos_label - mos_range[0]) / (mos_range[1] - mos_range[0]) + # convert to higher better if lower better is true if mos_lower_better: mos_label = 1 - mos_label return mos_label - self.paths_mos = [(p, normalize(m)) for p, m in self.paths_mos] - self.logger.info(f'mos_label is normalized from {mos_range}, lower_better[{mos_lower_better}] to [0, 1], higher better.') + for item in self.paths_mos: + item[1] = normalize(float(item[1])) + self.logger.info(f'mos_label is normalized from {mos_range}, lower_better[{mos_lower_better}] to [0, 1], lower_better[False(higher better)].') def get_transforms(self, opt): transform_list = [] diff --git a/pyiqa/data/general_nr_dataset.py b/pyiqa/data/general_nr_dataset.py index 7ca0c7f..f89ad53 100644 --- a/pyiqa/data/general_nr_dataset.py +++ b/pyiqa/data/general_nr_dataset.py @@ -17,7 +17,7 @@ def init_path_mos(self, opt): def __getitem__(self, index): img_path = self.paths_mos[index][0] - mos_label = self.paths_mos[index][1] + mos_label = float(self.paths_mos[index][1]) img_pil = Image.open(img_path).convert('RGB') img_tensor = self.trans(img_pil) * self.img_range diff --git a/pyiqa/data/piq_dataset.py b/pyiqa/data/piq_dataset.py new file mode 100644 index 0000000..8942cb3 --- /dev/null +++ b/pyiqa/data/piq_dataset.py @@ -0,0 +1,71 @@ +import torch +import os +import csv +from PIL import Image + +from pyiqa.data.data_util import read_meta_info_file +from pyiqa.utils.registry import DATASET_REGISTRY +from pyiqa.utils import get_root_logger +from .general_nr_dataset import GeneralNRDataset + + +@DATASET_REGISTRY.register() +class PIQDataset(GeneralNRDataset): + """General No Reference dataset with meta info file. + """ + def init_path_mos(self, opt): + logger = get_root_logger() + target_img_folder = opt['dataroot_target'] + attr = opt.get('attribute', 'Overall') + + assert attr in ['Details', 'Exposure', 'Overall'], f'attribute should be in [Details, Exposure, Overall], got {attr}' + + logger.info(f'Training on PIQ2023 dataset with attribute [{attr}]') + + with open(opt['meta_info_file'], 'r') as fin: + csvreader = csv.reader(fin) + name_mos = list(csvreader)[1:] + + self.paths_mos = name_mos + + self.paths_mos = [] + for item in name_mos: + if attr in item[0]: + item[0] = os.path.join(target_img_folder, item[0]) + self.paths_mos.append(item) + + def get_split(self, opt): + """Get split for PIQ2023 dataset: + 1: device split + 2: scene split + """ + logger = get_root_logger() + split_index = opt.get('split_index', None) + if split_index is not None: + assert split_index in [1, 2], f'split indexes should be, 1: device split; 2: scene split' + assert self.phase in ['train', 'test'], f'PIQDataset has no {self.phase} split' + + logger.info(f'Training on PIQ2023 dataset with split [{split_index}](1: device split; 2: scene split)') + + new_paths_mos = [] + for item in self.paths_mos: + if self.phase == 'train' and item[split_index - 3] == 'Train': + new_paths_mos.append(item) + elif self.phase == 'test' and item[split_index - 3] == 'Test': + new_paths_mos.append(item) + + self.paths_mos = new_paths_mos + + def __getitem__(self, index): + + img_path = self.paths_mos[index][0] + mos_label = float(self.paths_mos[index][1]) + img_pil = Image.open(img_path).convert('RGB') + + img_tensor = self.trans(img_pil) * self.img_range + mos_label_tensor = torch.Tensor([mos_label]) + + scene_idx = int(self.paths_mos[index][-4]) + + return {'img': img_tensor, 'mos_label': mos_label_tensor, 'img_path': img_path, 'scene_idx': scene_idx} + \ No newline at end of file diff --git a/pyiqa/default_model_configs.py b/pyiqa/default_model_configs.py index 8102060..9c4adc2 100644 --- a/pyiqa/default_model_configs.py +++ b/pyiqa/default_model_configs.py @@ -401,6 +401,16 @@ }, 'metric_mode': 'NR', }, + 'topiq_nr-face': { + 'metric_opts': { + 'type': 'CFANet', + 'semantic_model_name': 'resnet50', + 'model_name': 'topiq_nr_gfiqa_res50', + 'use_ref': False, + 'test_img_size': 512, + }, + 'metric_mode': 'NR', + }, 'topiq_fr': { 'metric_opts': { 'type': 'CFANet', @@ -440,6 +450,7 @@ 'inter_dim': 512, 'num_heads': 8, 'num_class': 10, + 'test_img_size': 384, }, 'metric_mode': 'NR', }, diff --git a/pyiqa/models/sr_model.py b/pyiqa/models/sr_model.py deleted file mode 100644 index 9ed99c1..0000000 --- a/pyiqa/models/sr_model.py +++ /dev/null @@ -1,231 +0,0 @@ -import torch -from collections import OrderedDict -from os import path as osp -from tqdm import tqdm - -from pyiqa.archs import build_network -from pyiqa.losses import build_loss -from pyiqa.metrics import calculate_metric -from pyiqa.utils import get_root_logger, imwrite, tensor2img -from pyiqa.utils.registry import MODEL_REGISTRY -from .base_model import BaseModel - - -@MODEL_REGISTRY.register() -class SRModel(BaseModel): - """Base SR model for single image super-resolution.""" - - def __init__(self, opt): - super(SRModel, self).__init__(opt) - - # define network - self.net_g = build_network(opt['network_g']) - self.net_g = self.model_to_device(self.net_g) - self.print_network(self.net_g) - - # load pretrained models - load_path = self.opt['path'].get('pretrain_network_g', None) - if load_path is not None: - param_key = self.opt['path'].get('param_key_g', 'params') - self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) - - if self.is_train: - self.init_training_settings() - - def init_training_settings(self): - self.net_g.train() - train_opt = self.opt['train'] - - self.ema_decay = train_opt.get('ema_decay', 0) - if self.ema_decay > 0: - logger = get_root_logger() - logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') - # define network net_g with Exponential Moving Average (EMA) - # net_g_ema is used only for testing on one GPU and saving - # There is no need to wrap with DistributedDataParallel - self.net_g_ema = build_network(self.opt['network_g']).to(self.device) - # load pretrained model - load_path = self.opt['path'].get('pretrain_network_g', None) - if load_path is not None: - self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') - else: - self.model_ema(0) # copy net_g weight - self.net_g_ema.eval() - - # define losses - if train_opt.get('pixel_opt'): - self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) - else: - self.cri_pix = None - - if train_opt.get('perceptual_opt'): - self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) - else: - self.cri_perceptual = None - - if self.cri_pix is None and self.cri_perceptual is None: - raise ValueError('Both pixel and perceptual losses are None.') - - # set up optimizers and schedulers - self.setup_optimizers() - self.setup_schedulers() - - def setup_optimizers(self): - train_opt = self.opt['train'] - optim_params = [] - for k, v in self.net_g.named_parameters(): - if v.requires_grad: - optim_params.append(v) - else: - logger = get_root_logger() - logger.warning(f'Params {k} will not be optimized.') - - optim_type = train_opt['optim_g'].pop('type') - self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) - self.optimizers.append(self.optimizer_g) - - def feed_data(self, data): - self.lq = data['lq'].to(self.device) - if 'gt' in data: - self.gt = data['gt'].to(self.device) - - def optimize_parameters(self, current_iter): - self.optimizer_g.zero_grad() - self.output = self.net_g(self.lq) - - l_total = 0 - loss_dict = OrderedDict() - # pixel loss - if self.cri_pix: - l_pix = self.cri_pix(self.output, self.gt) - l_total += l_pix - loss_dict['l_pix'] = l_pix - # perceptual loss - if self.cri_perceptual: - l_percep, l_style = self.cri_perceptual(self.output, self.gt) - if l_percep is not None: - l_total += l_percep - loss_dict['l_percep'] = l_percep - if l_style is not None: - l_total += l_style - loss_dict['l_style'] = l_style - - l_total.backward() - self.optimizer_g.step() - - self.log_dict = self.reduce_loss_dict(loss_dict) - - if self.ema_decay > 0: - self.model_ema(decay=self.ema_decay) - - def test(self): - if hasattr(self, 'net_g_ema'): - self.net_g_ema.eval() - with torch.no_grad(): - self.output = self.net_g_ema(self.lq) - else: - self.net_g.eval() - with torch.no_grad(): - self.output = self.net_g(self.lq) - self.net_g.train() - - def dist_validation(self, dataloader, current_iter, tb_logger, save_img): - if self.opt['rank'] == 0: - self.nondist_validation(dataloader, current_iter, tb_logger, save_img) - - def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): - dataset_name = dataloader.dataset.opt['name'] - with_metrics = self.opt['val'].get('metrics') is not None - use_pbar = self.opt['val'].get('pbar', False) - - if with_metrics: - if not hasattr(self, 'metric_results'): # only execute in the first run - self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} - # initialize the best metric results for each dataset_name (supporting multiple validation datasets) - self._initialize_best_metric_results(dataset_name) - # zero self.metric_results - if with_metrics: - self.metric_results = {metric: 0 for metric in self.metric_results} - - metric_data = dict() - if use_pbar: - pbar = tqdm(total=len(dataloader), unit='image') - - for idx, val_data in enumerate(dataloader): - img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] - self.feed_data(val_data) - self.test() - - visuals = self.get_current_visuals() - sr_img = tensor2img([visuals['result']]) - metric_data['img'] = sr_img - if 'gt' in visuals: - gt_img = tensor2img([visuals['gt']]) - metric_data['img2'] = gt_img - del self.gt - - # tentative for out of GPU memory - del self.lq - del self.output - torch.cuda.empty_cache() - - if save_img: - if self.opt['is_train']: - save_img_path = osp.join(self.opt['path']['visualization'], img_name, - f'{img_name}_{current_iter}.png') - else: - if self.opt['val']['suffix']: - save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, - f'{img_name}_{self.opt["val"]["suffix"]}.png') - else: - save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, - f'{img_name}_{self.opt["name"]}.png') - imwrite(sr_img, save_img_path) - - if with_metrics: - # calculate metrics - for name, opt_ in self.opt['val']['metrics'].items(): - self.metric_results[name] += calculate_metric(metric_data, opt_) - if use_pbar: - pbar.update(1) - pbar.set_description(f'Test {img_name}') - if use_pbar: - pbar.close() - - if with_metrics: - for metric in self.metric_results.keys(): - self.metric_results[metric] /= (idx + 1) - # update the best metric result - self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) - - self._log_validation_metric_values(current_iter, dataset_name, tb_logger) - - def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): - log_str = f'Validation {dataset_name}\n' - for metric, value in self.metric_results.items(): - log_str += f'\t # {metric}: {value:.4f}' - if hasattr(self, 'best_metric_results'): - log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' - f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') - log_str += '\n' - - logger = get_root_logger() - logger.info(log_str) - if tb_logger: - for metric, value in self.metric_results.items(): - tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) - - def get_current_visuals(self): - out_dict = OrderedDict() - out_dict['lq'] = self.lq.detach().cpu() - out_dict['result'] = self.output.detach().cpu() - if hasattr(self, 'gt'): - out_dict['gt'] = self.gt.detach().cpu() - return out_dict - - def save(self, epoch, current_iter): - if hasattr(self, 'net_g_ema'): - self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) - else: - self.save_network(self.net_g, 'net_g', current_iter) - self.save_training_state(epoch, current_iter) diff --git a/pyiqa/train.py b/pyiqa/train.py index 796245e..93469f1 100644 --- a/pyiqa/train.py +++ b/pyiqa/train.py @@ -220,8 +220,7 @@ def train_pipeline(root_path, opt=None, args=None): # validation if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0): - if len(val_loaders) > 1: - logger.warning('Multiple validation datasets are *only* supported by SRModel.') + logger.info(f'{len(val_loaders)} validation datasets are used for validation.') for val_loader in val_loaders: model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) diff --git a/scripts/process_gfiqa.py b/scripts/process_gfiqa.py new file mode 100644 index 0000000..884ec5f --- /dev/null +++ b/scripts/process_gfiqa.py @@ -0,0 +1,42 @@ +import os +import scipy.io as sio +import random +import numpy +import pickle +import csv +import pandas as pd + + +def get_random_splits(seed=123): + random.seed(seed) + + + meta_info_file = '../datasets/meta_info/meta_info_GFIQADataset.csv' + meta_info = pd.read_csv(meta_info_file) + img_list = meta_info['img_name'].tolist() + + total_num = len(img_list) + + all_img_index = list(range(total_num)) + num_splits = 10 + save_path = '../datasets/meta_info/gfiqa_seed123.pkl' + + ratio = [0.7, 0.1, 0.2] # train/val/test + + split_info = {} + for i in range(num_splits): + random.shuffle(all_img_index) + sep1 = int(total_num * ratio[0]) + sep2 = sep1 + int(total_num * ratio[1]) + split_info[i + 1] = { + 'train': all_img_index[:sep1], + 'val': all_img_index[sep1:sep2], + 'test': all_img_index[sep2:] + } + + with open(save_path, 'wb') as sf: + pickle.dump(split_info, sf) + + +if __name__ == '__main__': + get_random_splits() diff --git a/scripts/process_piq.py b/scripts/process_piq.py new file mode 100644 index 0000000..12ef718 --- /dev/null +++ b/scripts/process_piq.py @@ -0,0 +1,58 @@ +import os +import random +import numpy +import pickle +import csv +import pandas as pd + + +def get_meta_info(root_dir, save_meta_path): + attrs = ['Details', 'Exposure', 'Overall'] + + rows_all = [] + for att in attrs: + tmp_row = [] + # read labels + lpath = f'{root_dir}/Scores_{att}.csv' + lreader = csv.reader(open(lpath, 'r')) + header = next(lreader) + + header_all = header + ['DeviceSplit', 'SceneSplit'] + + # read train/test kksplits + device_split = {} + reader = csv.reader(open(f'{root_dir}/Device Split.csv')) + next(reader) + for item in reader: + device_split[item[0]] = item[1] + + scene_split = {} + reader = csv.reader(open(f'{root_dir}/Scene Split.csv')) + next(reader) + for item in reader: + scene_split[item[0]] = item[1] + + for item in lreader: + tmp_row = item + img_name = tmp_row[0].split("\\")[1] + + if img_name in device_split: + ds = device_split[img_name] + + for k, v in scene_split.items(): + if k in img_name: + ss = v + + tmp_row += [ds, ss] + tmp_row[0] = tmp_row[0].replace('\\', '/') + rows_all.append(tmp_row) + + with open(save_meta_path, 'w') as file: + csv_writer = csv.writer(file) + + csv_writer.writerow(header_all) + csv_writer.writerows(rows_all) + + +if __name__ == '__main__': + get_meta_info('../datasets/PIQ', '../datasets/meta_info/meta_info_PIQDataset.csv') \ No newline at end of file