Skip to content

Commit

Permalink
feat: match todd improvements on config system
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Aug 25, 2022
1 parent a8c1743 commit 3ecf4fa
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 62 deletions.
10 changes: 5 additions & 5 deletions head/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
from mmdet.datasets import CocoDataset as _CocoDataset
from mmdet.datasets import CustomDataset

from head.utils import DebugEnum
from head.utils import Debug


class DebugMixin(CustomDataset):

def __len__(self) -> int:
if DebugEnum.LESS_DATA.is_on:
if Debug.LESS_DATA:
return 4
return super().__len__()

def load_annotations(self, *args, **kwargs):
data_infos = super().load_annotations(*args, **kwargs)
if DebugEnum.LESS_DATA.is_on:
if Debug.LESS_DATA:
data_infos = data_infos[:len(self)]
return data_infos

def load_proposals(self, *args, **kwargs):
proposals = super().load_proposals(*args, **kwargs)
if DebugEnum.LESS_DATA.is_on:
if Debug.LESS_DATA:
proposals = proposals[:len(self)]
return proposals

Expand All @@ -35,7 +35,7 @@ class CocoDataset(DebugMixin, _CocoDataset):

def load_annotations(self, *args, **kwargs):
data_infos = super().load_annotations(*args, **kwargs)
if DebugEnum.LESS_DATA.is_on:
if Debug.LESS_DATA:
self.coco.dataset['images'] = \
self.coco.dataset['images'][:len(self)]
self.img_ids = [img['id'] for img in self.coco.dataset['images']]
Expand Down
2 changes: 1 addition & 1 deletion head/detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def aug_test(self, *args, **kwargs) -> NoReturn:
class SingleTeacherDistiller(todd.distillers.SingleTeacherDistiller):

def __init__(self, *args, teacher: dict, **kwargs):
teacher_model = todd.utils.ModelLoader.load_mmlab_models(
teacher_model = todd.base.load_open_mmlab_models(
DETECTORS,
**teacher,
)
Expand Down
2 changes: 1 addition & 1 deletion head/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def train_detector(

if cfg.resume_from:
runner.resume(cfg.resume_from)
todd.base.init_iter(runner.iter)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
todd.base.init_iter(runner.iter)
runner.run(data_loaders, cfg.workflow)
65 changes: 22 additions & 43 deletions head/utils.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,21 @@
import enum
import os
import os.path as osp

import torch
from mmcv import Config
from mmcv.cnn import NORM_LAYERS
from mmcv.runner import TextLoggerHook
from todd.base import get_logger
from todd.base import DebugMode, get_logger


class _DebugEnum(enum.Enum):

@classmethod
def is_active(cls) -> bool:
return any(e.is_on for e in cls)

@property
def is_on(self) -> bool:
return any(map(os.getenv, ['DEBUG', 'DEBUG_' + self.name]))

def turn_on(self) -> None:
os.environ['DEBUG_' + self.name] = '1'

def turn_off(self) -> None:
os.environ['DEBUG_' + self.name] = ''


class DebugEnum(_DebugEnum):
CUSTOM = enum.auto()
TRAIN_WITH_VAL_DATASET = enum.auto()
LESS_DATA = enum.auto()
LESS_BBOXES = enum.auto()
CPU = enum.auto()
SMALLER_BATCH_SIZE = enum.auto()
FREQUENT_EVAL = enum.auto()
class Debug:
CUSTOM = DebugMode()
TRAIN_WITH_VAL_DATASET = DebugMode()
LESS_DATA = DebugMode()
LESS_BBOXES = DebugMode()
CPU = DebugMode()
SMALLER_BATCH_SIZE = DebugMode()
FREQUENT_EVAL = DebugMode()


def odps_init():
Expand All @@ -57,22 +39,19 @@ def _dump_log(*args, **kwargs):
def debug_init(debug: bool, cfg: Config):
if torch.cuda.is_available():
if debug:
DebugEnum.LESS_DATA.turn_on()
if DebugEnum.is_active():
assert not DebugEnum.CPU.is_on
else:
return
Debug.LESS_DATA = True
assert not Debug.CPU
else:
DebugEnum.TRAIN_WITH_VAL_DATASET.turn_on()
DebugEnum.LESS_DATA.turn_on()
DebugEnum.LESS_BBOXES.turn_on()
DebugEnum.CPU.turn_on()
DebugEnum.SMALLER_BATCH_SIZE.turn_on()
DebugEnum.FREQUENT_EVAL.turn_on()
Debug.TRAIN_WITH_VAL_DATASET = True
Debug.LESS_DATA = True
Debug.LESS_BBOXES = True
Debug.CPU = True
Debug.SMALLER_BATCH_SIZE = True
Debug.FREQUENT_EVAL = True

if (
DebugEnum.TRAIN_WITH_VAL_DATASET.is_on and 'data' in cfg
and 'train' in cfg.data and 'val' in cfg.data
Debug.TRAIN_WITH_VAL_DATASET and 'data' in cfg and 'train' in cfg.data
and 'val' in cfg.data
):
data_train = cfg.data.train
data_val = cfg.data.val
Expand All @@ -86,16 +65,16 @@ def debug_init(debug: bool, cfg: Config):
}
data_train.update(data_val)

if DebugEnum.CPU.is_on:
if Debug.CPU:
cfg.fp16 = None
NORM_LAYERS.register_module(
'SyncBN',
force=True,
module=NORM_LAYERS.get('BN'),
)

if DebugEnum.SMALLER_BATCH_SIZE.is_on and 'data' in cfg:
if Debug.SMALLER_BATCH_SIZE and 'data' in cfg:
cfg.data.samples_per_gpu = 2

if DebugEnum.FREQUENT_EVAL.is_on and 'evaluation' in cfg:
if Debug.FREQUENT_EVAL and 'evaluation' in cfg:
cfg.evaluation.interval = 1
20 changes: 8 additions & 12 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
import warnings

import mmcv
import todd
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
from mmdet import __version__
from mmdet.apis import init_random_seed, set_random_seed
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger
from todd.base import Config, DictAction

from head.train import train_detector
from head.utils import debug_init, odps_init
Expand Down Expand Up @@ -62,14 +63,8 @@ def parse_args():
)
parser.add_argument(
'--cfg-options',
nargs='+',
nargs='?',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.'
)
parser.add_argument(
'--launcher',
Expand All @@ -93,9 +88,10 @@ def main():
if args.odps:
odps_init()

cfg = Config.fromfile(args.config)
cfg = Config.load(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
for k, v in args.cfg_options.items():
todd.base.setattr_recur(cfg, k, v)

debug_init(args.debug, cfg)

Expand Down Expand Up @@ -159,10 +155,10 @@ def main():
'Environment info:\n' + dash_line + env_info + '\n' + dash_line
)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
meta['config'] = cfg.dumps()
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
logger.info(f'Config:\n{cfg.dumps()}')

# set random seeds
seed = init_random_seed(args.seed)
Expand Down

0 comments on commit 3ecf4fa

Please sign in to comment.