Skip to content

Commit

Permalink
Update and rename instancesegmentation.py to instance_segmentation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ariannasole23 authored Jan 20, 2025
1 parent e249883 commit 7676ac3
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 174 deletions.
246 changes: 246 additions & 0 deletions torchgeo/trainers/instance_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Trainers for instance segmentation."""

from typing import Any
import torch.nn as nn
import torch
from torch import Tensor
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchmetrics import MetricCollection
from torchvision.models.detection import maskrcnn_resnet50_fpn
from base import BaseTask

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from torchgeo.datasets import RGBBandsMissingError, unbind_samples

# for testing
import pytorch_lightning as pl
from pytorch_lightning import LightningModule

Check failure on line 21 in torchgeo/trainers/instance_segmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

torchgeo/trainers/instance_segmentation.py:21:31: F401 `pytorch_lightning.LightningModule` imported but unused
from torch.utils.data import DataLoader
from torchgeo.datasets import VHR10

class InstanceSegmentationTask(BaseTask):

Check failure on line 25 in torchgeo/trainers/instance_segmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

torchgeo/trainers/instance_segmentation.py:6:1: I001 Import block is un-sorted or un-formatted
"""Instance Segmentation."""

def __init__(
self,
model: str = 'mask_rcnn',
backbone: str = 'resnet50',
weights: str | bool | None = None,
num_classes: int = 2,
lr: float = 1e-3,
patience: int = 10,
freeze_backbone: bool = False,
) -> None:
"""Initialize a new SemanticSegmentationTask instance.
Args:
model: Name of the model to use.
backbone: Name of the backbone to use.
weights: Initial model weights. Either a weight enum, the string
representation of a weight enum, True for ImageNet weights, False or
None for random weights, or the path to a saved model state dict.
in_channels: Number of input channels to model.
num_classes: Number of prediction classes (including the background).
lr: Learning rate for optimizer.
patience: Patience for learning rate scheduler.
freeze_backbone: Freeze the backbone network to fine-tune the
decoder and segmentation head.
.. versionadded:: 0.7
"""
self.weights = weights
super().__init__()
# self.save_hyperparameters()
# self.model = None
# self.validation_outputs = []
# self.test_outputs = []
# self.configure_models()
# self.configure_metrics()

def configure_models(self) -> None:
"""Initialize the model.
Raises:
ValueError: If *model* is invalid.
"""
model = self.hparams['model'].lower()
num_classes = self.hparams['num_classes']

if model == 'mask_rcnn':
# Load the Mask R-CNN model with a ResNet50 backbone
self.model = maskrcnn_resnet50_fpn(weights=self.weights is True)

# Update the classification head to predict `num_classes`
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = nn.Linear(in_features, num_classes)

# Update the mask head for instance segmentation
in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels
self.model.roi_heads.mask_predictor = nn.ConvTranspose2d(
in_features_mask, num_classes, kernel_size=2, stride=2
)

else:
raise ValueError(
f"Invalid model type '{model}'. Supported model: 'mask_rcnn'"
)

# Freeze backbone
if self.hparams['freeze_backbone']:
for param in self.model.backbone.parameters():
param.requires_grad = False


def configure_metrics(self) -> None:
"""Initialize the performance metrics.
- Uses Mean Average Precision (mAP) for masks (IOU-based metric).
"""
self.metrics = MetricCollection([MeanAveragePrecision(iou_type="segm")])
self.train_metrics = self.metrics.clone(prefix='train_')
self.val_metrics = self.metrics.clone(prefix='val_')
self.test_metrics = self.metrics.clone(prefix='test_')

def training_step(self, batch: Any, batch_idx: int) -> Tensor:
"""Compute the training loss.
Args:
batch: A batch of data from the DataLoader. Includes images and ground truth targets.
batch_idx: Index of the current batch.
Returns:
The total loss for the batch.
"""
images, targets = batch['image'], batch['target']
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
self.log('train_loss', loss, batch_size=len(images))
return loss

def validation_step(self, batch: Any, batch_idx: int) -> None:
"""Compute the validation loss.
Args:
batch: A batch of data from the DataLoader. Includes images and targets.
batch_idx: Index of the current batch.
Updates metrics and stores predictions/targets for further analysis.
"""
images, targets = batch['image'], batch['target']
outputs = self.model(images)
self.metrics.update(outputs, targets)
self.validation_outputs.append((outputs, targets))

metrics_dict = self.metrics.compute()
self.log_dict(metrics_dict)
self.metrics.reset()

# check
if (
batch_idx < 10
and hasattr(self.trainer, 'datamodule')
and hasattr(self.trainer.datamodule, 'plot')
and self.logger
and hasattr(self.logger, 'experiment')
and hasattr(self.logger.experiment, 'add_figure')
):
datamodule = self.trainer.datamodule

batch['prediction_masks'] = [output['masks'].cpu() for output in outputs]
batch['image'] = batch['image'].cpu()

sample = unbind_samples(batch)[0]

fig: Figure | None = None
try:
fig = datamodule.plot(sample)
except RGBBandsMissingError:
pass

if fig:
summary_writer = self.logger.experiment
summary_writer.add_figure(
f'image/{batch_idx}', fig, global_step=self.global_step
)
plt.close()


def test_step(self, batch: Any, batch_idx: int) -> None:
"""Compute the test loss and additional metrics."""

Check failure on line 173 in torchgeo/trainers/instance_segmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D202)

torchgeo/trainers/instance_segmentation.py:173:9: D202 No blank lines allowed after function docstring (found 1)

images, targets = batch['image'], batch['target']
outputs = self.model(images)
self.metrics.update(outputs, targets)
self.test_outputs.append((outputs, targets))

metrics_dict = self.metrics.compute()
self.log_dict(metrics_dict)


def predict_step(self, batch: Any, batch_idx: int) -> Tensor:

Check failure on line 184 in torchgeo/trainers/instance_segmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D417)

torchgeo/trainers/instance_segmentation.py:184:9: D417 Missing argument description in the docstring for `predict_step`: `batch_idx`
"""Perform inference on a batch of images.
Args:
batch: A batch of images.
Returns:
Predicted masks and bounding boxes for the batch.
"""
images = batch['image']
y_hat: Tensor = self.model(images)
return y_hat















#=================================================================
# TESTING
#=================================================================

def collate_fn(batch):

Check failure on line 215 in torchgeo/trainers/instance_segmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

torchgeo/trainers/instance_segmentation.py:215:5: ANN201 Missing return type annotation for public function `collate_fn`

Check failure on line 215 in torchgeo/trainers/instance_segmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D103)

torchgeo/trainers/instance_segmentation.py:215:5: D103 Missing docstring in public function
return tuple(zip(*batch))

train_dataset = VHR10(root="data", split="positive", transforms=None, download=True)
val_dataset = VHR10(root="data", split="positive", transforms=None)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

task = InstanceSegmentationTask(
model="mask_rcnn",
backbone="resnet50",
weights=True,
num_classes=11,
lr=1e-3,
freeze_backbone=False
)

trainer = pl.Trainer(
max_epochs=10,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1
)

trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)

trainer.test(task, dataloaders=val_loader)

test_sample = train_dataset[0]
test_image = test_sample["image"].unsqueeze(0)
predictions = task.predict_step({"image": test_image}, batch_idx=0)
print(predictions)
Loading

0 comments on commit 7676ac3

Please sign in to comment.