Skip to content

Commit

Permalink
perf(tsn): 对视频均匀分段再随机采样
Browse files Browse the repository at this point in the history
  • Loading branch information
zjykzj committed Aug 29, 2020
1 parent d626546 commit 1cb8731
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@
checkpointer = CheckPointer(model, optimizer=optimizer, scheduler=lr_scheduler, save_dir=output_dir,
save_to_disk=True, logger=None)

train_model('MobileNet_v2', model, criterion, optimizer, lr_scheduler, data_loaders, data_sizes, checkpointer,
train_model('ResNet50', model, criterion, optimizer, lr_scheduler, data_loaders, data_sizes, checkpointer,
epoches=epoches, device=device)
24 changes: 20 additions & 4 deletions tsn/data/hmdb51.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import cv2
from PIL import Image
import random
import os
import numpy as np
Expand Down Expand Up @@ -59,10 +60,21 @@ def __getitem__(self, index: int):
assert index < len(self.video_list)
target = self.cate_list[index]

# 视频帧数
video_length = self.img_num_list[index]
# 每一段帧数
seg_length = int(video_length / self.num_seg)
num_list = list()
if 'RGBDiff' in self.modality:
num_list = sorted(random.sample(range(self.img_num_list[index] - 1), self.num_seg))
# 在每段中随机挑选一帧
for i in range(self.num_seg):
# 如果使用`RGBDiff`,需要采集前后两帧进行差分
# random.randint(a, b) -> [a, b]
num_list.append(random.randint(i * seg_length, (i + 1) * seg_length - 2))
else:
num_list = sorted(random.sample(range(self.img_num_list[index]), self.num_seg))
# 在每段中随机挑选一帧
for i in range(self.num_seg):
num_list.append(random.randint(i * seg_length, (i + 1) * seg_length - 1))
video_path = os.path.join(self.data_dir, self.video_list[index])

image_list = list()
Expand All @@ -76,10 +88,14 @@ def __getitem__(self, index: int):
image_list.append(img)
if 'RGBDiff' in self.modality:
img1_path = os.path.join(video_path, 'img_{:0>5d}.jpg'.format(num))
img1 = cv2.imread(img1_path)
# img1 = cv2.imread(img1_path, cv2.IMREAD_COLOR)
img1 = np.array(Image.open(img1_path))

img2_path = os.path.join(video_path, 'img_{:0>5d}.jpg'.format(num + 1))
img2 = cv2.imread(img2_path)
# img2 = cv2.imread(img2_path, cv2.IMREAD_COLOR)
img2 = np.array(Image.open(img2_path))

# print(img1.shape, img2.shape)
img = rgbdiff(img1, img2)
if self.transform:
img = self.transform(img)
Expand Down

0 comments on commit 1cb8731

Please sign in to comment.