diff --git a/demo/__init__.py b/demo/__init__.py new file mode 100644 index 0000000..9ab16e0 --- /dev/null +++ b/demo/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- + +""" +@date: 2020/10/30 下午3:22 +@file: __init__.py.py +@author: zj +@description: +""" diff --git a/demo/tsn_r50_ucf101_rgb_raw_dense_1x16x4.yaml b/demo/tsn_r50_ucf101_rgb_raw_dense_1x16x4.yaml index 36c87fb..400c438 100644 --- a/demo/tsn_r50_ucf101_rgb_raw_dense_1x16x4.yaml +++ b/demo/tsn_r50_ucf101_rgb_raw_dense_1x16x4.yaml @@ -3,12 +3,14 @@ NUM_NODES: 1 RANK_ID: 0 DIST_BACKEND: "nccl" RNG_SEED: 1 -OUTPUT_DIR: 'outputs/tsn_r50_ucf101_rgb_raw_dense_1x16x4' DATASETS: MODALITY: 'RGB' + TYPE: 'RawFrame' + SAMPLE_STRATEGY: 'DenseSample' CLIP_LEN: 1 FRAME_INTERVAL: 16 NUM_CLIPS: 4 + NUM_SAMPLE_POSITIONS: 1 TRANSFORM: MEAN: (0.485, 0.456, 0.406) STD: (0.229, 0.224, 0.225) @@ -21,9 +23,8 @@ MODEL: NAME: 'TSN' PRETRAINED: '' SYNC_BN: True - INPUT_SIZE: (224, 224, 3) BACKBONE: - NAME: 'resnet50' + NAME: 'ResNet50' PARTIAL_BN: False TORCHVISION_PRETRAINED: True ZERO_INIT_RESIDUAL: True @@ -48,4 +49,4 @@ VISUALIZATION: DISPLAY_HEIGHT: 0 OUTPUT_FPS: -1 OUTPUT_FILE: "" - LABEL_FILE_PATH: 'data/ucf101/annotations/classInd.txt' + LABEL_FILE_PATH: 'data/ucf101/annotations/classInd.txt' \ No newline at end of file diff --git a/tsn/visualization/predictor/predictor.py b/tsn/visualization/predictor/predictor.py index abf2b98..62c0368 100644 --- a/tsn/visualization/predictor/predictor.py +++ b/tsn/visualization/predictor/predictor.py @@ -3,7 +3,7 @@ import torch -from tsn.model.build import build_model +from tsn.model.recognizers.build import build_recognizer from tsn.data.transforms.build import build_transform from .util import process_cv2_inputs from tsn.util.distributed import get_device, get_local_rank @@ -27,7 +27,7 @@ def __init__(self, cfg): device = get_device() # Build the video model and print model statistics. - self.model = build_model(cfg, device) + self.model = build_recognizer(cfg, device) self.model.eval() self.transform = build_transform(cfg, is_train=False)