forked from yangyanli/PointCNN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 695b161
Showing
82 changed files
with
7,737 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.idea | ||
arena/__pycache__/ | ||
__pycache__/ | ||
saver/ | ||
sampling/ | ||
data/ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# PointCNN | ||
|
||
Created by <a href="http://yangyan.li" target="_blank">Yangyan Li</a>, Rui Bu, <a href="http://www.mcsun.cn" target="_blank">Mingchao Sun</a>, and <a href="http://www.cs.sdu.edu.cn/~baoquan/" target="_blank">Baoquan Chen</a> from Shandong University. | ||
|
||
### Introduction | ||
|
||
PointCNN is a simple and general framework for feature learning from point cloud, which refreshed five benchmark records in point cloud processing, including: | ||
|
||
* classification accuracy on ModelNet40 (91.7%) | ||
* classification accuracy on ScanNet (77.9%) | ||
* segmentation part averaged IoU on ShapeNet Parts (86.13%) | ||
* segmentation mean IoU on S3DIS (62.74%) | ||
* per voxel labelling accuracy on ScanNet (85.1%). | ||
|
||
See our <a href="http://arxiv.org/abs/1801.07791" target="_blank">research paper on arXiv</a> for more details. | ||
|
||
### Code Organization | ||
The core X-Conv and PointCNN architecture are defined in ./pointcnn.py. | ||
|
||
The network/training/data augmentation hyperparameters for classification tasks are defined in ./pointcnn_cls/\*.py, for segmentation tasks are defined in ./pointcnn_cls/\*.py | ||
|
||
### Usage | ||
|
||
Commands for training and testing ModelNet40 classification: | ||
``` | ||
cd data_conversions | ||
python3 ./download_datasets.py -d modelnet | ||
cd ../pointcnn_cls | ||
./train_val_modelnet.sh -g 0 -x modelnet_x2_l4 | ||
``` | ||
|
||
Commands for training and testing ShapeNet Parts segmentation: | ||
``` | ||
cd data_conversions | ||
python3 ./download_datasets.py -d shapenet_partseg | ||
cd ../pointcnn_seg | ||
./train_val_shapenet.sh -g 0 -x shapenet_x8_2048_fps | ||
./test_shapenet.sh -g 0 -x shapenet_x8_2048_fps -l ../../models/seg/pointcnn_seg_shapenet_x8_2048_fps_xxxx/ckpts/iter-xxxxx -r 10 | ||
cd .. | ||
python3 ./evaluate_seg.py -g ../data/shapenet_partseg/test_label -p ../data/shapenet_partseg/test_data_pred_10 | ||
``` | ||
|
||
Other datasets can be processed in a similar way. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
theme: jekyll-theme-hacker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
#!/usr/bin/python3 | ||
'''Download datasets for this project.''' | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import sys | ||
import gzip | ||
import html | ||
import shutil | ||
import tarfile | ||
import zipfile | ||
import requests | ||
import argparse | ||
from tqdm import tqdm | ||
|
||
|
||
# from https://gist.github.com/hrouault/1358474 | ||
def query_yes_no(question, default="yes"): | ||
"""Ask a yes/no question via raw_input() and return their answer. | ||
"question" is a string that is presented to the user. | ||
"default" is the presumed answer if the user just hits <Enter>. | ||
It must be "yes" (the default), "no" or None (meaning | ||
an answer is required of the user). | ||
The "answer" return value is one of "yes" or "no". | ||
""" | ||
valid = {"yes": True, "y": True, "ye": True, | ||
"no": False, "n": False} | ||
if default == None: | ||
prompt = " [y/n] " | ||
elif default == "yes": | ||
prompt = " [Y/n] " | ||
elif default == "no": | ||
prompt = " [y/N] " | ||
else: | ||
raise ValueError("invalid default answer: '%s'" % default) | ||
|
||
while True: | ||
sys.stdout.write(question + prompt) | ||
choice = input().lower() | ||
if default is not None and choice == '': | ||
return valid[default] | ||
elif choice in valid: | ||
return valid[choice] | ||
else: | ||
sys.stdout.write("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") | ||
|
||
|
||
def download_from_url(url, dst): | ||
download = True | ||
if os.path.exists(dst): | ||
download = query_yes_no('Seems you have downloaded %s to %s, overwrite?' % (url, dst), default='no') | ||
if download: | ||
os.remove(dst) | ||
|
||
if download: | ||
response = requests.get(url, stream=True) | ||
total_size = int(response.headers.get('content-length', 0)) | ||
chunk_size = 1024 * 1024 | ||
bars = total_size // chunk_size | ||
with open(dst, "wb") as handle: | ||
for data in tqdm(response.iter_content(chunk_size=chunk_size), total=bars, desc=url.split('/')[-1], | ||
unit='M'): | ||
handle.write(data) | ||
|
||
|
||
def download_and_unzip(url, root, dataset): | ||
folder = os.path.join(root, dataset) | ||
folder_zips = os.path.join(folder, 'zips') | ||
if not os.path.exists(folder_zips): | ||
os.makedirs(folder_zips) | ||
filename_zip = os.path.join(folder_zips, url.split('/')[-1]) | ||
|
||
download_from_url(url, filename_zip) | ||
|
||
if filename_zip.endswith('.zip'): | ||
zip_ref = zipfile.ZipFile(filename_zip, 'r') | ||
zip_ref.extractall(folder) | ||
zip_ref.close() | ||
elif filename_zip.endswith(('.tar.gz', '.tgz')): | ||
tarfile.open(name=filename_zip, mode="r:gz").extractall(folder) | ||
elif filename_zip.endswith('.gz'): | ||
filename_no_gz = filename_zip[:-3] | ||
with gzip.open(filename_zip, 'rb') as f_in, open(filename_no_gz, 'wb') as f_out: | ||
shutil.copyfileobj(f_in, f_out) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--folder', '-f', help='Path to data folder.') | ||
parser.add_argument('--dataset', '-d', help='Dataset to download.') | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
root = args.folder if args.folder else '../../data' | ||
if args.dataset == 'tu_berlin': | ||
download_and_unzip('http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/sketches_svg.zip', root, | ||
args.dataset) | ||
elif args.dataset == 'modelnet': | ||
download_and_unzip('https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip', root, args.dataset) | ||
folder = os.path.join(root, args.dataset) | ||
folder_h5 = os.path.join(folder, 'modelnet40_ply_hdf5_2048') | ||
for filename in os.listdir(folder_h5): | ||
shutil.move(os.path.join(folder_h5, filename), os.path.join(folder, filename)) | ||
shutil.rmtree(folder_h5) | ||
elif args.dataset == 'shapenet_partseg': | ||
download_and_unzip('https://shapenet.cs.stanford.edu/iccv17/partseg/train_data.zip', root, args.dataset) | ||
download_and_unzip('https://shapenet.cs.stanford.edu/iccv17/partseg/train_label.zip', root, args.dataset) | ||
download_and_unzip('https://shapenet.cs.stanford.edu/iccv17/partseg/val_data.zip', root, args.dataset) | ||
download_and_unzip('https://shapenet.cs.stanford.edu/iccv17/partseg/val_label.zip', root, args.dataset) | ||
download_and_unzip('https://shapenet.cs.stanford.edu/iccv17/partseg/test_data.zip', root, args.dataset) | ||
download_and_unzip('https://shapenet.cs.stanford.edu/iccv17/partseg/test_label.zip', root, args.dataset) | ||
elif args.dataset == 'mnist': | ||
download_and_unzip('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', root, args.dataset) | ||
download_and_unzip('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', root, args.dataset) | ||
download_and_unzip('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', root, args.dataset) | ||
download_and_unzip('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', root, args.dataset) | ||
elif args.dataset == 'cifar10': | ||
download_and_unzip('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', root, args.dataset) | ||
elif args.dataset == 'quick_draw': | ||
url_categories = 'https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt' | ||
folder = os.path.join(root, args.dataset) | ||
folder_zips = os.path.join(folder, 'zips') | ||
if not os.path.exists(folder_zips): | ||
os.makedirs(folder_zips) | ||
filename_categories = os.path.join(folder_zips, url_categories.split('/')[-1]) | ||
download_from_url(url_categories, filename_categories) | ||
|
||
categories = [line.strip() for line in open(filename_categories, 'r')] | ||
url_base = 'https://storage.googleapis.com/quickdraw_dataset/sketchrnn/' | ||
for category in categories: | ||
url = url_base + html.escape(category) + '.npz' | ||
filename_category = os.path.join(folder_zips, category + '.npz') | ||
download_from_url(url, filename_category) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#!/usr/bin/python3 | ||
'''Convert CIFAR-10 to points.''' | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import sys | ||
import h5py | ||
import random | ||
import tarfile | ||
import argparse | ||
import numpy as np | ||
from datetime import datetime | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
import data_utils | ||
|
||
|
||
def unpickle(file): | ||
import pickle | ||
with open(file, 'rb') as fo: | ||
batch = pickle.load(fo, encoding='bytes') | ||
return batch | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--folder', '-f', help='Path to data folder') | ||
parser.add_argument('--save_ply', '-s', help='Convert .pts to .ply', action='store_true') | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
batch_size = 2048 | ||
|
||
folder_cifar10 = args.folder if args.folder else '../../data/cifar10/cifar-10-batches-py' | ||
folder_pts = os.path.join(os.path.dirname(folder_cifar10), 'pts') | ||
|
||
train_test_files = [('train', ['data_batch_%d' % (idx + 1) for idx in range(5)]), | ||
('test', ['test_batch'])] | ||
|
||
data = np.zeros((batch_size, 1024, 6)) | ||
label = np.zeros((batch_size), dtype=np.int32) | ||
for tag, filelist in train_test_files: | ||
data_list = [] | ||
labels_list = [] | ||
for filename in filelist: | ||
batch = unpickle(os.path.join(folder_cifar10, filename)) | ||
data_list.append(np.reshape(batch[b'data'], (10000, 3, 32, 32))) | ||
labels_list.append(batch[b'labels']) | ||
images = np.concatenate(data_list, axis=0) | ||
labels = np.concatenate(labels_list, axis=0) | ||
|
||
idx_h5 = 0 | ||
filename_filelist_h5 = os.path.join(os.path.dirname(folder_cifar10), '%s_files.txt' % tag) | ||
with open(filename_filelist_h5, 'w') as filelist_h5: | ||
for idx_img, image in enumerate(images): | ||
points = [] | ||
pixels = [] | ||
for x in range(32): | ||
for z in range(32): | ||
points.append((x, random.random() * 1e-6, z)) | ||
pixels.append((image[0, x, z], image[1, x, z], image[2, x, z])) | ||
points_array = np.array(points) | ||
pixels_array = (np.array(pixels).astype(np.float32) / 255)-0.5 | ||
|
||
points_min = np.amin(points_array, axis=0) | ||
points_max = np.amax(points_array, axis=0) | ||
points_center = (points_min + points_max) / 2 | ||
scale = np.amax(points_max - points_min) / 2 | ||
points_array = (points_array - points_center) * (0.8 / scale) | ||
|
||
if args.save_ply: | ||
filename_pts = os.path.join(folder_pts, tag, '{:06d}.ply'.format(idx_img)) | ||
data_utils.save_ply(points_array, filename_pts, colors=pixels_array+0.5) | ||
|
||
idx_in_batch = idx_img % batch_size | ||
data[idx_in_batch, ...] = np.concatenate((points_array, pixels_array), axis=-1) | ||
label[idx_in_batch] = labels[idx_img] | ||
if ((idx_img + 1) % batch_size == 0) or idx_img == len(images) - 1: | ||
item_num = idx_in_batch + 1 | ||
filename_h5 = os.path.join(os.path.dirname(folder_cifar10), '%s_%d.h5' % (tag, idx_h5)) | ||
print('{}-Saving {}...'.format(datetime.now(), filename_h5)) | ||
filelist_h5.write('./%s_%d.h5\n' % (tag, idx_h5)) | ||
|
||
file = h5py.File(filename_h5, 'w') | ||
file.create_dataset('data', data=data[0:item_num, ...]) | ||
file.create_dataset('label', data=label[0:item_num, ...]) | ||
file.close() | ||
|
||
idx_h5 = idx_h5 + 1 | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#!/usr/bin/python3 | ||
'''Convert MNIST to points.''' | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import sys | ||
import h5py | ||
import random | ||
import argparse | ||
import numpy as np | ||
from mnist import MNIST | ||
from datetime import datetime | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
import data_utils | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--folder', '-f', help='Path to data folder') | ||
parser.add_argument('--point_num', '-p', help='Point number for each sample', type=int, default=256) | ||
parser.add_argument('--save_ply', '-s', help='Convert .pts to .ply', action='store_true') | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
batch_size = 2048 | ||
|
||
folder_mnist = args.folder if args.folder else '../../data/mnist/zips' | ||
folder_pts = os.path.join(os.path.dirname(folder_mnist), 'pts') | ||
|
||
mnist_data = MNIST(folder_mnist) | ||
mnist_train_test = [(mnist_data.load_training(), 'train'), (mnist_data.load_testing(), 'test')] | ||
|
||
data = np.zeros((batch_size, args.point_num, 4)) | ||
label = np.zeros((batch_size), dtype=np.int32) | ||
for ((images, labels), tag) in mnist_train_test: | ||
idx_h5 = 0 | ||
filename_filelist_h5 = os.path.join(os.path.dirname(folder_mnist), '%s_files.txt' % tag) | ||
point_num_total = 0 | ||
with open(filename_filelist_h5, 'w') as filelist_h5: | ||
for idx_img, image in enumerate(images): | ||
points = [] | ||
pixels = [] | ||
for idx_pixel, pixel in enumerate(image): | ||
if pixel == 0: | ||
continue | ||
x = idx_pixel // 28 | ||
z = idx_pixel % 28 | ||
points.append((x, random.random() * 1e-6, z)) | ||
pixels.append(pixel) | ||
point_num_total = point_num_total + len(points) | ||
pixels_sum = sum(pixels) | ||
probs = [pixel / pixels_sum for pixel in pixels] | ||
indices = np.random.choice(list(range(len(points))), size=args.point_num, | ||
replace=(len(points) < args.point_num), p=probs) | ||
points_array = np.array(points)[indices] | ||
pixels_array_1d = (np.array(pixels)[indices].astype(np.float32) / 255) - 0.5 | ||
pixels_array = np.expand_dims(pixels_array_1d, axis=-1) | ||
|
||
points_min = np.amin(points_array, axis=0) | ||
points_max = np.amax(points_array, axis=0) | ||
points_center = (points_min + points_max) / 2 | ||
scale = np.amax(points_max - points_min) / 2 | ||
points_array = (points_array - points_center) * (0.8 / scale) | ||
|
||
if args.save_ply: | ||
filename_pts = os.path.join(folder_pts, tag, '{:06d}.ply'.format(idx_img)) | ||
data_utils.save_ply(points_array, filename_pts, colors=np.tile(pixels_array, (1, 3)) + 0.5) | ||
|
||
idx_in_batch = idx_img % batch_size | ||
data[idx_in_batch, ...] = np.concatenate((points_array, pixels_array), axis=-1) | ||
label[idx_in_batch] = labels[idx_img] | ||
if ((idx_img + 1) % batch_size == 0) or idx_img == len(images) - 1: | ||
item_num = idx_in_batch + 1 | ||
filename_h5 = os.path.join(os.path.dirname(folder_mnist), '%s_%d.h5' % (tag, idx_h5)) | ||
print('{}-Saving {}...'.format(datetime.now(), filename_h5)) | ||
filelist_h5.write('./%s_%d.h5\n' % (tag, idx_h5)) | ||
|
||
file = h5py.File(filename_h5, 'w') | ||
file.create_dataset('data', data=data[0:item_num, ...]) | ||
file.create_dataset('label', data=label[0:item_num, ...]) | ||
file.close() | ||
|
||
idx_h5 = idx_h5 + 1 | ||
print('Average point number in each sample is : %f!' % (point_num_total / len(images))) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.