-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Add segmentation support for Clay - Add preprocessing script for Chesapeake Bay dataset
- Loading branch information
Showing
7 changed files
with
936 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# lightning.pytorch==2.1.2 | ||
seed_everything: 42 | ||
data: | ||
train_chip_dir: data/cvpr/ny/train/chips/ | ||
train_label_dir: data/cvpr/ny/train/labels/ | ||
val_chip_dir: data/cvpr/ny/val/chips/ | ||
val_label_dir: data/cvpr/ny/val/labels/ | ||
metadata_path: configs/metadata.yaml | ||
batch_size: 40 | ||
num_workers: 8 | ||
platform: naip | ||
model: | ||
num_classes: 7 | ||
feature_maps: | ||
- 3 | ||
- 5 | ||
- 7 | ||
- 11 | ||
ckpt_path: checkpoints/clay-v1-base.ckpt | ||
lr: 1e-5 | ||
wd: 0.05 | ||
b1: 0.9 | ||
b2: 0.95 | ||
trainer: | ||
accelerator: auto | ||
strategy: ddp | ||
devices: auto | ||
num_nodes: 1 | ||
precision: bf16-mixed | ||
log_every_n_steps: 5 | ||
max_epochs: 10 | ||
accumulate_grad_batches: 1 | ||
default_root_dir: checkpoints/segment | ||
fast_dev_run: False | ||
num_sanity_val_steps: 0 | ||
logger: | ||
- class_path: lightning.pytorch.loggers.WandbLogger | ||
init_args: | ||
entity: developmentseed | ||
project: clay-segment | ||
log_model: false | ||
callbacks: | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint | ||
init_args: | ||
dirpath: checkpoints/segment | ||
auto_insert_metric_name: False | ||
filename: chesapeake-7class-segment_epoch-{epoch:02d}_val-iou-{val/iou:.4f} | ||
monitor: val/iou | ||
mode: max | ||
save_last: True | ||
save_top_k: 2 | ||
save_weights_only: True | ||
verbose: True | ||
- class_path: lightning.pytorch.callbacks.LearningRateMonitor | ||
init_args: | ||
logging_interval: step | ||
plugins: | ||
- class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Segmentor | ||
|
||
The `Segmentor` class is designed for semantic segmentation tasks, extracting feature maps from intermediate layers of the Clay Encoder and adding a Feature Pyramid Network (FPN) on top of it. | ||
|
||
Decoder is inspired by the Segformer paper. | ||
Todo: | ||
- Add neck & head for segmentation task from other papers like UperNet, PPANet, etc. to compare with other GeoAI models. | ||
|
||
|
||
## Parameters | ||
|
||
- `feature_maps (list)`: Indices of intermediate layers of the Clay Encoder used by FPN layers. | ||
- `ckpt_path (str)`: Path to the Clay model checkpoint. | ||
|
||
## Example | ||
|
||
In this example, we will use the `Segmentor` class to segment Land Use Land Cover (LULC) classes for the Chesapeake Bay CVPR dataset. The implementation includes data preprocessing, data loading, and model training workflow using PyTorch Lightning. | ||
|
||
## Dataset | ||
|
||
### Citation | ||
|
||
If you use this dataset, please cite the associated manuscript: | ||
|
||
Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N. | ||
Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data. | ||
Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition (CVPR 2019). | ||
|
||
Dataset URL: [Chesapeake Bay Land Cover Dataset](https://lila.science/datasets/chesapeakelandcover) | ||
|
||
## Setup | ||
|
||
Follow the instructions in the [README](../../README.md) to install the required dependencies. | ||
|
||
```bash | ||
git clone <repo-url> | ||
cd model | ||
mamba env create --file environment.yml | ||
mamba activate claymodel | ||
``` | ||
|
||
## Usage | ||
|
||
### Preparing the Dataset | ||
|
||
Download the Chesapeake Bay Land Cover dataset and organize your dataset directory as recommended. | ||
|
||
1. Copy `*_lc.tif` and `*_naip-new.tif` files for segmentation downstream tasks using s5cmd: | ||
```bash | ||
# train | ||
s5cmd cp --include "*_lc.tif" --include "*_naip-new.tif" "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-train_tiles/*" data/cvpr/files/train/ | ||
|
||
# val | ||
s5cmd cp --include "*_lc.tif" --include "*_naip-new.tif" "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-val_tiles/*" data/cvpr/files/val/ | ||
``` | ||
|
||
2. Create chips of size `224 x 224` to feed them to the model: | ||
```bash | ||
python finetune/segment/preprocess_data.py data/cvpr/files data/cvpr/ny 224 | ||
``` | ||
|
||
Directory structure: | ||
``` | ||
data/ | ||
└── cvpr/ | ||
├── files/ | ||
│ ├── train/ | ||
│ └── val/ | ||
└── ny/ | ||
├── train/ | ||
│ ├── chips/ | ||
│ └── labels/ | ||
└── val/ | ||
├── chips/ | ||
└── labels/ | ||
``` | ||
### Training the Model | ||
The model can be run via LightningCLI using configurations in `finetune/segment/configs/segment_chesapeake.yaml`. | ||
1. Download the Clay model checkpoint from [Huggingface model hub](https://huggingface.co/made-with-clay/Clay/blob/main/clay-v1-base.ckpt) and save it in the `checkpoints/` directory. | ||
2. Modify the batch size, learning rate, and other hyperparameters in the configuration file as needed: | ||
```yaml | ||
data: | ||
batch_size: 40 | ||
num_workers: 8 | ||
model: | ||
num_classes: 7 | ||
feature_maps: | ||
- 3 | ||
- 5 | ||
- 7 | ||
- 11 | ||
ckpt_path: checkpoints/clay-v1-base.ckpt | ||
lr: 1e-5 | ||
wd: 0.05 | ||
b1: 0.9 | ||
b2: 0.95 | ||
``` | ||
3. Update the [WandB logger](https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.WandbLogger.html#lightning.pytorch.loggers.WandbLogger) configuration in the configuration file with your WandB details or use [CSV Logger](https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.CSVLogger.html#lightning.pytorch.loggers.CSVLogger) if you don't want to log to WandB: | ||
```yaml | ||
logger: | ||
- class_path: lightning.pytorch.loggers.WandbLogger | ||
init_args: | ||
entity: <wandb-entity> | ||
project: <wandb-project> | ||
log_model: false | ||
``` | ||
4. Train the model: | ||
```bash | ||
python segment.py fit --config configs/segment_chesapeake.yaml | ||
``` | ||
## Acknowledgments | ||
Decoder implementation is inspired by the Segformer paper: | ||
``` | ||
Segformer: Simple and Efficient Design for Semantic Segmentation with Transformers | ||
Paper URL: https://arxiv.org/abs/2105.15203 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
""" | ||
DataModule for the Chesapeake Bay dataset for segmentation tasks. | ||
This implementation provides a structured way to handle the data loading and | ||
preprocessing required for training and validating a segmentation model. | ||
Dataset citation: | ||
Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N. | ||
Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data. | ||
Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition | ||
(CVPR 2019). | ||
Dataset URL: https://lila.science/datasets/chesapeakelandcover | ||
""" | ||
|
||
import re | ||
from pathlib import Path | ||
|
||
import lightning as L | ||
import numpy as np | ||
import torch | ||
import yaml | ||
from box import Box | ||
from torch.utils.data import DataLoader, Dataset | ||
from torchvision.transforms import v2 | ||
|
||
|
||
class ChesapeakeDataset(Dataset): | ||
""" | ||
Dataset class for the Chesapeake Bay segmentation dataset. | ||
Args: | ||
chip_dir (str): Directory containing the image chips. | ||
label_dir (str): Directory containing the labels. | ||
metadata (Box): Metadata for normalization and other dataset-specific details. | ||
platform (str): Platform identifier used in metadata. | ||
""" | ||
|
||
def __init__(self, chip_dir, label_dir, metadata, platform): | ||
self.chip_dir = Path(chip_dir) | ||
self.label_dir = Path(label_dir) | ||
self.metadata = metadata | ||
self.transform = self.create_transforms( | ||
mean=list(metadata[platform].bands.mean.values()), | ||
std=list(metadata[platform].bands.std.values()), | ||
) | ||
|
||
# Load chip and label file names | ||
self.chips = [chip_path.name for chip_path in self.chip_dir.glob("*.npy")] | ||
self.labels = [re.sub("_naip-new_", "_lc_", chip) for chip in self.chips] | ||
|
||
def create_transforms(self, mean, std): | ||
""" | ||
Create normalization transforms. | ||
Args: | ||
mean (list): Mean values for normalization. | ||
std (list): Standard deviation values for normalization. | ||
Returns: | ||
torchvision.transforms.Compose: A composition of transforms. | ||
""" | ||
return v2.Compose( | ||
[ | ||
v2.Normalize(mean=mean, std=std), | ||
], | ||
) | ||
|
||
def __len__(self): | ||
return len(self.chips) | ||
|
||
def __getitem__(self, idx): | ||
""" | ||
Get a sample from the dataset. | ||
Args: | ||
idx (int): Index of the sample. | ||
Returns: | ||
dict: A dictionary containing the image, label, and additional information. | ||
""" | ||
chip_name = self.chip_dir / self.chips[idx] | ||
label_name = self.label_dir / self.labels[idx] | ||
|
||
chip = np.load(chip_name).astype(np.float32) | ||
label = np.load(label_name) | ||
|
||
# Remap labels to match desired classes | ||
label_mapping = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 15: 6} | ||
remapped_label = np.vectorize(label_mapping.get)(label) | ||
|
||
# Apply transformations | ||
if self.transform: | ||
chip = self.transform(torch.from_numpy(chip)) | ||
|
||
sample = { | ||
"pixels": self.transform(torch.from_numpy(chip)), | ||
"label": torch.from_numpy(remapped_label[0]), | ||
"time": torch.zeros(4), # Placeholder for time information | ||
"latlon": torch.zeros(4), # Placeholder for latlon information | ||
} | ||
return sample | ||
|
||
|
||
class ChesapeakeDataModule(L.LightningDataModule): | ||
""" | ||
DataModule class for the Chesapeake Bay dataset. | ||
Args: | ||
train_chip_dir (str): Directory containing training image chips. | ||
train_label_dir (str): Directory containing training labels. | ||
val_chip_dir (str): Directory containing validation image chips. | ||
val_label_dir (str): Directory containing validation labels. | ||
metadata_path (str): Path to the metadata file. | ||
batch_size (int): Batch size for data loading. | ||
num_workers (int): Number of workers for data loading. | ||
platform (str): Platform identifier used in metadata. | ||
""" | ||
|
||
def __init__( # noqa: PLR0913 | ||
self, | ||
train_chip_dir, | ||
train_label_dir, | ||
val_chip_dir, | ||
val_label_dir, | ||
metadata_path, | ||
batch_size, | ||
num_workers, | ||
platform, | ||
): | ||
super().__init__() | ||
self.train_chip_dir = train_chip_dir | ||
self.train_label_dir = train_label_dir | ||
self.val_chip_dir = val_chip_dir | ||
self.val_label_dir = val_label_dir | ||
self.metadata = Box(yaml.safe_load(open(metadata_path))) | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
self.platform = platform | ||
|
||
def setup(self, stage=None): | ||
""" | ||
Setup datasets for training and validation. | ||
Args: | ||
stage (str): Stage identifier ('fit' or 'test'). | ||
""" | ||
if stage in {"fit", None}: | ||
self.trn_ds = ChesapeakeDataset( | ||
self.train_chip_dir, | ||
self.train_label_dir, | ||
self.metadata, | ||
self.platform, | ||
) | ||
self.val_ds = ChesapeakeDataset( | ||
self.val_chip_dir, | ||
self.val_label_dir, | ||
self.metadata, | ||
self.platform, | ||
) | ||
|
||
def train_dataloader(self): | ||
""" | ||
Create DataLoader for training data. | ||
Returns: | ||
DataLoader: DataLoader for training dataset. | ||
""" | ||
return DataLoader( | ||
self.trn_ds, | ||
batch_size=self.batch_size, | ||
shuffle=True, | ||
num_workers=self.num_workers, | ||
) | ||
|
||
def val_dataloader(self): | ||
""" | ||
Create DataLoader for validation data. | ||
Returns: | ||
DataLoader: DataLoader for validation dataset. | ||
""" | ||
return DataLoader( | ||
self.val_ds, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
) |
Oops, something went wrong.