Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/contrastive loss #50

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#ide
.idea/

# Data
output_dir/
datasets/


# checkpoints
experiments/

Expand Down
2 changes: 1 addition & 1 deletion bin_mean_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def merge_center(self, seed_point, bandwidth=0.25):

# merge center if distance between two points less than bandwidth
sorted_intensity, indices = torch.sort(intensity, descending=True)
is_center = np.ones(n, dtype=np.bool)
is_center = np.ones(n, dtype=bool)
indices = indices.cpu().numpy()
center = np.zeros(n, dtype=np.uint8)

Expand Down
10 changes: 6 additions & 4 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
seed: 123
num_gpus: 1
num_epochs: 100
resume_dir: None
num_epochs: 5
resume_dir: /Users/dimafadeev/Desktop/Catalog/TUM/WS23/ML3D/repo/output/models/
print_interval: 10


solver:
method: adam
lr: 0.0001
weight_decay: 0.00001

dataset:
root_dir: /new_disk2/yuzh/PlaneNetData/
root_dir: /Users/dimafadeev/Desktop/Catalog/TUM/WS23/ML3D/repo/output/processed/
batch_size: 16
num_workers: 8

model:
arch: resnet101
arch: resnet101 # dpt #
pretrained: True
embed_dims: 2
fix_bn: False
semantic: True
3 changes: 2 additions & 1 deletion data_tools/RecordReaderAll.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# modified from https://github.com/art-programmer/PlaneNet
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

HEIGHT=192
WIDTH=256
Expand Down
16 changes: 9 additions & 7 deletions data_tools/convert_tfrecords.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import tensorflow as tf
import numpy as np
import os
import argparse

from RecordReaderAll import *

os.environ['CUDA_VISIBLE_DEVICES']=''
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

#os.environ['CUDA_VISIBLE_DEVICES']='0'

parser = argparse.ArgumentParser()
parser.add_argument('--input_tfrecords_file', type=str,
Expand All @@ -27,14 +29,14 @@
os.makedirs(output_dir)

if data_type == 'train':
file_list = open(output_dir + '/train.txt', 'w')
output_dir = os.path.join(output_dir, 'train')
os.makedirs(output_dir)
os.makedirs(output_dir,exist_ok=True)
file_list = open(output_dir + '/train.txt', 'w')
max_num = 50000
elif data_type == 'val':
file_list = open(output_dir + '/val.txt', 'w')
output_dir = os.path.join(output_dir, 'val')
os.makedirs(output_dir)
os.makedirs(output_dir, exist_ok=True)
file_list = open(output_dir + '/val.txt', 'w')
max_num = 760
else:
print("unsupported data type")
Expand Down Expand Up @@ -74,7 +76,7 @@

file_list.write('%d.npz\n' % (i, ))

if i % 100 == 99:
if i % 1000 == 99:
print(i)

file_list.close()
3 changes: 3 additions & 0 deletions embedding.pt
Git LFS file not shown
3 changes: 3 additions & 0 deletions instance.pt
Git LFS file not shown
Loading