forked from AIC-DGU/MTGEA
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
94 changed files
with
6,806 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
### Code version (Git Hash) and PyTorch version | ||
|
||
### Dataset used | ||
|
||
### Expected behavior | ||
|
||
### Actual behavior | ||
|
||
### Steps to reproduce the behavior | ||
|
||
### Other comments |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
Copyright (c) 2018, Multimedia Laboratary, The Chinese University of Hong Kong | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
* Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
* Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,48 @@ | ||
# MTGEA | ||
## MTGEA | ||
|
||
## Acknowledgement | ||
|
||
The our framework is extended from the following repositories. We appreciate the authors for releasing the codes. | ||
- The 2-stream framework of our code is based on [ST-GCN](https://github.com/yysijie/st-gcn/blob/master/OLD_README.md). | ||
- The attention mechanism is based on [Mega](https://github.com/thecharm/Mega). | ||
|
||
## Prerequisites | ||
- Python3 (>=3.7) | ||
- PyTorch (>=1.6) | ||
- Other Python libraries can be installed by `pip install -r requirements.txt` | ||
|
||
|
||
### Installation | ||
``` shell | ||
git clone https://github.com/gw16/MTGEA.git; cd MTGEA | ||
cd torchlight; python setup.py install; cd .. | ||
``` | ||
|
||
### Data Preparation | ||
- For the dataset, you can download the pre-processed data from [here](https://drive.google.com/file/d/1wBEGb_rIJLsroDIDYG0_OJ_cb8f_MR3Q/view?usp=sharing) and the raw data from [here](https://drive.google.com/file/d/19nnycJ2FcgdqylE0g-a_lzDCq6RZewdD/view?usp=sharing). | ||
|
||
## Training and Testing | ||
To train a MTGEA model, run | ||
``` | ||
python main.py recognition -c config/mtgea/<dataset>/train.yaml [--work_dir <work folder for double train>] --phase 'double_train' | ||
``` | ||
where the ```<dataset>``` must be [DGUHA_Dataset](https://drive.google.com/file/d/1wBEGb_rIJLsroDIDYG0_OJ_cb8f_MR3Q/view?usp=sharing), and we recommend you to name ```<dataset>``` "dguha_dataset". | ||
As training results, **model weights**, configurations and logging files, will be saved under the ```<work folder for double train>```. (saved ```./work_dir``` by default but not recommended) | ||
|
||
After model training, trained model evaluation can be achieved by this command: | ||
``` | ||
python main.py recognition -c config/mtgea/<dataset>/test.yaml --weights <path to model weights from double train work folder> --phase 'double_test' | ||
``` | ||
|
||
Then, fixing the Kinect stream and training the MTGEA model with point clouds alone can be achieved by this command: | ||
``` | ||
python main.py recognition -c config/mtgea/<dataset>/test.yaml --weights <path to model weights from double train work folder> --phase 'freezing_train' [--work_dir <work folder for freezing train>] | ||
``` | ||
Finally, custom model evaluation can be achieved by this command: | ||
``` | ||
python main.py recognition -c config/mtgea/<dataset>/test.yaml --weights <path to model weights from freezing train work folder> --phase 'freezing_test' | ||
``` | ||
An example of testing from a pretrained model: | ||
``` | ||
python main.py recognition -c config/mtgea/<dataset>/test.yaml --weights '/path/MTGEA/saved_best_model/mtgea_model(with_ahc).pt' --phase 'freezing_test' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
weights: ./models/st_gcn.ntu-xsub.pt | ||
|
||
# feeder | ||
feeder: feeder.feeder.Feeder | ||
test_feeder_args: | ||
data_path: ./data/dguha_dataset/dguha_pcl_ahc_test.npy | ||
label_path: ./data/dguha_dataset/dguha_test_label.pkl | ||
|
||
# skeleton_test_feeder_args | ||
skeleton_test_feeder_args: | ||
data_path: ./data/dguha_dataset/dguha_sk_test.npy | ||
label_path: ./data/dguha_dataset/dguha_test_label.pkl | ||
|
||
# model | ||
model: net.mtgea.MTGEA | ||
|
||
model_args: | ||
|
||
base_model: net.st_gcn.Model | ||
in_channels: 3 | ||
num_class: 32 | ||
output_class: 7 | ||
dropout: 0.5 | ||
edge_importance_weighting: True | ||
graph_args: | ||
layout: 'ntu-rgb+d' | ||
strategy: 'spatial' | ||
|
||
# test | ||
phase: double_test | ||
device: [3] | ||
test_batch_size: 13 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import tools |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# sys | ||
import os | ||
import sys | ||
import numpy as np | ||
import random | ||
# import pickle | ||
import pickle5 as pickle | ||
# torch | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
|
||
# visualization | ||
import time | ||
|
||
# operation | ||
from . import tools | ||
|
||
class Feeder(torch.utils.data.Dataset): | ||
""" Feeder for skeleton-based action recognition | ||
Arguments: | ||
data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M) | ||
label_path: the path to label | ||
random_choose: If true, randomly choose a portion of the input sequence | ||
random_shift: If true, randomly pad zeros at the begining or end of sequence | ||
window_size: The length of the output sequence | ||
normalization: If true, normalize input sequence | ||
debug: If true, only use the first 100 samples | ||
""" | ||
|
||
def __init__(self, | ||
data_path, | ||
label_path, | ||
random_choose=False, | ||
random_move=False, | ||
window_size=-1, | ||
debug=False, | ||
mmap=True): | ||
self.debug = debug | ||
self.data_path = data_path | ||
self.label_path = label_path | ||
self.random_choose = random_choose | ||
self.random_move = random_move | ||
self.window_size = window_size | ||
|
||
self.load_data(mmap) | ||
|
||
def load_data(self, mmap): | ||
# data: N C V T M | ||
|
||
# load label | ||
with open(self.label_path, 'rb') as f: | ||
self.label = pickle.load(f) | ||
# self.sample_name, self.label = pickle.load(f) | ||
|
||
# load data | ||
if mmap: | ||
self.data = np.load(self.data_path, mmap_mode='r') | ||
else: | ||
self.data = np.load(self.data_path) | ||
|
||
if self.debug: | ||
self.label = self.label[0:100] | ||
self.data = self.data[0:100] | ||
self.sample_name = self.sample_name[0:100] | ||
|
||
self.N, self.C, self.T, self.V, self.M = self.data.shape | ||
|
||
def __len__(self): | ||
return len(self.label) | ||
|
||
def __getitem__(self, index): | ||
# get data | ||
data_numpy = np.array(self.data[index]) | ||
label = self.label[index] | ||
|
||
# processing | ||
if self.random_choose: | ||
data_numpy = tools.random_choose(data_numpy, self.window_size) | ||
elif self.window_size > 0: | ||
data_numpy = tools.auto_pading(data_numpy, self.window_size) | ||
if self.random_move: | ||
data_numpy = tools.random_move(data_numpy) | ||
|
||
return data_numpy, label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# sys | ||
import os | ||
import sys | ||
import numpy as np | ||
import random | ||
import pickle | ||
import json | ||
# torch | ||
import torch | ||
import torch.nn as nn | ||
from torchvision import datasets, transforms | ||
|
||
# operation | ||
from . import tools | ||
|
||
|
||
class Feeder_kinetics(torch.utils.data.Dataset): | ||
""" Feeder for skeleton-based action recognition in kinetics-skeleton dataset | ||
Arguments: | ||
data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M) | ||
label_path: the path to label | ||
random_choose: If true, randomly choose a portion of the input sequence | ||
random_shift: If true, randomly pad zeros at the begining or end of sequence | ||
random_move: If true, perform randomly but continuously changed transformation to input sequence | ||
window_size: The length of the output sequence | ||
pose_matching: If ture, match the pose between two frames | ||
num_person_in: The number of people the feeder can observe in the input sequence | ||
num_person_out: The number of people the feeder in the output sequence | ||
debug: If true, only use the first 100 samples | ||
""" | ||
|
||
def __init__(self, | ||
data_path, | ||
label_path, | ||
ignore_empty_sample=True, | ||
random_choose=False, | ||
random_shift=False, | ||
random_move=False, | ||
window_size=-1, | ||
pose_matching=False, | ||
num_person_in=5, | ||
num_person_out=2, | ||
debug=False): | ||
self.debug = debug | ||
self.data_path = data_path | ||
self.label_path = label_path | ||
self.random_choose = random_choose | ||
self.random_shift = random_shift | ||
self.random_move = random_move | ||
self.window_size = window_size | ||
self.num_person_in = num_person_in | ||
self.num_person_out = num_person_out | ||
self.pose_matching = pose_matching | ||
self.ignore_empty_sample = ignore_empty_sample | ||
|
||
self.load_data() | ||
|
||
def load_data(self): | ||
# load file list | ||
self.sample_name = os.listdir(self.data_path) | ||
|
||
if self.debug: | ||
self.sample_name = self.sample_name[0:2] | ||
|
||
# load label | ||
label_path = self.label_path | ||
with open(label_path) as f: | ||
label_info = json.load(f) | ||
|
||
sample_id = [name.split('.')[0] for name in self.sample_name] | ||
self.label = np.array( | ||
[label_info[id]['label_index'] for id in sample_id]) | ||
has_skeleton = np.array( | ||
[label_info[id]['has_skeleton'] for id in sample_id]) | ||
|
||
# ignore the samples which does not has skeleton sequence | ||
if self.ignore_empty_sample: | ||
self.sample_name = [ | ||
s for h, s in zip(has_skeleton, self.sample_name) if h | ||
] | ||
self.label = self.label[has_skeleton] | ||
|
||
# output data shape (N, C, T, V, M) | ||
self.N = len(self.sample_name) #sample | ||
self.C = 3 #channel | ||
self.T = 300 #frame | ||
self.V = 18 #joint | ||
self.M = self.num_person_out #person | ||
|
||
def __len__(self): | ||
return len(self.sample_name) | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __getitem__(self, index): | ||
|
||
# output shape (C, T, V, M) | ||
# get data | ||
sample_name = self.sample_name[index] | ||
sample_path = os.path.join(self.data_path, sample_name) | ||
with open(sample_path, 'r') as f: | ||
video_info = json.load(f) | ||
|
||
# fill data_numpy | ||
data_numpy = np.zeros((self.C, self.T, self.V, self.num_person_in)) | ||
for frame_info in video_info['data']: | ||
frame_index = frame_info['frame_index'] | ||
for m, skeleton_info in enumerate(frame_info["skeleton"]): | ||
if m >= self.num_person_in: | ||
break | ||
pose = skeleton_info['pose'] | ||
score = skeleton_info['score'] | ||
data_numpy[0, frame_index, :, m] = pose[0::2] | ||
data_numpy[1, frame_index, :, m] = pose[1::2] | ||
data_numpy[2, frame_index, :, m] = score | ||
|
||
# centralization | ||
data_numpy[0:2] = data_numpy[0:2] - 0.5 | ||
data_numpy[0][data_numpy[2] == 0] = 0 | ||
data_numpy[1][data_numpy[2] == 0] = 0 | ||
|
||
# get & check label index | ||
label = video_info['label_index'] | ||
assert (self.label[index] == label) | ||
|
||
# data augmentation | ||
if self.random_shift: | ||
data_numpy = tools.random_shift(data_numpy) | ||
if self.random_choose: | ||
data_numpy = tools.random_choose(data_numpy, self.window_size) | ||
elif self.window_size > 0: | ||
data_numpy = tools.auto_pading(data_numpy, self.window_size) | ||
if self.random_move: | ||
data_numpy = tools.random_move(data_numpy) | ||
|
||
# sort by score | ||
sort_index = (-data_numpy[2, :, :, :].sum(axis=1)).argsort(axis=1) | ||
for t, s in enumerate(sort_index): | ||
data_numpy[:, t, :, :] = data_numpy[:, t, :, s].transpose((1, 2, | ||
0)) | ||
data_numpy = data_numpy[:, :, :, 0:self.num_person_out] | ||
|
||
# match poses between 2 frames | ||
if self.pose_matching: | ||
data_numpy = tools.openpose_match(data_numpy) | ||
|
||
return data_numpy, label | ||
|
||
def top_k(self, score, top_k): | ||
assert (all(self.label >= 0)) | ||
|
||
rank = score.argsort() | ||
hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)] | ||
return sum(hit_top_k) * 1.0 / len(hit_top_k) | ||
|
||
def top_k_by_category(self, score, top_k): | ||
assert (all(self.label >= 0)) | ||
return tools.top_k_by_category(self.label, score, top_k) | ||
|
||
def calculate_recall_precision(self, score): | ||
assert (all(self.label >= 0)) | ||
return tools.calculate_recall_precision(self.label, score) |
Oops, something went wrong.