From 7676ac362114e4b2e88cece49f9e9800ba6f8964 Mon Sep 17 00:00:00 2001 From: Arianna Sole Date: Mon, 20 Jan 2025 12:30:27 +0100 Subject: [PATCH] Update and rename instancesegmentation.py to instance_segmentation.py --- torchgeo/trainers/instance_segmentation.py | 246 +++++++++++++++++++++ torchgeo/trainers/instancesegmentation.py | 174 --------------- 2 files changed, 246 insertions(+), 174 deletions(-) create mode 100644 torchgeo/trainers/instance_segmentation.py delete mode 100644 torchgeo/trainers/instancesegmentation.py diff --git a/torchgeo/trainers/instance_segmentation.py b/torchgeo/trainers/instance_segmentation.py new file mode 100644 index 0000000000..d28f47f610 --- /dev/null +++ b/torchgeo/trainers/instance_segmentation.py @@ -0,0 +1,246 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Trainers for instance segmentation.""" + +from typing import Any +import torch.nn as nn +import torch +from torch import Tensor +from torchmetrics.detection.mean_ap import MeanAveragePrecision +from torchmetrics import MetricCollection +from torchvision.models.detection import maskrcnn_resnet50_fpn +from base import BaseTask + +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from torchgeo.datasets import RGBBandsMissingError, unbind_samples + +# for testing +import pytorch_lightning as pl +from pytorch_lightning import LightningModule +from torch.utils.data import DataLoader +from torchgeo.datasets import VHR10 + +class InstanceSegmentationTask(BaseTask): + """Instance Segmentation.""" + + def __init__( + self, + model: str = 'mask_rcnn', + backbone: str = 'resnet50', + weights: str | bool | None = None, + num_classes: int = 2, + lr: float = 1e-3, + patience: int = 10, + freeze_backbone: bool = False, + ) -> None: + """Initialize a new SemanticSegmentationTask instance. + + Args: + model: Name of the model to use. + backbone: Name of the backbone to use. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False or + None for random weights, or the path to a saved model state dict. + in_channels: Number of input channels to model. + num_classes: Number of prediction classes (including the background). + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. + freeze_backbone: Freeze the backbone network to fine-tune the + decoder and segmentation head. + + .. versionadded:: 0.7 + """ + self.weights = weights + super().__init__() + # self.save_hyperparameters() + # self.model = None + # self.validation_outputs = [] + # self.test_outputs = [] + # self.configure_models() + # self.configure_metrics() + + def configure_models(self) -> None: + """Initialize the model. + + Raises: + ValueError: If *model* is invalid. + """ + model = self.hparams['model'].lower() + num_classes = self.hparams['num_classes'] + + if model == 'mask_rcnn': + # Load the Mask R-CNN model with a ResNet50 backbone + self.model = maskrcnn_resnet50_fpn(weights=self.weights is True) + + # Update the classification head to predict `num_classes` + in_features = self.model.roi_heads.box_predictor.cls_score.in_features + self.model.roi_heads.box_predictor = nn.Linear(in_features, num_classes) + + # Update the mask head for instance segmentation + in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels + self.model.roi_heads.mask_predictor = nn.ConvTranspose2d( + in_features_mask, num_classes, kernel_size=2, stride=2 + ) + + else: + raise ValueError( + f"Invalid model type '{model}'. Supported model: 'mask_rcnn'" + ) + + # Freeze backbone + if self.hparams['freeze_backbone']: + for param in self.model.backbone.parameters(): + param.requires_grad = False + + + def configure_metrics(self) -> None: + """Initialize the performance metrics. + + - Uses Mean Average Precision (mAP) for masks (IOU-based metric). + """ + self.metrics = MetricCollection([MeanAveragePrecision(iou_type="segm")]) + self.train_metrics = self.metrics.clone(prefix='train_') + self.val_metrics = self.metrics.clone(prefix='val_') + self.test_metrics = self.metrics.clone(prefix='test_') + + def training_step(self, batch: Any, batch_idx: int) -> Tensor: + """Compute the training loss. + + Args: + batch: A batch of data from the DataLoader. Includes images and ground truth targets. + batch_idx: Index of the current batch. + + Returns: + The total loss for the batch. + """ + images, targets = batch['image'], batch['target'] + loss_dict = self.model(images, targets) + loss = sum(loss for loss in loss_dict.values()) + self.log('train_loss', loss, batch_size=len(images)) + return loss + + def validation_step(self, batch: Any, batch_idx: int) -> None: + """Compute the validation loss. + + Args: + batch: A batch of data from the DataLoader. Includes images and targets. + batch_idx: Index of the current batch. + + Updates metrics and stores predictions/targets for further analysis. + """ + images, targets = batch['image'], batch['target'] + outputs = self.model(images) + self.metrics.update(outputs, targets) + self.validation_outputs.append((outputs, targets)) + + metrics_dict = self.metrics.compute() + self.log_dict(metrics_dict) + self.metrics.reset() + + # check + if ( + batch_idx < 10 + and hasattr(self.trainer, 'datamodule') + and hasattr(self.trainer.datamodule, 'plot') + and self.logger + and hasattr(self.logger, 'experiment') + and hasattr(self.logger.experiment, 'add_figure') + ): + datamodule = self.trainer.datamodule + + batch['prediction_masks'] = [output['masks'].cpu() for output in outputs] + batch['image'] = batch['image'].cpu() + + sample = unbind_samples(batch)[0] + + fig: Figure | None = None + try: + fig = datamodule.plot(sample) + except RGBBandsMissingError: + pass + + if fig: + summary_writer = self.logger.experiment + summary_writer.add_figure( + f'image/{batch_idx}', fig, global_step=self.global_step + ) + plt.close() + + + def test_step(self, batch: Any, batch_idx: int) -> None: + """Compute the test loss and additional metrics.""" + + images, targets = batch['image'], batch['target'] + outputs = self.model(images) + self.metrics.update(outputs, targets) + self.test_outputs.append((outputs, targets)) + + metrics_dict = self.metrics.compute() + self.log_dict(metrics_dict) + + + def predict_step(self, batch: Any, batch_idx: int) -> Tensor: + """Perform inference on a batch of images. + + Args: + batch: A batch of images. + + Returns: + Predicted masks and bounding boxes for the batch. + """ + images = batch['image'] + y_hat: Tensor = self.model(images) + return y_hat + + + + + + + + + + + + + + + +#================================================================= +# TESTING +#================================================================= + +def collate_fn(batch): + return tuple(zip(*batch)) + +train_dataset = VHR10(root="data", split="positive", transforms=None, download=True) +val_dataset = VHR10(root="data", split="positive", transforms=None) + +train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) +val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn) + +task = InstanceSegmentationTask( + model="mask_rcnn", + backbone="resnet50", + weights=True, + num_classes=11, + lr=1e-3, + freeze_backbone=False +) + +trainer = pl.Trainer( + max_epochs=10, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + devices=1 +) + +trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader) + +trainer.test(task, dataloaders=val_loader) + +test_sample = train_dataset[0] +test_image = test_sample["image"].unsqueeze(0) +predictions = task.predict_step({"image": test_image}, batch_idx=0) +print(predictions) diff --git a/torchgeo/trainers/instancesegmentation.py b/torchgeo/trainers/instancesegmentation.py deleted file mode 100644 index c07a9b38e2..0000000000 --- a/torchgeo/trainers/instancesegmentation.py +++ /dev/null @@ -1,174 +0,0 @@ -from typing import Any -import torch.nn as nn -import torch -from torch import Tensor -from torchmetrics.detection.mean_ap import MeanAveragePrecision -from torchvision.models.detection import maskrcnn_resnet50_fpn -from ultralytics import YOLO -from .base import BaseTask - -class InstanceSegmentationTask(BaseTask): - """ - Task class for training and evaluating instance segmentation models. - - This class supports Mask R-CNN and YOLO models and handles the following: - - Model configuration - - Loss computation - - Metric computation (e.g., Mean Average Precision) - - Training, validation, testing, and prediction steps - """ - - def __init__( - self, - model: str = 'mask_rcnn', # Model type, e.g., 'mask_rcnn' or 'yolo' - backbone: str = 'resnet50', # Backbone type for Mask R-CNN (ignored for YOLO) - weights: str | bool | None = None, # Pretrained weights or custom checkpoint path - num_classes: int = 2, # Number of classes, including background - lr: float = 1e-3, # Learning rate for the optimizer - patience: int = 10, # Patience for the learning rate scheduler - freeze_backbone: bool = False, # Whether to freeze backbone layers (useful for transfer learning) - ) -> None: - """ - Constructor for the InstanceSegmentationTask. - - Initializes the hyperparameters, sets up the model and metrics. - """ - self.weights = weights # Save weights for model initialization - super().__init__() # Initialize the BaseTask class (inherits common functionality) - self.save_hyperparameters() # Save input arguments for later use (e.g., in checkpoints or logs) - self.model = None # Placeholder for the model (to be initialized later) - self.validation_outputs = [] # List to store outputs during validation (used for debugging or analysis) - self.test_outputs = [] # List to store outputs during testing - self.configure_models() # Call method to set up the model - self.configure_metrics() # Call method to set up metrics - - def configure_models(self) -> None: - """ - Set up the instance segmentation model based on the specified type (Mask R-CNN or YOLO). - - Configures: - - Backbone (for Mask R-CNN) - - Classifier and mask heads - - Pretrained weights - """ - model = self.hparams['model'].lower() # Read the model type from hyperparameters (convert to lowercase) - num_classes = self.hparams['num_classes'] # Number of output classes - - if model == 'mask_rcnn': - # Load the Mask R-CNN model with a ResNet50 backbone - self.model = maskrcnn_resnet50_fpn(pretrained=self.weights is True) - - # Update the classification head to predict `num_classes` - in_features = self.model.roi_heads.box_predictor.cls_score.in_features - self.model.roi_heads.box_predictor = nn.Linear(in_features, num_classes) - - # Update the mask head for instance segmentation - in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels - self.model.roi_heads.mask_predictor = nn.ConvTranspose2d( - in_features_mask, num_classes, kernel_size=2, stride=2 - ) - - elif model == 'yolo': - # Initialize YOLOv8 for instance segmentation - self.model = YOLO('yolov8n-seg') # Load a small YOLOv8 segmentation model - self.model.model.args['nc'] = num_classes # Set the number of classes in YOLO - if self.weights: - # If weights are provided, load the custom checkpoint - self.model = YOLO(self.weights) - - else: - raise ValueError( - f"Invalid model type '{model}'. Supported models: 'mask_rcnn', 'yolo'." - ) - - # Freeze the backbone if specified (useful for transfer learning) - if self.hparams['freeze_backbone'] and model == 'mask_rcnn': - for param in self.model.backbone.parameters(): - param.requires_grad = False # Prevent these layers from being updated during training - - def configure_metrics(self) -> None: - """ - Set up metrics for evaluating instance segmentation models. - - - Uses Mean Average Precision (mAP) for masks (IOU-based metric). - """ - self.metrics = MeanAveragePrecision(iou_type="segm") # Track segmentation-specific mAP - - def training_step(self, batch: Any, batch_idx: int) -> Tensor: - """ - Perform a single training step. - - Args: - batch: A batch of data from the DataLoader. Includes images and ground truth targets. - batch_idx: Index of the current batch. - - Returns: - The total loss for the batch. - """ - images, targets = batch['image'], batch['target'] # Unpack images and targets - loss_dict = self.model(images, targets) # Compute losses (classification, box regression, mask loss, etc.) - loss = sum(loss for loss in loss_dict.values()) # Combine all losses into a single value - self.log('train_loss', loss, batch_size=len(images)) # Log the training loss for monitoring - return loss # Return the loss for optimization - - def validation_step(self, batch: Any, batch_idx: int) -> None: - """ - Perform a single validation step. - - Args: - batch: A batch of data from the DataLoader. Includes images and targets. - batch_idx: Index of the current batch. - - Updates metrics and stores predictions/targets for further analysis. - """ - images, targets = batch['image'], batch['target'] # Unpack images and targets - outputs = self.model(images) # Run inference on the model - self.metrics.update(outputs, targets) # Update mAP metrics with predictions and ground truths - self.validation_outputs.append((outputs, targets)) # Store outputs for debugging or visualization - - def on_validation_epoch_end(self) -> None: - """ - At the end of the validation epoch, compute and log metrics. - - Resets the stored outputs to free memory. - """ - metrics_dict = self.metrics.compute() # Calculate final mAP and other metrics - self.log_dict(metrics_dict) # Log all computed metrics - self.metrics.reset() # Reset metrics for the next epoch - self.validation_outputs.clear() # Clear stored outputs to free memory - - def test_step(self, batch: Any, batch_idx: int) -> None: - """ - Perform a single test step. - - Similar to validation but used for test data. - """ - images, targets = batch['image'], batch['target'] - outputs = self.model(images) - self.metrics.update(outputs, targets) - self.test_outputs.append((outputs, targets)) - - def on_test_epoch_end(self) -> None: - """ - At the end of the test epoch, compute and log metrics. - - Resets the stored outputs to free memory. - """ - metrics_dict = self.metrics.compute() - self.log_dict(metrics_dict) - self.metrics.reset() - self.test_outputs.clear() - - def predict_step(self, batch: Any, batch_idx: int) -> Tensor: - """ - Perform inference on a batch of images. - - Args: - batch: A batch of images. - - Returns: - Predicted masks and bounding boxes for the batch. - """ - images = batch['image'] # Extract images from the batch - predictions = self.model(images) # Run inference on the model - return predictions # Return the predictions