Skip to content

Commit

Permalink
Option to generate train_files.txt and test_files.txt for tu_berlin
Browse files Browse the repository at this point in the history
This change avoids having to manually concatenate `fold_*_files.txt`.
  • Loading branch information
Garrett Smith committed Nov 21, 2018
1 parent 7da9960 commit 78c298e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ Here we list the commands for training/evaluating PointCNN on classification and
```
cd data_conversions
python3 ./download_datasets.py -d tu_berlin
python3 ./prepare_tu_berlin_data.py -f ../../data/tu_berlin/ -a
cat ../../data/tu_berlin/fold_1_*.txt ../../data/tu_berlin/fold_0_*.txt > ../../data/tu_berlin/train_files.txt
cat ../../data/tu_berlin/fold_2_files.txt > ../../data/tu_berlin/test_files.txt
python3 ./prepare_tu_berlin_data.py -f ../../data/tu_berlin/ -a --create-train-test
cd ../pointcnn_cls
./train_val_tu_berlin.sh -g 0 -x tu_berlin_x3_l4
```
Expand Down
19 changes: 19 additions & 0 deletions data_conversions/prepare_tu_berlin_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def main():
parser.add_argument('--point_num', '-p', help='Point number for each sample', type=int, default=1024)
parser.add_argument('--save_ply', '-s', help='Convert .pts to .ply', action='store_true')
parser.add_argument('--augment', '-a', help='Data augmentation', action='store_true')
parser.add_argument('--create-train-test',
help='Concatenate file lists to generate train_files.txt and test_files.txt',
action='store_true')
args = parser.parse_args()
print(args)

Expand Down Expand Up @@ -129,6 +132,9 @@ def main():
random.shuffle(filelist_svg_fold)

filename_filelist_svg_fold = os.path.join(root_folder, 'filelist_fold_%d.txt' % (idx_fold))
if os.path.exists(filename_filelist_svg_fold):
print('{}-{} exists, skipping'.format(datetime.now(), filename_filelist_svg_fold))
continue
with open(filename_filelist_svg_fold, 'w') as filelist_svg_fold_file:
for filename in filelist_svg_fold:
filelist_svg_fold_file.write('%s\n' % (filename))
Expand Down Expand Up @@ -215,6 +221,19 @@ def main():
if len(filelist_svg_failed) != 0:
print('{}-Failed to parse {} sketches!'.format(datetime.now(), len(filelist_svg_failed)))

if args.create_train_test:
print('{}-Generating train_files.txt and test_files.txt'.format(datetime.now()))
train_files = open(os.path.join(root_folder, "train_files.txt"), "w")
test_files = open(os.path.join(root_folder, "test_files.txt"), "w")
with train_files, test_files:
for idx_fold in range(fold_num):
filename = os.path.join(root_folder, 'fold_%d_files%s.txt' % (idx_fold, tag_aug))
contents = open(filename, "r").read()
# Use folders 0..N-1 for train and N for test
if idx_fold < fold_num - 1:
train_files.write(contents)
else:
test_files.write(contents)

if __name__ == '__main__':
main()

0 comments on commit 78c298e

Please sign in to comment.