diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3ca7fb5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,159 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +env/ +.venv +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +### VisualStudioCode +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace +**/.vscode + +# JetBrains +.idea/ + +# Data & Models +*.h5 +*.tar +*.tar.gz + +# Lightning-Hydra-Template +configs/local/default.yaml +/data/ +/logs/ +.env + +# Aim logging +.aim + +# local +exp/ +data/ +.cache/ \ No newline at end of file diff --git a/README.md b/README.md index d2c0515..52aff77 100644 --- a/README.md +++ b/README.md @@ -29,25 +29,94 @@ PonderV2 is a comprehensive 3D pre-training framework designed to facilitate the

## News: +- *Dec. 2023*: Multi-dataset training supported! More instructions on installation and usage are available. Please check out! - *Nov. 2023*: [**Model files**](./ponder/models/ponder/) are released! Usage instructions, complete codes and checkpoints are coming soon! - *Oct. 2023*: **PonderV2** is released on [arXiv](https://arxiv.org/abs/2310.08586), code will be made public and supported by [Pointcept](https://github.com/Pointcept/Pointcept) soon. -## Example Usage: -Pre-train PonderV2 on single Structured3D dataset with 8 GPUs: +## Installation +This repository is mainly based on [Pointcept](https://github.com/Pointcept/Pointcept). + +### Requirements +- Ubuntu: 18.04 or higher +- CUDA: 11.3 or higher +- PyTorch: 1.10.0 or higher + +### Conda Environment ```bash -bash scripts/train.sh -g 8 -d s3dis -c pretrain-ponder-spunet-v1m1-0-base -n ponderv2-pretrain +conda create -n ponderv2 python=3.8 -y +conda activate ponderv2 +# Choose version you want here: https://pytorch.org/get-started/previous-versions/ +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch -y +conda install h5py pyyaml -c anaconda -y +conda install sharedarray tensorboard tensorboardx addict einops scipy plyfile termcolor timm -c conda-forge -y +conda install pytorch-cluster pytorch-scatter pytorch-sparse -c pyg -y +pip install torch-geometric yapf==0.40.1 opencv-python open3d==0.10.0 imageio +pip install git+https://github.com/openai/CLIP.git + +# spconv (SparseUNet) +# refer https://github.com/traveller59/spconv +pip install spconv-cu113 + +# NeuS renderer +cd libs/smooth-sampler +# usual +python setup.py install +# docker & multi GPU arch +TORCH_CUDA_ARCH_LIST="ARCH LIST" python setup.py install +# e.g. 7.5: RTX 3000; 8.0: a100 More available in: https://developer.nvidia.com/cuda-gpus +TORCH_CUDA_ARCH_LIST="7.5 8.0" python setup.py install +cd ../.. ``` -More detailed instructions on installation, data pre-processing, pre-training and finetuning will come soon! +## Data Preparation +Please check out [docs/data_preparation.md](docs/data_preparation.md) + +## Quick Start: +- **Pretraining**: Pretrain PonderV2 on indoor or outdoor datasets. + +Pre-train PonderV2 (indoor) on single ScanNet dataset with 8 GPUs: +```bash +# -g: number of GPUs +# -d: dataset +# -c: config file, the final config is ./config/${-d}/${-c}.py +# -n: experiment name +bash scripts/train.sh -g 8 -d scannet -c pretrain-ponder-spunet-v1m1-0-base -n ponderv2-pretrain-sc +``` + +Pre-train PonderV2 (indoor) on ScanNet, S3DIS and Structured3D datasets using [Point Prompt Training (PPT)](https://arxiv.org/abs/2308.09718) with 8 GPUs: +```bash +bash scripts/train.sh -g 8 -d scannet -c pretrain-ponder-ppt-v1m1-0-sc-s3-st-spunet -n ponderv2-pretrain-sc-s3-st +``` + +Pre-train PonderV2 (outdoor) on single nuScenes dataset with 4 GPUs: +```bash +bash scripts/train.sh -g 4 -d nuscenes -c pretrain-ponder-spunet-v1m1-0-base -n ponderv2-pretrain-nu +``` + +- **Finetuning**: Finetune on downstream tasks with PonderV2 pre-trained checkpoints. + +Finetune PonderV2 on ScanNet semantic segmentation downstream task with PPT: +```bash +# -w: path to checkpoint +bash scripts/train.sh -g 8 -d scannet -c semseg-ppt-v1m1-0-sc-s3-st-spunet-lovasz-ft -n ponderv2-semseg-ft -w ${PATH/TO/CHECKPOINT} +``` + +- **Testing**: Test a finetuned model on a downstream task. +```bash +# Based on experiment folder created by training script +bash scripts/test.sh -g 8 -d scannet -n ponderv2-semseg-ft -w ${CHECKPOINT/NAME} +``` + +For more detailed options and examples, please refer to [docs/getting_started.md](docs/getting_started.md). For more outdoor pre-training and downstream information, you can also refer to [UniPAD](https://github.com/Nightmare-n/UniPAD). ## Todo: -- [ ] add instructions on installation and usage -- [ ] add ScanNet w. RGB-D dataloader and data pre-processing scripts -- [ ] add multi-dataset loader and trainer -- [ ] add multi-dataset point prompt training model -- [ ] add more pre-training and finetuning scripts +- [x] add instructions on installation and usage +- [x] add ScanNet w. RGB-D dataloader and data pre-processing scripts +- [x] add multi-dataset loader and trainer +- [x] add multi-dataset point prompt training model +- [ ] add more pre-training and finetuning configs - [ ] add pre-trained checkpoints ## Citation @@ -74,3 +143,8 @@ For more outdoor pre-training and downstream information, you can also refer to year={2023}, } ``` + +## Acknowledgement +This project is mainly based on the following codebases. Thanks for their great works! +- [SDFStudio](https://github.com/autonomousvision/sdfstudio) +- [Pointcept](https://github.com/Pointcept/Pointcept) \ No newline at end of file diff --git a/configs/s3dis/pretrain-ponder-spunet-v1m1-0-base.py b/configs/s3dis/pretrain-ponder-spunet-v1m1-0-base.py index 4e619bb..e840ea2 100644 --- a/configs/s3dis/pretrain-ponder-spunet-v1m1-0-base.py +++ b/configs/s3dis/pretrain-ponder-spunet-v1m1-0-base.py @@ -93,7 +93,7 @@ pool_type="mean", share_volume=True, render_semantic=True, - conditions=("Structured3D",), + conditions=("S3DIS",), template=( "itap of a [x]", "a origami [x]", @@ -157,7 +157,7 @@ # dataset settings num_cameras = 5 data = dict( - num_classes=25, + num_classes=13, ignore_index=-1, names=( "ceiling", @@ -176,7 +176,7 @@ ), train=dict( type="S3DISRGBDDataset", - split="train", + split=("Area_1", "Area_2", "Area_3", "Area_4", "Area_6"), data_root="data/s3dis", render_semantic=True, num_cameras=num_cameras, @@ -248,7 +248,7 @@ ), dict(type="NormalizeColor"), dict(type="ShufflePoint"), - dict(type="Add", keys_dict={"condition": "Structured3D"}), + dict(type="Add", keys_dict={"condition": "S3DIS"}), dict(type="ToTensor"), dict( type="Collect", diff --git a/configs/scannet/pretrain-ponder-ppt-v1m1-0-sc-s3-st-spunet.py b/configs/scannet/pretrain-ponder-ppt-v1m1-0-sc-s3-st-spunet.py new file mode 100644 index 0000000..88a5dfa --- /dev/null +++ b/configs/scannet/pretrain-ponder-ppt-v1m1-0-sc-s3-st-spunet.py @@ -0,0 +1,565 @@ +_base_ = ["../_base_/default_runtime.py"] + +num_gpu = 8 +limit_num_coord = 2000000 + +# misc custom setting +batch_size = 8 * num_gpu # bs: total bs in all gpus +num_worker = 16 * num_gpu + +mix_prob = 0.0 +empty_cache = True +enable_amp = True +evaluate = False +find_unused_parameters = True + +# trainer +train = dict( + type="MultiDatasetTrainer", +) + +# model settings +model = dict( + type="PonderIndoor-v2", + backbone=dict( + type="SpUNet-v1m3", + in_channels=6, + num_classes=0, + base_channels=32, + context_channels=256, + channels=(32, 64, 128, 256, 256, 128, 96, 96), + layers=(2, 3, 4, 6, 2, 2, 2, 2), + cls_mode=False, + conditions=("ScanNet", "S3DIS", "Structured3D"), + zero_init=False, + norm_decouple=True, + norm_adaptive=True, + norm_affine=True, + ), + projection=dict( + type="UNet3D-v1m2", + in_channels=96, + out_channels=128, + ), + renderer=dict( + type="NeuSModel", + field=dict( + type="SDFField", + sdf_decoder=dict( + in_dim=128, + out_dim=65, # 64 + 1 + hidden_size=128, + n_blocks=1, + ), + rgb_decoder=dict( + in_dim=198, # 128 + 64 + 3 + 3 + out_dim=3, + hidden_size=128, + n_blocks=0, + ), + semantic_decoder=dict( + in_dim=195, # 128 + 64 + 3, no directions + out_dim=512, + hidden_size=128, + n_blocks=0, + ), + beta_init=0.3, + use_gradient=True, + volume_type="default", + padding_mode="zeros", + share_volume=True, + ), + collider=dict( + type="AABBBoxCollider", + near_plane=0.01, + bbox=[-0.55, -0.55, -0.55, 0.55, 0.55, 0.55], + ), + sampler=dict( + type="NeuSSampler", + initial_sampler="UniformSampler", + num_samples=96, + num_samples_importance=36, + num_upsample_steps=1, + train_stratified=True, + single_jitter=False, + ), + loss=dict( + sensor_depth_truncation=0.05, + temperature=0.01, + weights=dict( + eikonal_loss=0.01, + free_space_loss=1.0, + sdf_loss=10.0, + depth_loss=1.0, + rgb_loss=10.0, + semantic_loss=0.1, + ), + ), + ), + # mask=dict(ratio=0.8, size=8, channel=6), + mask=None, + grid_shape=(128, 128, 32), + grid_size=0.02, + val_ray_split=10240, + ray_nsample=256, + padding=0.1, + backbone_out_channels=96, + context_channels=256, + pool_type="mean", + share_volume=True, + render_semantic=True, + conditions=("Structured3D", "ScanNet", "S3DIS"), + template=( + "itap of a [x]", + "a origami [x]", + "a rendering of a [x]", + "a painting of a [x]", + "a photo of a [x]", + "a photo of one [x]", + "a photo of a nice [x]", + "a photo of a weird [x]", + "a cropped photo of a [x]", + "a bad photo of a [x]", + "a good photo of a [x]", + "a photo of the large [x]", + "a photo of the small [x]", + "a photo of a clean [x]", + "a photo of a dirty [x]", + "a bright photo of a [x]", + "a dark photo of a [x]", + "a [x] in a living room", + "a [x] in a bedroom", + "a [x] in a kitchen", + "a [x] in a bathroom", + ), + clip_model="ViT-B/16", + class_name=( + "wall", + "floor", + "cabinet", + "bed", + "chair", + "sofa", + "table", + "door", + "window", + "bookshelf", + "bookcase", + "picture", + "counter", + "desk", + "shelves", + "curtain", + "dresser", + "pillow", + "mirror", + "ceiling", + "refrigerator", + "television", + "shower curtain", + "nightstand", + "toilet", + "sink", + "lamp", + "bathtub", + "garbagebin", + "board", + "beam", + "column", + "clutter", + "other structure", + "other furniture", + "other property", + ), + valid_index=( + ( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 11, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 23, + 25, + 26, + 33, + 34, + 35, + ), + (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 15, 20, 22, 24, 25, 27, 34), + (0, 1, 4, 5, 6, 7, 8, 10, 19, 29, 30, 31, 32), + ), + ppt_loss_weight=0.0, + ppt_criteria=[dict(type="CrossEntropyLoss", loss_weight=1.0, ignore_index=-1)], +) + +# scheduler settings +epoch = 400 +eval_epoch = 1 +optimizer = dict( + type="SGD", + lr=0.0001 * batch_size / 8, + momentum=0.9, + weight_decay=0.0001, + nesterov=True, +) +scheduler = dict( + type="OneCycleLR", + max_lr=optimizer["lr"], + pct_start=0.05, + anneal_strategy="cos", + div_factor=10.0, + final_div_factor=10000.0, +) + +# dataset settings +num_cameras = 5 +data = dict( + num_classes=20, + ignore_index=-1, + names=( + "wall", + "floor", + "cabinet", + "bed", + "chair", + "sofa", + "table", + "door", + "window", + "bookshelf", + "picture", + "counter", + "desk", + "curtain", + "refridgerator", + "shower curtain", + "toilet", + "sink", + "bathtub", + "otherfurniture", + ), + train=dict( + type="ConcatDataset", + datasets=[ + # Structured3D + dict( + type="Structured3DRGBDDataset", + split="train", + data_root="data/structured3d", + render_semantic=True, + num_cameras=num_cameras, + transform=[ + dict(type="CenterShift", apply_z=True), + dict( + type="RandomDropout", + dropout_ratio=0.8, + dropout_application_ratio=1.0, + ), + dict( + type="RandomRotate", + angle=[-1, 1], + axis="z", + center=[0, 0, 0], + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomRotate", + angle=[-1 / 64, 1 / 64], + axis="x", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomRotate", + angle=[-1 / 64, 1 / 64], + axis="y", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomScale", + scale=[0.9, 1.1], + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomFlip", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="GridSample", + grid_size=0.02, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + dict( + type="CenterShift", + apply_z=False, + keys=[ + "extrinsic", + ], + ), + dict(type="NormalizeColor"), + dict(type="ShufflePoint"), + dict(type="Add", keys_dict={"condition": "Structured3D"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=( + "coord", + "grid_coord", + "segment", + "condition", + "rgb", + "depth", + "depth_scale", + ), + stack_keys=( + "intrinsic", + "extrinsic", + "rgb", + "depth", + "semantic", + ), + feat_keys=("color", "normal"), + ), + ], + test_mode=False, + loop=4, # sampling weight + ), + # ScanNet + dict( + type="ScanNetRGBDDataset", + split="train", + data_root="data/scannet", + rgbd_root="data/scannet/rgbd", + render_semantic=True, + num_cameras=num_cameras, + transform=[ + dict(type="CenterShift", apply_z=True), + dict( + type="RandomDropout", + dropout_ratio=0.8, + dropout_application_ratio=1.0, + ), + dict( + type="RandomRotate", + angle=[-1, 1], + axis="z", + center=[0, 0, 0], + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomRotate", + angle=[-1 / 64, 1 / 64], + axis="x", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomRotate", + angle=[-1 / 64, 1 / 64], + axis="y", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomScale", + scale=[0.9, 1.1], + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomFlip", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="GridSample", + grid_size=0.02, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + dict( + type="CenterShift", + apply_z=False, + keys=[ + "extrinsic", + ], + ), + dict(type="NormalizeColor"), + dict(type="ShufflePoint"), + dict(type="Add", keys_dict={"condition": "ScanNet"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=( + "coord", + "grid_coord", + "segment", + "condition", + "rgb", + "depth", + "depth_scale", + ), + stack_keys=( + "intrinsic", + "extrinsic", + "rgb", + "depth", + "semantic", + ), + feat_keys=("color", "normal"), + ), + ], + test_mode=False, + loop=2, # sampling weight + ), + # S3DIS + dict( + type="S3DISRGBDDataset", + split=("Area_1", "Area_2", "Area_3", "Area_4", "Area_6"), + data_root="data/s3dis", + render_semantic=True, + num_cameras=num_cameras, + transform=[ + dict(type="CenterShift", apply_z=True), + dict( + type="RandomDropout", + dropout_ratio=0.8, + dropout_application_ratio=1.0, + ), + # dict(type="RandomRotateTargetAngle", angle=(1/2, 1, 3/2), center=[0, 0, 0], axis="z", p=0.75), + dict( + type="RandomRotate", + angle=[-1, 1], + axis="z", + center=[0, 0, 0], + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomRotate", + angle=[-1 / 64, 1 / 64], + axis="x", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomRotate", + angle=[-1 / 64, 1 / 64], + axis="y", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomScale", + scale=[0.9, 1.1], + keys=[ + "extrinsic", + ], + ), + # dict(type="RandomShift", shift=[0.2, 0.2, 0.2]), + dict( + type="RandomFlip", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="GridSample", + grid_size=0.02, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + # dict(type="SphereCrop", sample_rate=0.8, mode="random"), + dict( + type="CenterShift", + apply_z=False, + keys=[ + "extrinsic", + ], + ), + dict(type="NormalizeColor"), + dict(type="ShufflePoint"), + dict(type="Add", keys_dict={"condition": "S3DIS"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=( + "coord", + "grid_coord", + "segment", + "condition", + "rgb", + "depth", + "depth_scale", + ), + stack_keys=( + "intrinsic", + "extrinsic", + "rgb", + "depth", + "semantic", + ), + feat_keys=("color", "normal"), + ), + ], + test_mode=False, + loop=1, # sampling weight + ), + ], + ) +) + +hooks = [ + dict(type="CheckpointLoader"), + dict(type="IterationTimer", warmup_iter=2), + dict(type="InformationWriter"), + dict(type="CheckpointSaver", save_freq=None), +] diff --git a/configs/scannet/pretrain-ponder-spunet-v1m1-0-base.py b/configs/scannet/pretrain-ponder-spunet-v1m1-0-base.py new file mode 100644 index 0000000..512a7c4 --- /dev/null +++ b/configs/scannet/pretrain-ponder-spunet-v1m1-0-base.py @@ -0,0 +1,298 @@ +_base_ = ["../_base_/default_runtime.py"] + +num_gpu = 8 +limit_num_coord = 2000000 + +# misc custom setting +batch_size = 8 * num_gpu # bs: total bs in all gpus +num_worker = 16 * num_gpu + +mix_prob = 0.0 +empty_cache = True +enable_amp = True +evaluate = False +find_unused_parameters = True + +# model settings +model = dict( + type="PonderIndoor-v2", + backbone=dict( + type="SpUNet-v1m1", + in_channels=6, + num_classes=0, + channels=(32, 64, 128, 256, 256, 128, 96, 96), + layers=(2, 3, 4, 6, 2, 2, 2, 2), + ), + projection=dict( + type="UNet3D-v1m2", + in_channels=96, + out_channels=128, + ), + renderer=dict( + type="NeuSModel", + field=dict( + type="SDFField", + sdf_decoder=dict( + in_dim=128, + out_dim=65, # 64 + 1 + hidden_size=128, + n_blocks=1, + ), + rgb_decoder=dict( + in_dim=198, # 128 + 64 + 3 + 3 + out_dim=3, + hidden_size=128, + n_blocks=0, + ), + semantic_decoder=dict( + in_dim=195, # 128 + 64 + 3, no directions + out_dim=512, + hidden_size=128, + n_blocks=0, + ), + beta_init=0.3, + use_gradient=True, + volume_type="default", + padding_mode="zeros", + share_volume=True, + ), + collider=dict( + type="AABBBoxCollider", + near_plane=0.01, + bbox=[-0.55, -0.55, -0.55, 0.55, 0.55, 0.55], + ), + sampler=dict( + type="NeuSSampler", + initial_sampler="UniformSampler", + num_samples=96, + num_samples_importance=36, + num_upsample_steps=1, + train_stratified=True, + single_jitter=False, + ), + loss=dict( + sensor_depth_truncation=0.05, + temperature=0.01, + weights=dict( + eikonal_loss=0.01, + free_space_loss=1.0, + sdf_loss=10.0, + depth_loss=1.0, + rgb_loss=10.0, + semantic_loss=0.1, + ), + ), + ), + # mask=dict(ratio=0.8, size=8, channel=6), + mask=None, + grid_shape=(128, 128, 32), + grid_size=0.02, + val_ray_split=10240, + ray_nsample=256, + padding=0.1, + pool_type="mean", + share_volume=True, + render_semantic=True, + conditions=("ScanNet",), + template=( + "itap of a [x]", + "a origami [x]", + "a rendering of a [x]", + "a painting of a [x]", + "a photo of a [x]", + "a photo of one [x]", + "a photo of a nice [x]", + "a photo of a weird [x]", + "a cropped photo of a [x]", + "a bad photo of a [x]", + "a good photo of a [x]", + "a photo of the large [x]", + "a photo of the small [x]", + "a photo of a clean [x]", + "a photo of a dirty [x]", + "a bright photo of a [x]", + "a dark photo of a [x]", + "a [x] in a living room", + "a [x] in a bedroom", + "a [x] in a kitchen", + "a [x] in a bathroom", + ), + clip_model="ViT-B/16", + class_name=( + "wall", + "floor", + "cabinet", + "bed", + "chair", + "sofa", + "table", + "door", + "window", + "bookshelf", + "picture", + "counter", + "desk", + "curtain", + "refridgerator", + "shower curtain", + "toilet", + "sink", + "bathtub", + "otherfurniture", + ), + valid_index=(tuple(range(20)),), +) + +# scheduler settings +epoch = 800 +optimizer = dict( + type="SGD", + lr=0.0001 * batch_size / 8, + momentum=0.9, + weight_decay=0.0001, + nesterov=True, +) +scheduler = dict( + type="OneCycleLR", + max_lr=optimizer["lr"], + pct_start=0.05, + anneal_strategy="cos", + div_factor=10.0, + final_div_factor=10000.0, +) + +# dataset settings +num_cameras = 5 +data = dict( + num_classes=20, + ignore_index=-1, + names=( + "wall", + "floor", + "cabinet", + "bed", + "chair", + "sofa", + "table", + "door", + "window", + "bookshelf", + "picture", + "counter", + "desk", + "curtain", + "refridgerator", + "shower curtain", + "toilet", + "sink", + "bathtub", + "otherfurniture", + ), + train=dict( + type="ScanNetRGBDDataset", + split="train", + data_root="data/scannet", + render_semantic=True, + num_cameras=num_cameras, + transform=[ + dict(type="CenterShift", apply_z=True), + dict( + type="RandomDropout", + dropout_ratio=0.8, + dropout_application_ratio=1.0, + ), + # dict(type="RandomRotateTargetAngle", angle=(1/2, 1, 3/2), center=[0, 0, 0], axis="z", p=0.75), + dict( + type="RandomRotate", + angle=[-1, 1], + axis="z", + center=[0, 0, 0], + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomRotate", + angle=[-1 / 64, 1 / 64], + axis="x", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomRotate", + angle=[-1 / 64, 1 / 64], + axis="y", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="RandomScale", + scale=[0.9, 1.1], + keys=[ + "extrinsic", + ], + ), + # dict(type="RandomShift", shift=[0.2, 0.2, 0.2]), + dict( + type="RandomFlip", + p=0.5, + keys=[ + "extrinsic", + ], + ), + dict( + type="GridSample", + grid_size=0.02, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + # dict(type="SphereCrop", sample_rate=0.8, mode="random"), + dict( + type="CenterShift", + apply_z=False, + keys=[ + "extrinsic", + ], + ), + dict(type="NormalizeColor"), + dict(type="ShufflePoint"), + dict(type="Add", keys_dict={"condition": "ScanNet"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=( + "coord", + "grid_coord", + "segment", + "condition", + "rgb", + "depth", + "depth_scale", + ), + stack_keys=( + "intrinsic", + "extrinsic", + "rgb", + "depth", + "semantic", + ), + feat_keys=("color", "normal"), + ), + ], + test_mode=False, + loop=1, # sampling weight + ), +) + +hooks = [ + dict(type="CheckpointLoader"), + dict(type="IterationTimer", warmup_iter=2), + dict(type="InformationWriter"), + dict(type="CheckpointSaver", save_freq=None), +] diff --git a/configs/scannet/semseg-ppt-v1m1-0-sc-s3-st-spunet-lovasz-ft.py b/configs/scannet/semseg-ppt-v1m1-0-sc-s3-st-spunet-lovasz-ft.py new file mode 100644 index 0000000..cac5c86 --- /dev/null +++ b/configs/scannet/semseg-ppt-v1m1-0-sc-s3-st-spunet-lovasz-ft.py @@ -0,0 +1,528 @@ +_base_ = ["../_base_/default_runtime.py"] + +# misc custom setting +batch_size = 24 # bs: total bs in all gpus +num_worker = 48 +mix_prob = 0.8 +empty_cache = False +enable_amp = True +find_unused_parameters = True + +# trainer +train = dict( + type="MultiDatasetTrainer", +) + +# model settings +model = dict( + type="PPT-v1m1", + backbone=dict( + type="SpUNet-v1m3", + in_channels=6, + num_classes=0, + base_channels=32, + context_channels=256, + channels=(32, 64, 128, 256, 256, 128, 96, 96), + layers=(2, 3, 4, 6, 2, 2, 2, 2), + cls_mode=False, + conditions=("ScanNet", "S3DIS", "Structured3D"), + zero_init=False, + norm_decouple=True, + norm_adaptive=True, + norm_affine=True, + ), + criteria=[ + dict(type="CrossEntropyLoss", loss_weight=1.0, ignore_index=-1), + dict(type="LovaszLoss", mode="multiclass", loss_weight=1.0, ignore_index=-1), + ], + backbone_out_channels=96, + context_channels=256, + conditions=("Structured3D", "ScanNet", "S3DIS"), + template=( + "itap of a [x]", + "a origami [x]", + "a rendering of a [x]", + "a painting of a [x]", + "a photo of a [x]", + "a photo of one [x]", + "a photo of a nice [x]", + "a photo of a weird [x]", + "a cropped photo of a [x]", + "a bad photo of a [x]", + "a good photo of a [x]", + "a photo of the large [x]", + "a photo of the small [x]", + "a photo of a clean [x]", + "a photo of a dirty [x]", + "a bright photo of a [x]", + "a dark photo of a [x]", + "a [x] in a living room", + "a [x] in a bedroom", + "a [x] in a kitchen", + "a [x] in a bathroom", + ), + clip_model="ViT-B/16", + class_name=( + "wall", + "floor", + "cabinet", + "bed", + "chair", + "sofa", + "table", + "door", + "window", + "bookshelf", + "bookcase", + "picture", + "counter", + "desk", + "shelves", + "curtain", + "dresser", + "pillow", + "mirror", + "ceiling", + "refrigerator", + "television", + "shower curtain", + "nightstand", + "toilet", + "sink", + "lamp", + "bathtub", + "garbagebin", + "board", + "beam", + "column", + "clutter", + "other structure ", + "other furniture", + "other property", + ), + valid_index=( + ( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 11, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 23, + 25, + 26, + 33, + 34, + 35, + ), + (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 15, 20, 22, 24, 25, 27, 34), + (0, 1, 4, 5, 6, 7, 8, 10, 19, 29, 30, 31, 32), + ), + backbone_mode=False, +) + +# scheduler settings +epoch = 100 +optimizer = dict(type="SGD", lr=0.05, momentum=0.9, weight_decay=0.0001, nesterov=True) +scheduler = dict( + type="OneCycleLR", + max_lr=optimizer["lr"], + pct_start=0.05, + anneal_strategy="cos", + div_factor=10.0, + final_div_factor=10000.0, +) +# param_dicts = [dict(keyword="modulation", lr=0.005)] + +# dataset settings +data = dict( + num_classes=20, + ignore_index=-1, + names=[ + "wall", + "floor", + "cabinet", + "bed", + "chair", + "sofa", + "table", + "door", + "window", + "bookshelf", + "picture", + "counter", + "desk", + "curtain", + "refridgerator", + "shower curtain", + "toilet", + "sink", + "bathtub", + "otherfurniture", + ], + train=dict( + type="ConcatDataset", + datasets=[ + # Structured3D + dict( + type="Structured3DDataset", + split="train", + data_root="data/structured3d", + transform=[ + dict(type="CenterShift", apply_z=True), + dict( + type="RandomDropout", + dropout_ratio=0.2, + dropout_application_ratio=0.2, + ), + # dict(type="RandomRotateTargetAngle", angle=(1/2, 1, 3/2), center=[0, 0, 0], axis="z", p=0.75), + dict( + type="RandomRotate", + angle=[-1, 1], + axis="z", + center=[0, 0, 0], + p=0.5, + ), + dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="x", p=0.5), + dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="y", p=0.5), + dict(type="RandomScale", scale=[0.9, 1.1]), + # dict(type="RandomShift", shift=[0.2, 0.2, 0.2]), + dict(type="RandomFlip", p=0.5), + dict(type="RandomJitter", sigma=0.005, clip=0.02), + dict( + type="ElasticDistortion", + distortion_params=[[0.2, 0.4], [0.8, 1.6]], + ), + dict(type="ChromaticAutoContrast", p=0.2, blend_factor=None), + dict(type="ChromaticTranslation", p=0.95, ratio=0.05), + dict(type="ChromaticJitter", p=0.95, std=0.05), + # dict(type="HueSaturationTranslation", hue_max=0.2, saturation_max=0.2), + # dict(type="RandomColorDrop", p=0.2, color_augment=0.0), + dict( + type="GridSample", + grid_size=0.02, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + dict(type="SphereCrop", sample_rate=0.8, mode="random"), + dict(type="CenterShift", apply_z=False), + dict(type="NormalizeColor"), + dict(type="ShufflePoint"), + dict(type="Add", keys_dict={"condition": "Structured3D"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment", "condition"), + feat_keys=("color", "normal"), + ), + ], + test_mode=False, + loop=4, # sampling weight + ), + # ScanNet + dict( + type="ScanNetDataset", + split="train", + data_root="data/scannet", + transform=[ + dict(type="CenterShift", apply_z=True), + dict( + type="RandomDropout", + dropout_ratio=0.2, + dropout_application_ratio=0.2, + ), + # dict(type="RandomRotateTargetAngle", angle=(1/2, 1, 3/2), center=[0, 0, 0], axis="z", p=0.75), + dict( + type="RandomRotate", + angle=[-1, 1], + axis="z", + center=[0, 0, 0], + p=0.5, + ), + dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="x", p=0.5), + dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="y", p=0.5), + dict(type="RandomScale", scale=[0.9, 1.1]), + # dict(type="RandomShift", shift=[0.2, 0.2, 0.2]), + dict(type="RandomFlip", p=0.5), + dict(type="RandomJitter", sigma=0.005, clip=0.02), + dict( + type="ElasticDistortion", + distortion_params=[[0.2, 0.4], [0.8, 1.6]], + ), + dict(type="ChromaticAutoContrast", p=0.2, blend_factor=None), + dict(type="ChromaticTranslation", p=0.95, ratio=0.05), + dict(type="ChromaticJitter", p=0.95, std=0.05), + # dict(type="HueSaturationTranslation", hue_max=0.2, saturation_max=0.2), + # dict(type="RandomColorDrop", p=0.2, color_augment=0.0), + dict( + type="GridSample", + grid_size=0.02, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + dict(type="SphereCrop", point_max=100000, mode="random"), + dict(type="CenterShift", apply_z=False), + dict(type="NormalizeColor"), + dict(type="ShufflePoint"), + dict(type="Add", keys_dict={"condition": "ScanNet"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment", "condition"), + feat_keys=("color", "normal"), + ), + ], + test_mode=False, + loop=2, # sampling weight + ), + # S3DIS + dict( + type="S3DISDataset", + split=("Area_1", "Area_2", "Area_3", "Area_4", "Area_6"), + data_root="data/s3dis", + transform=[ + dict(type="CenterShift", apply_z=True), + dict( + type="RandomDropout", + dropout_ratio=0.2, + dropout_application_ratio=0.2, + ), + # dict(type="RandomRotateTargetAngle", angle=(1/2, 1, 3/2), center=[0, 0, 0], axis="z", p=0.75), + dict( + type="RandomRotate", + angle=[-1, 1], + axis="z", + center=[0, 0, 0], + p=0.5, + ), + dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="x", p=0.5), + dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="y", p=0.5), + dict(type="RandomScale", scale=[0.9, 1.1]), + # dict(type="RandomShift", shift=[0.2, 0.2, 0.2]), + dict(type="RandomFlip", p=0.5), + dict(type="RandomJitter", sigma=0.005, clip=0.02), + dict( + type="ElasticDistortion", + distortion_params=[[0.2, 0.4], [0.8, 1.6]], + ), + dict(type="ChromaticAutoContrast", p=0.2, blend_factor=None), + dict(type="ChromaticTranslation", p=0.95, ratio=0.05), + dict(type="ChromaticJitter", p=0.95, std=0.05), + # dict(type="HueSaturationTranslation", hue_max=0.2, saturation_max=0.2), + # dict(type="RandomColorDrop", p=0.2, color_augment=0.0), + dict( + type="GridSample", + grid_size=0.02, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + dict(type="SphereCrop", sample_rate=0.6, mode="random"), + dict(type="CenterShift", apply_z=False), + dict(type="NormalizeColor"), + dict(type="ShufflePoint"), + dict(type="Add", keys_dict={"condition": "S3DIS"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment", "condition"), + feat_keys=("color", "normal"), + ), + ], + test_mode=False, + loop=1, # sampling weight + ), + ], + ), + val=dict( + type="ScanNetDataset", + split="val", + data_root="data/scannet", + transform=[ + dict(type="CenterShift", apply_z=True), + dict( + type="GridSample", + grid_size=0.02, + hash_type="fnv", + mode="train", + return_grid_coord=True, + ), + # dict(type="SphereCrop", point_max=1000000, mode="center"), + dict(type="CenterShift", apply_z=False), + dict(type="NormalizeColor"), + dict(type="ToTensor"), + dict(type="Add", keys_dict={"condition": "ScanNet"}), + dict( + type="Collect", + keys=("coord", "grid_coord", "segment", "condition"), + feat_keys=("color", "normal"), + ), + ], + test_mode=False, + ), + test=dict( + type="ScanNetDataset", + split="val", + data_root="data/scannet", + transform=[ + dict(type="CenterShift", apply_z=True), + dict(type="NormalizeColor"), + ], + test_mode=True, + test_cfg=dict( + voxelize=dict( + type="GridSample", + grid_size=0.02, + hash_type="fnv", + mode="test", + return_grid_coord=True, + keys=("coord", "color", "normal"), + ), + crop=None, + post_transform=[ + dict(type="CenterShift", apply_z=False), + dict(type="Add", keys_dict={"condition": "ScanNet"}), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "index", "condition"), + feat_keys=("color", "normal"), + ), + ], + aug_transform=[ + [ + dict( + type="RandomRotateTargetAngle", + angle=[0], + axis="z", + center=[0, 0, 0], + p=1, + ) + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[1 / 2], + axis="z", + center=[0, 0, 0], + p=1, + ) + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[1], + axis="z", + center=[0, 0, 0], + p=1, + ) + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[3 / 2], + axis="z", + center=[0, 0, 0], + p=1, + ) + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[0], + axis="z", + center=[0, 0, 0], + p=1, + ), + dict(type="RandomScale", scale=[0.95, 0.95]), + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[1 / 2], + axis="z", + center=[0, 0, 0], + p=1, + ), + dict(type="RandomScale", scale=[0.95, 0.95]), + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[1], + axis="z", + center=[0, 0, 0], + p=1, + ), + dict(type="RandomScale", scale=[0.95, 0.95]), + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[3 / 2], + axis="z", + center=[0, 0, 0], + p=1, + ), + dict(type="RandomScale", scale=[0.95, 0.95]), + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[0], + axis="z", + center=[0, 0, 0], + p=1, + ), + dict(type="RandomScale", scale=[1.05, 1.05]), + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[1 / 2], + axis="z", + center=[0, 0, 0], + p=1, + ), + dict(type="RandomScale", scale=[1.05, 1.05]), + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[1], + axis="z", + center=[0, 0, 0], + p=1, + ), + dict(type="RandomScale", scale=[1.05, 1.05]), + ], + [ + dict( + type="RandomRotateTargetAngle", + angle=[3 / 2], + axis="z", + center=[0, 0, 0], + p=1, + ), + dict(type="RandomScale", scale=[1.05, 1.05]), + ], + [dict(type="RandomFlip", p=1)], + ], + ), + ), +) diff --git a/docs/data_preparation.md b/docs/data_preparation.md new file mode 100644 index 0000000..0ad369f --- /dev/null +++ b/docs/data_preparation.md @@ -0,0 +1,125 @@ +## Data Preparation + +### ScanNet v2 + +The preprocessing support semantic and instance segmentation for both ScanNet20, ScanNet200 and ScanNet Data Efficient. + +- Download the [ScanNet](http://www.scan-net.org/) v2 dataset. +- Run preprocessing code for raw ScanNet as follows: +```bash +# RAW_SCANNET_DIR: the directory of downloaded ScanNet v2 raw dataset. +# PROCESSED_SCANNET_DIR: the directory of processed ScanNet dataset (output dir). +python ponder/datasets/preprocessing/scannet/preprocess_scannet.py --dataset_root ${RAW_SCANNET_DIR} --output_root ${PROCESSED_SCANNET_DIR} +# extract RGB-D iamges and 2D semantic labels: +python ponder/datasets/preprocessing/scannet/reader.py --scans_path ${RAW_SCANNET_DIR}/scans --output_path ${PROCESSED_SCANNET_DIR}/rgbd --export_depth_images --export_color_images --export_poses --export_intrinsics --export_label +``` + +- (Optional) Download ScanNet Data Efficient files: +```bash +# download-scannet.py is the official download script +# or follow instruction here: https://kaldir.vc.in.tum.de/scannet_benchmark/data_efficient/documentation#download +python download-scannet.py --data_efficient -o ${RAW_SCANNET_DIR} +# unzip downloads +cd ${RAW_SCANNET_DIR}/tasks +unzip limited-annotation-points.zip +unzip limited-bboxes.zip +unzip limited-reconstruction-scenes.zip +# copy files to processed dataset folder +cp -r ${RAW_SCANNET_DIR}/tasks ${PROCESSED_SCANNET_DIR} +``` + +- Link processed dataset to codebase: +```bash +# PROCESSED_SCANNET_DIR: the directory of processed ScanNet dataset. +mkdir data +ln -s ${PROCESSED_SCANNET_DIR} ${CODEBASE_DIR}/data/scannet +``` + +## S3DIS +- Download S3DIS data by filling this [Google form](https://docs.google.com/forms/d/e/1FAIpQLScDimvNMCGhy_rmBA2gHfDu3naktRm6A8BPwAWWDv-Uhm6Shw/viewform?c=0&w=1). Download the `Stanford3dDataset_v1.2.zip` file and unzip it. +- Run preprocessing code for S3DIS as follows: +```bash +# S3DIS_DIR: the directory of downloaded Stanford3dDataset_v1.2 dataset. +# RAW_S3DIS_DIR: the directory of Stanford2d3dDataset_noXYZ dataset. (optional, for parsing normal) +# PROCESSED_S3DIS_DIR: the directory of processed S3DIS dataset (output dir). + +# S3DIS with normal vector, RGB-D images and 2D semantic labels +python ponder/datasets/preprocessing/s3dis/preprocess_s3dis.py --dataset_root ${S3DIS_DIR} --output_root ${PROCESSED_S3DIS_DIR} --raw_root ${RAW_S3DIS_DIR} --parse_normal --parse_rgbd +# if you want S3DIS with aligned angle: +python ponder/datasets/preprocessing/s3dis/preprocess_s3dis.py --dataset_root ${S3DIS_DIR} --output_root ${PROCESSED_S3DIS_DIR} --raw_root ${RAW_S3DIS_DIR} --align_angle --parse_normal --parse_rgbd +``` +- Link processed dataset to codebase. +```bash +# PROCESSED_S3DIS_DIR: the directory of processed S3DIS dataset. +mkdir data +ln -s ${PROCESSED_S3DIS_DIR} ${CODEBASE_DIR}/data/s3dis +``` + +## Structured3D +- Download Structured3D panorama related and perspective (full) related zip files by filling this [Google form](https://docs.google.com/forms/d/e/1FAIpQLSc0qtvh4vHSoZaW6UvlXYy79MbcGdZfICjh4_t4bYofQIVIdw/viewform?pli=1) (no need to unzip them). +- Organize all downloaded zip file in one folder (`${STRUCT3D_DIR}`). +- Run preprocessing code for Structured3D as follows: +```bash +# STRUCT3D_DIR: the directory of downloaded Structured3D dataset. +# PROCESSED_STRUCT3D_DIR: the directory of processed Structured3D dataset (output dir). +# NUM_WORKERS: Number for workers for preprocessing, default same as cpu count (might OOM). +export PYTHONPATH=./ +python ponder/datasets/preprocessing/structured3d/preprocess_structured3d.py --dataset_root ${STRUCT3D_DIR} --output_root ${PROCESSED_STRUCT3D_DIR} --num_workers ${NUM_WORKERS} --grid_size 0.01 --fuse_prsp --fuse_pano --parse_rgbd +``` + +Following the instruction of [Swin3D](https://arxiv.org/abs/2304.06906), we keep 25 categories with frequencies of more than 0.001, out of the original 40 categories. + +- Link processed dataset to codebase. +```bash +# PROCESSED_STRUCT3D_DIR: the directory of processed Structured3D dataset (output dir). +mkdir data +ln -s ${PROCESSED_STRUCT3D_DIR} ${CODEBASE_DIR}/data/structured3d +``` + +## nuScenes +- Download the official [NuScene](https://www.nuscenes.org/nuscenes#download) dataset (with Lidar Segmentation) and organize the downloaded files as follows: +```bash +NUSCENES_DIR +│── samples +│── sweeps +│── lidarseg +... +│── v1.0-trainval +│── v1.0-test +``` + +- Run information preprocessing code (modified from OpenPCDet) for nuScenes as follows: +```bash +# NUSCENES_DIR: the directory of downloaded nuScenes dataset. +# PROCESSED_NUSCENES_DIR: the directory of processed nuScenes dataset (output dir). +# MAX_SWEEPS: Max number of sweeps. Default: 10. +pip install nuscenes-devkit pyquaternion +python ponder/datasets/preprocessing/nuscenes/preprocess_nuscenes_info.py --dataset_root ${NUSCENES_DIR} --output_root ${PROCESSED_NUSCENES_DIR} --max_sweeps ${MAX_SWEEPS} --with_camera +``` + +- Link raw dataset to processed NuScene dataset folder: +```bash +# NUSCENES_DIR: the directory of downloaded nuScenes dataset. +# PROCESSED_NUSCENES_DIR: the directory of processed nuScenes dataset (output dir). +ln -s ${NUSCENES_DIR} {PROCESSED_NUSCENES_DIR}/raw +``` + +then the processed nuscenes folder is organized as follows: +```bash +nuscene +|── raw + │── samples + │── sweeps + │── lidarseg + ... + │── v1.0-trainval + │── v1.0-test +|── info +``` + +- Link processed dataset to codebase. +```bash +# PROCESSED_NUSCENES_DIR: the directory of processed nuScenes dataset (output dir). +mkdir data +ln -s ${PROCESSED_NUSCENES_DIR} ${CODEBASE_DIR}/data/nuscenes +``` \ No newline at end of file diff --git a/docs/getting_started.md b/docs/getting_started.md new file mode 100644 index 0000000..a524411 --- /dev/null +++ b/docs/getting_started.md @@ -0,0 +1,74 @@ +## Getting Started + +### (Pre)Training + +**Train from scratch.** The training processing is based on configs in `configs` folder. +The training script will generate an experiment folder in `exp` folder and backup essential code in the experiment folder. +Training config, log, tensorboard and checkpoints will also be saved into the experiment folder during the training process. +```bash +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} +# Script (Recommended) +sh scripts/train.sh -p ${INTERPRETER_PATH} -g ${NUM_GPU} -d ${DATASET_NAME} -c ${CONFIG_NAME} -n ${EXP_NAME} +# Direct +export PYTHONPATH=./ +python tools/train.py --config-file ${CONFIG_PATH} --num-gpus ${NUM_GPU} --options save_path=${SAVE_PATH} +``` + +For example: +```bash +# By script (Recommended) +# -p is default set as python and can be ignored +sh scripts/train.sh -p python -d scannet -c semseg-pt-v2m2-0-base -n semseg-pt-v2m2-0-base +# Direct +export PYTHONPATH=./ +python tools/train.py --config-file configs/scannet/semseg-pt-v2m2-0-base.py --options save_path=exp/scannet/semseg-pt-v2m2-0-base +``` +**Resume training from checkpoint.** If the training process is interrupted by accident, the following script can resume training from a given checkpoint. +```bash +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} +# Script (Recommended) +# simply add "-r true" +sh scripts/train.sh -p ${INTERPRETER_PATH} -g ${NUM_GPU} -d ${DATASET_NAME} -c ${CONFIG_NAME} -n ${EXP_NAME} -r true +# Direct +export PYTHONPATH=./ +python tools/train.py --config-file ${CONFIG_PATH} --num-gpus ${NUM_GPU} --options save_path=${SAVE_PATH} resume=True weight=${CHECKPOINT_PATH} +``` + +### Testing +During training, model evaluation is performed on point clouds after grid sampling (voxelization), providing an initial assessment of model performance. However, to obtain precise evaluation results, testing is **essential**. The testing process involves subsampling a dense point cloud into a sequence of voxelized point clouds, ensuring comprehensive coverage of all points. These sub-results are then predicted and collected to form a complete prediction of the entire point cloud. This approach yields higher evaluation results compared to simply mapping/interpolating the prediction. In addition, our testing code supports TTA (test time augmentation) testing, which further enhances the stability of evaluation performance. + +```bash +# By script (Based on experiment folder created by training script) +sh scripts/test.sh -p ${INTERPRETER_PATH} -g ${NUM_GPU} -d ${DATASET_NAME} -n ${EXP_NAME} -w ${CHECKPOINT_NAME} +# Direct +export PYTHONPATH=./ +python tools/test.py --config-file ${CONFIG_PATH} --num-gpus ${NUM_GPU} --options save_path=${SAVE_PATH} weight=${CHECKPOINT_PATH} +``` +For example: +```bash +# By script (Based on experiment folder created by training script) +# -p is default set as python and can be ignored +# -w is default set as model_best and can be ignored +sh scripts/test.sh -p python -d scannet -n semseg-pt-v2m2-0-base -w model_best +# Direct +export PYTHONPATH=./ +python tools/test.py --config-file configs/scannet/semseg-pt-v2m2-0-base.py --options save_path=exp/scannet/semseg-pt-v2m2-0-base weight=exp/scannet/semseg-pt-v2m2-0-base/model/model_best.pth +``` + +The TTA can be disabled by replace `data.test.test_cfg.aug_transform = [...]` with: + +```python +data = dict( + train = dict(...), + val = dict(...), + test = dict( + ..., + test_cfg = dict( + ..., + aug_transform = [ + [dict(type="RandomRotateTargetAngle", angle=[0], axis="z", center=[0, 0, 0], p=1)] + ] + ) + ) +) +``` \ No newline at end of file diff --git a/ponder/datasets/__init__.py b/ponder/datasets/__init__.py index 42ff63e..9d5a9be 100644 --- a/ponder/datasets/__init__.py +++ b/ponder/datasets/__init__.py @@ -1,7 +1,8 @@ from .builder import build_dataset +from .dataloader import MultiDatasetDataloader from .defaults import DefaultDataset from .nuscenes import NuScenesDataset from .s3dis import S3DISDataset, S3DISRGBDDataset -from .scannet import ScanNet200Dataset, ScanNetDataset +from .scannet import ScanNet200Dataset, ScanNetDataset, ScanNetRGBDDataset from .structure3d import Structured3DDataset, Structured3DRGBDDataset from .utils import collate_fn, point_collate_fn diff --git a/ponder/datasets/dataloader.py b/ponder/datasets/dataloader.py new file mode 100644 index 0000000..46bf003 --- /dev/null +++ b/ponder/datasets/dataloader.py @@ -0,0 +1,117 @@ +import weakref +from functools import partial + +import torch +import torch.utils.data + +import ponder.utils.comm as comm +from ponder.utils.env import set_seed + +from .defaults import ConcatDataset +from .utils import point_collate_fn + + +class MultiDatasetDummySampler: + def __init__(self): + self.dataloader = None + + def set_epoch(self, epoch): + if comm.get_world_size() > 1: + for dataloader in self.dataloader.dataloaders: + dataloader.sampler.set_epoch(epoch) + return + + +class MultiDatasetDataloader: + """ + Multiple Datasets Dataloader, batch data from a same dataset and mix up ratio determined by loop of each sub dataset. + The overall length is determined by the main dataset (first) and loop of concat dataset. + """ + + def __init__( + self, + concat_dataset: ConcatDataset, + batch_size_per_gpu: int, + num_worker_per_gpu: int, + mix_prob=0, + seed=None, + max_point=-1, + ): + self.datasets = concat_dataset.datasets + self.ratios = [dataset.loop for dataset in self.datasets] + # reset data loop, original loop serve as ratios + for dataset in self.datasets: + dataset.loop = 1 + # determine union training epoch by main dataset + self.datasets[0].loop = concat_dataset.loop + # build sub-dataloaders + num_workers = num_worker_per_gpu // len(self.datasets) + self.dataloaders = [] + for dataset_id, dataset in enumerate(self.datasets): + if comm.get_world_size() > 1: + sampler = torch.utils.data.distributed.DistributedSampler(dataset) + else: + sampler = None + + init_fn = ( + partial( + self._worker_init_fn, + dataset_id=dataset_id, + num_workers=num_workers, + num_datasets=len(self.datasets), + rank=comm.get_rank(), + seed=seed, + ) + if seed is not None + else None + ) + self.dataloaders.append( + torch.utils.data.DataLoader( + dataset, + batch_size=batch_size_per_gpu, + shuffle=(sampler is None), + num_workers=num_worker_per_gpu, + sampler=sampler, + collate_fn=partial( + point_collate_fn, mix_prob=mix_prob, max_point=max_point + ), + pin_memory=True, + worker_init_fn=init_fn, + drop_last=True, + persistent_workers=True, + ) + ) + self.sampler = MultiDatasetDummySampler() + self.sampler.dataloader = weakref.proxy(self) + + def __iter__(self): + iterator = [iter(dataloader) for dataloader in self.dataloaders] + while True: + for i in range(len(self.ratios)): + for _ in range(self.ratios[i]): + try: + batch = next(iterator[i]) + except StopIteration: + if i == 0: + return + else: + iterator[i] = iter(self.dataloaders[i]) + batch = next(iterator[i]) + yield batch + + def __len__(self): + main_data_loader_length = len(self.dataloaders[0]) + return ( + main_data_loader_length // self.ratios[0] * sum(self.ratios) + + main_data_loader_length % self.ratios[0] + ) + + @staticmethod + def _worker_init_fn(worker_id, num_workers, dataset_id, num_datasets, rank, seed): + worker_seed = ( + num_workers * num_datasets * rank + + num_workers * dataset_id + + worker_id + + seed + ) + set_seed(worker_seed) diff --git a/ponder/datasets/defaults.py b/ponder/datasets/defaults.py index 33733bf..688a68d 100644 --- a/ponder/datasets/defaults.py +++ b/ponder/datasets/defaults.py @@ -138,3 +138,42 @@ def __getitem__(self, idx): def __len__(self): return len(self.data_list) * self.loop + + +@DATASETS.register_module() +class ConcatDataset(Dataset): + def __init__(self, datasets, loop=1): + super(ConcatDataset, self).__init__() + self.datasets = [build_dataset(dataset) for dataset in datasets] + self.loop = loop + self.data_list = self.get_data_list() + logger = get_root_logger() + logger.info( + "Totally {} x {} samples in the concat set.".format( + len(self.data_list), self.loop + ) + ) + + def get_data_list(self): + data_list = [] + for i in range(len(self.datasets)): + data_list.extend( + zip( + np.ones(len(self.datasets[i])) * i, np.arange(len(self.datasets[i])) + ) + ) + return data_list + + def get_data(self, idx): + dataset_idx, data_idx = self.data_list[idx % len(self.data_list)] + return self.datasets[dataset_idx][data_idx] + + def get_data_name(self, idx): + dataset_idx, data_idx = self.data_list[idx % len(self.data_list)] + return self.datasets[dataset_idx].get_data_name(data_idx) + + def __getitem__(self, idx): + return self.get_data(idx) + + def __len__(self): + return len(self.data_list) * self.loop diff --git a/ponder/datasets/preprocessing/scannet/SensorData.py b/ponder/datasets/preprocessing/scannet/SensorData.py new file mode 100644 index 0000000..a873e8b --- /dev/null +++ b/ponder/datasets/preprocessing/scannet/SensorData.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Code borrowed from: https://github.com/ScanNet/ScanNet +""" +import os +import struct +import zlib + +import cv2 +import imageio +import numpy as np + +COMPRESSION_TYPE_COLOR = {-1: "unknown", 0: "raw", 1: "png", 2: "jpeg"} +COMPRESSION_TYPE_DEPTH = { + -1: "unknown", + 0: "raw_ushort", + 1: "zlib_ushort", + 2: "occi_ushort", +} + + +class RGBDFrame: + def load(self, file_handle): + self.camera_to_world = np.asarray( + struct.unpack("f" * 16, file_handle.read(16 * 4)), dtype=np.float32 + ).reshape(4, 4) + self.timestamp_color = struct.unpack("Q", file_handle.read(8))[0] + self.timestamp_depth = struct.unpack("Q", file_handle.read(8))[0] + self.color_size_bytes = struct.unpack("Q", file_handle.read(8))[0] + self.depth_size_bytes = struct.unpack("Q", file_handle.read(8))[0] + self.color_data = b"".join( + struct.unpack( + "c" * self.color_size_bytes, file_handle.read(self.color_size_bytes) + ) + ) + self.depth_data = b"".join( + struct.unpack( + "c" * self.depth_size_bytes, file_handle.read(self.depth_size_bytes) + ) + ) + + def decompress_depth(self, compression_type): + if compression_type == "zlib_ushort": + return self.decompress_depth_zlib() + else: + raise + + def decompress_depth_zlib(self): + return zlib.decompress(self.depth_data) + + def decompress_color(self, compression_type): + if compression_type == "jpeg": + return self.decompress_color_jpeg() + else: + raise + + def decompress_color_jpeg(self): + return imageio.imread(self.color_data) + + +class SensorData: + def __init__(self, filename): + self.version = 4 + self.load(filename) + + def load(self, filename): + with open(filename, "rb") as f: + version = struct.unpack("I", f.read(4))[0] + assert self.version == version + strlen = struct.unpack("Q", f.read(8))[0] + self.sensor_name = b"".join(struct.unpack("c" * strlen, f.read(strlen))) + self.intrinsic_color = np.asarray( + struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32 + ).reshape(4, 4) + self.extrinsic_color = np.asarray( + struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32 + ).reshape(4, 4) + self.intrinsic_depth = np.asarray( + struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32 + ).reshape(4, 4) + self.extrinsic_depth = np.asarray( + struct.unpack("f" * 16, f.read(16 * 4)), dtype=np.float32 + ).reshape(4, 4) + self.color_compression_type = COMPRESSION_TYPE_COLOR[ + struct.unpack("i", f.read(4))[0] + ] + self.depth_compression_type = COMPRESSION_TYPE_DEPTH[ + struct.unpack("i", f.read(4))[0] + ] + self.color_width = struct.unpack("I", f.read(4))[0] + self.color_height = struct.unpack("I", f.read(4))[0] + self.depth_width = struct.unpack("I", f.read(4))[0] + self.depth_height = struct.unpack("I", f.read(4))[0] + self.depth_shift = struct.unpack("f", f.read(4))[0] + num_frames = struct.unpack("Q", f.read(8))[0] + self.frames = [] + for i in range(num_frames): + frame = RGBDFrame() + frame.load(f) + self.frames.append(frame) + + def export_depth_images(self, output_path, image_size=None, frame_skip=1): + if not os.path.exists(output_path): + os.makedirs(output_path) + print( + "exporting", len(self.frames) // frame_skip, " depth frames to", output_path + ) + for f in range(0, len(self.frames), frame_skip): + depth_data = self.frames[f].decompress_depth(self.depth_compression_type) + depth = np.fromstring(depth_data, dtype=np.uint16).reshape( + self.depth_height, self.depth_width + ) + if image_size is not None: + depth = cv2.resize( + depth, + (image_size[1], image_size[0]), + interpolation=cv2.INTER_NEAREST, + ) + imageio.imwrite(os.path.join(output_path, str(f) + ".png"), depth) + + def export_color_images(self, output_path, image_size=None, frame_skip=1): + if not os.path.exists(output_path): + os.makedirs(output_path) + print( + "exporting", len(self.frames) // frame_skip, "color frames to", output_path + ) + for f in range(0, len(self.frames), frame_skip): + color = self.frames[f].decompress_color(self.color_compression_type) + if image_size is not None: + color = cv2.resize( + color, + (image_size[1], image_size[0]), + interpolation=cv2.INTER_NEAREST, + ) + imageio.imwrite(os.path.join(output_path, str(f) + ".jpg"), color) + + def save_mat_to_file(self, matrix, filename): + with open(filename, "w") as f: + for line in matrix: + np.savetxt(f, line[np.newaxis], fmt="%f") + + def export_poses(self, output_path, frame_skip=1): + if not os.path.exists(output_path): + os.makedirs(output_path) + print( + "exporting", len(self.frames) // frame_skip, "camera poses to", output_path + ) + for f in range(0, len(self.frames), frame_skip): + self.save_mat_to_file( + self.frames[f].camera_to_world, + os.path.join(output_path, str(f) + ".txt"), + ) + + def export_intrinsics(self, output_path): + if not os.path.exists(output_path): + os.makedirs(output_path) + print("exporting camera intrinsics to", output_path) + self.save_mat_to_file( + self.intrinsic_color, os.path.join(output_path, "intrinsic_color.txt") + ) + self.save_mat_to_file( + self.extrinsic_color, os.path.join(output_path, "extrinsic_color.txt") + ) + self.save_mat_to_file( + self.intrinsic_depth, os.path.join(output_path, "intrinsic_depth.txt") + ) + self.save_mat_to_file( + self.extrinsic_depth, os.path.join(output_path, "extrinsic_depth.txt") + ) diff --git a/ponder/datasets/preprocessing/scannet/reader.py b/ponder/datasets/preprocessing/scannet/reader.py new file mode 100644 index 0000000..b36afe8 --- /dev/null +++ b/ponder/datasets/preprocessing/scannet/reader.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import csv +import os +import sys +import zipfile +from glob import glob + +import cv2 +import imageio.v2 as imageio +import numpy as np + +from ponder.datasets.preprocessing.scannet.SensorData import SensorData + +# params +parser = argparse.ArgumentParser() +# data paths +parser.add_argument("--scans_path", required=True, help="path to scans folder") +parser.add_argument("--output_path", required=True, help="path to output folder") +parser.add_argument( + "--export_depth_images", dest="export_depth_images", action="store_true" +) +parser.add_argument( + "--export_color_images", dest="export_color_images", action="store_true" +) +parser.add_argument("--export_poses", dest="export_poses", action="store_true") +parser.add_argument( + "--export_intrinsics", dest="export_intrinsics", action="store_true" +) +parser.add_argument("--export_label", dest="export_label", action="store_true") +parser.set_defaults( + export_depth_images=False, + export_color_images=False, + export_poses=False, + export_intrinsics=False, + export_label=False, +) + +opt = parser.parse_args() +print(opt) + + +def represents_int(s): + try: + int(s) + return True + except ValueError: + return False + + +def read_label_mapping(filename, label_from="raw_category", label_to="nyu40id"): + # assert os.path.isfile(filename) + mapping = dict() + # print(filename) + with open(filename, "r") as csvfile: + reader = csv.DictReader(csvfile, delimiter="\t") + for row in reader: + mapping[row[label_from]] = int(row[label_to]) + # if ints convert + if represents_int(list(mapping.keys())[0]): + mapping = {int(k): v for k, v in mapping.items()} + return mapping + + +def main(): + scans = glob(opt.scans_path + "/*") + scans.sort() + + label_mapping = None + if opt.export_label: + root = os.path.dirname(opt.scans_path) + label_map = read_label_mapping( + filename=os.path.join(root, "scannetv2-labels.combined.tsv"), + label_from="id", + label_to="nyu40id", + ) + + for scan in scans: + scenename = scan.split("/")[-1] + filename = os.path.join(scan, scenename + ".sens") + if not os.path.exists(opt.output_path): + os.makedirs(opt.output_path) + # os.makedirs(os.path.join(opt.output_path, 'depth')) + # os.makedirs(os.path.join(opt.output_path, 'color')) + # os.makedirs(os.path.join(opt.output_path, 'pose')) + # os.makedirs(os.path.join(opt.output_path, 'intrinsic')) + os.makedirs(os.path.join(opt.output_path, scenename)) + # load the data + print("loading %s..." % filename) + sd = SensorData(filename) + print("loaded!\n") + if opt.export_depth_images: + # sd.export_depth_images(os.path.join(opt.output_path, 'depth', scenename)) + sd.export_depth_images(os.path.join(opt.output_path, scenename, "depth")) + if opt.export_color_images: + # sd.export_color_images(os.path.join(opt.output_path, 'color', scenename)) + sd.export_color_images(os.path.join(opt.output_path, scenename, "color")) + if opt.export_poses: + # sd.export_poses(os.path.join(opt.output_path, 'pose', scenename)) + sd.export_poses(os.path.join(opt.output_path, scenename, "pose")) + if opt.export_intrinsics: + # sd.export_intrinsics(os.path.join(opt.output_path, 'intrinsic', scenename)) + sd.export_intrinsics(os.path.join(opt.output_path, scenename, "intrinsic")) + + os.system(f"cp {scan}/scene*.txt {opt.output_path}/{scenename}/") + + if opt.export_label: + + def map_label_image(image, label_mapping): + mapped = np.copy(image) + for k, v in label_mapping.items(): + mapped[image == k] = v + return mapped.astype(np.uint8) + + label_zip_path = os.path.join( + opt.scans_path, scenename, f"{scenename}_2d-label-filt.zip" + ) + print("process labels") + with open(label_zip_path, "rb") as f: + zip_file = zipfile.ZipFile(f) + for frame in range(0, len(sd.frames)): + label_file = f"label-filt/{frame}.png" + with zip_file.open(label_file) as lf: + image = np.array(imageio.imread(lf)) + + mapped_image = map_label_image(image, label_map) + output_path = os.path.join(opt.output_path, scenename, "label") + os.makedirs(output_path, exist_ok=True) + print("output:", output_path) + cv2.imwrite(os.path.join(output_path, f"{frame}.png"), mapped_image) + + +if __name__ == "__main__": + main() diff --git a/ponder/datasets/s3dis.py b/ponder/datasets/s3dis.py index 15a9025..02a513d 100644 --- a/ponder/datasets/s3dis.py +++ b/ponder/datasets/s3dis.py @@ -196,7 +196,7 @@ def get_data_list(self): filtered_data_list = [] for data_path in data_list: rgbd_paths = glob.glob( - os.path.join(data_path.split(".pth")[0], "rgbd", "*.pth") + os.path.join(data_path.split(".pth")[0] + "_rgbd", "*.pth") ) if len(rgbd_paths) <= 0: # print(f"{data_path} has no rgbd data.") @@ -205,7 +205,7 @@ def get_data_list(self): print( f"Finish filtering! Totally {len(filtered_data_list)} from {len(data_list)} data." ) - return data_list + return filtered_data_list def get_data(self, idx): data_path = self.data_list[idx % len(self.data_list)] @@ -226,20 +226,14 @@ def get_data(self, idx): print(f"{data_path} has no rgbd data.") return self.get_data(np.random.randint(0, self.__len__())) - # if self.num_cameras > len(rgbd_paths): - # print(f"Warning: {data_path.split('.pth')[0]} has only {len(rgbd_paths)} frames, but {self.num_cameras} cameras are required.") rgbd_paths = np.random.choice( rgbd_paths, self.num_cameras, replace=self.num_cameras > len(rgbd_paths) ) rgbd_dicts = [torch.load(p) for p in rgbd_paths] - has_bad = False for i in range(len(rgbd_dicts)): if (rgbd_dicts[i]["depth_mask"]).mean() < 0.25: - os.rename(rgbd_paths[i], rgbd_paths[i] + ".bad") - has_bad = True - if has_bad: - return self.get_data(idx) + return self.get_data(idx) name = ( os.path.basename(self.data_list[idx % len(self.data_list)]) diff --git a/ponder/datasets/scannet.py b/ponder/datasets/scannet.py index d88815b..60a5c09 100644 --- a/ponder/datasets/scannet.py +++ b/ponder/datasets/scannet.py @@ -6,10 +6,13 @@ """ import glob +import json import os +from collections import defaultdict from collections.abc import Sequence from copy import deepcopy +import cv2 import numpy as np import torch from torch.utils.data import Dataset @@ -204,3 +207,391 @@ def get_data(self, idx): data_dict["segment"] = segment data_dict["sampled_index"] = sampled_index return data_dict + + +@DATASETS.register_module() +class ScanNetRGBDDataset(Dataset): + def __init__( + self, + split="train", + data_root="data/scannet", + rgbd_root="data/scannet/rgbd", + transform=None, + lr_file=None, + la_file=None, + ignore_index=-1, + test_mode=False, + test_cfg=None, + cache=False, + frame_interval=10, + nearby_num=2, + nearby_interval=20, + num_cameras=5, + render_semantic=True, + align_axis=False, + loop=1, + ): + super(ScanNetRGBDDataset, self).__init__() + self.data_root = data_root + self.split = split + self.rgbd_root = rgbd_root + self.frame_interval = frame_interval + self.nearby_num = nearby_num + self.nearby_interval = nearby_interval + self.num_cameras = num_cameras + self.render_semantic = render_semantic + self.align_axis = align_axis + + self.transform = Compose(transform) + self.cache = cache + self.loop = ( + loop if not test_mode else 1 + ) # force make loop = 1 while in test mode + self.test_mode = test_mode + self.test_cfg = test_cfg if test_mode else None + + if test_mode: + self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) + self.test_crop = ( + TRANSFORMS.build(self.test_cfg.crop) if self.test_cfg.crop else None + ) + self.post_transform = Compose(self.test_cfg.post_transform) + self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform] + + self.logger = get_root_logger() + + if lr_file: + full_data_list = self.get_data_list() + self.data_list = [] + lr_list = np.loadtxt(lr_file, dtype=str) + for data_dict in full_data_list: + if data_dict["scene"] in lr_list: + self.data_list.append(data_dict) + else: + self.data_list = self.get_data_list() + self.la = torch.load(la_file) if la_file else None + self.ignore_index = ignore_index + + self.logger.info( + "Totally {} x {} samples in {} set.".format( + len(self.data_list), self.loop, split + ) + ) + + def get_data_list(self): + self.axis_align_matrix_list = {} + self.intrinsic_list = {} + self.frame_lists = {} + + # Get all models + data_list = [] + split_json = os.path.join(os.path.join(self.data_root, self.split + ".json")) + + if os.path.exists(split_json): + with open(split_json, "r") as f: + data_list = json.load(f) + else: + scene_list = [ + filename.split(".")[0] + for filename in os.listdir(os.path.join(self.data_root, self.split)) + ] + + skip_list = [] + skip_counter = 0 + skip_file = os.path.join(os.path.join(self.data_root, "skip.lst")) + if os.path.exists(skip_file): + with open(skip_file, "r") as f: + for i in f.read().split("\n"): + scene_name, frame_idx = i.split() + skip_list.append((scene_name, int(frame_idx))) + + # walk through the subfolder + from tqdm import tqdm + + for scene_name in tqdm(scene_list): + # filenames = os.listdir(os.path.join(subpath, m, 'pointcloud')) + frame_list = self.get_frame_list(scene_name) + + # for test and val, we only use 1/10 of the data, since those data will not affect + # the training and we use them just for visualization and debugging + if self.split == "val": + frame_list = frame_list[::10] + if self.split == "test": + frame_list = frame_list[::10] + + for frame_idx in frame_list[ + self.nearby_num + * self.nearby_interval : -(self.nearby_num + 1) + * self.nearby_interval : self.frame_interval + ]: + frame_idx = int(frame_idx.split(".")[0]) + if (scene_name, frame_idx) in skip_list: + skip_counter += 1 + continue + data_list.append({"scene": scene_name, "frame": frame_idx}) + + self.logger.info( + f"ScanNet: <{skip_counter} Frames will be skipped in {self.split} data.>" + ) + + with open(split_json, "w") as f: + json.dump(data_list, f) + + data_dict = defaultdict(list) + for data in data_list: + data_dict[data["scene"]].append(data["frame"]) + + data_list = [] + for scene_name, frame_list in data_dict.items(): + data_list.append({"scene": scene_name, "frame": frame_list}) + + return data_list + + def get_data(self, idx): + scene_name = self.data_list[idx % len(self.data_list)]["scene"] + frame_list = self.data_list[idx % len(self.data_list)]["frame"] + scene_path = os.path.join(self.data_root, self.split, f"{scene_name}.pth") + if not self.cache: + data = torch.load(scene_path) + else: + data_name = scene_path.replace(os.path.dirname(self.data_root), "").split( + "." + )[0] + cache_name = "ponder" + data_name.replace(os.path.sep, "-") + data = shared_dict(cache_name) + + if self.num_cameras > len(frame_list): + print( + f"Warning: {scene_name} has only {len(frame_list)} frames, " + f"but {self.num_cameras} cameras are required." + ) + frame_idxs = np.random.choice( + frame_list, self.num_cameras, replace=self.num_cameras > len(frame_list) + ) + intrinsic, extrinsic, rgb, depth = ( + [], + [], + [], + [], + ) + + if self.render_semantic: + semantic = [] + for frame_idx in frame_idxs: + if not self.render_semantic: + intri, rot, transl, rgb_im, depth_im = self.get_2d_meta( + scene_name, frame_idx + ) + else: + intri, rot, transl, rgb_im, depth_im, semantic_im = self.get_2d_meta( + scene_name, frame_idx + ) + assert semantic_im.max() <= 20, semantic_im + semantic.append(semantic_im) + intrinsic.append(intri) + extri = np.eye(4) + extri[:3, :3] = rot + extri[:3, 3] = transl + extrinsic.append(extri) + rgb.append(rgb_im) + depth.append(depth_im) + + intrinsic = np.stack(intrinsic, axis=0) + extrinsic = np.stack(extrinsic, axis=0) + rgb = np.stack(rgb, axis=0) + depth = np.stack(depth, axis=0) + + coord = data["coord"] + color = data["color"] + normal = data["normal"] + scene_id = data["scene_id"] + if "semantic_gt20" in data.keys(): + segment = data["semantic_gt20"].reshape([-1]) + else: + segment = np.ones(coord.shape[0]) * -1 + if "instance_gt" in data.keys(): + instance = data["instance_gt"].reshape([-1]) + else: + instance = np.ones(coord.shape[0]) * -1 + data_dict = dict( + coord=coord, + normal=normal, + color=color, + segment=segment, + instance=instance, + scene_id=scene_id, + intrinsic=intrinsic, + extrinsic=extrinsic, + rgb=rgb, + depth=depth, + depth_scale=1.0 / 1000.0, + id=f"{scene_name}/{frame_idxs[0]}", + ) + if self.render_semantic: + semantic = np.stack(semantic, axis=0) + data_dict.update(dict(semantic=semantic)) + + if self.la: + sampled_index = self.la[self.get_data_name(scene_path)] + mask = np.ones_like(segment).astype(np.bool) + mask[sampled_index] = False + segment[mask] = self.ignore_index + data_dict["segment"] = segment + data_dict["sampled_index"] = sampled_index + data_dict["semantic"] = np.zeros_like(data_dict["semantic"]) - 1 + + return data_dict + + def get_data_name(self, scene_path): + return os.path.basename(scene_path).split(".")[0] + + def get_frame_list(self, scene_name): + if scene_name in self.frame_lists: + return self.frame_lists[scene_name] + + if not os.path.exists(os.path.join(self.rgbd_root, scene_name, "color")): + return [] + + frame_list = os.listdir(os.path.join(self.rgbd_root, scene_name, "color")) + frame_list = list(frame_list) + frame_list = [frame for frame in frame_list if frame.endswith(".jpg")] + frame_list.sort(key=lambda x: int(x.split(".")[0])) + self.frame_lists[scene_name] = frame_list + return self.frame_lists[scene_name] + + def get_axis_align_matrix(self, scene_name): + if scene_name in self.axis_align_matrix_list: + return self.axis_align_matrix_list[scene_name] + txt_file = os.path.join(self.rgbd_root, scene_name, "%s.txt" % scene_name) + # align axis + with open(txt_file, "r") as f: + lines = f.readlines() + for line in lines: + if "axisAlignment" in line: + self.axis_align_matrix_list[scene_name] = [ + float(x) for x in line.rstrip().strip("axisAlignment = ").split(" ") + ] + break + self.axis_align_matrix_list[scene_name] = np.array( + self.axis_align_matrix_list[scene_name] + ).reshape((4, 4)) + return self.axis_align_matrix_list[scene_name] + + def get_intrinsic(self, scene_name): + if scene_name in self.intrinsic_list: + return self.intrinsic_list[scene_name] + self.intrinsic_list[scene_name] = np.loadtxt( + os.path.join(self.rgbd_root, scene_name, "intrinsic", "intrinsic_depth.txt") + ) + return self.intrinsic_list[scene_name] + + def get_2d_meta(self, scene_name, frame_idx): + # framelist + frame_list = self.get_frame_list(scene_name) + intrinsic = self.get_intrinsic(scene_name) + if self.align_axis: + axis_align_matrix = self.get_axis_align_matrix(scene_name) + + if not self.render_semantic: + rgb_im, depth_im, pose = self.read_data(scene_name, frame_list[frame_idx]) + else: + rgb_im, depth_im, pose, semantic_im = self.read_data( + scene_name, frame_list[frame_idx] + ) + semantic_im_40 = cv2.resize( + semantic_im, + (depth_im.shape[1], depth_im.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + semantic_im_40 = semantic_im_40.astype(np.int16) + semantic_im = np.zeros_like(semantic_im_40) - 1 + for i, id in enumerate(VALID_CLASS_IDS_20): + semantic_im[semantic_im_40 == id] = i + + rgb_im = cv2.resize(rgb_im, (depth_im.shape[1], depth_im.shape[0])) + rgb_im = cv2.cvtColor(rgb_im, cv2.COLOR_BGR2RGB) # H, W, 3 + depth_im = depth_im.astype(np.float32) # H, W + + if self.align_axis: + pose = np.matmul(axis_align_matrix, pose) + pose = np.linalg.inv(pose) + + intrinsic = np.array(intrinsic) + rotation = np.array(pose)[:3, :3] + translation = np.array(pose)[:3, 3] + + if not self.render_semantic: + return intrinsic, rotation, translation, rgb_im, depth_im + else: + return intrinsic, rotation, translation, rgb_im, depth_im, semantic_im + + def read_data(self, scene_name, frame_name): + color_path = os.path.join(self.rgbd_root, scene_name, "color", frame_name) + depth_path = os.path.join( + self.rgbd_root, scene_name, "depth", frame_name.replace(".jpg", ".png") + ) + + depth_im = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) + rgb_im = cv2.imread(color_path, cv2.IMREAD_UNCHANGED) + + pose = np.loadtxt( + os.path.join( + self.rgbd_root, + scene_name, + "pose", + frame_name.replace(".jpg", ".txt"), + ) + ) + + if not self.render_semantic: + return rgb_im, depth_im, pose + else: + seg_path = os.path.join( + self.rgbd_root, + scene_name, + "label", + frame_name.replace(".jpg", ".png"), + ) + semantic_im = cv2.imread(seg_path, cv2.IMREAD_UNCHANGED) + return rgb_im, depth_im, pose, semantic_im + + def prepare_train_data(self, idx): + # load data + data_dict = self.get_data(idx) + data_dict = self.transform(data_dict) + return data_dict + + def prepare_test_data(self, idx): + # load data + data_dict = self.get_data(idx) + segment = data_dict.pop("segment") + data_dict = self.transform(data_dict) + data_dict_list = [] + for aug in self.aug_transform: + data_dict_list.append(aug(deepcopy(data_dict))) + + input_dict_list = [] + for data in data_dict_list: + data_part_list = self.test_voxelize(data) + for data_part in data_part_list: + if self.test_crop: + data_part = self.test_crop(data_part) + else: + data_part = [data_part] + input_dict_list += data_part + + for i in range(len(input_dict_list)): + input_dict_list[i] = self.post_transform(input_dict_list[i]) + data_dict = dict( + fragment_list=input_dict_list, segment=segment, name=self.get_data_name(idx) + ) + return data_dict + + def __getitem__(self, idx): + if self.test_mode: + return self.prepare_test_data(idx) + else: + return self.prepare_train_data(idx) + + def __len__(self): + return len(self.data_list) * self.loop diff --git a/ponder/datasets/structure3d.py b/ponder/datasets/structure3d.py index bfe92fd..51eff43 100644 --- a/ponder/datasets/structure3d.py +++ b/ponder/datasets/structure3d.py @@ -76,7 +76,7 @@ def get_data_list(self): filtered_data_list = [] for data_path in data_list: rgbd_paths = glob.glob( - os.path.join(data_path.split(".pth")[0], "rgbd", "*.pth") + os.path.join(data_path.split(".pth")[0] + "_rgbd", "*.pth") ) if len(rgbd_paths) <= 0: # print(f"{data_path} has no rgbd data.") @@ -111,16 +111,9 @@ def get_data(self, idx): ) rgbd_dicts = [torch.load(p) for p in rgbd_paths] - has_bad = False for i in range(len(rgbd_dicts)): if (rgbd_dicts[i]["depth_mask"]).mean() < 0.25: - print( - f"{rgbd_paths[i]} has bad depth data. ({rgbd_dicts[i]['depth_mask'].mean()})" - ) - os.rename(rgbd_paths[i], rgbd_paths[i] + ".bad") - has_bad = True - if has_bad: - return self.get_data(idx) + return self.get_data(idx) for d in rgbd_dicts: d["extrinsic"] = np.array( diff --git a/ponder/engines/train.py b/ponder/engines/train.py index 63560d2..ec9cde3 100644 --- a/ponder/engines/train.py +++ b/ponder/engines/train.py @@ -289,3 +289,21 @@ def build_scheduler(self): def build_scaler(self): scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None return scaler + + +@TRAINERS.register_module("MultiDatasetTrainer") +class MultiDatasetTrainer(Trainer): + def build_train_loader(self): + from ponder.datasets import MultiDatasetDataloader + + train_data = build_dataset(self.cfg.data.train) + train_loader = MultiDatasetDataloader( + train_data, + self.cfg.batch_size_per_gpu, + self.cfg.num_worker_per_gpu, + self.cfg.mix_prob, + self.cfg.seed, + max_point=self.cfg.get("max_point", -1), + ) + self.comm_info["iter_per_epoch"] = len(train_loader) + return train_loader diff --git a/ponder/models/__init__.py b/ponder/models/__init__.py index 3f88e7a..10764dc 100644 --- a/ponder/models/__init__.py +++ b/ponder/models/__init__.py @@ -3,6 +3,9 @@ # Semantic Segmentation from .default import DefaultClassifier, DefaultSegmentor +# PPT +from .point_prompt_training import * + # Pretraining from .ponder import * diff --git a/ponder/models/point_prompt_training/__init__.py b/ponder/models/point_prompt_training/__init__.py new file mode 100644 index 0000000..3a73588 --- /dev/null +++ b/ponder/models/point_prompt_training/__init__.py @@ -0,0 +1,2 @@ +from .point_prompt_training_v1m1_language_guided import * +from .point_prompt_training_v1m2_decoupled import * diff --git a/ponder/models/point_prompt_training/point_prompt_training_v1m1_language_guided.py b/ponder/models/point_prompt_training/point_prompt_training_v1m1_language_guided.py new file mode 100644 index 0000000..1a16078 --- /dev/null +++ b/ponder/models/point_prompt_training/point_prompt_training_v1m1_language_guided.py @@ -0,0 +1,185 @@ +""" +Point Prompt Training + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" +from collections import OrderedDict +from collections.abc import Sequence +from functools import partial + +import torch +import torch.nn as nn + +from ponder.models.builder import MODELS +from ponder.models.losses import build_criteria + + +@MODELS.register_module("PPT-v1m1") +class PointPromptTraining(nn.Module): + """ + PointPromptTraining provides Data-driven Context and enables multi-dataset training with + Language-driven Categorical Alignment. PDNorm is supported by SpUNet-v1m3 to adapt the + backbone to a specific dataset with a given dataset condition and context. + """ + + def __init__( + self, + backbone=None, + criteria=None, + backbone_out_channels=96, + context_channels=256, + conditions=("Structured3D", "ScanNet", "S3DIS"), + template="[x]", + clip_model="ViT-B/16", + class_name=( + "wall", + "floor", + "cabinet", + "bed", + "chair", + "sofa", + "table", + "door", + "window", + "bookshelf", + "bookcase", + "picture", + "counter", + "desk", + "shelves", + "curtain", + "dresser", + "pillow", + "mirror", + "ceiling", + "refrigerator", + "television", + "shower curtain", + "nightstand", + "toilet", + "sink", + "lamp", + "bathtub", + "garbagebin", + "board", + "beam", + "column", + "clutter", + "otherstructure", + "otherfurniture", + "otherprop", + ), + valid_index=( + ( + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 11, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 23, + 25, + 26, + 33, + 34, + 35, + ), + (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 15, 20, 22, 24, 25, 27, 34), + (0, 1, 4, 5, 6, 7, 8, 10, 19, 29, 30, 31, 32), + ), + backbone_mode=False, + ): + super().__init__() + assert len(conditions) == len(valid_index) + assert backbone.type in [ + "SpUNet-v1m3", + "PT-v2m3", + ] # SpUNet v1m3: Sparse UNet with PDNorm + self.backbone = MODELS.build(backbone) + self.criteria = build_criteria(criteria) + self.conditions = conditions + self.valid_index = valid_index + self.embedding_table = nn.Embedding(len(conditions), context_channels) + self.backbone_mode = backbone_mode + if not self.backbone_mode: + import clip + + clip_model, _ = clip.load( + clip_model, device="cpu", download_root="./.cache/clip" + ) + clip_model.requires_grad_(False) + if isinstance(template, str): + class_prompt = [template.replace("[x]", name) for name in class_name] + elif isinstance(template, Sequence): + class_prompt = [ + temp.replace("[x]", name) + for name in class_name + for temp in template + ] + class_token = clip.tokenize(class_prompt) + class_embedding = clip_model.encode_text(class_token) + class_embedding = class_embedding / class_embedding.norm( + dim=-1, keepdim=True + ) + if (not isinstance(template, str)) and isinstance(template, Sequence): + class_embedding = class_embedding.reshape( + len(template), len(class_name), clip_model.text_projection.shape[1] + ) + class_embedding = class_embedding.mean(0) + class_embedding = class_embedding / class_embedding.norm( + dim=-1, keepdim=True + ) + self.register_buffer("class_embedding", class_embedding) + self.proj_head = nn.Linear( + backbone_out_channels, clip_model.text_projection.shape[1] + ) + self.logit_scale = clip_model.logit_scale + + def forward(self, data_dict): + condition = data_dict["condition"][0] + assert condition in self.conditions + context = self.embedding_table( + torch.tensor( + [self.conditions.index(condition)], device=data_dict["coord"].device + ) + ) + data_dict["context"] = context + feat = self.backbone(data_dict) + if self.backbone_mode: + # PPT serve as a multi-dataset backbone when enable backbone mode + return feat + feat = self.proj_head(feat) + feat = feat / feat.norm(dim=-1, keepdim=True) + sim = ( + feat + @ self.class_embedding[ + self.valid_index[self.conditions.index(condition)], : + ].t() + ) + logit_scale = self.logit_scale.exp() + seg_logits = logit_scale * sim + # train + if self.training: + loss = self.criteria(seg_logits, data_dict["segment"]) + return dict(loss=loss) + # eval + elif "segment" in data_dict.keys(): + loss = self.criteria(seg_logits, data_dict["segment"]) + return dict(loss=loss, seg_logits=seg_logits) + # test + else: + return dict(seg_logits=seg_logits) diff --git a/ponder/models/point_prompt_training/point_prompt_training_v1m2_decoupled.py b/ponder/models/point_prompt_training/point_prompt_training_v1m2_decoupled.py new file mode 100644 index 0000000..e6ce62e --- /dev/null +++ b/ponder/models/point_prompt_training/point_prompt_training_v1m2_decoupled.py @@ -0,0 +1,67 @@ +""" +Point Prompt Training with decoupled segmentation head + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn + +from ponder.models.builder import MODELS +from ponder.models.losses import build_criteria + + +@MODELS.register_module("PPT-v1m2") +class PointPromptTraining(nn.Module): + """ + PointPromptTraining v1m2 provides Data-driven Context and enables multi-dataset training with + Decoupled Segmentation Head. PDNorm is supported by SpUNet-v1m3 to adapt the + backbone to a specific dataset with a given dataset condition and context. + """ + + def __init__( + self, + backbone=None, + criteria=None, + backbone_out_channels=96, + context_channels=256, + conditions=("Structured3D", "ScanNet", "S3DIS"), + num_classes=(25, 20, 13), + ): + super().__init__() + assert len(conditions) == len(num_classes) + assert backbone.type in ["SpUNet-v1m3"] # SpUNet v1m3: Sparse UNet with PDNorm + self.backbone = MODELS.build(backbone) + self.criteria = build_criteria(criteria) + self.conditions = conditions + self.embedding_table = nn.Embedding(len(conditions), context_channels) + self.seg_heads = nn.ModuleList( + [nn.Linear(backbone_out_channels, num_cls) for num_cls in num_classes] + ) + + def forward(self, data_dict): + condition = data_dict["condition"][0] + assert condition in self.conditions + context = self.embedding_table( + torch.tensor( + [self.conditions.index(condition)], device=data_dict["coord"].device + ) + ) + data_dict["context"] = context + feat = self.backbone(data_dict) + seg_head = self.seg_heads[self.conditions.index(condition)] + seg_logits = seg_head(feat) + # train + if self.training: + loss = self.criteria(seg_logits, data_dict["segment"]) + return dict(loss=loss) + # eval + elif "segment" in data_dict.keys(): + loss = self.criteria(seg_logits, data_dict["segment"]) + return dict(loss=loss, seg_logits=seg_logits) + # test + else: + return dict(seg_logits=seg_logits) diff --git a/ponder/models/ponder/ponder_indoor_base.py b/ponder/models/ponder/ponder_indoor_base.py index a5b6c2d..e51bf0d 100644 --- a/ponder/models/ponder/ponder_indoor_base.py +++ b/ponder/models/ponder/ponder_indoor_base.py @@ -29,6 +29,8 @@ def __init__( val_ray_split=10240, ray_nsample=128, padding=0.1, + backbone_out_channels=96, + context_channels=256, pool_type="mean", share_volume=True, render_semantic=False, # whether to render 2D semantic maps. @@ -37,6 +39,8 @@ def __init__( clip_model=None, class_name=None, valid_index=None, + ppt_loss_weight=0.0, # whether and how much to use PPT's loss + ppt_criteria=None, ): super().__init__() self.grid_shape = ( @@ -66,8 +70,19 @@ def __init__( self.render_semantic = render_semantic self.conditions = conditions self.valid_index = valid_index + self.embedding_table = nn.Embedding(len(conditions), context_channels) + self.backbone_out_channels = backbone_out_channels if render_semantic: + self.ppt_loss_weight = ppt_loss_weight self.load_semantic(template, clip_model, class_name) + else: + self.ppt_loss_weight = ( + 0.0 # ppt loss is not available when render_semantic is `False` + ) + + if self.ppt_loss_weight > 0: + assert ppt_criteria is not None, "Please provide PPT's loss function." + self.ppt_criteria = build_criteria(ppt_criteria) def load_semantic(self, template, clip_model, class_name): import clip @@ -95,6 +110,11 @@ def load_semantic(self, template, clip_model, class_name): dim=-1, keepdim=True ) self.register_buffer("class_embedding", class_embedding.float().cpu()) + self.logit_scale = clip_model.logit_scale + if self.ppt_loss_weight > 0: + self.proj_head = nn.Linear( + self.backbone_out_channels, clip_model.text_projection.shape[1] + ) del clip_model, class_prompt, class_token torch.cuda.empty_cache() @@ -143,6 +163,16 @@ def random_masking(B, H, W, ratio, device): feat[~grid_mask] = self.mtoken data_dict["feat"] = feat + if "condition" in data_dict: + condition = data_dict["condition"][0] + assert condition in self.conditions + context = self.embedding_table( + torch.tensor( + [self.conditions.index(condition)], device=data_dict["coord"].device + ) + ) + data_dict["context"] = context + data_dict["sparse_backbone_feat"] = self.backbone(data_dict) return data_dict @@ -642,6 +672,19 @@ def render_loss(self, render_out, ray_dict): loss = sum(_value for _key, _value in loss_dict.items() if "loss" in _key) return loss, loss_dict + def ppt_loss(self, data_dict): + feat = self.proj_head(data_dict["sparse_backbone_feat"]) + feat = feat / feat.norm(dim=-1, keepdim=True) + sim = ( + feat + @ self.class_embedding[ + self.valid_index[self.conditions.index(data_dict["condition"][0])], : + ].t() + ) + logit_scale = self.logit_scale.exp() + seg_logits = logit_scale * sim + return self.ppt_criteria(seg_logits, data_dict["segment"]) + def forward(self, data_dict): data_dict = self.extract_feature(data_dict) ray_dict, data_dict = self.prepare_ray(data_dict) @@ -649,4 +692,9 @@ def forward(self, data_dict): render_out = self.render_func(ray_dict, volume_feature) loss, loss_dict = self.render_loss(render_out, ray_dict) out_dict = dict(loss=loss, **loss_dict) + + if self.ppt_loss_weight > 0: + ppt_loss = self.ppt_loss(data_dict) + out_dict["ppt_loss"] = ppt_loss + return out_dict diff --git a/ponder/models/sparse_unet/__init__.py b/ponder/models/sparse_unet/__init__.py index 2040db0..7636301 100644 --- a/ponder/models/sparse_unet/__init__.py +++ b/ponder/models/sparse_unet/__init__.py @@ -1 +1,4 @@ +# from .mink_unet import * from .spconv_unet_v1m1_base import * +from .spconv_unet_v1m2_bn_momentum import * +from .spconv_unet_v1m3_pdnorm import * diff --git a/ponder/models/sparse_unet/mink_unet.py b/ponder/models/sparse_unet/mink_unet.py new file mode 100644 index 0000000..fb47067 --- /dev/null +++ b/ponder/models/sparse_unet/mink_unet.py @@ -0,0 +1,441 @@ +""" +SparseUNet Driven by MinkowskiEngine + +Modified from chrischoy/SpatioTemporalSegmentation + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import torch +import torch.nn as nn + +try: + import MinkowskiEngine as ME +except ImportError: + import warnings + + warnings.warn("Please follow `README.md` to install MinkowskiEngine.`") + +from pointcept.models.builder import MODELS + + +def offset2batch(offset): + return ( + torch.cat( + [ + torch.tensor([i] * (o - offset[i - 1])) + if i > 0 + else torch.tensor([i] * o) + for i, o in enumerate(offset) + ], + dim=0, + ) + .long() + .to(offset.device) + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + bn_momentum=0.1, + dimension=-1, + ): + super(BasicBlock, self).__init__() + assert dimension > 0 + + self.conv1 = ME.MinkowskiConvolution( + inplanes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + dimension=dimension, + ) + self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) + self.conv2 = ME.MinkowskiConvolution( + planes, + planes, + kernel_size=3, + stride=1, + dilation=dilation, + dimension=dimension, + ) + self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) + self.relu = ME.MinkowskiReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + bn_momentum=0.1, + dimension=-1, + ): + super(Bottleneck, self).__init__() + assert dimension > 0 + + self.conv1 = ME.MinkowskiConvolution( + inplanes, planes, kernel_size=1, dimension=dimension + ) + self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) + + self.conv2 = ME.MinkowskiConvolution( + planes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + dimension=dimension, + ) + self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) + + self.conv3 = ME.MinkowskiConvolution( + planes, planes * self.expansion, kernel_size=1, dimension=dimension + ) + self.norm3 = ME.MinkowskiBatchNorm( + planes * self.expansion, momentum=bn_momentum + ) + + self.relu = ME.MinkowskiReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class MinkUNetBase(nn.Module): + BLOCK = None + PLANES = None + DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) + LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) + PLANES = (32, 64, 128, 256, 256, 128, 96, 96) + INIT_DIM = 32 + OUT_TENSOR_STRIDE = 1 + + def __init__(self, in_channels, out_channels, dimension=3): + super().__init__() + self.D = dimension + assert self.BLOCK is not None + # Output of the first conv concated to conv6 + self.inplanes = self.INIT_DIM + self.conv0p1s1 = ME.MinkowskiConvolution( + in_channels, self.inplanes, kernel_size=5, dimension=self.D + ) + + self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) + + self.conv1p1s2 = ME.MinkowskiConvolution( + self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D + ) + self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) + + self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0]) + + self.conv2p2s2 = ME.MinkowskiConvolution( + self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D + ) + self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) + + self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1]) + + self.conv3p4s2 = ME.MinkowskiConvolution( + self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D + ) + + self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) + self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2]) + + self.conv4p8s2 = ME.MinkowskiConvolution( + self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=self.D + ) + self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) + self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3]) + + self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=self.D + ) + self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) + + self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion + self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4]) + self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=self.D + ) + self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) + + self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion + self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5]) + self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=self.D + ) + self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) + + self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion + self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6]) + self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=self.D + ) + self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) + + self.inplanes = self.PLANES[7] + self.INIT_DIM + self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7]) + + self.final = ME.MinkowskiConvolution( + self.PLANES[7] * self.BLOCK.expansion, + out_channels, + kernel_size=1, + bias=True, + dimension=self.D, + ) + self.relu = ME.MinkowskiReLU(inplace=True) + + self.weight_initialization() + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, ME.MinkowskiConvolution): + ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") + + if isinstance(m, ME.MinkowskiBatchNorm): + nn.init.constant_(m.bn.weight, 1) + nn.init.constant_(m.bn.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + ME.MinkowskiConvolution( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + dimension=self.D, + ), + ME.MinkowskiBatchNorm(planes * block.expansion), + ) + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride=stride, + dilation=dilation, + downsample=downsample, + dimension=self.D, + ) + ) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D + ) + ) + + return nn.Sequential(*layers) + + def forward(self, data_dict): + grid_coord = data_dict["grid_coord"] + feat = data_dict["feat"] + offset = data_dict["offset"] + batch = offset2batch(offset) + in_field = ME.TensorField( + feat, + coordinates=torch.cat([batch.unsqueeze(-1).int(), grid_coord.int()], dim=1), + quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, + minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED, + device=feat.device, + ) + x = in_field.sparse() + + out = self.conv0p1s1(x) + out = self.bn0(out) + out_p1 = self.relu(out) + + out = self.conv1p1s2(out_p1) + out = self.bn1(out) + out = self.relu(out) + out_b1p2 = self.block1(out) + + out = self.conv2p2s2(out_b1p2) + out = self.bn2(out) + out = self.relu(out) + out_b2p4 = self.block2(out) + + out = self.conv3p4s2(out_b2p4) + out = self.bn3(out) + out = self.relu(out) + out_b3p8 = self.block3(out) + + # tensor_stride=16 + out = self.conv4p8s2(out_b3p8) + out = self.bn4(out) + out = self.relu(out) + out = self.block4(out) + + # tensor_stride=8 + out = self.convtr4p16s2(out) + out = self.bntr4(out) + out = self.relu(out) + + out = ME.cat(out, out_b3p8) + out = self.block5(out) + + # tensor_stride=4 + out = self.convtr5p8s2(out) + out = self.bntr5(out) + out = self.relu(out) + + out = ME.cat(out, out_b2p4) + out = self.block6(out) + + # tensor_stride=2 + out = self.convtr6p4s2(out) + out = self.bntr6(out) + out = self.relu(out) + + out = ME.cat(out, out_b1p2) + out = self.block7(out) + + # tensor_stride=1 + out = self.convtr7p2s2(out) + out = self.bntr7(out) + out = self.relu(out) + + out = ME.cat(out, out_p1) + out = self.block8(out) + + return self.final(out).slice(in_field).F + + +@MODELS.register_module() +class MinkUNet14(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) + + +@MODELS.register_module() +class MinkUNet18(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) + + +@MODELS.register_module() +class MinkUNet34(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) + + +@MODELS.register_module() +class MinkUNet50(MinkUNetBase): + BLOCK = Bottleneck + LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) + + +@MODELS.register_module() +class MinkUNet101(MinkUNetBase): + BLOCK = Bottleneck + LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) + + +@MODELS.register_module() +class MinkUNet14A(MinkUNet14): + PLANES = (32, 64, 128, 256, 128, 128, 96, 96) + + +@MODELS.register_module() +class MinkUNet14B(MinkUNet14): + PLANES = (32, 64, 128, 256, 128, 128, 128, 128) + + +@MODELS.register_module() +class MinkUNet14C(MinkUNet14): + PLANES = (32, 64, 128, 256, 192, 192, 128, 128) + + +@MODELS.register_module() +class MinkUNet14D(MinkUNet14): + PLANES = (32, 64, 128, 256, 384, 384, 384, 384) + + +@MODELS.register_module() +class MinkUNet18A(MinkUNet18): + PLANES = (32, 64, 128, 256, 128, 128, 96, 96) + + +@MODELS.register_module() +class MinkUNet18B(MinkUNet18): + PLANES = (32, 64, 128, 256, 128, 128, 128, 128) + + +@MODELS.register_module() +class MinkUNet18D(MinkUNet18): + PLANES = (32, 64, 128, 256, 384, 384, 384, 384) + + +@MODELS.register_module() +class MinkUNet34A(MinkUNet34): + PLANES = (32, 64, 128, 256, 256, 128, 96, 96) + + +@MODELS.register_module() +class MinkUNet34B(MinkUNet34): + PLANES = (32, 64, 128, 256, 256, 128, 64, 32) + + +@MODELS.register_module() +class MinkUNet34C(MinkUNet34): + PLANES = (32, 64, 128, 256, 256, 128, 96, 96) diff --git a/ponder/models/sparse_unet/spconv_unet_v1m2_bn_momentum.py b/ponder/models/sparse_unet/spconv_unet_v1m2_bn_momentum.py new file mode 100644 index 0000000..dedb491 --- /dev/null +++ b/ponder/models/sparse_unet/spconv_unet_v1m2_bn_momentum.py @@ -0,0 +1,287 @@ +""" +SparseUNet Driven by SpConv (recommend) + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn + +try: + import spconv.pytorch as spconv +except ImportError: + import warnings + + warnings.warn("Please follow `README.md` to install spconv2.`") + +from timm.models.layers import trunc_normal_ + +from ponder.models.builder import MODELS + + +def offset2batch(offset): + return ( + torch.cat( + [ + torch.tensor([i] * (o - offset[i - 1])) + if i > 0 + else torch.tensor([i] * o) + for i, o in enumerate(offset) + ], + dim=0, + ) + .long() + .to(offset.device) + ) + + +class BasicBlock(spconv.SparseModule): + expansion = 1 + + def __init__( + self, + in_channels, + embed_channels, + stride=1, + norm_fn=None, + indice_key=None, + bias=False, + ): + super().__init__() + + assert norm_fn is not None + + if in_channels == embed_channels: + self.proj = spconv.SparseSequential(nn.Identity()) + else: + self.proj = spconv.SparseSequential( + spconv.SubMConv3d( + in_channels, embed_channels, kernel_size=1, bias=False + ), + norm_fn(embed_channels, momentum=0.02), + ) + + self.conv1 = spconv.SubMConv3d( + in_channels, + embed_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=bias, + indice_key=indice_key, + ) + self.bn1 = norm_fn(embed_channels) + self.relu = nn.ReLU() + self.conv2 = spconv.SubMConv3d( + embed_channels, + embed_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=bias, + indice_key=indice_key, + ) + self.bn2 = norm_fn(embed_channels) + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = out.replace_feature(self.bn1(out.features)) + out = out.replace_feature(self.relu(out.features)) + + out = self.conv2(out) + out = out.replace_feature(self.bn2(out.features)) + + out = out.replace_feature(out.features + self.proj(residual).features) + out = out.replace_feature(self.relu(out.features)) + + return out + + +@MODELS.register_module("SpUNet-v1m2") +class SpUNetBase(nn.Module): + def __init__( + self, + in_channels, + num_classes, + base_channels=32, + channels=(32, 64, 128, 256, 256, 128, 96, 96), + layers=(2, 3, 4, 6, 2, 2, 2, 2), + bn_momentum=0.1, + ): + super().__init__() + assert len(layers) % 2 == 0 + assert len(layers) == len(channels) + self.in_channels = in_channels + self.num_classes = num_classes + self.base_channels = base_channels + self.channels = channels + self.layers = layers + self.num_stages = len(layers) // 2 + + norm_fn = partial(nn.BatchNorm1d, eps=1e-5, momentum=bn_momentum) + block = BasicBlock + + self.conv_input = spconv.SparseSequential( + spconv.SubMConv3d( + in_channels, + base_channels, + kernel_size=5, + padding=1, + bias=False, + indice_key="stem", + ), + norm_fn(base_channels, momentum=0.02), + nn.ReLU(), + ) + + enc_channels = base_channels + dec_channels = channels[-1] + self.down = nn.ModuleList() + self.up = nn.ModuleList() + self.enc = nn.ModuleList() + self.dec = nn.ModuleList() + + for s in range(self.num_stages): + # encode num_stages + self.down.append( + spconv.SparseSequential( + spconv.SparseConv3d( + enc_channels, + channels[s], + kernel_size=2, + stride=2, + bias=False, + indice_key=f"spconv{s + 1}", + ), + norm_fn(channels[s], momentum=0.02), + nn.ReLU(), + ) + ) + self.enc.append( + spconv.SparseSequential( + OrderedDict( + [ + # (f"block{i}", block(enc_channels, channels[s], norm_fn=norm_fn, indice_key=f"subm{s + 1}")) + # if i == 0 else + ( + f"block{i}", + block( + channels[s], + channels[s], + norm_fn=norm_fn, + indice_key=f"subm{s + 1}", + ), + ) + for i in range(layers[s]) + ] + ) + ) + ) + + # decode num_stages + self.up.append( + spconv.SparseSequential( + spconv.SparseInverseConv3d( + channels[len(channels) - s - 2], + dec_channels, + kernel_size=2, + bias=False, + indice_key=f"spconv{s + 1}", + ), + norm_fn(dec_channels, momentum=0.02), + nn.ReLU(), + ) + ) + self.dec.append( + spconv.SparseSequential( + OrderedDict( + [ + ( + f"block{i}", + block( + dec_channels + enc_channels, + dec_channels, + norm_fn=norm_fn, + indice_key=f"subm{s}", + ), + ) + if i == 0 + else ( + f"block{i}", + block( + dec_channels, + dec_channels, + norm_fn=norm_fn, + indice_key=f"subm{s}", + ), + ) + for i in range(layers[len(channels) - s - 1]) + ] + ) + ) + ) + enc_channels = channels[s] + dec_channels = channels[len(channels) - s - 2] + + self.final = ( + spconv.SubMConv3d( + channels[-1], num_classes, kernel_size=1, padding=1, bias=True + ) + if num_classes > 0 + else spconv.Identity() + ) + self.apply(self._init_weights) + + @staticmethod + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, spconv.SubMConv3d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, data_dict): + grid_coord = data_dict["grid_coord"] + feat = data_dict["feat"] + offset = data_dict["offset"] + + batch = offset2batch(offset) + sparse_shape = torch.add(torch.max(grid_coord, dim=0).values, 1).tolist() + x = spconv.SparseConvTensor( + features=feat, + indices=torch.cat( + [batch.unsqueeze(-1).int(), grid_coord.int()], dim=1 + ).contiguous(), + spatial_shape=sparse_shape, + batch_size=batch[-1].tolist() + 1, + ) + x = self.conv_input(x) + skips = [x] + # enc forward + for s in range(self.num_stages): + x = self.down[s](x) + x = self.enc[s](x) + skips.append(x) + x = skips.pop(-1) + # dec forward + for s in reversed(range(self.num_stages)): + x = self.up[s](x) + skip = skips.pop(-1) + x = x.replace_feature(torch.cat((x.features, skip.features), dim=1)) + x = self.dec[s](x) + + x = self.final(x) + return x.features diff --git a/ponder/models/sparse_unet/spconv_unet_v1m3_pdnorm.py b/ponder/models/sparse_unet/spconv_unet_v1m3_pdnorm.py new file mode 100644 index 0000000..123d46f --- /dev/null +++ b/ponder/models/sparse_unet/spconv_unet_v1m3_pdnorm.py @@ -0,0 +1,424 @@ +""" +SparseUNet V1M3 + +Enable Prompt-Driven Normalization for Point Prompt Training + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" +from collections import OrderedDict +from functools import partial + +import spconv.pytorch as spconv +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_ +from torch_geometric.utils import scatter + +from ponder.models.builder import MODELS +from ponder.models.utils import offset2batch + + +class PDBatchNorm(torch.nn.Module): + def __init__( + self, + num_features, + context_channels=256, + eps=1e-3, + momentum=0.01, + conditions=("ScanNet", "S3DIS", "Structured3D"), + decouple=True, + adaptive=False, + affine=True, + ): + super().__init__() + self.conditions = conditions + self.decouple = decouple + self.adaptive = adaptive + self.affine = affine + if self.decouple: + self.bns = nn.ModuleList( + [ + nn.BatchNorm1d( + num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + ) + for _ in conditions + ] + ) + else: + self.bn = nn.BatchNorm1d( + num_features=num_features, eps=eps, momentum=momentum, affine=affine + ) + if self.adaptive: + self.modulation = nn.Sequential( + nn.SiLU(), nn.Linear(context_channels, 2 * num_features, bias=True) + ) + + def forward(self, feat, condition=None, context=None): + if self.decouple: + assert condition in self.conditions + bn = self.bns[self.conditions.index(condition)] + else: + bn = self.bn + feat = bn(feat) + if self.adaptive: + assert context is not None + shift, scale = self.modulation(context).chunk(2, dim=1) + feat = feat * (1.0 + scale) + shift + return feat + + +class BasicBlock(spconv.SparseModule): + expansion = 1 + + def __init__( + self, + in_channels, + embed_channels, + stride=1, + norm_fn=None, + indice_key=None, + bias=False, + ): + super().__init__() + + assert norm_fn is not None + + self.in_channels = in_channels + self.embed_channels = embed_channels + if in_channels == embed_channels: + self.proj = spconv.SparseSequential(nn.Identity()) + else: + # TODO remove norm after project + self.proj_conv = spconv.SubMConv3d( + in_channels, embed_channels, kernel_size=1, bias=False + ) + self.proj_norm = norm_fn(embed_channels) + + self.conv1 = spconv.SubMConv3d( + in_channels, + embed_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=bias, + indice_key=indice_key, + ) + self.bn1 = norm_fn(embed_channels) + self.relu = nn.ReLU() + self.conv2 = spconv.SubMConv3d( + embed_channels, + embed_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=bias, + indice_key=indice_key, + ) + self.bn2 = norm_fn(embed_channels) + self.stride = stride + + def forward(self, x): + x, condition, context = x + residual = x + + out = self.conv1(x) + out = out.replace_feature(self.bn1(out.features, condition, context)) + out = out.replace_feature(self.relu(out.features)) + + out = self.conv2(out) + out = out.replace_feature(self.bn2(out.features, condition, context)) + + if self.in_channels == self.embed_channels: + residual = self.proj(residual) + else: + residual = residual.replace_feature( + self.proj_norm(self.proj_conv(residual).features, condition, context) + ) + out = out.replace_feature(out.features + residual.features) + out = out.replace_feature(self.relu(out.features)) + return out, condition, context + + +class SPConvDown(nn.Module): + def __init__( + self, + in_channels, + out_channels, + indice_key, + kernel_size=2, + bias=False, + norm_fn=None, + ): + super().__init__() + self.conv = spconv.SparseConv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=kernel_size, + bias=bias, + indice_key=indice_key, + ) + self.bn = norm_fn(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x, condition, context = x + out = self.conv(x) + out = out.replace_feature(self.bn(out.features, condition, context)) + out = out.replace_feature(self.relu(out.features)) + return out + + +class SPConvUp(nn.Module): + def __init__( + self, + in_channels, + out_channels, + indice_key, + kernel_size=2, + bias=False, + norm_fn=None, + ): + super().__init__() + self.conv = spconv.SparseInverseConv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + bias=bias, + indice_key=indice_key, + ) + self.bn = norm_fn(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x, condition, context = x + out = self.conv(x) + out = out.replace_feature(self.bn(out.features, condition, context)) + out = out.replace_feature(self.relu(out.features)) + return out + + +class SPConvPatchEmbedding(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=5, norm_fn=None): + super().__init__() + self.conv = spconv.SubMConv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=1, + bias=False, + indice_key="stem", + ) + self.bn = norm_fn(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x, condition, context = x + out = self.conv(x) + out = out.replace_feature(self.bn(out.features, condition, context)) + out = out.replace_feature(self.relu(out.features)) + return out + + +@MODELS.register_module("SpUNet-v1m3") +class SpUNetBase(nn.Module): + def __init__( + self, + in_channels, + num_classes=0, + base_channels=32, + context_channels=256, + channels=(32, 64, 128, 256, 256, 128, 96, 96), + layers=(2, 3, 4, 6, 2, 2, 2, 2), + cls_mode=False, + conditions=("ScanNet", "S3DIS", "Structured3D"), + zero_init=True, + norm_decouple=True, + norm_adaptive=True, + norm_affine=False, + ): + super().__init__() + assert len(layers) % 2 == 0 + assert len(layers) == len(channels) + self.in_channels = in_channels + self.num_classes = num_classes + self.base_channels = base_channels + self.channels = channels + self.layers = layers + self.num_stages = len(layers) // 2 + self.cls_mode = cls_mode + self.conditions = conditions + self.zero_init = zero_init + + norm_fn = partial( + PDBatchNorm, + eps=1e-3, + momentum=0.01, + conditions=conditions, + context_channels=context_channels, + decouple=norm_decouple, + adaptive=norm_adaptive, + affine=norm_affine, + ) + block = BasicBlock + + self.conv_input = SPConvPatchEmbedding( + in_channels, base_channels, kernel_size=5, norm_fn=norm_fn + ) + + enc_channels = base_channels + dec_channels = channels[-1] + self.down = nn.ModuleList() + self.up = nn.ModuleList() + self.enc = nn.ModuleList() + self.dec = nn.ModuleList() if not self.cls_mode else None + + for s in range(self.num_stages): + # encode num_stages + self.down.append( + SPConvDown( + enc_channels, + channels[s], + kernel_size=2, + bias=False, + indice_key=f"spconv{s + 1}", + norm_fn=norm_fn, + ) + ) + self.enc.append( + spconv.SparseSequential( + OrderedDict( + [ + # (f"block{i}", block(enc_channels, channels[s], norm_fn=norm_fn, indice_key=f"subm{s + 1}")) + # if i == 0 else + ( + f"block{i}", + block( + channels[s], + channels[s], + norm_fn=norm_fn, + indice_key=f"subm{s + 1}", + ), + ) + for i in range(layers[s]) + ] + ) + ) + ) + if not self.cls_mode: + # decode num_stages + self.up.append( + SPConvUp( + channels[len(channels) - s - 2], + dec_channels, + kernel_size=2, + bias=False, + indice_key=f"spconv{s + 1}", + norm_fn=norm_fn, + ) + ) + self.dec.append( + spconv.SparseSequential( + OrderedDict( + [ + ( + f"block{i}", + block( + dec_channels + enc_channels, + dec_channels, + norm_fn=norm_fn, + indice_key=f"subm{s}", + ), + ) + if i == 0 + else ( + f"block{i}", + block( + dec_channels, + dec_channels, + norm_fn=norm_fn, + indice_key=f"subm{s}", + ), + ) + for i in range(layers[len(channels) - s - 1]) + ] + ) + ) + ) + + enc_channels = channels[s] + dec_channels = channels[len(channels) - s - 2] + + final_in_channels = ( + channels[-1] if not self.cls_mode else channels[self.num_stages - 1] + ) + self.final = ( + spconv.SubMConv3d( + final_in_channels, num_classes, kernel_size=1, padding=1, bias=True + ) + if num_classes > 0 + else spconv.Identity() + ) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, spconv.SubMConv3d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + if m.affine: + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, PDBatchNorm): + if self.zero_init: + nn.init.constant_(m.modulation[-1].weight, 0) + nn.init.constant_(m.modulation[-1].bias, 0) + + def forward(self, input_dict): + grid_coord = input_dict["grid_coord"] + feat = input_dict["feat"] + offset = input_dict["offset"] + condition = input_dict["condition"][0] + context = input_dict["context"] if "context" in input_dict.keys() else None + + batch = offset2batch(offset) + sparse_shape = torch.add(torch.max(grid_coord, dim=0).values, 96).tolist() + x = spconv.SparseConvTensor( + features=feat, + indices=torch.cat( + [batch.unsqueeze(-1).int(), grid_coord.int()], dim=1 + ).contiguous(), + spatial_shape=sparse_shape, + batch_size=batch[-1].tolist() + 1, + ) + x = self.conv_input([x, condition, context]) + skips = [x] + # enc forward + for s in range(self.num_stages): + x = self.down[s]([x, condition, context]) + x, _, _ = self.enc[s]([x, condition, context]) + skips.append(x) + x = skips.pop(-1) + if not self.cls_mode: + # dec forward + for s in reversed(range(self.num_stages)): + x = self.up[s]([x, condition, context]) + skip = skips.pop(-1) + x = x.replace_feature(torch.cat((x.features, skip.features), dim=1)) + x, _, _ = self.dec[s]([x, condition, context]) + + x = self.final(x) + if self.cls_mode: + x = x.replace_feature( + scatter(x.features, x.indices[:, 0].long(), reduce="mean", dim=0) + ) + return x.features diff --git a/tools/test_s3dis_6fold.py b/tools/test_s3dis_6fold.py new file mode 100644 index 0000000..988d60c --- /dev/null +++ b/tools/test_s3dis_6fold.py @@ -0,0 +1,103 @@ +""" +Test script for S3DIS 6-fold cross validation + +Gathering Area_X.pth from result folder of experiment record of each area as follows: +|- RECORDS_PATH + |- Area_1.pth + |- Area_2.pth + |- Area_3.pth + |- Area_4.pth + |- Area_5.pth + |- Area_6.pth + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +import argparse +import glob +import os + +import numpy as np +import torch + +from ponder.utils.logger import get_root_logger + +CLASS_NAMES = [ + "ceiling", + "floor", + "wall", + "beam", + "column", + "window", + "door", + "table", + "chair", + "sofa", + "bookcase", + "board", + "clutter", +] + + +def evaluation(intersection, union, target, logger=None): + iou_class = intersection / (union + 1e-10) + accuracy_class = intersection / (target + 1e-10) + mIoU = np.mean(iou_class) + mAcc = np.mean(accuracy_class) + allAcc = sum(intersection) / (sum(target) + 1e-10) + + if logger is not None: + logger.info( + "Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}".format( + mIoU, mAcc, allAcc + ) + ) + for i in range(len(CLASS_NAMES)): + logger.info( + "Class_{idx} - {name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format( + idx=i, + name=CLASS_NAMES[i], + iou=iou_class[i], + accuracy=accuracy_class[i], + ) + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--record_root", + required=True, + help="Path to the S3DIS record of each split", + ) + config = parser.parse_args() + logger = get_root_logger( + log_file=os.path.join(config.record_root, "6-fold.log"), + file_mode="w", + ) + + records = sorted(glob.glob(os.path.join(config.record_root, "Area_*.pth"))) + assert len(records) == 6 + intersection_ = np.zeros(len(CLASS_NAMES), dtype=int) + union_ = np.zeros(len(CLASS_NAMES), dtype=int) + target_ = np.zeros(len(CLASS_NAMES), dtype=int) + + for record in records: + area = os.path.basename(record).split(".")[0] + info = torch.load(record) + logger.info(f"<<<<<<<<<<<<<<<<< Parsing {area} <<<<<<<<<<<<<<<<<") + intersection = info["intersection"] + union = info["union"] + target = info["target"] + evaluation(intersection, union, target, logger=logger) + intersection_ += intersection + union_ += union + target_ += target + + logger.info(f"<<<<<<<<<<<<<<<<< Parsing 6-fold <<<<<<<<<<<<<<<<<") + evaluation(intersection_, union_, target_, logger=logger) + + +if __name__ == "__main__": + main()